feat: 完善分割工作区导入导出与管理流程
- 新增基于 JWT 当前用户的登录恢复、角色权限、用户管理、审计日志和演示出厂重置后台接口与前端管理页。 - 重串 GT_label 导出和 GT Mask 导入逻辑:导出保留类别真实 maskid,导入仅接受灰度或 RGB 等通道 maskid 图,支持未知 maskid 策略、尺寸最近邻拉伸和导入预览。 - 统一分割结果导出体验:默认当前帧,按项目抽帧顺序和 XhXXmXXsXXXms 时间戳命名 ZIP 与图片,补齐 GT/Pro/Mix/分开 Mask 输出和映射 JSON。 - 调整工作区左侧工具栏:移除创建点/线段入口,新增画笔、橡皮擦及尺寸控制,并按绘制、布尔、导入/AI 工具分组分隔。 - 扩展 Canvas 编辑能力:画笔按语义分类绘制并可自动并入连通选中 mask,橡皮擦对选中区域扣除,优化布尔操作、选区、撤销重做和保存状态联动。 - 优化自动传播时间轴显示:同一蓝色系按传播新旧递进变暗,老传播记录达到阈值后统一旧记录色,并维护范围选择与清空后的历史显示。 - 将 AI 智能分割入口替换为更明确的 AI 元素图标,并同步侧栏、工作区和 AI 页面入口表现。 - 完善模板分类、maskid 工具函数、分类树联动、遮罩透明度、边缘平滑和传播链同步相关前端状态。 - 扩展后端项目、媒体、任务、Dashboard、模板和传播 runner 的用户隔离、任务控制、进度事件与兼容处理。 - 补充前后端测试,覆盖用户管理、GT_label 往返导入导出、GT Mask 校验和预览、画笔/橡皮擦、时间轴传播历史、导出范围、WebSocket 与 API 封装。 - 更新 AGENTS、README 和 doc 文档,记录当前接口契约、实现状态、测试计划、安装说明和 maskid/GT_label 规则。
This commit is contained in:
55
AGENTS.md
55
AGENTS.md
@@ -65,7 +65,7 @@ Seg_Server/
|
||||
│ ├── main.py # 应用入口、lifespan、CORS、路由注册、WebSocket
|
||||
│ ├── config.py # Pydantic Settings;读取 backend/.env
|
||||
│ ├── database.py # SQLAlchemy Engine / Session
|
||||
│ ├── models.py # Project/Frame/Template/Annotation/Mask/ProcessingTask ORM
|
||||
│ ├── models.py # User/Project/Frame/Template/Annotation/Mask/AuditLog/ProcessingTask ORM
|
||||
│ ├── schemas.py # Pydantic 请求/响应模型
|
||||
│ ├── minio_client.py # MinIO 上传、下载、预签名 URL
|
||||
│ ├── redis_client.py # Redis 连接封装
|
||||
@@ -75,7 +75,8 @@ Seg_Server/
|
||||
│ ├── setup_sam3_env.sh # 历史保留的 SAM 3 独立 Python 3.12 环境安装脚本;当前产品入口禁用
|
||||
│ ├── requirements.txt # Python 依赖
|
||||
│ ├── routers/
|
||||
│ │ ├── auth.py # /api/auth/login
|
||||
│ │ ├── auth.py # /api/auth/login、/api/auth/me、鉴权依赖
|
||||
│ │ ├── admin.py # /api/admin/users、/api/admin/audit-logs、/api/admin/demo-factory-reset
|
||||
│ │ ├── projects.py # /api/projects 与 /api/projects/{id}/frames
|
||||
│ │ ├── templates.py # /api/templates
|
||||
│ │ ├── media.py # /api/media/upload、/upload/dicom、/parse
|
||||
@@ -181,8 +182,8 @@ uvicorn main:app --host 0.0.0.0 --port 8000 --reload
|
||||
- 模块值包括:`dashboard`、`projects`、`ai`、`workspace`、`templates`。
|
||||
- 默认模块是 `workspace`。
|
||||
- 未登录时渲染 `Login`。
|
||||
- 登录成功后 token 写入 `localStorage`,Axios request interceptor 会附加 `Authorization: Bearer <token>`。
|
||||
- `App.tsx` 在登录后调用 `getProjects()` 初始化项目列表。
|
||||
- 登录成功后后端返回签名 JWT,token 写入 `localStorage`,Axios request interceptor 会附加 `Authorization: Bearer <token>`。
|
||||
- `App.tsx` 在已有 token 或登录成功后调用 `/api/auth/me` 恢复当前用户,再调用 `getProjects()` 初始化当前用户项目列表。
|
||||
|
||||
### 后端
|
||||
|
||||
@@ -195,6 +196,10 @@ uvicorn main:app --host 0.0.0.0 --port 8000 --reload
|
||||
- 如果本地存在 `Data_MyVideo_1.mp4`,后台 seed 默认演示项目并拆前 100 帧。
|
||||
- API 路由包括:
|
||||
- `POST /api/auth/login`
|
||||
- `GET /api/auth/me`
|
||||
- `GET/POST/PATCH/DELETE /api/admin/users`
|
||||
- `GET /api/admin/audit-logs`
|
||||
- `POST /api/admin/demo-factory-reset`
|
||||
- `GET/POST/PATCH/DELETE /api/projects`
|
||||
- `GET/POST /api/projects/{project_id}/frames`
|
||||
- `GET/POST/PATCH/DELETE /api/templates`
|
||||
@@ -231,18 +236,19 @@ uvicorn main:app --host 0.0.0.0 --port 8000 --reload
|
||||
|
||||
## 主要业务流程
|
||||
|
||||
1. 登录:`Login.tsx` 调用 `POST /api/auth/login`,默认开发凭证为 `admin / 123456`。
|
||||
2. 项目管理:`ProjectLibrary.tsx` 调用项目 API 创建项目、拉取列表、删除项目;删除当前项目后会清空工作区当前项目、帧、mask 和选区。
|
||||
3. 上传资源:视频走 `/api/media/upload`,只上传源文件并关联项目,不自动拆帧;DICOM 批量走 `/api/media/upload/dicom`。
|
||||
4. 生成帧入队:用户在项目库点击“生成帧”,选择目标 FPS 后前端调用 `/api/media/parse`;后端创建 `ProcessingTask` 并投递 Celery,接口支持 `parse_fps`、`max_frames` 和 `target_width` 标准帧序列参数;项目库和模板库的成功/失败短反馈使用非阻塞 `TransientNotice`,会自动消失。
|
||||
5. worker 执行:Celery worker 用 FFmpeg 优先拆视频帧,失败后用 OpenCV fallback,DICOM 使用 pydicom;视频帧按 `frame_%06d.jpg` 连续命名并记录 `timestamp_ms`、`source_frame_number` 和任务 `frame_sequence` 元数据。
|
||||
6. 帧展示:`VideoWorkspace.tsx` 调用 `/api/projects/{id}/frames`,`CanvasArea.tsx` 和 `FrameTimeline.tsx` 显示当前帧与时间轴缩略图;`CanvasArea` 会按容器和帧尺寸默认居中放大底图并保留边距;`FrameTimeline` 会根据已保存标注回显到 `Mask.metadata` 的传播来源,把自动传播生成的帧在视频处理进度条显示为蓝色区段,人工/AI 标注帧显示红色竖线;每次自动传播成功处理帧后,`VideoWorkspace` 会把本次传播范围作为当前会话历史片段传给 `FrameTimeline`,在视频处理进度条上叠加不同色系的深到浅渐变条;视频处理进度条和红/蓝标识可点击跳转到对应帧;底部缩略图中人工/AI 标注帧用红色边框、自动传播/推理帧用蓝色边框,同一帧同时具备两种状态时红色标注边框优先保留,蓝色传播状态以内描边表达;当前帧仍以青色外框高亮优先;若当前帧同时是人工/AI 标注帧,则在青色外框内增加红色内描边,固定为外层当前帧、内层人工/AI 标注;只有进入自动传播范围选择模式时,播放进度条和视频处理进度条才显示黄色范围框,并可点击/拖拽选择传播起止帧;前端 `Frame` 会保留后端返回的帧序列时间戳和源帧号。
|
||||
7. 手工标注:`CanvasArea.tsx` 支持多边形、矩形、圆、点区域和线段生成 polygon mask;多边形可按 Enter 或点击首节点闭合;绘制工具可在已有 mask 上继续落点;Canvas 左上角工具上下文提示会在切换工具或操作状态变化时短暂显示,数秒后自动隐藏,避免长期遮挡底图;工具栏有“调整多边形”入口,左侧 `ToolsPalette` 使用紧凑垂直布局并在高度不足时自身滚动;点击 mask 后可按住顶点直接拖动并实时更新 polygon,顶点/seed point 拖拽结束不会触发 Stage 平移或重置 Canvas 视口;也可删除 polygon 顶点、通过边中点或双击边界插入新顶点,并能选择编辑多 polygon mask 的单个子区域;选中整块 mask 可用 Delete/Backspace 删除,已保存 mask 会同步后端删除;区域合并/去除会隐藏编辑手柄并显示已选数量,第一个选中的主区域用黄色实线轮廓,后续参与合并/扣除的区域用红色虚线轮廓,使用 `polygon-clipping` 做 union/difference,内含去除结果用 even-odd 规则渲染 hole;Zustand 维护 `maskHistory/maskFuture` 支持撤销/重做。
|
||||
8. AI 分割:前端工具包括 SAM 2.1 变体选择、正向点、反向点和框选;AI 画布会按容器和当前帧尺寸默认居中放大底图并保留边距;工作区和 AI 页面都可点击已有提示点删除单点,AI 页面也可删除最近锚点、删除选中候选或清空本页锚点;这些删除入口会限制在当前提示点/本页 AI 候选范围内,避免误删工作区已有 mask。SAM 2.1 框选会建立候选 mask,后续正/反点通过 `interactive` prompt 携带原始框和累计点细化同一个候选 mask;AI 页面框选会先固化 `promptBox`,执行分割时只框选发送 `box` prompt,框选后继续加正/反点发送 `interactive` prompt;重复执行高精度分割会替换上一次 AI 页候选,只保留最新一个候选。包含反向点时工作区会传 `options.auto_filter_background=true` 和 `min_score=0.05`,如果后端过滤为空则移除旧候选 mask。后端 `ai.py` 期望按 `image_id`、`prompt_type`、`prompt_data`、`model` 和可选 `options` 调用 SAM registry。当前 registry 暴露 `sam2.1_hiera_tiny`、`sam2.1_hiera_small`、`sam2.1_hiera_base_plus`、`sam2.1_hiera_large`,并兼容 `sam2` 作为 tiny 别名;`model=sam3` 会被拒绝,`semantic` 文本提示也被禁用。SAM 2.1 支持点/框/interactive/自动分割和 video predictor 传播;多候选默认只采用最高分区域,避免重叠候选同时显示;AI 页面只渲染本页最新生成的候选 mask,不会把工作区已有 mask 带入 AI 画布;AI 页面生成的 mask 会写入全局 `masks` 并自动选中,右侧分类树可直接改标签,推送到工作区会切到“调整多边形”并保留选择和当前帧视角。`options.crop_to_prompt` 可对点/框/interactive prompt 做局部裁剪推理并回映射,`options.auto_filter_background` 可按分数和负向点过滤结果。
|
||||
9. 视频片段传播:工作区以当前打开帧作为参考帧,使用该帧全部 mask 作为 seed,并用传播起始帧和传播结束帧指定追踪范围;用户可直接修改数字框,也可点击“自动传播”进入时间轴范围选择模式,在播放进度条或视频处理进度条上点击/拖拽选择范围,再点击“开始传播”。工作区顶栏有独立“传播权重”选择器,可为本次传播二次选择 SAM 2.1 tiny/small/base+/large 权重,不提供 SAM2/SAM3 家族切换,也不影响 AI 单帧分割权重;前端提交传播前会先保存当前项目中的 draft/dirty mask,使 seed 优先带稳定的后端 `source_annotation_id`,再按传播权重 id、seed mask、seed 来源 id、边缘平滑参数和前/后方向组装 `steps` 并调用 `POST /api/ai/propagate/task` 创建 `propagate_masks` 后台任务;后端入队时会规范化/校验权重 id 并把规范化后的 id 写入任务 payload/result;Celery worker 顺序执行各 step,避免多个视频 tracker 并发抢占 GPU;每个 step 会根据 seed 来源 id、方向、平滑参数和 seed 签名做幂等判断,同权重且未改变的 seed 直接跳过,已改变、改动平滑参数或换用其他权重的 seed 会先删除同源旧自动传播标注再重传;旧版本用前端临时 `source_mask_id` 生成的传播标注会按同一参考帧、方向和语义信息兼容清理;中间帧人工新增/修改同一物体后重新传播时,后端会在写入目标帧新结果前按语义和空间重叠清理旧传播结果,且写入前清理不受旧结果传播方向限制;后端按项目帧序列下载片段帧,当前使用所选 SAM 2.1 权重变体的 `SAM2VideoPredictor.add_new_mask()` + `propagate_in_video()`,并把后续帧结果保存为 `Annotation`;如果 seed 带 `geometry_smoothing`,forward/backward 两个方向的传播结果保存前都会应用同一参数。工作区轮询 `GET /api/tasks/{task_id}` 展示进度并刷新标注,Dashboard 也能显示/取消/重试传播任务。
|
||||
10. GT 导入:工作区“导入 GT Mask”调用 `/api/ai/import-gt-mask`;后端按非零像素值和连通域生成 polygon 标注,并用 distance transform 生成 seed point;前端回显 seed point,拖动后可归档更新。
|
||||
11. 模板管理:`TemplateRegistry.tsx` 管理分类、颜色和 z-index;`OntologyInspector.tsx` 在工作区显示当前模板分类树。
|
||||
12. 导出:后端支持 COCO JSON 和 PNG mask ZIP 导出;PNG ZIP 包含单标注 mask、按 zIndex 融合的语义 mask 和 `semantic_classes.json`。
|
||||
1. 登录:`Login.tsx` 调用 `POST /api/auth/login`,后端用 `users` 表和密码哈希校验凭证,默认启动时会种子化开发管理员 `admin / 123456`;成功后返回签名 JWT,`GET /api/auth/me` 可读取当前用户;角色包括 `admin`、`annotator`、`viewer`,写入类业务接口要求 `admin/annotator`,用户管理后台要求 `admin`。
|
||||
2. 用户管理:`Sidebar` 仅对 `admin` 显示“用户管理”,`UserAdmin.tsx` 调用 `/api/admin/users` 新增、停用/启用、改角色、改密码和删除无项目用户,并调用 `/api/admin/audit-logs` 展示登录和管理操作审计;演示部署可通过“恢复演示出厂设置”二次确认后调用 `/api/admin/demo-factory-reset`,清空演示数据,只保留默认 admin 和一个尚未生成帧的演示视频项目。
|
||||
3. 项目管理:`ProjectLibrary.tsx` 调用项目 API 创建项目、拉取列表、删除项目;删除当前项目后会清空工作区当前项目、帧、mask 和选区。
|
||||
4. 上传资源:视频走 `/api/media/upload`,只上传源文件并关联项目,不自动拆帧;DICOM 批量走 `/api/media/upload/dicom`。
|
||||
5. 生成帧入队:用户在项目库点击“生成帧”,选择目标 FPS 后前端调用 `/api/media/parse`;后端创建 `ProcessingTask` 并投递 Celery,接口支持 `parse_fps`、`max_frames` 和 `target_width` 标准帧序列参数;项目库和模板库的成功/失败短反馈使用非阻塞 `TransientNotice`,会自动消失。
|
||||
6. worker 执行:Celery worker 用 FFmpeg 优先拆视频帧,失败后用 OpenCV fallback,DICOM 使用 pydicom;视频帧按 `frame_%06d.jpg` 连续命名并记录 `timestamp_ms`、`source_frame_number` 和任务 `frame_sequence` 元数据。
|
||||
7. 帧展示:`VideoWorkspace.tsx` 调用 `/api/projects/{id}/frames`,`CanvasArea.tsx` 和 `FrameTimeline.tsx` 显示当前帧与时间轴缩略图;`CanvasArea` 会按容器和帧尺寸默认居中放大底图并保留边距;`FrameTimeline` 会根据已保存标注回显到 `Mask.metadata` 的传播来源,把自动传播生成的帧在视频处理进度条显示为蓝色区段,人工/AI 标注帧显示红色竖线;每次自动传播成功处理帧后,`VideoWorkspace` 会把本次传播范围作为当前会话历史片段传给 `FrameTimeline`,在视频处理进度条上叠加同一蓝色系、最新传播最亮、旧传播逐次变暗且第 5 次及更早统一为阈值旧记录色的纯色条;视频处理进度条和红/蓝标识可点击跳转到对应帧;底部缩略图中人工/AI 标注帧用红色边框、自动传播/推理帧用蓝色边框,同一帧同时具备两种状态时红色标注边框优先保留,蓝色传播状态以内描边表达;当前帧仍以青色外框高亮优先;若当前帧同时是人工/AI 标注帧,则在青色外框内增加红色内描边,固定为外层当前帧、内层人工/AI 标注;进入自动传播、清空遮罩或特定范围帧导出选择模式时,播放进度条和视频处理进度条会显示黄色范围框,并可点击/拖拽选择起止帧;前端 `Frame` 会保留后端返回的帧序列时间戳和源帧号。
|
||||
8. 手工标注:`CanvasArea.tsx` 支持多边形、矩形、圆、画笔和橡皮擦生成/编辑 polygon mask;多边形可按 Enter 或点击首节点闭合;画笔/橡皮擦可在左侧工具栏调整大小,画笔要求右侧语义分类树已有选中类别,画出的圆形连续笔触会在鼠标松开时一次性 union 成 mask,若与当前选中 mask 连通则自动合并到该 mask,橡皮擦要求已选中 mask 并在松开时从该 mask 中 difference 扣除;未选中特定 mask 时,Canvas 会按右侧语义分类树拖拽得到的内部覆盖优先级从低到高渲染 mask,使高优先级类别显示在上层;Canvas 左上角工具上下文提示会在切换工具或操作状态变化时短暂显示,数秒后自动隐藏,避免长期遮挡底图;工具栏有“调整多边形”入口,左侧 `ToolsPalette` 使用紧凑垂直布局并在高度不足时自身滚动,且在“重叠区域去除”之后提供紫色“导入 GT Mask”入口;工作区左侧工具栏不展示 AI 页的正向选点、反向选点和边界框选,也不重复放置撤销/重做;点击 mask 后可按住顶点直接拖动并实时更新 polygon,顶点/seed point 拖拽结束不会触发 Stage 平移或重置 Canvas 视口;也可删除 polygon 顶点、通过边中点或双击边界插入新顶点,并能选择编辑多 polygon mask 的单个子区域;选中整块 mask 可用 Delete/Backspace 删除,已保存 mask 会同步后端删除;区域合并/去除会隐藏编辑手柄并显示已选数量,第一个选中的主区域用黄色实线轮廓,后续参与合并/扣除的区域用红色虚线轮廓,使用 `polygon-clipping` 做 union/difference,内含去除结果用 even-odd 规则渲染 hole;Zustand 维护 `maskHistory/maskFuture` 支持撤销/重做。
|
||||
9. AI 分割:侧栏和工作区工具栏的 AI 智能分割入口使用 Bot + Sparkles 组合图标强化 AI 识别;前端工具包括 SAM 2.1 变体选择、正向点、反向点和框选;AI 画布会按容器和当前帧尺寸默认居中放大底图并保留边距;工作区和 AI 页面都可点击已有提示点删除单点,AI 页面也可删除最近锚点、删除选中候选或清空本页锚点;这些删除入口会限制在当前提示点/本页 AI 候选范围内,避免误删工作区已有 mask。SAM 2.1 框选会建立候选 mask,后续正/反点通过 `interactive` prompt 携带原始框和累计点细化同一个候选 mask;AI 页面框选会先固化 `promptBox`,执行分割时只框选发送 `box` prompt,框选后继续加正/反点发送 `interactive` prompt;重复执行高精度分割会替换上一次 AI 页候选,只保留最新一个候选。包含反向点时工作区会传 `options.auto_filter_background=true` 和 `min_score=0.05`,如果后端过滤为空则移除旧候选 mask。后端 `ai.py` 期望按 `image_id`、`prompt_type`、`prompt_data`、`model` 和可选 `options` 调用 SAM registry。当前 registry 暴露 `sam2.1_hiera_tiny`、`sam2.1_hiera_small`、`sam2.1_hiera_base_plus`、`sam2.1_hiera_large`,并兼容 `sam2` 作为 tiny 别名;`model=sam3` 会被拒绝,`semantic` 文本提示也被禁用。SAM 2.1 支持点/框/interactive/自动分割和 video predictor 传播;多候选默认只采用最高分区域,避免重叠候选同时显示;AI 页面只渲染本页最新生成的候选 mask,不会把工作区已有 mask 带入 AI 画布;AI 页面生成的 mask 会写入全局 `masks` 并自动选中,右侧分类树可直接改标签,推送到工作区会切到“调整多边形”并保留选择和当前帧视角。`options.crop_to_prompt` 可对点/框/interactive prompt 做局部裁剪推理并回映射,`options.auto_filter_background` 可按分数和负向点过滤结果。
|
||||
10. 视频片段传播:工作区以当前打开帧作为参考帧,使用该帧全部 mask 作为 seed,并用传播起始帧和传播结束帧指定追踪范围;用户可直接修改数字框,也可点击“自动传播”进入时间轴范围选择模式,在播放进度条或视频处理进度条上点击/拖拽选择范围,再点击“开始传播”。工作区顶栏有独立“传播权重”选择器,可为本次传播二次选择 SAM 2.1 tiny/small/base+/large 权重,不提供 SAM2/SAM3 家族切换,也不影响 AI 单帧分割权重;前端提交传播前会先保存当前项目中的 draft/dirty mask,使 seed 优先带稳定的后端 `source_annotation_id`,再按传播权重 id、seed mask、seed 来源 id 和前/后方向组装 `steps` 并调用 `POST /api/ai/propagate/task` 创建 `propagate_masks` 后台任务;后端入队时会规范化/校验权重 id 并把规范化后的 id 写入任务 payload/result;Celery worker 顺序执行各 step,避免多个视频 tracker 并发抢占 GPU;每个 step 会根据 seed 来源 id、方向和 seed 签名做幂等判断,同权重且未改变的 seed 直接跳过,已改变或换用其他权重的 seed 会先删除同源旧自动传播标注再重传;旧版本用前端临时 `source_mask_id` 生成的传播标注会按同一参考帧、方向和语义信息兼容清理;中间帧人工新增/修改同一物体后重新传播时,后端会在写入目标帧新结果前按语义和空间重叠清理旧传播结果,且写入前清理不受旧结果传播方向限制;后端按项目帧序列下载片段帧,当前使用所选 SAM 2.1 权重变体的 `SAM2VideoPredictor.add_new_mask()` + `propagate_in_video()`,并把后续帧结果保存为 `Annotation`;若历史或外部 seed 仍带 `geometry_smoothing`,forward/backward 两个方向的传播结果保存前仍会应用同一参数;当前工作区平滑按钮应用后会直接改写实际 polygon,后续传播以新几何参与签名和追踪。工作区轮询 `GET /api/tasks/{task_id}` 展示进度并刷新标注,Dashboard 也能显示/取消/重试传播任务。
|
||||
11. GT 导入:工作区左侧工具栏“导入 GT Mask”调用 `/api/ai/import-gt-mask`;选择文件后前端会显示导入结果预览,并让用户决定未知 maskid 处理方式,可舍弃未知类别,也可导入为“未定义类别”等待重新命名。后端用 `cv2.IMREAD_UNCHANGED` 读取 mask,避免导出的低数值/16-bit `GT_label图` 被压成 0;GT 图片必须是灰度 maskid 图,或 RGB 三通道完全相同的 `[X,X,X]` maskid 图,0 为背景、X 为 maskid,不符合时返回明确错误;灰度/RGB 等通道图按模板 `maskId` 匹配类别,超出现有类别时按 `unknown_color_policy` 处理;如果 mask 图片尺寸和当前帧不同,会按当前帧长宽最近邻拉伸后再提取区域;每个连通域生成 polygon 标注,并用 distance transform 生成 seed point;前端回显 seed point,拖动后可归档更新。
|
||||
12. 模板管理:`TemplateRegistry.tsx` 管理分类、颜色、maskid 和内部覆盖顺序;`OntologyInspector.tsx` 在工作区显示当前模板分类树,也支持拖拽调整内部覆盖顺序。maskid 只作为 GT_label/类别 ID,不参与排序。
|
||||
13. 导出:工作区使用统一“分割结果导出”入口,导出前先保存待归档 mask;用户可选择整体视频、特定范围帧或当前图片,默认导出范围为当前图片,并勾选分开二值 mask、GT_label 黑白图、Pro_label 彩色图和 Mix_label 原图叠加图。选择特定范围帧时,可直接修改起止帧输入框,也可在播放进度条或视频处理进度条上点击/拖拽选择导出范围;选择 Mix_label 时可调透明度,默认 0.3,并显示当前/待导出第一帧预览。下载 ZIP 文件名使用 `{项目库项目名}_seg_T_{起始时间戳}-{结束时间戳}_P_{起始项目帧序号}-{结束项目帧序号}.zip`,项目名来自 `Project.name` 并替换文件系统不安全字符,时间戳格式为 `0h00m00s000ms`,帧号使用项目抽帧后的 1-based 顺序而非原视频帧号。后端保留兼容的 COCO JSON 和 PNG mask ZIP 接口,同时新增统一结果 ZIP;统一 ZIP 固定包含 `annotations_coco.json`、`maskid_GT像素值_类别映射.json` 和 `原始图片/`;导出时 GT_label 像素值使用类别真实 `maskid`,缺失 `maskid` 的旧标注才补下一个可用值,保证导出的 GT_label 可按同一模板再导入;选择分开 mask 时输出 `分开Mask分割结果/{视频名称_时间戳_项目帧序号}_分别导出/{视频名称_时间戳_项目帧序号}_{类别名称}_maskid{maskid}.png`,同一帧同一类别合并为一张图;选择 GT_label/Pro_label/Mix_label 时分别输出 `GT_label图/{视频名称_时间戳_项目帧序号}.png`、`Pro_label彩色分割结果/{视频名称_时间戳_项目帧序号}.png`、`Mix_label重叠覆盖彩色分割结果/{视频名称_时间戳_项目帧序号}.png`。maskid 不参与覆盖排序,GT_label/Pro_label/Mix_label 重叠区域覆盖顺序由内部拖拽排序字段决定,并与未选中状态下的 Canvas 显示顺序一致。
|
||||
|
||||
---
|
||||
|
||||
@@ -253,10 +259,11 @@ uvicorn main:app --host 0.0.0.0 --port 8000 --reload
|
||||
- 手工绘制工具会生成可保存的 `Mask.segmentation`;撤销/重做通过 `maskHistory/maskFuture` 工作。
|
||||
- Polygon 顶点编辑和新增顶点会重算 `pathData/segmentation/bbox/area`;已保存 mask 进入 dirty 状态后复用归档 PATCH 链路。
|
||||
- 区域合并/去除会重算主 mask 的几何;合并已保存的次级 mask 时会通过工作区回调删除对应后端标注。
|
||||
- 前端 `importGtMask()` 已对齐后端 `/api/ai/import-gt-mask`;工作区“导入 GT Mask”会导入后端生成的多类别标注和 seed point 并回显。
|
||||
- 前端 `exportCoco()` 已对齐后端 `/api/export/{project_id}/coco`;前端 `exportMasks()` 已对齐后端 `/api/export/{project_id}/masks`;工作区导出按钮会先保存当前待归档 mask。
|
||||
- 工作区“结构化归档保存”按钮已接入 `POST /api/ai/annotate` 和 `PATCH /api/ai/annotations/{id}`;加载工作区时会通过 `GET /api/ai/annotations` 回显已保存标注。
|
||||
- 右侧实例属性面板“边缘平滑强度/应用边缘平滑”已接入 `POST /api/ai/smooth-mask`;后端用 Chaikin smoothing 返回新 polygon、bbox、area 和拓扑锚点,前端将当前 mask 标记为 dirty/draft,保存后 `geometry_smoothing` 会随标注和传播 seed 保留。
|
||||
- 前端 `importGtMask()` 已对齐后端 `/api/ai/import-gt-mask`;工作区左侧工具栏“导入 GT Mask”会在上传前显示导入结果预览并选择未知 maskid 策略,后端支持二值 mask、低数值/16-bit GT_label 灰度图和 RGB 三通道完全相同的 `[X,X,X]` maskid 图,不再按彩色 RGB 类别图做颜色匹配;尺寸不同的 mask 会最近邻拉伸到当前帧,导入后回显多类别标注和 seed point。
|
||||
- 前端 `exportCoco()` 已对齐后端 `/api/export/{project_id}/coco`;前端 `exportMasks()` 已对齐后端 `/api/export/{project_id}/masks`;前端 `exportSegmentationResults()` 已对齐后端 `/api/export/{project_id}/results`;工作区“分割结果导出”按钮会先保存当前待归档 mask,再按所选范围、outputs 和 Mix_label 透明度下载统一 ZIP;特定范围帧导出可用帧号输入框或时间轴拖拽选择范围;下载文件名按项目库项目名、导出范围首尾时间戳和首尾项目帧序号生成;统一 ZIP 包含 maskid/GT 像素值映射 JSON、原始图片文件夹、按帧/类别合并的分开 Mask 文件夹、GT_label 图文件夹、Pro_label 彩色图文件夹和 Mix_label 叠加图文件夹;GT_label 像素值使用类别真实 `maskid` 并跨图一致。
|
||||
- 右侧语义分类树和 Canvas “应用分类”都会把分类变更同步到同一传播链前后帧对应 mask;识别依据为 `source_annotation_id`、`source_mask_id` 和 `propagation_seed_key`,被同步更新的已保存 mask 会进入 dirty 状态,等待工作区归档保存 PATCH 到后端。
|
||||
- 工作区保存状态按钮会按当前项目待保存数量显示“保存 X 个改动”或“已全部保存”,并已接入 `POST /api/ai/annotate` 和 `PATCH /api/ai/annotations/{id}`;加载工作区时会通过 `GET /api/ai/annotations` 回显已保存标注。
|
||||
- 右侧实例属性面板“边缘平滑强度/应用边缘平滑”已接入 `POST /api/ai/smooth-mask`;滑杆会即时更新数值,但后端预览请求有短防抖,避免拖动时连续请求卡顿;预览不写入撤销历史也不标 dirty;点击应用后会把返回 polygon 作为新的实际 mask 几何写入当前 mask 和同传播链前后对应 mask,整次应用作为一个撤销/重做历史步骤,相关 mask 标记为 dirty/draft,平滑强度重置为 0,用户可继续用 polygon 编辑工具调整新多边形。
|
||||
- 工作区“自动传播”按钮已接入 `POST /api/ai/propagate/task`;若用户尚未显式设置范围,第一次点击会进入时间轴范围选择模式,第二次点击“开始传播”才提交后台任务;当前启用所选 SAM 2.1 变体的视频 predictor 后台任务,运行中轮询任务进度,完成后刷新后端已保存标注;工作区顶栏模型状态用紧凑 GPU/CPU 徽标,具体 SAM 2.1 传播权重由旁边下拉选择;同步 `POST /api/ai/propagate` 仍作为单 seed 兼容接口保留。
|
||||
- 工作区顶栏短状态会自动消失;保存、导出、导入 GT、传播进行中和无帧项目提示会保留到状态变化。
|
||||
- 工作区“清空遮罩”会调用 `DELETE /api/ai/annotations/{id}` 删除当前帧已保存标注,并清空当前帧本地 mask。
|
||||
@@ -307,9 +314,11 @@ uvicorn main:app --host 0.0.0.0 --port 8000 --reload
|
||||
|
||||
## 安全注意事项
|
||||
|
||||
- FastAPI 登录是开发用硬编码凭证:`admin / 123456`。
|
||||
- 登录成功返回固定 token:`fake-jwt-token-for-admin`,没有真实 JWT 签名校验。
|
||||
- Axios 会附加 Bearer token,但后端大多数业务路由当前没有鉴权依赖。
|
||||
- FastAPI 已有真实 `users` 表、密码哈希和签名 JWT;默认 `admin / 123456` 只是开发种子用户,生产部署应通过环境变量或数据库改密。
|
||||
- 业务路由会校验 Bearer token;项目、帧、标注、任务、Dashboard 和导出按当前用户拥有的项目过滤,模板支持系统模板(`owner_user_id IS NULL`)和用户模板。
|
||||
- 角色分为 `admin`、`annotator`、`viewer`:`admin/annotator` 可调用写入类业务接口,`viewer` 只能调用读接口;`/api/admin/*` 仅允许 `admin`。
|
||||
- 管理员后台支持用户新增、停用/启用、改角色、改密码、删除无项目用户、查看登录/用户管理审计日志,以及二次确认后恢复演示出厂设置;禁止管理员删除、停用或降级自己。
|
||||
- JWT 默认开发密钥在 `backend/config.py`,生产部署必须通过 `backend/.env` 覆盖 `JWT_SECRET_KEY`。
|
||||
- `backend/.env` 被 `.gitignore` 忽略;不要提交真实数据库、MinIO、Redis、模型路径等敏感配置。
|
||||
- `start_services.sh` 中包含本机路径和 sudo 启动逻辑,迁移机器时要审查。
|
||||
- Express `server.ts` 只负责前端开发/静态服务,不承担业务 API 或鉴权。
|
||||
|
||||
42
README.md
42
README.md
@@ -13,12 +13,12 @@
|
||||
## 核心功能
|
||||
|
||||
- **多媒体资产管理** — 支持视频(MP4/AVI/MOV)和 DICOM 医学影像上传;视频导入与生成帧分离,生成帧时选择目标 FPS,项目卡片可删除项目及其关联帧、标注和任务记录
|
||||
- **AI 智能分割引擎** — 当前产品入口启用 SAM 2.1 四个变体(tiny/small/base+/large)选择;支持点分割(point)、框分割(box)、交互式正/反点细化、提示点单点删除、AI 候选单独删除、自动分割(auto)和 Celery 后台 video predictor 传播,前端默认只采用最高分候选避免重叠备选同时显示
|
||||
- **交互式画布标注** — 基于 Konva 的高性能 Canvas,工作区和 AI 画布会默认居中放大底图并保留边距;支持缩放/平移/手工多边形/矩形/圆/点/线、polygon 顶点直接拖动/删除、边中点插点、双击边界插点、区域合并/去除、选点/框选、撤销/重做,实时渲染 Mask 遮罩
|
||||
- **GT Mask 导入** — 工作区可导入 GT mask 图片,后端按非零像素值和连通域生成 polygon 标注并用 distance transform 写入 seed point;前端可回显和拖动 seed point
|
||||
- **本体字典管理** — 可配置的分类体系、颜色映射、图层优先级(z-index)
|
||||
- **AI 智能分割引擎** — 当前产品入口启用 SAM 2.1 四个变体(tiny/small/base+/large)选择;侧栏和工作区跳转入口使用 Bot + Sparkles 组合图标强化 AI 识别;支持点分割(point)、框分割(box)、交互式正/反点细化、提示点单点删除、AI 候选单独删除、自动分割(auto)和 Celery 后台 video predictor 传播,前端默认只采用最高分候选避免重叠备选同时显示
|
||||
- **交互式画布标注** — 基于 Konva 的高性能 Canvas,工作区和 AI 画布会默认居中放大底图并保留边距;工作区支持缩放/平移/手工多边形/矩形/圆/画笔/橡皮擦、polygon 顶点直接拖动/删除、边中点插点、双击边界插点、区域合并/去除、撤销/重做;画笔和橡皮擦支持尺寸调节,画笔可按当前语义分类生成连续区域并自动合并连通的选中 mask,橡皮擦可从选中 mask 扣除区域;未选中特定 mask 时按右侧语义分类树的内部优先级叠放显示,AI 智能分割页单独提供正/反点和框选,实时渲染 Mask 遮罩
|
||||
- **GT Mask 导入** — 工作区可导入二值 mask、GT_label 灰度图或 RGB 三通道完全相同的 `[X,X,X]` maskid 图;导入前会显示本地预览,不符合灰度/maskid 图要求时反馈错误,尺寸不同会按当前帧长宽最近邻拉伸;后端按 maskid 匹配当前模板类别并生成 polygon 标注和 seed point,超出现有类别的 maskid 可由用户选择舍弃或导入为“未定义类别”等待重新命名
|
||||
- **本体字典管理** — 可配置的分类体系、颜色映射、稳定且跨图一致的 maskid;右侧分类树可拖拽调整内部图层覆盖顺序,maskid 不参与排序
|
||||
- **项目工作区** — 项目创建、帧浏览、多图层标注、自动传播帧提示、进度追踪
|
||||
- **数据导出** — 支持 COCO JSON 格式和 PNG Mask 批量导出;PNG ZIP 包含单标注 mask、按 z-index 融合的语义 mask 和类别映射
|
||||
- **数据导出** — 工作区使用统一“分割结果导出”入口,可选择整体视频、特定范围帧或当前图片;特定范围帧支持输入帧号或在时间轴进度条上拖拽选择,并导出 COCO JSON、maskid/GT 像素值映射、原始图片、分开二值 mask、GT_label 黑白图、Pro_label 彩色图和 Mix_label 原图叠加图;GT/Pro/Mix 的重叠覆盖顺序和右侧语义分类树内部优先级一致,GT_label 背景为 0,类别值使用模板中的真实 maskid,缺失 maskid 的旧标注才补下一个可用正整数
|
||||
|
||||
---
|
||||
|
||||
@@ -36,12 +36,13 @@
|
||||
│ 业务逻辑层 (FastAPI + Python 3.11) │
|
||||
│ localhost:8000 │
|
||||
│ ├── /api/auth 登录认证 │
|
||||
│ ├── /api/admin 管理员用户管理、审计日志、演示出厂重置 │
|
||||
│ ├── /api/projects 项目 & 视频帧 CRUD │
|
||||
│ ├── /api/templates 本体字典(分类/颜色/z-index) │
|
||||
│ ├── /api/media 文件上传 & 异步拆帧任务创建 │
|
||||
│ ├── /api/tasks Celery 后台任务状态/取消/重试/详情 │
|
||||
│ ├── /api/ai SAM 2 推理与模型状态 │
|
||||
│ └── /api/export COCO JSON / PNG Masks 导出 │
|
||||
│ └── /api/export COCO JSON / PNG / 统一分割结果导出 │
|
||||
└──────────────────────────┬──────────────────────────────────┘
|
||||
│ SQLAlchemy 2.0
|
||||
┌──────────────────────────▼──────────────────────────────────┐
|
||||
@@ -85,7 +86,7 @@ Seg_Server/
|
||||
│ ├── main.py # 应用入口(CORS/生命周期/路由注册/WebSocket)
|
||||
│ ├── config.py # 环境变量配置(Pydantic Settings)
|
||||
│ ├── database.py # SQLAlchemy 引擎 + Session
|
||||
│ ├── models.py # ORM 模型(Project/Frame/Template/Annotation/Mask/ProcessingTask)
|
||||
│ ├── models.py # ORM 模型(User/Project/Frame/Template/Annotation/Mask/AuditLog/ProcessingTask)
|
||||
│ ├── schemas.py # Pydantic 请求/响应校验模型
|
||||
│ ├── minio_client.py # MinIO 上传/下载/预签名URL封装
|
||||
│ ├── redis_client.py # Redis 连接封装
|
||||
@@ -97,7 +98,8 @@ Seg_Server/
|
||||
│ ├── setup_sam3_env.sh # 历史保留的 SAM 3 独立 Python 3.12 环境安装脚本;当前产品入口禁用
|
||||
│ ├── requirements.txt # Python 依赖
|
||||
│ ├── routers/ # API 路由
|
||||
│ │ ├── auth.py # 登录认证
|
||||
│ │ ├── auth.py # 用户表、密码哈希、JWT 登录和 /api/auth/me
|
||||
│ │ ├── admin.py # 管理员用户 CRUD 和审计日志
|
||||
│ │ ├── projects.py # 项目 & 帧 CRUD
|
||||
│ │ ├── templates.py # 本体字典管理
|
||||
│ │ ├── media.py # 上传 & 解析
|
||||
@@ -122,6 +124,7 @@ Seg_Server/
|
||||
│ └── components/ # 组件(扁平化目录)
|
||||
│ ├── Login.tsx # 登录页
|
||||
│ ├── Sidebar.tsx # 左侧导航栏
|
||||
│ ├── UserAdmin.tsx # 管理员用户管理后台
|
||||
│ ├── Dashboard.tsx # 总体概况仪表盘(任务进度/任务控制)
|
||||
│ ├── ProjectLibrary.tsx # 项目库列表
|
||||
│ ├── VideoWorkspace.tsx # 核心分割工作区布局
|
||||
@@ -289,6 +292,10 @@ sam_default_model=sam2.1_hiera_tiny
|
||||
# sam3_external_python=/home/wkmgc/miniconda3/envs/sam3/bin/python
|
||||
# sam3_timeout_seconds=300
|
||||
cors_origins=["http://localhost:3000","http://192.168.3.11:3000"]
|
||||
jwt_secret_key=change-this-to-a-long-random-production-secret
|
||||
access_token_expire_minutes=1440
|
||||
default_admin_username=admin
|
||||
default_admin_password=123456
|
||||
```
|
||||
|
||||
前端根目录的 `.env.example` 包含 AI Studio 注入变量和前端 API 配置:
|
||||
@@ -318,7 +325,7 @@ nohup uvicorn main:app --host 0.0.0.0 --port 8000 > /tmp/fastapi.log 2>&1 &
|
||||
- 测试 Redis 连接
|
||||
- 懒加载所选 SAM 2.1 模型;`GET /api/ai/models/status` 会返回 tiny/small/base+/large 和 GPU 的真实可用状态,`selected_model=sam3` 会返回不支持
|
||||
- `/api/ai/predict` 支持 AI 参数 `crop_to_prompt`、`auto_filter_background` 和 `min_score`,用于点/框 prompt 的局部裁剪推理、回映射和背景过滤
|
||||
- `/api/ai/smooth-mask` 支持对当前 mask polygon 做后端边缘平滑,返回新的 polygon、bbox、面积和拓扑锚点;前端应用后仍走正常标注保存链路
|
||||
- `/api/ai/smooth-mask` 支持对当前 mask polygon 做后端边缘平滑,返回新的 polygon、bbox、面积和拓扑锚点;前端应用时会把结果写成新的实际 mask 几何,并走正常标注保存链路
|
||||
- `/api/ai/propagate/task` 支持从当前帧 seed 区域向视频片段创建后台传播任务:当前使用所选 SAM 2.1 变体的 `SAM2VideoPredictor.add_new_mask()` + `propagate_in_video()`;同步 `/api/ai/propagate` 仍作为单 seed 兼容接口保留
|
||||
|
||||
### 步骤 6.1: 启动 Celery Worker
|
||||
@@ -408,6 +415,10 @@ cd ~/Desktop/Seg_Server
|
||||
| MinIO 控制台 | http://localhost:9001 | minioadmin / minioadmin |
|
||||
| PostgreSQL | localhost:5432 | seguser / segpass123 |
|
||||
|
||||
后端启动时会自动种子化默认管理员 `admin / 123456`,密码以哈希形式存入 `users` 表。登录成功返回签名 JWT,前端会把 token 写入 `localStorage` 并通过 `Authorization: Bearer <token>` 调用业务接口;页面刷新后会用 `/api/auth/me` 恢复当前用户。
|
||||
|
||||
当前项目、帧、标注、任务、Dashboard 和导出接口已经按当前 JWT 用户拥有的项目隔离;模板支持系统模板(`owner_user_id IS NULL`)和用户模板。角色分为 `admin`、`annotator`、`viewer`:`admin/annotator` 可调用写入类业务接口,`viewer` 只能读取;管理员会在侧栏看到“用户管理”,可通过 `/api/admin/users` 新增、停用/启用、改角色、改密码和删除无项目用户,并通过 `/api/admin/audit-logs` 查看登录与用户管理审计。演示部署还提供“恢复演示出厂设置”,二次确认后调用 `/api/admin/demo-factory-reset`,只保留默认 admin 与一个尚未生成帧的演示视频项目。生产部署时必须在 `backend/.env` 覆盖 `JWT_SECRET_KEY` 并修改默认管理员密码。
|
||||
|
||||
---
|
||||
|
||||
## 可用命令
|
||||
@@ -498,12 +509,13 @@ pip install -e . --no-build-isolation
|
||||
- 前端 `predictMask()` 已发送后端需要的 `image_id`、`prompt_type`、`prompt_data`,并把后端 `polygons` 转成 Konva `pathData`。
|
||||
- 工作区点选/框选会使用当前帧的数据库 `frame.id` 调用 `/api/ai/predict`。
|
||||
- 工作区 SAM 2.1 交互式细化包含反向点时会启用后端背景过滤;若反向点排除了当前候选区域并返回空结果,前端会移除旧候选 mask。
|
||||
- AI 页面只显示本页最新生成的 SAM 2.1 候选,不会把工作区已有 mask 带入 AI 画布;重复执行高精度分割会替换上一次 AI 页候选;新生成 mask 会写入全局 `masks` 并自动选中,右侧分类树可直接给生成结果换标签,“推送至工作区编辑”会切回工作区的多边形调整工具并保留选择和当前帧视角,不会因工作区重新加载而跳回第一帧。
|
||||
- 工作区传播功能会使用当前打开参考帧的全部 mask 作为 seed,按用户设置的传播起始帧和传播结束帧向前/向后追踪;用户可直接修改数字框,也可先点击“自动传播”进入时间轴范围选择模式,在播放进度条或视频处理进度条上点击/拖拽选择范围,再点击“开始传播”。工作区顶栏可单独选择本次传播使用的 SAM 2.1 tiny/small/base+/large 权重,不提供 SAM2/SAM3 家族切换;前端提交传播前会先保存当前项目中的 draft/dirty mask,使 seed 优先携带稳定的后端 `source_annotation_id`,再把传播权重 id、seed、seed 来源 id、边缘平滑参数和方向组装为 `/api/ai/propagate/task` 后台任务。后端入队时会规范化/校验权重 id,并把规范化后的 id 写入任务 payload/result;worker 会按 seed 来源、方向、平滑参数和 seed 签名去重,同权重且未改变的 mask 二次传播时直接跳过,已改变、修改平滑参数或换用其他权重的 mask 会先删除同源旧自动传播标注再重传;旧版本使用前端临时 `source_mask_id` 生成的传播结果会按同一参考帧、方向和语义信息兼容清理,中间帧人工新增或修改同一物体后重新传播时,也会在写入目标帧新结果前按语义和空间重叠清理旧传播结果,且写入前清理不受旧结果传播方向限制,避免向前传播时与早先向后生成的旧 mask 叠加。带 `geometry_smoothing` 的 seed 在 forward/backward 两个方向保存前都会应用同一平滑参数。任务进度写入 `processing_tasks` 并可在 Dashboard 查看/取消/重试,工作区轮询任务状态并刷新已保存标注。传播结果回显后,视频处理进度条会把自动传播生成的帧区段标为蓝色,人工/AI 标注帧显示为红色竖线;每次自动传播成功处理过的范围会在当前会话中额外叠加不同色系的深到浅渐变片段,用于辨认最近处理过哪一段视频;普通状态下点击视频处理进度条或红/蓝帧标识可跳转到对应帧,底部缩略图也会用红色边框标识人工/AI 标注帧、蓝色边框标识传播/推理帧;如果同一帧同时有人工作业和传播结果,红色人工/AI 标注框优先保留,蓝色传播状态以内描边表达;当前帧如果同时是人工/AI 标注帧,会显示青色外框加红色内描边,固定为外层当前帧、内层标注框。
|
||||
- 前端 `exportCoco()` 已对齐到 `/api/export/{projectId}/coco`。
|
||||
- 工作区“导出 JSON 标注集”和“导出 PNG Mask ZIP”按钮已绑定下载流程;导出前会先保存当前待归档的前端 mask。
|
||||
- 工作区“导入 GT Mask”按钮已绑定 `/api/ai/import-gt-mask`,导入后会刷新并回显已保存标注和 seed point。
|
||||
- 工作区“结构化归档保存”按钮会把当前项目未保存 mask 写入 `POST /api/ai/annotate`,并把 dirty mask 写入 `PATCH /api/ai/annotations/{id}`。
|
||||
- AI 页面只显示本页最新生成的 SAM 2.1 候选,不会把工作区已有 mask 带入 AI 画布;重复执行高精度分割会替换上一次 AI 页候选;新生成 mask 会写入全局 `masks` 并自动选中,右侧分类树可直接给生成结果换标签;如果换标签的 mask 属于传播链,系统会同步更新前后帧对应传播 mask 的分类元数据;“推送至工作区编辑”会切回工作区的多边形调整工具并保留选择和当前帧视角,不会因工作区重新加载而跳回第一帧。
|
||||
- 工作区传播功能会使用当前打开参考帧的全部 mask 作为 seed,按用户设置的传播起始帧和传播结束帧向前/向后追踪;用户可直接修改数字框,也可先点击“自动传播”进入时间轴范围选择模式,在播放进度条或视频处理进度条上点击/拖拽选择范围,再点击“开始传播”。工作区顶栏可单独选择本次传播使用的 SAM 2.1 tiny/small/base+/large 权重,不提供 SAM2/SAM3 家族切换;前端提交传播前会先保存当前项目中的 draft/dirty mask,使 seed 优先携带稳定的后端 `source_annotation_id`,再把传播权重 id、seed、seed 来源 id 和方向组装为 `/api/ai/propagate/task` 后台任务。后端入队时会规范化/校验权重 id,并把规范化后的 id 写入任务 payload/result;worker 会按 seed 来源、方向和 seed 签名去重,同权重且未改变的 mask 二次传播时直接跳过,已改变或换用其他权重的 mask 会先删除同源旧自动传播标注再重传;旧版本使用前端临时 `source_mask_id` 生成的传播结果会按同一参考帧、方向和语义信息兼容清理,中间帧人工新增或修改同一物体后重新传播时,也会在写入目标帧新结果前按语义和空间重叠清理旧传播结果,且写入前清理不受旧结果传播方向限制,避免向前传播时与早先向后生成的旧 mask 叠加。若历史或外部 seed metadata 仍带 `geometry_smoothing`,后端仍会在 forward/backward 两个方向保存前应用同一平滑参数;当前工作区平滑按钮应用后会直接改写实际 polygon,因此后续传播以新几何参与签名和追踪。任务进度写入 `processing_tasks` 并可在 Dashboard 查看/取消/重试,工作区轮询任务状态并刷新已保存标注。传播结果回显后,视频处理进度条会把自动传播生成的帧区段标为蓝色,人工/AI 标注帧显示为红色竖线;每次自动传播成功处理过的范围会在当前会话中额外叠加同一蓝色系、最新传播最亮、旧传播逐次变暗且第 5 次及更早统一为阈值旧记录色的纯色片段,用于辨认第一次、第二次、第 N 次传播;普通状态下点击视频处理进度条或红/蓝帧标识可跳转到对应帧,底部缩略图也会用红色边框标识人工/AI 标注帧、蓝色边框标识传播/推理帧;如果同一帧同时有人工作业和传播结果,红色人工/AI 标注框优先保留,蓝色传播状态以内描边表达;当前帧如果同时是人工/AI 标注帧,会显示青色外框加红色内描边,固定为外层当前帧、内层标注框。
|
||||
- 右侧实例属性面板的“边缘平滑强度”滑杆会先防抖预览;点击“应用边缘平滑”后会把平滑结果作为新的实际 polygon 写入当前 mask,并同步写入同一传播链前后对应 mask,整次操作进入同一个撤销/重做历史步骤,应用后强度重置为 0,用户可继续用“调整多边形”编辑新多边形。
|
||||
- 前端 `exportCoco()` 已对齐到 `/api/export/{projectId}/coco`,`exportMasks()` 已对齐到 `/api/export/{projectId}/masks`,统一导出 `exportSegmentationResults()` 已对齐到 `/api/export/{projectId}/results`。
|
||||
- 工作区“分割结果导出”按钮已绑定下载流程;点击后可在下拉栏选择整体视频、特定范围帧或当前图片,默认选择当前图片,并勾选分开 mask、GT_label 黑白图、Pro_label 彩色图和 Mix_label 原图叠加图。选择特定范围帧时,可直接修改帧号,也可在播放进度条或视频处理进度条上拖拽选择范围;选择 Mix_label 时可调遮罩透明度,默认 0.3,并显示当前/待导出第一帧预览。导出前会先保存当前待归档的前端 mask。下载 ZIP 命名为 `{项目库项目名}_seg_T_{起始时间戳}-{结束时间戳}_P_{起始项目帧序号}-{结束项目帧序号}.zip`,项目名会替换文件系统不安全字符;时间戳来自帧 `timestampMs` 并格式化为 `0h00m00s000ms`,帧号使用项目抽帧后的 1-based 帧顺序,不使用原视频帧号。统一导出 ZIP 固定包含 `annotations_coco.json`、`maskid_GT像素值_类别映射.json` 和 `原始图片/`;GT_label 像素值使用类别真实 maskid,缺失 maskid 的旧标注才补下一个可用正整数并写入映射 JSON,因此同一模板导出的 GT_label 可直接再导入;选择分开 mask 时输出 `分开Mask分割结果/`,按帧子目录和类别 maskid 合并命名;选择 GT_label、Pro_label、Mix_label 时分别输出 `GT_label图/`、`Pro_label彩色分割结果/`、`Mix_label重叠覆盖彩色分割结果/`,重叠区域按右侧语义分类树内部优先级从低到高覆盖。
|
||||
- 工作区左侧工具栏“导入 GT Mask”按钮已绑定 `/api/ai/import-gt-mask`,入口位于“重叠区域去除”之后;选择文件后会先显示导入结果预览,并让用户决定未知 maskid 的处理方式,可“舍弃未知类别”或“导入为未定义”。后端使用 `cv2.IMREAD_UNCHANGED` 读取 mask,因此导出的低数值/16-bit `GT_label图` 能按 maskid 重新识别;导入格式限定为灰度 maskid 图或 RGB 三通道完全相同的 `[X,X,X]` maskid 图,0 为背景,X 为 maskid;尺寸不同会自动按当前帧长宽最近邻拉伸,导入后刷新并回显已保存标注和 seed point。
|
||||
- 工作区保存状态按钮会按当前项目待保存数量显示“保存 X 个改动”或“已全部保存”;点击后会把未保存 mask 写入 `POST /api/ai/annotate`,并把 dirty mask 写入 `PATCH /api/ai/annotations/{id}`。
|
||||
- 工作区“清空遮罩”会通过 `DELETE /api/ai/annotations/{id}` 删除当前帧已保存标注,并清空当前帧本地 mask。
|
||||
|
||||
**验证**:
|
||||
|
||||
@@ -33,6 +33,12 @@ class Settings(BaseSettings):
|
||||
# App
|
||||
app_env: str = "development"
|
||||
cors_origins: list[str] = ["http://localhost:3000", "http://192.168.3.11:3000"]
|
||||
jwt_secret_key: str = "seg-server-dev-secret-change-me"
|
||||
jwt_algorithm: str = "HS256"
|
||||
access_token_expire_minutes: int = 60 * 24
|
||||
default_admin_username: str = "admin"
|
||||
default_admin_password: str = "123456"
|
||||
demo_video_path: str = "/home/wkmgc/Desktop/Seg_Server/Data_MyVideo_1.mp4"
|
||||
|
||||
class Config:
|
||||
env_file = ".env"
|
||||
|
||||
@@ -20,7 +20,7 @@ from progress_events import PROGRESS_CHANNEL
|
||||
from redis_client import get_redis_client, ping as redis_ping
|
||||
from statuses import PROJECT_STATUS_PENDING, PROJECT_STATUS_READY
|
||||
|
||||
from routers import projects, templates, media, ai, export, auth, dashboard, tasks
|
||||
from routers import projects, templates, media, ai, export, auth, dashboard, tasks, admin
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
@@ -28,7 +28,7 @@ logging.basicConfig(
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_VIDEO_PATH = "/home/wkmgc/Desktop/Seg_Server/Data_MyVideo_1.mp4"
|
||||
DEFAULT_VIDEO_PATH = settings.demo_video_path
|
||||
|
||||
|
||||
def _ensure_runtime_schema_columns() -> None:
|
||||
@@ -36,25 +36,56 @@ def _ensure_runtime_schema_columns() -> None:
|
||||
try:
|
||||
inspector = inspect(engine)
|
||||
frame_columns = {column["name"] for column in inspector.get_columns("frames")}
|
||||
project_columns = {column["name"] for column in inspector.get_columns("projects")}
|
||||
template_columns = {column["name"] for column in inspector.get_columns("templates")}
|
||||
with engine.begin() as connection:
|
||||
if "timestamp_ms" not in frame_columns:
|
||||
connection.execute(text("ALTER TABLE frames ADD COLUMN timestamp_ms FLOAT"))
|
||||
if "source_frame_number" not in frame_columns:
|
||||
connection.execute(text("ALTER TABLE frames ADD COLUMN source_frame_number INTEGER"))
|
||||
if "owner_user_id" not in project_columns:
|
||||
connection.execute(text("ALTER TABLE projects ADD COLUMN owner_user_id INTEGER"))
|
||||
if "owner_user_id" not in template_columns:
|
||||
connection.execute(text("ALTER TABLE templates ADD COLUMN owner_user_id INTEGER"))
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.warning("Runtime schema column check failed: %s", exc)
|
||||
|
||||
|
||||
def _seed_default_admin_and_ownership_sync() -> None:
|
||||
"""Ensure the default admin exists and owns legacy unassigned projects."""
|
||||
from models import Project
|
||||
from routers.auth import ensure_default_admin
|
||||
|
||||
db = SessionLocal()
|
||||
try:
|
||||
admin = ensure_default_admin(db)
|
||||
db.query(Project).filter(Project.owner_user_id.is_(None)).update(
|
||||
{"owner_user_id": admin.id},
|
||||
synchronize_session=False,
|
||||
)
|
||||
db.commit()
|
||||
logger.info("Default admin ready; legacy projects assigned to user id=%s", admin.id)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.error("Failed to seed default admin or ownership: %s", exc)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
def _seed_default_project_sync() -> None:
|
||||
"""Synchronously seed the default video project on first startup."""
|
||||
import cv2
|
||||
from models import Project, Frame
|
||||
from routers.auth import ensure_default_admin
|
||||
from services.frame_parser import parse_video, upload_frames_to_minio, extract_thumbnail
|
||||
|
||||
db = SessionLocal()
|
||||
try:
|
||||
admin = ensure_default_admin(db)
|
||||
existing = db.query(Project).filter(Project.name == "Data_MyVideo_1").first()
|
||||
if existing is not None:
|
||||
if existing.owner_user_id is None:
|
||||
existing.owner_user_id = admin.id
|
||||
db.commit()
|
||||
return
|
||||
|
||||
if not os.path.exists(DEFAULT_VIDEO_PATH):
|
||||
@@ -67,6 +98,7 @@ def _seed_default_project_sync() -> None:
|
||||
status=PROJECT_STATUS_PENDING,
|
||||
source_type="video",
|
||||
parse_fps=30.0,
|
||||
owner_user_id=admin.id,
|
||||
)
|
||||
db.add(project)
|
||||
db.commit()
|
||||
@@ -196,6 +228,7 @@ async def lifespan(app: FastAPI):
|
||||
try:
|
||||
Base.metadata.create_all(bind=engine)
|
||||
_ensure_runtime_schema_columns()
|
||||
_seed_default_admin_and_ownership_sync()
|
||||
logger.info("Database tables initialized.")
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.error("Database initialization failed: %s", exc)
|
||||
@@ -265,6 +298,7 @@ app.include_router(ai.router)
|
||||
app.include_router(export.router)
|
||||
app.include_router(dashboard.router)
|
||||
app.include_router(tasks.router)
|
||||
app.include_router(admin.router)
|
||||
|
||||
|
||||
@app.get("/health", tags=["Health"])
|
||||
|
||||
@@ -17,6 +17,25 @@ from database import Base
|
||||
from statuses import PROJECT_STATUS_PENDING
|
||||
|
||||
|
||||
class User(Base):
|
||||
"""Application user used for authentication and data ownership."""
|
||||
|
||||
__tablename__ = "users"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
username = Column(String(150), unique=True, index=True, nullable=False)
|
||||
password_hash = Column(String(255), nullable=False)
|
||||
role = Column(String(50), default="admin", nullable=False)
|
||||
is_active = Column(Integer, default=1, nullable=False)
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
updated_at = Column(
|
||||
DateTime(timezone=True), server_default=func.now(), onupdate=func.now()
|
||||
)
|
||||
|
||||
projects = relationship("Project", back_populates="owner")
|
||||
templates = relationship("Template", back_populates="owner")
|
||||
|
||||
|
||||
class Project(Base):
|
||||
"""Project model representing a segmentation project."""
|
||||
|
||||
@@ -31,11 +50,13 @@ class Project(Base):
|
||||
source_type = Column(String(20), default="video", nullable=False) # video | dicom
|
||||
original_fps = Column(Float, nullable=True)
|
||||
parse_fps = Column(Float, default=30.0, nullable=False)
|
||||
owner_user_id = Column(Integer, ForeignKey("users.id", ondelete="SET NULL"), nullable=True)
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
updated_at = Column(
|
||||
DateTime(timezone=True), server_default=func.now(), onupdate=func.now()
|
||||
)
|
||||
|
||||
owner = relationship("User", back_populates="projects")
|
||||
frames = relationship("Frame", back_populates="project", cascade="all, delete-orphan")
|
||||
annotations = relationship(
|
||||
"Annotation", back_populates="project", cascade="all, delete-orphan"
|
||||
@@ -77,8 +98,10 @@ class Template(Base):
|
||||
color = Column(String(50), nullable=False)
|
||||
z_index = Column(Integer, default=0, nullable=False)
|
||||
mapping_rules = Column(JSON, nullable=True)
|
||||
owner_user_id = Column(Integer, ForeignKey("users.id", ondelete="SET NULL"), nullable=True)
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
|
||||
owner = relationship("User", back_populates="templates")
|
||||
annotations = relationship(
|
||||
"Annotation", back_populates="template", cascade="all, delete-orphan"
|
||||
)
|
||||
@@ -129,6 +152,22 @@ class Mask(Base):
|
||||
annotation = relationship("Annotation", back_populates="masks")
|
||||
|
||||
|
||||
class AuditLog(Base):
|
||||
"""Audit trail for security and administrative actions."""
|
||||
|
||||
__tablename__ = "audit_logs"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
actor_user_id = Column(Integer, ForeignKey("users.id", ondelete="SET NULL"), nullable=True)
|
||||
action = Column(String(120), nullable=False)
|
||||
target_type = Column(String(80), nullable=True)
|
||||
target_id = Column(String(120), nullable=True)
|
||||
detail = Column(JSON, nullable=True)
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
|
||||
actor = relationship("User")
|
||||
|
||||
|
||||
class ProcessingTask(Base):
|
||||
"""Background task state persisted for dashboard and polling."""
|
||||
|
||||
|
||||
270
backend/routers/admin.py
Normal file
270
backend/routers/admin.py
Normal file
@@ -0,0 +1,270 @@
|
||||
"""Administrator-only user and audit management endpoints."""
|
||||
|
||||
import os
|
||||
from typing import List
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from config import settings
|
||||
from database import get_db
|
||||
from minio_client import upload_file
|
||||
from models import Annotation, AuditLog, Frame, Mask, ProcessingTask, Project, Template, User
|
||||
from routers.auth import ensure_default_admin, hash_password, require_admin, write_audit_log
|
||||
from schemas import (
|
||||
AdminUserCreate,
|
||||
AdminUserUpdate,
|
||||
AuditLogOut,
|
||||
DemoFactoryResetOut,
|
||||
DemoFactoryResetRequest,
|
||||
UserOut,
|
||||
)
|
||||
from statuses import PROJECT_STATUS_PENDING
|
||||
|
||||
router = APIRouter(prefix="/api/admin", tags=["Admin"])
|
||||
|
||||
VALID_ROLES = {"admin", "annotator", "viewer"}
|
||||
DEMO_RESET_CONFIRMATION = "RESET_DEMO_FACTORY"
|
||||
DEMO_PROJECT_NAME = "Data_MyVideo_1"
|
||||
|
||||
|
||||
def _normalize_role(role: str | None) -> str:
|
||||
normalized = (role or "annotator").strip().lower()
|
||||
if normalized not in VALID_ROLES:
|
||||
raise HTTPException(status_code=400, detail=f"Unsupported role: {role}")
|
||||
return normalized
|
||||
|
||||
|
||||
@router.get("/users", response_model=List[UserOut], summary="List users")
|
||||
def list_users(
|
||||
db: Session = Depends(get_db),
|
||||
admin_user: User = Depends(require_admin),
|
||||
) -> List[User]:
|
||||
"""Return all users for the administrator console."""
|
||||
_ = admin_user
|
||||
return db.query(User).order_by(User.id).all()
|
||||
|
||||
|
||||
@router.post(
|
||||
"/users",
|
||||
response_model=UserOut,
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
summary="Create user",
|
||||
)
|
||||
def create_user(
|
||||
payload: AdminUserCreate,
|
||||
db: Session = Depends(get_db),
|
||||
admin_user: User = Depends(require_admin),
|
||||
) -> User:
|
||||
"""Create a user with an initial password and role."""
|
||||
username = payload.username.strip()
|
||||
if not username:
|
||||
raise HTTPException(status_code=400, detail="Username is required")
|
||||
if len(payload.password) < 6:
|
||||
raise HTTPException(status_code=400, detail="Password must be at least 6 characters")
|
||||
user = User(
|
||||
username=username,
|
||||
password_hash=hash_password(payload.password),
|
||||
role=_normalize_role(payload.role),
|
||||
is_active=1 if payload.is_active else 0,
|
||||
)
|
||||
db.add(user)
|
||||
try:
|
||||
db.commit()
|
||||
except IntegrityError as exc:
|
||||
db.rollback()
|
||||
raise HTTPException(status_code=409, detail="Username already exists") from exc
|
||||
db.refresh(user)
|
||||
write_audit_log(
|
||||
db,
|
||||
actor=admin_user,
|
||||
action="admin.user_created",
|
||||
target_type="user",
|
||||
target_id=user.id,
|
||||
detail={"username": user.username, "role": user.role, "is_active": bool(user.is_active)},
|
||||
)
|
||||
return user
|
||||
|
||||
|
||||
@router.patch("/users/{user_id}", response_model=UserOut, summary="Update user")
|
||||
def update_user(
|
||||
user_id: int,
|
||||
payload: AdminUserUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
admin_user: User = Depends(require_admin),
|
||||
) -> User:
|
||||
"""Update username, password, role or active state."""
|
||||
user = db.query(User).filter(User.id == user_id).first()
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
updates = payload.model_dump(exclude_unset=True)
|
||||
audit_detail: dict = {"before": {"username": user.username, "role": user.role, "is_active": bool(user.is_active)}}
|
||||
if "username" in updates:
|
||||
username = (updates["username"] or "").strip()
|
||||
if not username:
|
||||
raise HTTPException(status_code=400, detail="Username is required")
|
||||
user.username = username
|
||||
if "password" in updates:
|
||||
password = updates["password"] or ""
|
||||
if len(password) < 6:
|
||||
raise HTTPException(status_code=400, detail="Password must be at least 6 characters")
|
||||
user.password_hash = hash_password(password)
|
||||
if "role" in updates:
|
||||
next_role = _normalize_role(updates["role"])
|
||||
if user.id == admin_user.id and next_role != "admin":
|
||||
raise HTTPException(status_code=400, detail="Cannot remove your own admin role")
|
||||
user.role = next_role
|
||||
if "is_active" in updates:
|
||||
if user.id == admin_user.id and not updates["is_active"]:
|
||||
raise HTTPException(status_code=400, detail="Cannot deactivate yourself")
|
||||
user.is_active = 1 if updates["is_active"] else 0
|
||||
|
||||
try:
|
||||
db.commit()
|
||||
except IntegrityError as exc:
|
||||
db.rollback()
|
||||
raise HTTPException(status_code=409, detail="Username already exists") from exc
|
||||
db.refresh(user)
|
||||
audit_detail["after"] = {"username": user.username, "role": user.role, "is_active": bool(user.is_active)}
|
||||
audit_detail["password_changed"] = "password" in updates
|
||||
write_audit_log(
|
||||
db,
|
||||
actor=admin_user,
|
||||
action="admin.user_updated",
|
||||
target_type="user",
|
||||
target_id=user.id,
|
||||
detail=audit_detail,
|
||||
)
|
||||
return user
|
||||
|
||||
|
||||
@router.delete("/users/{user_id}", status_code=status.HTTP_204_NO_CONTENT, summary="Delete user")
|
||||
def delete_user(
|
||||
user_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
admin_user: User = Depends(require_admin),
|
||||
) -> None:
|
||||
"""Delete a user when it is safe to remove the account."""
|
||||
if user_id == admin_user.id:
|
||||
raise HTTPException(status_code=400, detail="Cannot delete yourself")
|
||||
user = db.query(User).filter(User.id == user_id).first()
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
owned_project_count = db.query(Project).filter(Project.owner_user_id == user_id).count()
|
||||
if owned_project_count:
|
||||
raise HTTPException(status_code=409, detail="User owns projects; deactivate or migrate projects first")
|
||||
username = user.username
|
||||
db.delete(user)
|
||||
db.commit()
|
||||
write_audit_log(
|
||||
db,
|
||||
actor=admin_user,
|
||||
action="admin.user_deleted",
|
||||
target_type="user",
|
||||
target_id=user_id,
|
||||
detail={"username": username},
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
@router.get("/audit-logs", response_model=List[AuditLogOut], summary="List audit logs")
|
||||
def list_audit_logs(
|
||||
limit: int = 100,
|
||||
db: Session = Depends(get_db),
|
||||
admin_user: User = Depends(require_admin),
|
||||
) -> List[AuditLog]:
|
||||
"""Return recent audit events for administrators."""
|
||||
_ = admin_user
|
||||
safe_limit = min(max(int(limit or 100), 1), 500)
|
||||
return db.query(AuditLog).order_by(AuditLog.created_at.desc(), AuditLog.id.desc()).limit(safe_limit).all()
|
||||
|
||||
|
||||
@router.post(
|
||||
"/demo-factory-reset",
|
||||
response_model=DemoFactoryResetOut,
|
||||
summary="Reset demo data to factory defaults",
|
||||
)
|
||||
def reset_demo_factory(
|
||||
payload: DemoFactoryResetRequest,
|
||||
db: Session = Depends(get_db),
|
||||
admin_user: User = Depends(require_admin),
|
||||
) -> dict:
|
||||
"""Reset a demo deployment to one admin account and one unparsed demo video project."""
|
||||
if payload.confirmation != DEMO_RESET_CONFIRMATION:
|
||||
raise HTTPException(status_code=400, detail="Invalid reset confirmation")
|
||||
|
||||
if not os.path.exists(settings.demo_video_path):
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail=f"Demo video not found: {settings.demo_video_path}",
|
||||
)
|
||||
|
||||
requested_by = admin_user.username
|
||||
preserved_admin = ensure_default_admin(db)
|
||||
preserved_admin.username = settings.default_admin_username
|
||||
preserved_admin.password_hash = hash_password(settings.default_admin_password)
|
||||
preserved_admin.role = "admin"
|
||||
preserved_admin.is_active = 1
|
||||
db.flush()
|
||||
|
||||
deleted_counts = {
|
||||
"masks": db.query(Mask).delete(synchronize_session=False),
|
||||
"annotations": db.query(Annotation).delete(synchronize_session=False),
|
||||
"frames": db.query(Frame).delete(synchronize_session=False),
|
||||
"tasks": db.query(ProcessingTask).delete(synchronize_session=False),
|
||||
"projects": db.query(Project).delete(synchronize_session=False),
|
||||
"user_templates": db.query(Template).filter(Template.owner_user_id.is_not(None)).delete(synchronize_session=False),
|
||||
"audit_logs": db.query(AuditLog).delete(synchronize_session=False),
|
||||
"users": db.query(User).filter(User.id != preserved_admin.id).delete(synchronize_session=False),
|
||||
}
|
||||
db.flush()
|
||||
db.expunge_all()
|
||||
|
||||
preserved_admin = db.query(User).filter(User.username == settings.default_admin_username).first()
|
||||
if not preserved_admin:
|
||||
raise HTTPException(status_code=500, detail="Default admin was not preserved")
|
||||
|
||||
project = Project(
|
||||
name=DEMO_PROJECT_NAME,
|
||||
description="默认演示视频,尚未生成帧",
|
||||
status=PROJECT_STATUS_PENDING,
|
||||
source_type="video",
|
||||
parse_fps=30.0,
|
||||
owner_user_id=preserved_admin.id,
|
||||
)
|
||||
db.add(project)
|
||||
db.flush()
|
||||
|
||||
with open(settings.demo_video_path, "rb") as file_obj:
|
||||
data = file_obj.read()
|
||||
object_name = f"uploads/{project.id}/{os.path.basename(settings.demo_video_path)}"
|
||||
upload_file(object_name, data, content_type="video/mp4", length=len(data))
|
||||
project.video_path = object_name
|
||||
project.thumbnail_url = None
|
||||
project.original_fps = None
|
||||
db.commit()
|
||||
db.refresh(preserved_admin)
|
||||
db.refresh(project)
|
||||
|
||||
write_audit_log(
|
||||
db,
|
||||
actor=preserved_admin,
|
||||
action="admin.demo_factory_reset",
|
||||
target_type="project",
|
||||
target_id=project.id,
|
||||
detail={
|
||||
"project_name": project.name,
|
||||
"video_path": project.video_path,
|
||||
"deleted_counts": deleted_counts,
|
||||
"requested_by": requested_by,
|
||||
},
|
||||
)
|
||||
|
||||
return {
|
||||
"admin_user": preserved_admin,
|
||||
"project": project,
|
||||
"deleted_counts": deleted_counts,
|
||||
"message": "演示环境已恢复出厂设置",
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
"""AI inference endpoints using selectable SAM runtimes."""
|
||||
|
||||
import logging
|
||||
import math
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Any, List
|
||||
@@ -8,11 +9,13 @@ from typing import Any, List
|
||||
import cv2
|
||||
import numpy as np
|
||||
from fastapi import APIRouter, Depends, File, Form, HTTPException, Response, UploadFile, status
|
||||
from sqlalchemy import or_
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from database import get_db
|
||||
from minio_client import download_file
|
||||
from models import Project, Frame, Template, Annotation, ProcessingTask
|
||||
from models import Project, Frame, Template, Annotation, ProcessingTask, User
|
||||
from routers.auth import get_current_user, require_editor
|
||||
from schemas import (
|
||||
AiRuntimeStatus,
|
||||
MaskAnalysisRequest,
|
||||
@@ -38,6 +41,102 @@ logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/api/ai", tags=["AI"])
|
||||
|
||||
|
||||
def _owned_project_or_404(project_id: int, db: Session, current_user: User) -> Project:
|
||||
project = db.query(Project).filter(
|
||||
Project.id == project_id,
|
||||
Project.owner_user_id == current_user.id,
|
||||
).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
return project
|
||||
|
||||
|
||||
def _owned_frame_or_404(frame_id: int, db: Session, current_user: User, project_id: int | None = None) -> Frame:
|
||||
query = (
|
||||
db.query(Frame)
|
||||
.join(Project, Project.id == Frame.project_id)
|
||||
.filter(Frame.id == frame_id, Project.owner_user_id == current_user.id)
|
||||
)
|
||||
if project_id is not None:
|
||||
query = query.filter(Frame.project_id == project_id)
|
||||
frame = query.first()
|
||||
if not frame:
|
||||
raise HTTPException(status_code=404, detail="Frame not found")
|
||||
return frame
|
||||
|
||||
|
||||
def _visible_template_or_404(template_id: int, db: Session, current_user: User) -> Template:
|
||||
template = db.query(Template).filter(
|
||||
Template.id == template_id,
|
||||
or_(Template.owner_user_id == current_user.id, Template.owner_user_id.is_(None)),
|
||||
).first()
|
||||
if not template:
|
||||
raise HTTPException(status_code=404, detail="Template not found")
|
||||
return template
|
||||
|
||||
|
||||
def _normalize_hex_color(value: Any) -> str | None:
|
||||
if not isinstance(value, str):
|
||||
return None
|
||||
text = value.strip().lower()
|
||||
if not text:
|
||||
return None
|
||||
if not text.startswith("#"):
|
||||
text = f"#{text}"
|
||||
if len(text) == 4:
|
||||
text = "#" + "".join(char * 2 for char in text[1:])
|
||||
if len(text) != 7:
|
||||
return None
|
||||
try:
|
||||
int(text[1:], 16)
|
||||
except ValueError:
|
||||
return None
|
||||
return text
|
||||
|
||||
|
||||
def _rgb_tuple_to_hex(rgb: tuple[int, int, int]) -> str:
|
||||
values = []
|
||||
for channel in rgb:
|
||||
value = int(channel)
|
||||
if value > 255:
|
||||
value = int(round(value / 257))
|
||||
values.append(min(max(value, 0), 255))
|
||||
return f"#{values[0]:02x}{values[1]:02x}{values[2]:02x}"
|
||||
|
||||
|
||||
def _template_class_maps(template: Template | None) -> tuple[dict[int, dict[str, Any]], dict[str, dict[str, Any]]]:
|
||||
classes = ((template.mapping_rules or {}).get("classes") if template else None) or []
|
||||
by_maskid: dict[int, dict[str, Any]] = {}
|
||||
by_color: dict[str, dict[str, Any]] = {}
|
||||
for index, item in enumerate(classes):
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
maskid_value = item.get("maskId", item.get("maskid", item.get("mask_id")))
|
||||
try:
|
||||
maskid = int(maskid_value)
|
||||
except (TypeError, ValueError):
|
||||
maskid = index + 1
|
||||
color = _normalize_hex_color(item.get("color")) or "#22c55e"
|
||||
class_meta = {
|
||||
"id": str(item.get("id") or f"maskid-{maskid}"),
|
||||
"name": str(item.get("name") or f"类别 {maskid}"),
|
||||
"color": color,
|
||||
"zIndex": int(item.get("zIndex", item.get("z_index", index * 10))),
|
||||
"maskId": maskid,
|
||||
**({"category": item.get("category")} if item.get("category") else {}),
|
||||
}
|
||||
if maskid > 0:
|
||||
by_maskid[maskid] = class_meta
|
||||
by_color[color] = class_meta
|
||||
return by_maskid, by_color
|
||||
|
||||
|
||||
def _gt_unknown_label(token: int | str) -> str:
|
||||
if isinstance(token, int):
|
||||
return f"未定义类别 {token}"
|
||||
return f"未定义颜色 {token}"
|
||||
|
||||
|
||||
def _load_frame_image(frame: Frame) -> np.ndarray:
|
||||
"""Download a frame from MinIO and decode it to an RGB numpy array."""
|
||||
try:
|
||||
@@ -106,16 +205,20 @@ def _normalize_polygons(polygons: list[list[list[float]]]) -> list[list[list[flo
|
||||
return [polygon for polygon in (_normalize_polygon(polygon) for polygon in polygons) if len(polygon) >= 3]
|
||||
|
||||
|
||||
def _analysis_anchors(polygons: list[list[list[float]]], points: list[list[float]] | None) -> list[list[float]]:
|
||||
if points:
|
||||
return [[_clamp01(point[0]), _clamp01(point[1])] for point in points if len(point) >= 2]
|
||||
def _sample_anchor_points(anchors: list[list[float]], limit: int = 64) -> list[list[float]]:
|
||||
if len(anchors) <= limit:
|
||||
return anchors
|
||||
step = max(1, math.ceil(len(anchors) / limit))
|
||||
return anchors[::step][:limit]
|
||||
|
||||
|
||||
def _analysis_anchor_summary(polygons: list[list[list[float]]]) -> tuple[int, list[list[float]]]:
|
||||
anchors: list[list[float]] = []
|
||||
for polygon in polygons:
|
||||
if not polygon:
|
||||
continue
|
||||
step = max(1, len(polygon) // 12)
|
||||
anchors.extend([[_clamp01(point[0]), _clamp01(point[1])] for point in polygon[::step]])
|
||||
return anchors[:32]
|
||||
anchors.extend([[_clamp01(point[0]), _clamp01(point[1])] for point in polygon])
|
||||
return len(anchors), _sample_anchor_points(anchors)
|
||||
|
||||
|
||||
def _normalize_smoothing_options(strength: float | int | None, method: str | None = None) -> dict[str, Any]:
|
||||
@@ -129,8 +232,14 @@ def _normalize_smoothing_options(strength: float | int | None, method: str | Non
|
||||
}
|
||||
|
||||
|
||||
def _chaikin_smooth_polygon(polygon: list[list[float]], iterations: int) -> list[list[float]]:
|
||||
def _smoothing_ratio(strength: float, curve: float = 1.65) -> float:
|
||||
normalized = max(0.0, min(float(strength or 0.0), 100.0)) / 100.0
|
||||
return normalized ** curve
|
||||
|
||||
|
||||
def _chaikin_smooth_polygon(polygon: list[list[float]], iterations: int, corner_cut: float = 0.25) -> list[list[float]]:
|
||||
points = polygon
|
||||
q = max(0.02, min(float(corner_cut), 0.25))
|
||||
for _ in range(max(0, iterations)):
|
||||
if len(points) < 3:
|
||||
break
|
||||
@@ -138,12 +247,12 @@ def _chaikin_smooth_polygon(polygon: list[list[float]], iterations: int) -> list
|
||||
for index, current in enumerate(points):
|
||||
following = points[(index + 1) % len(points)]
|
||||
next_points.append([
|
||||
_clamp01(0.75 * current[0] + 0.25 * following[0]),
|
||||
_clamp01(0.75 * current[1] + 0.25 * following[1]),
|
||||
_clamp01((1.0 - q) * current[0] + q * following[0]),
|
||||
_clamp01((1.0 - q) * current[1] + q * following[1]),
|
||||
])
|
||||
next_points.append([
|
||||
_clamp01(0.25 * current[0] + 0.75 * following[0]),
|
||||
_clamp01(0.25 * current[1] + 0.75 * following[1]),
|
||||
_clamp01(q * current[0] + (1.0 - q) * following[0]),
|
||||
_clamp01(q * current[1] + (1.0 - q) * following[1]),
|
||||
])
|
||||
points = next_points
|
||||
return points
|
||||
@@ -154,7 +263,7 @@ def _simplify_polygon(polygon: list[list[float]], strength: float) -> list[list[
|
||||
return polygon
|
||||
contour = np.array([[[point[0], point[1]]] for point in polygon], dtype=np.float32)
|
||||
arc_length = cv2.arcLength(contour, True)
|
||||
epsilon = arc_length * (0.001 + (strength / 100.0) * 0.006)
|
||||
epsilon = arc_length * (0.00015 + _smoothing_ratio(strength) * 0.00735)
|
||||
approx = cv2.approxPolyDP(contour, epsilon, True).reshape(-1, 2)
|
||||
if len(approx) < 3:
|
||||
return polygon
|
||||
@@ -165,9 +274,25 @@ def _smooth_polygon(polygon: list[list[float]], smoothing: dict[str, Any]) -> li
|
||||
strength = float(smoothing.get("strength") or 0.0)
|
||||
if strength <= 0:
|
||||
return _normalize_polygon(polygon)
|
||||
iterations = max(1, min(3, int(strength // 35) + 1))
|
||||
smoothed = _chaikin_smooth_polygon(_normalize_polygon(polygon), iterations)
|
||||
simplified = _simplify_polygon(smoothed, strength)
|
||||
effective_strength = _smoothing_ratio(strength, curve=1.45) * 100.0
|
||||
if effective_strength >= 85:
|
||||
iterations = 4
|
||||
elif effective_strength >= 55:
|
||||
iterations = 3
|
||||
elif effective_strength >= 25:
|
||||
iterations = 2
|
||||
else:
|
||||
iterations = 1
|
||||
corner_cut = 0.03 + _smoothing_ratio(strength, curve=1.35) * 0.22
|
||||
normalized = _normalize_polygon(polygon)
|
||||
pre_simplified = _simplify_polygon(normalized, effective_strength * 0.25)
|
||||
smoothed = _chaikin_smooth_polygon(pre_simplified, iterations, corner_cut)
|
||||
simplified = _simplify_polygon(smoothed, effective_strength)
|
||||
if len(simplified) > len(normalized):
|
||||
for fallback_strength in (25.0, 35.0, 50.0, 70.0, 90.0, 100.0):
|
||||
simplified = _simplify_polygon(simplified, max(effective_strength, fallback_strength))
|
||||
if len(simplified) <= len(normalized):
|
||||
break
|
||||
return simplified if len(simplified) >= 3 else _normalize_polygon(polygon)
|
||||
|
||||
|
||||
@@ -321,7 +446,11 @@ def _filter_predictions(
|
||||
response_model=PredictResponse,
|
||||
summary="Run SAM inference with a prompt",
|
||||
)
|
||||
def predict(payload: PredictRequest, db: Session = Depends(get_db)) -> dict:
|
||||
def predict(
|
||||
payload: PredictRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_editor),
|
||||
) -> dict:
|
||||
"""Execute selected SAM segmentation given an image and a prompt.
|
||||
|
||||
- **point**: `prompt_data` is either a list of `[[x, y], ...]` normalized
|
||||
@@ -330,9 +459,7 @@ def predict(payload: PredictRequest, db: Session = Depends(get_db)) -> dict:
|
||||
- **interactive**: `prompt_data` is `{ "box": [...], "points": [[x, y]], "labels": [1, 0] }`.
|
||||
- **semantic**: disabled in the current SAM 2.1 point/box product flow.
|
||||
"""
|
||||
frame = db.query(Frame).filter(Frame.id == payload.image_id).first()
|
||||
if not frame:
|
||||
raise HTTPException(status_code=404, detail="Frame not found")
|
||||
frame = _owned_frame_or_404(payload.image_id, db, current_user)
|
||||
|
||||
image = _load_frame_image(frame)
|
||||
prompt_type = payload.prompt_type.lower()
|
||||
@@ -478,7 +605,10 @@ def predict(payload: PredictRequest, db: Session = Depends(get_db)) -> dict:
|
||||
response_model=AiRuntimeStatus,
|
||||
summary="Get SAM model and GPU runtime status",
|
||||
)
|
||||
def model_status(selected_model: str | None = None) -> dict:
|
||||
def model_status(
|
||||
selected_model: str | None = None,
|
||||
_current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
"""Return real runtime availability for GPU and the currently enabled SAM model."""
|
||||
try:
|
||||
return sam_registry.runtime_status(selected_model)
|
||||
@@ -491,12 +621,14 @@ def model_status(selected_model: str | None = None) -> dict:
|
||||
response_model=MaskAnalysisResponse,
|
||||
summary="Analyze mask geometry and prompt anchors",
|
||||
)
|
||||
def analyze_mask(payload: MaskAnalysisRequest, db: Session = Depends(get_db)) -> dict:
|
||||
def analyze_mask(
|
||||
payload: MaskAnalysisRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
"""Return backend-computed mask properties for the frontend inspector."""
|
||||
if payload.frame_id is not None:
|
||||
frame = db.query(Frame).filter(Frame.id == payload.frame_id).first()
|
||||
if not frame:
|
||||
raise HTTPException(status_code=404, detail="Frame not found")
|
||||
_owned_frame_or_404(payload.frame_id, db, current_user)
|
||||
|
||||
mask_data = payload.mask_data or {}
|
||||
polygons = mask_data.get("polygons") or []
|
||||
@@ -521,13 +653,13 @@ def analyze_mask(payload: MaskAnalysisRequest, db: Session = Depends(get_db)) ->
|
||||
else:
|
||||
confidence_source = "manual_or_imported"
|
||||
|
||||
anchors = _analysis_anchors(valid_polygons, payload.points)
|
||||
anchor_count, anchors = _analysis_anchor_summary(valid_polygons)
|
||||
message = "已从后端重新提取几何拓扑锚点" if payload.extract_skeleton else "已读取后端几何属性"
|
||||
|
||||
return {
|
||||
"confidence": confidence,
|
||||
"confidence_source": confidence_source,
|
||||
"topology_anchor_count": len(anchors),
|
||||
"topology_anchor_count": anchor_count,
|
||||
"topology_anchors": anchors,
|
||||
"area": area,
|
||||
"bbox": bbox,
|
||||
@@ -541,16 +673,18 @@ def analyze_mask(payload: MaskAnalysisRequest, db: Session = Depends(get_db)) ->
|
||||
response_model=SmoothMaskResponse,
|
||||
summary="Smooth editable mask polygons with backend geometry rules",
|
||||
)
|
||||
def smooth_mask(payload: SmoothMaskRequest, db: Session = Depends(get_db)) -> dict:
|
||||
def smooth_mask(
|
||||
payload: SmoothMaskRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_editor),
|
||||
) -> dict:
|
||||
"""Return a smoothed polygon mask without persisting it.
|
||||
|
||||
The frontend keeps this as an explicit edit operation: users preview/apply it
|
||||
to the current mask, then save through the normal annotation endpoint.
|
||||
"""
|
||||
if payload.frame_id is not None:
|
||||
frame = db.query(Frame).filter(Frame.id == payload.frame_id).first()
|
||||
if not frame:
|
||||
raise HTTPException(status_code=404, detail="Frame not found")
|
||||
_owned_frame_or_404(payload.frame_id, db, current_user)
|
||||
|
||||
polygons = payload.mask_data.get("polygons") or []
|
||||
valid_polygons = _normalize_polygons(polygons)
|
||||
@@ -564,10 +698,10 @@ def smooth_mask(payload: SmoothMaskRequest, db: Session = Depends(get_db)) -> di
|
||||
|
||||
area = sum(_polygon_area(polygon) for polygon in smoothed_polygons)
|
||||
bbox = _polygon_bbox(smoothed_polygons[0])
|
||||
anchors = _analysis_anchors(smoothed_polygons, payload.points)
|
||||
anchor_count, anchors = _analysis_anchor_summary(smoothed_polygons)
|
||||
return {
|
||||
"polygons": smoothed_polygons,
|
||||
"topology_anchor_count": len(anchors),
|
||||
"topology_anchor_count": anchor_count,
|
||||
"topology_anchors": anchors,
|
||||
"area": area,
|
||||
"bbox": bbox,
|
||||
@@ -581,7 +715,11 @@ def smooth_mask(payload: SmoothMaskRequest, db: Session = Depends(get_db)) -> di
|
||||
response_model=PropagateResponse,
|
||||
summary="Propagate one current-frame region across a video frame segment",
|
||||
)
|
||||
def propagate(payload: PropagateRequest, db: Session = Depends(get_db)) -> dict:
|
||||
def propagate(
|
||||
payload: PropagateRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_editor),
|
||||
) -> dict:
|
||||
"""Track one selected region from the current frame across nearby frames.
|
||||
|
||||
SAM 2 uses the official video predictor with the selected mask as the seed.
|
||||
@@ -592,16 +730,8 @@ def propagate(payload: PropagateRequest, db: Session = Depends(get_db)) -> dict:
|
||||
raise HTTPException(status_code=400, detail="direction must be forward, backward, or both")
|
||||
max_frames = max(1, min(int(payload.max_frames or 30), 500))
|
||||
|
||||
project = db.query(Project).filter(Project.id == payload.project_id).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
|
||||
source_frame = db.query(Frame).filter(
|
||||
Frame.id == payload.frame_id,
|
||||
Frame.project_id == payload.project_id,
|
||||
).first()
|
||||
if not source_frame:
|
||||
raise HTTPException(status_code=404, detail="Frame not found")
|
||||
_owned_project_or_404(payload.project_id, db, current_user)
|
||||
source_frame = _owned_frame_or_404(payload.frame_id, db, current_user, payload.project_id)
|
||||
|
||||
seed = payload.seed.model_dump(exclude_none=True)
|
||||
polygons = seed.get("polygons") or []
|
||||
@@ -709,18 +839,14 @@ def propagate(payload: PropagateRequest, db: Session = Depends(get_db)) -> dict:
|
||||
response_model=ProcessingTaskOut,
|
||||
summary="Queue a background video propagation task",
|
||||
)
|
||||
def queue_propagate_task(payload: PropagateTaskRequest, db: Session = Depends(get_db)) -> ProcessingTaskOut:
|
||||
def queue_propagate_task(
|
||||
payload: PropagateTaskRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_editor),
|
||||
) -> ProcessingTaskOut:
|
||||
"""Queue multiple seed/direction propagation steps as one background task."""
|
||||
project = db.query(Project).filter(Project.id == payload.project_id).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
|
||||
source_frame = db.query(Frame).filter(
|
||||
Frame.id == payload.frame_id,
|
||||
Frame.project_id == payload.project_id,
|
||||
).first()
|
||||
if not source_frame:
|
||||
raise HTTPException(status_code=404, detail="Frame not found")
|
||||
_owned_project_or_404(payload.project_id, db, current_user)
|
||||
source_frame = _owned_frame_or_404(payload.frame_id, db, current_user, payload.project_id)
|
||||
|
||||
if not payload.steps:
|
||||
raise HTTPException(status_code=400, detail="Propagation task requires at least one step")
|
||||
@@ -768,11 +894,13 @@ def queue_propagate_task(payload: PropagateTaskRequest, db: Session = Depends(ge
|
||||
response_model=PredictResponse,
|
||||
summary="Run automatic segmentation",
|
||||
)
|
||||
def auto_segment(image_id: int, db: Session = Depends(get_db)) -> dict:
|
||||
def auto_segment(
|
||||
image_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_editor),
|
||||
) -> dict:
|
||||
"""Run automatic mask generation on a frame using a grid of point prompts."""
|
||||
frame = db.query(Frame).filter(Frame.id == image_id).first()
|
||||
if not frame:
|
||||
raise HTTPException(status_code=404, detail="Frame not found")
|
||||
frame = _owned_frame_or_404(image_id, db, current_user)
|
||||
|
||||
image = _load_frame_image(frame)
|
||||
try:
|
||||
@@ -792,16 +920,15 @@ def auto_segment(image_id: int, db: Session = Depends(get_db)) -> dict:
|
||||
def save_annotation(
|
||||
payload: AnnotationCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_editor),
|
||||
) -> Annotation:
|
||||
"""Persist an annotation (mask, points, bbox) into the database."""
|
||||
project = db.query(Project).filter(Project.id == payload.project_id).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
_owned_project_or_404(payload.project_id, db, current_user)
|
||||
|
||||
if payload.frame_id:
|
||||
frame = db.query(Frame).filter(Frame.id == payload.frame_id).first()
|
||||
if not frame:
|
||||
raise HTTPException(status_code=404, detail="Frame not found")
|
||||
_owned_frame_or_404(payload.frame_id, db, current_user, payload.project_id)
|
||||
if payload.template_id:
|
||||
_visible_template_or_404(payload.template_id, db, current_user)
|
||||
|
||||
annotation = Annotation(**payload.model_dump())
|
||||
db.add(annotation)
|
||||
@@ -823,8 +950,10 @@ async def import_gt_mask(
|
||||
template_id: int | None = Form(None),
|
||||
label: str = Form("GT Mask"),
|
||||
color: str = Form("#22c55e"),
|
||||
unknown_color_policy: str = Form("undefined"),
|
||||
file: UploadFile = File(...),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_editor),
|
||||
) -> List[Annotation]:
|
||||
"""Convert a binary/label mask image into persisted polygon annotations.
|
||||
|
||||
@@ -833,62 +962,122 @@ async def import_gt_mask(
|
||||
the frontend an editable point-region representation instead of a static
|
||||
bitmap layer.
|
||||
"""
|
||||
project = db.query(Project).filter(Project.id == project_id).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
_owned_project_or_404(project_id, db, current_user)
|
||||
frame = _owned_frame_or_404(frame_id, db, current_user, project_id)
|
||||
|
||||
frame = db.query(Frame).filter(Frame.id == frame_id, Frame.project_id == project_id).first()
|
||||
if not frame:
|
||||
raise HTTPException(status_code=404, detail="Frame not found")
|
||||
if unknown_color_policy not in {"discard", "undefined"}:
|
||||
raise HTTPException(status_code=400, detail="unknown_color_policy must be discard or undefined")
|
||||
|
||||
template: Template | None = None
|
||||
if template_id is not None:
|
||||
template = db.query(Template).filter(Template.id == template_id).first()
|
||||
if not template:
|
||||
raise HTTPException(status_code=404, detail="Template not found")
|
||||
template = _visible_template_or_404(template_id, db, current_user)
|
||||
|
||||
data = await file.read()
|
||||
image = cv2.imdecode(np.frombuffer(data, dtype=np.uint8), cv2.IMREAD_GRAYSCALE)
|
||||
image = cv2.imdecode(np.frombuffer(data, dtype=np.uint8), cv2.IMREAD_UNCHANGED)
|
||||
if image is None:
|
||||
raise HTTPException(status_code=400, detail="Invalid mask image")
|
||||
|
||||
if image.ndim == 2:
|
||||
label_image = image
|
||||
elif image.ndim == 3 and image.shape[2] >= 3:
|
||||
channels = image[:, :, :3]
|
||||
# GT label images are maskid maps: either grayscale or RGB/BGR where
|
||||
# all three color channels contain the same maskid value [X, X, X].
|
||||
if not (np.array_equal(channels[:, :, 0], channels[:, :, 1]) and np.array_equal(channels[:, :, 1], channels[:, :, 2])):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="GT Mask 图片不符合要求:请上传灰度图,或 RGB 三通道完全相同的 maskid 图(背景 0,像素值为 maskid)。",
|
||||
)
|
||||
label_image = channels[:, :, 0]
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="GT Mask 图片不符合要求:请上传灰度图,或 RGB 三通道完全相同的 maskid 图(背景 0,像素值为 maskid)。",
|
||||
)
|
||||
|
||||
width = int(frame.width or image.shape[1])
|
||||
height = int(frame.height or image.shape[0])
|
||||
label_values = [int(value) for value in np.unique(image) if int(value) > 0]
|
||||
if not label_values:
|
||||
original_height, original_width = int(label_image.shape[0]), int(label_image.shape[1])
|
||||
resized_to_frame = original_width != width or original_height != height
|
||||
if resized_to_frame:
|
||||
label_image = cv2.resize(label_image, (width, height), interpolation=cv2.INTER_NEAREST)
|
||||
|
||||
by_maskid, _by_color = _template_class_maps(template)
|
||||
has_template_classes = bool(by_maskid)
|
||||
fallback_color = _normalize_hex_color(color) or "#22c55e"
|
||||
|
||||
import_items: list[dict[str, Any]] = []
|
||||
skipped_unknown = 0
|
||||
label_values = [int(value) for value in np.unique(label_image) if int(value) > 0]
|
||||
for label_value in label_values:
|
||||
class_meta = by_maskid.get(label_value)
|
||||
is_unknown = has_template_classes and class_meta is None
|
||||
if is_unknown and unknown_color_policy == "discard":
|
||||
skipped_unknown += 1
|
||||
continue
|
||||
if class_meta:
|
||||
annotation_label = class_meta["name"]
|
||||
annotation_color = class_meta["color"]
|
||||
elif is_unknown:
|
||||
annotation_label = _gt_unknown_label(label_value)
|
||||
annotation_color = fallback_color
|
||||
else:
|
||||
annotation_label = f"{label} {label_value}" if len(label_values) > 1 else label
|
||||
annotation_color = fallback_color
|
||||
import_items.append({
|
||||
"token": label_value,
|
||||
"binary": np.where(label_image == label_value, 255, 0).astype(np.uint8),
|
||||
"label": annotation_label,
|
||||
"color": annotation_color,
|
||||
"class": class_meta,
|
||||
"unknown": is_unknown,
|
||||
"metadata": {
|
||||
"gt_label_value": label_value,
|
||||
"gt_original_size": {"width": original_width, "height": original_height},
|
||||
"gt_resized_to_frame": resized_to_frame,
|
||||
},
|
||||
})
|
||||
|
||||
if not import_items:
|
||||
if skipped_unknown > 0:
|
||||
raise HTTPException(status_code=400, detail="No matching GT mask classes found")
|
||||
raise HTTPException(status_code=400, detail="No foreground mask regions found")
|
||||
has_multiple_labels = len(label_values) > 1
|
||||
|
||||
annotations: list[Annotation] = []
|
||||
for label_value in label_values:
|
||||
binary = np.where(image == label_value, 255, 0).astype(np.uint8)
|
||||
for item in import_items:
|
||||
binary = item["binary"]
|
||||
contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||
annotation_label = f"{label} {label_value}" if has_multiple_labels else label
|
||||
|
||||
for contour in contours:
|
||||
if cv2.contourArea(contour) < 1:
|
||||
continue
|
||||
|
||||
polygon = _normalized_contour(contour, image.shape[1], image.shape[0])
|
||||
polygon = _normalized_contour(contour, binary.shape[1], binary.shape[0])
|
||||
if len(polygon) < 3:
|
||||
continue
|
||||
|
||||
component = np.zeros_like(binary, dtype=np.uint8)
|
||||
cv2.drawContours(component, [contour], -1, 1, thickness=-1)
|
||||
seed_point = _component_seed_point(component, image.shape[1], image.shape[0])
|
||||
bbox = _contour_bbox(contour, image.shape[1], image.shape[0])
|
||||
seed_point = _component_seed_point(component, binary.shape[1], binary.shape[0])
|
||||
bbox = _contour_bbox(contour, binary.shape[1], binary.shape[0])
|
||||
mask_data = {
|
||||
"polygons": [polygon],
|
||||
"label": item["label"],
|
||||
"color": item["color"],
|
||||
"source": "gt_mask",
|
||||
"image_size": {"width": width, "height": height},
|
||||
**item["metadata"],
|
||||
}
|
||||
if item["class"]:
|
||||
mask_data["class"] = item["class"]
|
||||
if item["unknown"]:
|
||||
mask_data["gt_unknown_class"] = True
|
||||
|
||||
annotation = Annotation(
|
||||
project_id=project_id,
|
||||
frame_id=frame_id,
|
||||
template_id=template_id,
|
||||
mask_data={
|
||||
"polygons": [polygon],
|
||||
"label": annotation_label,
|
||||
"color": color,
|
||||
"source": "gt_mask",
|
||||
"gt_label_value": label_value,
|
||||
"image_size": {"width": width, "height": height},
|
||||
},
|
||||
mask_data=mask_data,
|
||||
points=[seed_point],
|
||||
bbox=bbox,
|
||||
)
|
||||
@@ -914,14 +1103,14 @@ def list_annotations(
|
||||
project_id: int,
|
||||
frame_id: int | None = None,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> List[Annotation]:
|
||||
"""Return persisted annotations for a project, optionally scoped to one frame."""
|
||||
project = db.query(Project).filter(Project.id == project_id).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
_owned_project_or_404(project_id, db, current_user)
|
||||
|
||||
query = db.query(Annotation).filter(Annotation.project_id == project_id)
|
||||
if frame_id is not None:
|
||||
_owned_frame_or_404(frame_id, db, current_user, project_id)
|
||||
query = query.filter(Annotation.frame_id == frame_id)
|
||||
return query.order_by(Annotation.id).all()
|
||||
|
||||
@@ -935,17 +1124,21 @@ def update_annotation(
|
||||
annotation_id: int,
|
||||
payload: AnnotationUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_editor),
|
||||
) -> Annotation:
|
||||
"""Update mutable annotation fields persisted in the database."""
|
||||
annotation = db.query(Annotation).filter(Annotation.id == annotation_id).first()
|
||||
annotation = (
|
||||
db.query(Annotation)
|
||||
.join(Project, Project.id == Annotation.project_id)
|
||||
.filter(Annotation.id == annotation_id, Project.owner_user_id == current_user.id)
|
||||
.first()
|
||||
)
|
||||
if not annotation:
|
||||
raise HTTPException(status_code=404, detail="Annotation not found")
|
||||
|
||||
updates = payload.model_dump(exclude_unset=True)
|
||||
if "template_id" in updates and updates["template_id"] is not None:
|
||||
template = db.query(Template).filter(Template.id == updates["template_id"]).first()
|
||||
if not template:
|
||||
raise HTTPException(status_code=404, detail="Template not found")
|
||||
_visible_template_or_404(updates["template_id"], db, current_user)
|
||||
|
||||
for field, value in updates.items():
|
||||
setattr(annotation, field, value)
|
||||
@@ -964,9 +1157,15 @@ def update_annotation(
|
||||
def delete_annotation(
|
||||
annotation_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_editor),
|
||||
) -> Response:
|
||||
"""Delete an annotation and its derived mask rows through ORM cascade."""
|
||||
annotation = db.query(Annotation).filter(Annotation.id == annotation_id).first()
|
||||
annotation = (
|
||||
db.query(Annotation)
|
||||
.join(Project, Project.id == Annotation.project_id)
|
||||
.filter(Annotation.id == annotation_id, Project.owner_user_id == current_user.id)
|
||||
.first()
|
||||
)
|
||||
if not annotation:
|
||||
raise HTTPException(status_code=404, detail="Annotation not found")
|
||||
|
||||
|
||||
@@ -1,9 +1,23 @@
|
||||
"""Authentication endpoints."""
|
||||
"""Authentication endpoints and dependencies."""
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||
from jose import JWTError, jwt
|
||||
from passlib.context import CryptContext
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from config import settings
|
||||
from database import get_db
|
||||
from models import AuditLog, User
|
||||
from schemas import LoginResponse, UserOut
|
||||
|
||||
router = APIRouter(prefix="/api/auth", tags=["Auth"])
|
||||
password_context = CryptContext(schemes=["pbkdf2_sha256"], deprecated="auto")
|
||||
bearer_scheme = HTTPBearer(auto_error=False)
|
||||
|
||||
|
||||
class LoginRequest(BaseModel):
|
||||
@@ -11,14 +25,151 @@ class LoginRequest(BaseModel):
|
||||
password: str
|
||||
|
||||
|
||||
class LoginResponse(BaseModel):
|
||||
token: str
|
||||
username: str
|
||||
def hash_password(password: str) -> str:
|
||||
"""Hash a plain password for storage."""
|
||||
return password_context.hash(password)
|
||||
|
||||
|
||||
def verify_password(password: str, password_hash: str) -> bool:
|
||||
"""Verify a plain password against a stored hash."""
|
||||
return password_context.verify(password, password_hash)
|
||||
|
||||
|
||||
def create_access_token(user: User, expires_delta: timedelta | None = None) -> str:
|
||||
"""Create a signed JWT access token for a user."""
|
||||
expire = datetime.now(timezone.utc) + (
|
||||
expires_delta or timedelta(minutes=settings.access_token_expire_minutes)
|
||||
)
|
||||
payload: dict[str, Any] = {
|
||||
"sub": str(user.id),
|
||||
"username": user.username,
|
||||
"role": user.role,
|
||||
"exp": expire,
|
||||
}
|
||||
return jwt.encode(payload, settings.jwt_secret_key, algorithm=settings.jwt_algorithm)
|
||||
|
||||
|
||||
def ensure_default_admin(db: Session) -> User:
|
||||
"""Create the default development admin if the user table is empty."""
|
||||
existing = db.query(User).filter(User.username == settings.default_admin_username).first()
|
||||
if existing:
|
||||
return existing
|
||||
user = User(
|
||||
username=settings.default_admin_username,
|
||||
password_hash=hash_password(settings.default_admin_password),
|
||||
role="admin",
|
||||
is_active=1,
|
||||
)
|
||||
db.add(user)
|
||||
db.commit()
|
||||
db.refresh(user)
|
||||
return user
|
||||
|
||||
|
||||
def get_current_user(
|
||||
credentials: HTTPAuthorizationCredentials | None = Depends(bearer_scheme),
|
||||
db: Session = Depends(get_db),
|
||||
) -> User:
|
||||
"""Resolve and validate the current user from the Bearer token."""
|
||||
if credentials is None or credentials.scheme.lower() != "bearer":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Not authenticated",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
credentials.credentials,
|
||||
settings.jwt_secret_key,
|
||||
algorithms=[settings.jwt_algorithm],
|
||||
)
|
||||
user_id = int(payload.get("sub"))
|
||||
except (JWTError, TypeError, ValueError) as exc:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid or expired token",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
) from exc
|
||||
|
||||
user = db.query(User).filter(User.id == user_id).first()
|
||||
if not user or not user.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Inactive or missing user",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
return user
|
||||
|
||||
|
||||
def require_admin(current_user: User = Depends(get_current_user)) -> User:
|
||||
"""Require the current user to have the administrator role."""
|
||||
if current_user.role != "admin":
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Admin permission required")
|
||||
return current_user
|
||||
|
||||
|
||||
def require_editor(current_user: User = Depends(get_current_user)) -> User:
|
||||
"""Require a user role that can modify segmentation data."""
|
||||
if current_user.role not in {"admin", "annotator"}:
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Edit permission required")
|
||||
return current_user
|
||||
|
||||
|
||||
def write_audit_log(
|
||||
db: Session,
|
||||
*,
|
||||
actor: User | None,
|
||||
action: str,
|
||||
target_type: str | None = None,
|
||||
target_id: str | int | None = None,
|
||||
detail: dict[str, Any] | None = None,
|
||||
) -> AuditLog:
|
||||
"""Persist a compact audit event."""
|
||||
log = AuditLog(
|
||||
actor_user_id=actor.id if actor else None,
|
||||
action=action,
|
||||
target_type=target_type,
|
||||
target_id=str(target_id) if target_id is not None else None,
|
||||
detail=detail or {},
|
||||
)
|
||||
db.add(log)
|
||||
db.commit()
|
||||
db.refresh(log)
|
||||
return log
|
||||
|
||||
|
||||
@router.post("/login", response_model=LoginResponse)
|
||||
def login(payload: LoginRequest) -> dict:
|
||||
"""Simple login for development."""
|
||||
if payload.username == "admin" and payload.password == "123456":
|
||||
return {"token": "fake-jwt-token-for-admin", "username": payload.username}
|
||||
raise HTTPException(status_code=401, detail="Invalid credentials")
|
||||
def login(payload: LoginRequest, db: Session = Depends(get_db)) -> dict:
|
||||
"""Authenticate a user and return a signed JWT."""
|
||||
ensure_default_admin(db)
|
||||
user = db.query(User).filter(User.username == payload.username).first()
|
||||
if not user or not user.is_active or not verify_password(payload.password, user.password_hash):
|
||||
write_audit_log(
|
||||
db,
|
||||
actor=None,
|
||||
action="auth.login_failed",
|
||||
target_type="user",
|
||||
target_id=payload.username,
|
||||
detail={"username": payload.username},
|
||||
)
|
||||
raise HTTPException(status_code=401, detail="Invalid credentials")
|
||||
write_audit_log(
|
||||
db,
|
||||
actor=user,
|
||||
action="auth.login_success",
|
||||
target_type="user",
|
||||
target_id=user.id,
|
||||
detail={"username": user.username},
|
||||
)
|
||||
return {
|
||||
"token": create_access_token(user),
|
||||
"token_type": "bearer",
|
||||
"username": user.username,
|
||||
"user": user,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/me", response_model=UserOut)
|
||||
def read_current_user(current_user: User = Depends(get_current_user)) -> User:
|
||||
"""Return the authenticated user profile."""
|
||||
return current_user
|
||||
|
||||
@@ -5,11 +5,12 @@ from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import func, or_
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from database import get_db
|
||||
from models import Annotation, Frame, ProcessingTask, Project, Template
|
||||
from models import Annotation, Frame, ProcessingTask, Project, Template, User
|
||||
from routers.auth import get_current_user
|
||||
|
||||
router = APIRouter(prefix="/api/dashboard", tags=["Dashboard"])
|
||||
|
||||
@@ -52,22 +53,45 @@ def _task_payload(task: ProcessingTask) -> dict[str, Any]:
|
||||
|
||||
|
||||
@router.get("/overview", summary="Get dashboard overview")
|
||||
def get_dashboard_overview(db: Session = Depends(get_db)) -> dict[str, Any]:
|
||||
def get_dashboard_overview(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict[str, Any]:
|
||||
"""Return live dashboard data derived from persisted backend records."""
|
||||
project_count = db.query(func.count(Project.id)).scalar() or 0
|
||||
frame_count = db.query(func.count(Frame.id)).scalar() or 0
|
||||
annotation_count = db.query(func.count(Annotation.id)).scalar() or 0
|
||||
template_count = db.query(func.count(Template.id)).scalar() or 0
|
||||
owned_project_ids_query = db.query(Project.id).filter(Project.owner_user_id == current_user.id)
|
||||
project_count = db.query(func.count(Project.id)).filter(Project.owner_user_id == current_user.id).scalar() or 0
|
||||
frame_count = db.query(func.count(Frame.id)).filter(Frame.project_id.in_(owned_project_ids_query)).scalar() or 0
|
||||
annotation_count = (
|
||||
db.query(func.count(Annotation.id))
|
||||
.filter(Annotation.project_id.in_(owned_project_ids_query))
|
||||
.scalar()
|
||||
or 0
|
||||
)
|
||||
template_count = (
|
||||
db.query(func.count(Template.id))
|
||||
.filter(or_(Template.owner_user_id == current_user.id, Template.owner_user_id.is_(None)))
|
||||
.scalar()
|
||||
or 0
|
||||
)
|
||||
active_task_count = (
|
||||
db.query(func.count(ProcessingTask.id))
|
||||
.outerjoin(Project, Project.id == ProcessingTask.project_id)
|
||||
.filter((ProcessingTask.project_id.is_(None)) | (Project.owner_user_id == current_user.id))
|
||||
.filter(ProcessingTask.status.in_(ACTIVE_TASK_STATUSES))
|
||||
.scalar()
|
||||
or 0
|
||||
)
|
||||
|
||||
projects = db.query(Project).order_by(Project.updated_at.desc()).all()
|
||||
projects = (
|
||||
db.query(Project)
|
||||
.filter(Project.owner_user_id == current_user.id)
|
||||
.order_by(Project.updated_at.desc())
|
||||
.all()
|
||||
)
|
||||
recent_tasks = (
|
||||
db.query(ProcessingTask)
|
||||
.outerjoin(Project, Project.id == ProcessingTask.project_id)
|
||||
.filter((ProcessingTask.project_id.is_(None)) | (Project.owner_user_id == current_user.id))
|
||||
.order_by(ProcessingTask.created_at.desc())
|
||||
.limit(50)
|
||||
.all()
|
||||
@@ -96,6 +120,7 @@ def get_dashboard_overview(db: Session = Depends(get_db)) -> dict[str, Any]:
|
||||
|
||||
recent_annotations = (
|
||||
db.query(Annotation)
|
||||
.filter(Annotation.project_id.in_(owned_project_ids_query))
|
||||
.order_by(Annotation.updated_at.desc())
|
||||
.limit(10)
|
||||
.all()
|
||||
@@ -112,6 +137,7 @@ def get_dashboard_overview(db: Session = Depends(get_db)) -> dict[str, Any]:
|
||||
|
||||
recent_templates = (
|
||||
db.query(Template)
|
||||
.filter(or_(Template.owner_user_id == current_user.id, Template.owner_user_id.is_(None)))
|
||||
.order_by(Template.created_at.desc())
|
||||
.limit(10)
|
||||
.all()
|
||||
|
||||
@@ -4,17 +4,22 @@ import io
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import zipfile
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List
|
||||
from urllib.parse import quote
|
||||
|
||||
import numpy as np
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from database import get_db
|
||||
from models import Project, Annotation, Frame, Template
|
||||
from minio_client import download_file
|
||||
from models import Project, Annotation, Frame, Template, User
|
||||
from routers.auth import get_current_user
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/api/export", tags=["Export"])
|
||||
@@ -49,6 +54,30 @@ def _annotation_z_index(annotation: Annotation) -> int:
|
||||
return 0
|
||||
|
||||
|
||||
def _annotation_mask_id(annotation: Annotation) -> int | None:
|
||||
class_meta = (annotation.mask_data or {}).get("class") or {}
|
||||
if isinstance(class_meta, dict):
|
||||
for key in ("maskId", "maskid", "mask_id"):
|
||||
if class_meta.get(key) is None:
|
||||
continue
|
||||
try:
|
||||
value = int(class_meta[key])
|
||||
except (TypeError, ValueError):
|
||||
continue
|
||||
if value > 0:
|
||||
return value
|
||||
return None
|
||||
|
||||
|
||||
def _annotation_category_name(annotation: Annotation) -> str:
|
||||
class_meta = (annotation.mask_data or {}).get("class") or {}
|
||||
if isinstance(class_meta, dict) and class_meta.get("category"):
|
||||
return str(class_meta["category"])
|
||||
if annotation.template and annotation.template.name:
|
||||
return str(annotation.template.name)
|
||||
return ""
|
||||
|
||||
|
||||
def _annotation_class_key(annotation: Annotation) -> str:
|
||||
class_meta = (annotation.mask_data or {}).get("class") or {}
|
||||
if isinstance(class_meta, dict):
|
||||
@@ -85,38 +114,162 @@ def _annotation_color(annotation: Annotation) -> str:
|
||||
return "#ffffff"
|
||||
|
||||
|
||||
@router.get(
|
||||
"/{project_id}/coco",
|
||||
summary="Export annotations in COCO format",
|
||||
)
|
||||
def export_coco(project_id: int, db: Session = Depends(get_db)) -> StreamingResponse:
|
||||
"""Export all annotations for a project as a COCO-format JSON file."""
|
||||
project = db.query(Project).filter(Project.id == project_id).first()
|
||||
def _hex_to_rgb(color: str) -> list[int]:
|
||||
value = str(color or "").strip()
|
||||
if value.startswith("#"):
|
||||
value = value[1:]
|
||||
if len(value) == 3:
|
||||
value = "".join(part * 2 for part in value)
|
||||
if len(value) != 6:
|
||||
return [255, 255, 255]
|
||||
try:
|
||||
return [int(value[i:i + 2], 16) for i in (0, 2, 4)]
|
||||
except ValueError:
|
||||
return [255, 255, 255]
|
||||
|
||||
|
||||
def _safe_filename_part(value: Any, fallback: str = "unknown") -> str:
|
||||
text = str(value or "").strip()
|
||||
if not text:
|
||||
text = fallback
|
||||
text = re.sub(r"[\\/:*?\"<>|\s]+", "_", text)
|
||||
text = re.sub(r"_+", "_", text).strip("._")
|
||||
return text or fallback
|
||||
|
||||
|
||||
def _project_video_name(project: Project) -> str:
|
||||
if project.video_path:
|
||||
stem = Path(project.video_path).name
|
||||
if "." in stem:
|
||||
stem = ".".join(stem.split(".")[:-1])
|
||||
if stem:
|
||||
return _safe_filename_part(stem, f"project_{project.id}")
|
||||
return _safe_filename_part(project.name, f"project_{project.id}")
|
||||
|
||||
|
||||
def _project_export_name(project: Project) -> str:
|
||||
return _safe_filename_part(project.name, f"project_{project.id}")
|
||||
|
||||
|
||||
def _frame_timestamp_ms(frame: Frame, project: Project) -> float:
|
||||
if frame.timestamp_ms is not None:
|
||||
return float(frame.timestamp_ms)
|
||||
fps = project.parse_fps or project.original_fps or 30.0
|
||||
return float(frame.frame_index) * 1000.0 / max(float(fps), 1.0)
|
||||
|
||||
|
||||
def _project_frame_number(frame: Frame) -> int:
|
||||
return int(frame.frame_index) + 1
|
||||
|
||||
|
||||
def _format_timestamp_ms(value: float) -> str:
|
||||
total_ms = max(0, int(round(float(value))))
|
||||
hours = total_ms // 3_600_000
|
||||
minutes = (total_ms % 3_600_000) // 60_000
|
||||
seconds = (total_ms % 60_000) // 1_000
|
||||
milliseconds = total_ms % 1_000
|
||||
return f"{hours}h{minutes:02d}m{seconds:02d}s{milliseconds:03d}ms"
|
||||
|
||||
|
||||
def _frame_export_stem(project: Project, frame: Frame) -> str:
|
||||
return "_".join([
|
||||
_project_video_name(project),
|
||||
_format_timestamp_ms(_frame_timestamp_ms(frame, project)),
|
||||
f"frame{_project_frame_number(frame):06d}",
|
||||
])
|
||||
|
||||
|
||||
def _segmentation_results_filename(project: Project, frames: list[Frame]) -> str:
|
||||
if not frames:
|
||||
return f"{_project_export_name(project)}_seg_T_0h00m00s000ms-0h00m00s000ms_P_0-0.zip"
|
||||
first_frame = frames[0]
|
||||
last_frame = frames[-1]
|
||||
return (
|
||||
f"{_project_export_name(project)}"
|
||||
f"_seg_T_{_format_timestamp_ms(_frame_timestamp_ms(first_frame, project))}"
|
||||
f"-{_format_timestamp_ms(_frame_timestamp_ms(last_frame, project))}"
|
||||
f"_P_{_project_frame_number(first_frame)}-{_project_frame_number(last_frame)}.zip"
|
||||
)
|
||||
|
||||
|
||||
def _download_content_disposition(filename: str) -> str:
|
||||
ascii_fallback = filename.encode("ascii", "ignore").decode("ascii") or "segmentation_results.zip"
|
||||
ascii_fallback = _safe_filename_part(ascii_fallback, "segmentation_results.zip")
|
||||
if not ascii_fallback.endswith(".zip") and filename.endswith(".zip"):
|
||||
ascii_fallback = f"{ascii_fallback}.zip"
|
||||
return f"attachment; filename=\"{ascii_fallback}\"; filename*=UTF-8''{quote(filename)}"
|
||||
|
||||
|
||||
def _frame_image_extension(frame: Frame) -> str:
|
||||
suffix = Path(frame.image_url or "").suffix.lower()
|
||||
return suffix if suffix in {".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff"} else ".jpg"
|
||||
|
||||
|
||||
def _project_or_404(project_id: int, db: Session, current_user: User) -> Project:
|
||||
project = db.query(Project).filter(
|
||||
Project.id == project_id,
|
||||
Project.owner_user_id == current_user.id,
|
||||
).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
return project
|
||||
|
||||
annotations = (
|
||||
db.query(Annotation)
|
||||
.filter(Annotation.project_id == project_id)
|
||||
.all()
|
||||
)
|
||||
frames = (
|
||||
|
||||
def _project_frames(project_id: int, db: Session) -> list[Frame]:
|
||||
return (
|
||||
db.query(Frame)
|
||||
.filter(Frame.project_id == project_id)
|
||||
.order_by(Frame.frame_index)
|
||||
.all()
|
||||
)
|
||||
templates = db.query(Template).all()
|
||||
|
||||
# Build COCO structure
|
||||
|
||||
def _filter_frames(
|
||||
frames: list[Frame],
|
||||
*,
|
||||
scope: str = "all",
|
||||
start_frame: int | None = None,
|
||||
end_frame: int | None = None,
|
||||
frame_id: int | None = None,
|
||||
) -> list[Frame]:
|
||||
if scope == "current":
|
||||
if frame_id is None:
|
||||
raise HTTPException(status_code=400, detail="frame_id is required for current-frame export")
|
||||
selected = [frame for frame in frames if frame.id == frame_id]
|
||||
if not selected:
|
||||
raise HTTPException(status_code=404, detail="Frame not found")
|
||||
return selected
|
||||
|
||||
if scope == "range":
|
||||
if start_frame is None or end_frame is None:
|
||||
raise HTTPException(status_code=400, detail="start_frame and end_frame are required for range export")
|
||||
start = max(1, min(int(start_frame), int(end_frame)))
|
||||
end = max(1, max(int(start_frame), int(end_frame)))
|
||||
return frames[start - 1:end]
|
||||
|
||||
return frames
|
||||
|
||||
|
||||
def _filtered_annotations(project_id: int, frame_ids: set[int], db: Session) -> list[Annotation]:
|
||||
if not frame_ids:
|
||||
return []
|
||||
return (
|
||||
db.query(Annotation)
|
||||
.filter(Annotation.project_id == project_id)
|
||||
.filter(Annotation.frame_id.in_(frame_ids))
|
||||
.all()
|
||||
)
|
||||
|
||||
|
||||
def _build_coco(project: Project, frames: list[Frame], annotations: list[Annotation], templates: list[Template]) -> dict[str, Any]:
|
||||
images = []
|
||||
for idx, frame in enumerate(frames):
|
||||
for frame in frames:
|
||||
images.append({
|
||||
"id": frame.id,
|
||||
"file_name": frame.image_url,
|
||||
"width": frame.width or 1920,
|
||||
"height": frame.height or 1080,
|
||||
"frame_index": idx,
|
||||
"frame_index": frame.frame_index,
|
||||
})
|
||||
|
||||
categories = []
|
||||
@@ -131,14 +284,14 @@ def export_coco(project_id: int, db: Session = Depends(get_db)) -> StreamingResp
|
||||
|
||||
coco_annotations = []
|
||||
ann_id = 1
|
||||
selected_frame_ids = {frame.id for frame in frames}
|
||||
for ann in annotations:
|
||||
if not ann.mask_data:
|
||||
if ann.frame_id not in selected_frame_ids or not ann.mask_data:
|
||||
continue
|
||||
polygons = ann.mask_data.get("polygons", [])
|
||||
if not polygons:
|
||||
continue
|
||||
|
||||
# Use first polygon for bbox / area approximation
|
||||
first_poly = polygons[0]
|
||||
xs = [p[0] for p in first_poly]
|
||||
ys = [p[1] for p in first_poly]
|
||||
@@ -171,7 +324,7 @@ def export_coco(project_id: int, db: Session = Depends(get_db)) -> StreamingResp
|
||||
})
|
||||
ann_id += 1
|
||||
|
||||
coco = {
|
||||
return {
|
||||
"info": {
|
||||
"description": f"Annotations for {project.name}",
|
||||
"version": "1.0",
|
||||
@@ -183,39 +336,235 @@ def export_coco(project_id: int, db: Session = Depends(get_db)) -> StreamingResp
|
||||
"categories": categories,
|
||||
}
|
||||
|
||||
data = json.dumps(coco, ensure_ascii=False, indent=2).encode("utf-8")
|
||||
filename = f"project_{project_id}_coco.json"
|
||||
|
||||
return StreamingResponse(
|
||||
io.BytesIO(data),
|
||||
media_type="application/json",
|
||||
headers={"Content-Disposition": f'attachment; filename="{filename}"'},
|
||||
def _class_mapping_entry(annotation: Annotation) -> dict[str, Any]:
|
||||
return {
|
||||
"key": _annotation_class_key(annotation),
|
||||
"className": _annotation_label(annotation),
|
||||
"chineseName": _annotation_label(annotation),
|
||||
"categoryName": _annotation_category_name(annotation),
|
||||
"color": _annotation_color(annotation),
|
||||
"internalPriority": _annotation_z_index(annotation),
|
||||
"maskidHint": _annotation_mask_id(annotation),
|
||||
"template_id": annotation.template_id,
|
||||
}
|
||||
|
||||
|
||||
def _build_gt_class_mapping(annotations: list[Annotation]) -> tuple[dict[str, int], list[dict[str, Any]]]:
|
||||
entries_by_key: dict[str, dict[str, Any]] = {}
|
||||
for annotation in annotations:
|
||||
if not annotation.mask_data or not annotation.mask_data.get("polygons"):
|
||||
continue
|
||||
entry = _class_mapping_entry(annotation)
|
||||
entries_by_key.setdefault(entry["key"], entry)
|
||||
|
||||
ordered = sorted(
|
||||
entries_by_key.values(),
|
||||
key=lambda item: (
|
||||
item["maskidHint"] if isinstance(item.get("maskidHint"), int) and item["maskidHint"] > 0 else 10_000_000,
|
||||
str(item["className"]),
|
||||
str(item["key"]),
|
||||
),
|
||||
)
|
||||
key_to_value: dict[str, int] = {}
|
||||
classes: list[dict[str, Any]] = []
|
||||
used_maskids: set[int] = set()
|
||||
next_maskid = 1
|
||||
|
||||
def next_available_maskid() -> int:
|
||||
nonlocal next_maskid
|
||||
while next_maskid in used_maskids:
|
||||
next_maskid += 1
|
||||
value = next_maskid
|
||||
used_maskids.add(value)
|
||||
next_maskid += 1
|
||||
return value
|
||||
|
||||
for entry in ordered:
|
||||
hinted_maskid = entry.get("maskidHint")
|
||||
if isinstance(hinted_maskid, int) and hinted_maskid > 0 and hinted_maskid not in used_maskids:
|
||||
maskid = hinted_maskid
|
||||
used_maskids.add(maskid)
|
||||
else:
|
||||
maskid = next_available_maskid()
|
||||
key_to_value[entry["key"]] = maskid
|
||||
classes.append({
|
||||
"gt_pixel_value": maskid,
|
||||
"maskid": maskid,
|
||||
"chineseName": entry["chineseName"],
|
||||
"className": entry["className"],
|
||||
"categoryName": entry["categoryName"],
|
||||
"rgb": _hex_to_rgb(entry["color"]),
|
||||
"color": entry["color"],
|
||||
"key": entry["key"],
|
||||
"template_id": entry["template_id"],
|
||||
})
|
||||
return key_to_value, classes
|
||||
|
||||
|
||||
@router.get(
|
||||
"/{project_id}/masks",
|
||||
summary="Export PNG masks as a ZIP archive",
|
||||
)
|
||||
def export_masks(project_id: int, db: Session = Depends(get_db)) -> StreamingResponse:
|
||||
"""Export individual masks plus z-index fused semantic masks inside a ZIP."""
|
||||
project = db.query(Project).filter(Project.id == project_id).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
def _parse_result_outputs(mask_type: str, outputs: str | None) -> set[str]:
|
||||
allowed = {"separate", "gt_label", "pro_label", "mix_label"}
|
||||
if outputs:
|
||||
parsed = {item.strip() for item in outputs.split(",") if item.strip()}
|
||||
invalid = parsed - allowed
|
||||
if invalid:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid outputs: {', '.join(sorted(invalid))}")
|
||||
return parsed or allowed
|
||||
|
||||
if mask_type == "separate":
|
||||
return {"separate"}
|
||||
if mask_type == "gt_label":
|
||||
return {"gt_label"}
|
||||
if mask_type == "pro_label":
|
||||
return {"pro_label"}
|
||||
if mask_type == "mix_label":
|
||||
return {"mix_label"}
|
||||
return allowed
|
||||
|
||||
|
||||
def _write_original_frames(
|
||||
zf: zipfile.ZipFile,
|
||||
project: Project,
|
||||
frames: list[Frame],
|
||||
) -> dict[int, bytes]:
|
||||
image_bytes_by_frame: dict[int, bytes] = {}
|
||||
for frame in frames:
|
||||
image_bytes = download_file(frame.image_url)
|
||||
image_bytes_by_frame[frame.id] = image_bytes
|
||||
zf.writestr(
|
||||
f"原始图片/{_frame_export_stem(project, frame)}{_frame_image_extension(frame)}",
|
||||
image_bytes,
|
||||
)
|
||||
return image_bytes_by_frame
|
||||
|
||||
|
||||
def _decode_original_image(image_bytes: bytes | None, width: int, height: int) -> np.ndarray:
|
||||
import cv2
|
||||
|
||||
annotations = (
|
||||
db.query(Annotation)
|
||||
.filter(Annotation.project_id == project_id)
|
||||
.all()
|
||||
)
|
||||
frames = (
|
||||
db.query(Frame)
|
||||
.filter(Frame.project_id == project_id)
|
||||
.order_by(Frame.frame_index)
|
||||
.all()
|
||||
)
|
||||
if image_bytes:
|
||||
decoded = cv2.imdecode(np.frombuffer(image_bytes, dtype=np.uint8), cv2.IMREAD_COLOR)
|
||||
if decoded is not None:
|
||||
if decoded.shape[1] != width or decoded.shape[0] != height:
|
||||
decoded = cv2.resize(decoded, (width, height), interpolation=cv2.INTER_AREA)
|
||||
return decoded
|
||||
return np.zeros((height, width, 3), dtype=np.uint8)
|
||||
|
||||
|
||||
def _write_result_mask_outputs(
|
||||
zf: zipfile.ZipFile,
|
||||
project: Project,
|
||||
frames: list[Frame],
|
||||
annotations: list[Annotation],
|
||||
*,
|
||||
outputs: set[str],
|
||||
class_values: dict[str, int],
|
||||
class_mapping: list[dict[str, Any]],
|
||||
original_images: dict[int, bytes],
|
||||
mix_opacity: float,
|
||||
) -> None:
|
||||
import cv2
|
||||
|
||||
include_individual = "separate" in outputs
|
||||
include_semantic = "gt_label" in outputs
|
||||
include_pro_label = "pro_label" in outputs
|
||||
include_mix_label = "mix_label" in outputs
|
||||
class_rgb_by_key = {
|
||||
item["key"]: item.get("rgb") or _hex_to_rgb(item.get("color", "#ffffff"))
|
||||
for item in class_mapping
|
||||
}
|
||||
annotations_by_frame: dict[int, list[Annotation]] = {}
|
||||
selected_frame_ids = {frame.id for frame in frames}
|
||||
for annotation in annotations:
|
||||
if annotation.frame_id not in selected_frame_ids or not annotation.mask_data:
|
||||
continue
|
||||
if not annotation.mask_data.get("polygons"):
|
||||
continue
|
||||
annotations_by_frame.setdefault(annotation.frame_id, []).append(annotation)
|
||||
|
||||
for frame in frames:
|
||||
frame_annotations = annotations_by_frame.get(frame.id, [])
|
||||
if not frame_annotations:
|
||||
continue
|
||||
width = frame.width or 1920
|
||||
height = frame.height or 1080
|
||||
frame_stem = _frame_export_stem(project, frame)
|
||||
|
||||
if include_individual:
|
||||
class_masks: dict[str, np.ndarray] = {}
|
||||
class_meta: dict[str, dict[str, Any]] = {}
|
||||
for annotation in frame_annotations:
|
||||
key = _annotation_class_key(annotation)
|
||||
combined = class_masks.setdefault(key, np.zeros((height, width), dtype=np.uint8))
|
||||
for poly in (annotation.mask_data or {}).get("polygons", []):
|
||||
combined[:] = np.maximum(combined, _mask_from_polygon(poly, width, height))
|
||||
class_meta.setdefault(key, _class_mapping_entry(annotation))
|
||||
|
||||
folder = f"分开Mask分割结果/{frame_stem}_分别导出"
|
||||
for key, mask in sorted(class_masks.items(), key=lambda item: int(class_meta[item[0]]["internalPriority"])):
|
||||
meta = class_meta[key]
|
||||
maskid = class_values.get(key)
|
||||
if maskid is None:
|
||||
continue
|
||||
_, encoded = cv2.imencode(".png", mask)
|
||||
class_name = _safe_filename_part(meta["className"], "class")
|
||||
zf.writestr(
|
||||
f"{folder}/{frame_stem}_{class_name}_maskid{maskid}.png",
|
||||
encoded.tobytes(),
|
||||
)
|
||||
|
||||
needs_fused_output = include_semantic or include_pro_label or include_mix_label
|
||||
semantic = np.zeros((height, width), dtype=np.uint16) if needs_fused_output else None
|
||||
pro_label = np.zeros((height, width, 3), dtype=np.uint8) if (include_pro_label or include_mix_label) else None
|
||||
|
||||
if needs_fused_output:
|
||||
for annotation in sorted(frame_annotations, key=_annotation_z_index):
|
||||
key = _annotation_class_key(annotation)
|
||||
value = class_values.get(key)
|
||||
if value is None:
|
||||
continue
|
||||
combined = np.zeros((height, width), dtype=np.uint8)
|
||||
for poly in (annotation.mask_data or {}).get("polygons", []):
|
||||
combined = np.maximum(combined, _mask_from_polygon(poly, width, height))
|
||||
if semantic is not None:
|
||||
semantic[combined > 0] = value
|
||||
if pro_label is not None:
|
||||
rgb = class_rgb_by_key.get(key, [255, 255, 255])
|
||||
bgr = np.array([rgb[2], rgb[1], rgb[0]], dtype=np.uint8)
|
||||
pro_label[combined > 0] = bgr
|
||||
|
||||
if include_semantic and semantic is not None:
|
||||
_, encoded = cv2.imencode(".png", semantic)
|
||||
zf.writestr(f"GT_label图/{frame_stem}.png", encoded.tobytes())
|
||||
|
||||
if include_pro_label and pro_label is not None:
|
||||
_, encoded = cv2.imencode(".png", pro_label)
|
||||
zf.writestr(f"Pro_label彩色分割结果/{frame_stem}.png", encoded.tobytes())
|
||||
|
||||
if include_mix_label and pro_label is not None:
|
||||
original = _decode_original_image(original_images.get(frame.id), width, height)
|
||||
mask_pixels = np.any(pro_label > 0, axis=2)
|
||||
mixed = original.copy()
|
||||
opacity = min(max(float(mix_opacity), 0.0), 1.0)
|
||||
mixed[mask_pixels] = (
|
||||
original[mask_pixels].astype(np.float32) * (1.0 - opacity)
|
||||
+ pro_label[mask_pixels].astype(np.float32) * opacity
|
||||
).clip(0, 255).astype(np.uint8)
|
||||
_, encoded = cv2.imencode(".png", mixed)
|
||||
zf.writestr(f"Mix_label重叠覆盖彩色分割结果/{frame_stem}.png", encoded.tobytes())
|
||||
|
||||
|
||||
def _write_mask_pngs(
|
||||
zf: zipfile.ZipFile,
|
||||
frames: list[Frame],
|
||||
annotations: list[Annotation],
|
||||
*,
|
||||
mask_type: str,
|
||||
individual_prefix: str = "",
|
||||
semantic_prefix: str = "",
|
||||
semantic_file_stem: str = "semantic_frame",
|
||||
semantic_dtype: Any = np.uint8,
|
||||
) -> list[dict[str, Any]]:
|
||||
import cv2
|
||||
|
||||
class_values: dict[str, int] = {}
|
||||
semantic_classes: list[dict[str, Any]] = []
|
||||
@@ -235,46 +584,102 @@ def export_masks(project_id: int, db: Session = Depends(get_db)) -> StreamingRes
|
||||
})
|
||||
return class_values[key]
|
||||
|
||||
zip_buffer = io.BytesIO()
|
||||
with zipfile.ZipFile(zip_buffer, "w", zipfile.ZIP_DEFLATED) as zf:
|
||||
frame_masks: dict[int, list[tuple[Annotation, np.ndarray]]] = {}
|
||||
for ann in annotations:
|
||||
if not ann.mask_data:
|
||||
continue
|
||||
polygons = ann.mask_data.get("polygons", [])
|
||||
if not polygons:
|
||||
continue
|
||||
include_individual = mask_type in {"separate", "both"}
|
||||
include_semantic = mask_type in {"gt_label", "both"}
|
||||
frame_masks: dict[int, list[tuple[Annotation, np.ndarray]]] = {}
|
||||
selected_frame_ids = {frame.id for frame in frames}
|
||||
|
||||
width = ann.frame.width if ann.frame else 1920
|
||||
height = ann.frame.height if ann.frame else 1080
|
||||
combined = np.zeros((height, width), dtype=np.uint8)
|
||||
for ann in annotations:
|
||||
if ann.frame_id not in selected_frame_ids or not ann.mask_data:
|
||||
continue
|
||||
polygons = ann.mask_data.get("polygons", [])
|
||||
if not polygons:
|
||||
continue
|
||||
|
||||
for poly in polygons:
|
||||
mask = _mask_from_polygon(poly, width, height)
|
||||
combined = np.maximum(combined, mask)
|
||||
width = ann.frame.width if ann.frame else 1920
|
||||
height = ann.frame.height if ann.frame else 1080
|
||||
combined = np.zeros((height, width), dtype=np.uint8)
|
||||
for poly in polygons:
|
||||
mask = _mask_from_polygon(poly, width, height)
|
||||
combined = np.maximum(combined, mask)
|
||||
|
||||
if include_individual:
|
||||
_, encoded = cv2.imencode(".png", combined)
|
||||
fname = f"mask_{ann.id:06d}.png"
|
||||
zf.writestr(fname, encoded.tobytes())
|
||||
if ann.frame_id is not None:
|
||||
frame_masks.setdefault(ann.frame_id, []).append((ann, combined))
|
||||
zf.writestr(f"{individual_prefix}mask_{ann.id:06d}.png", encoded.tobytes())
|
||||
if include_semantic and ann.frame_id is not None:
|
||||
frame_masks.setdefault(ann.frame_id, []).append((ann, combined))
|
||||
|
||||
if include_semantic:
|
||||
for frame in frames:
|
||||
entries = frame_masks.get(frame.id, [])
|
||||
if not entries:
|
||||
continue
|
||||
width = frame.width or 1920
|
||||
height = frame.height or 1080
|
||||
semantic = np.zeros((height, width), dtype=np.uint8)
|
||||
semantic = np.zeros((height, width), dtype=semantic_dtype)
|
||||
for ann, mask in sorted(entries, key=lambda item: _annotation_z_index(item[0])):
|
||||
semantic[mask > 0] = class_value(ann)
|
||||
_, encoded = cv2.imencode(".png", semantic)
|
||||
zf.writestr(f"semantic_frame_{frame.frame_index:06d}.png", encoded.tobytes())
|
||||
zf.writestr(f"{semantic_prefix}{semantic_file_stem}_{frame.frame_index:06d}.png", encoded.tobytes())
|
||||
|
||||
if include_semantic:
|
||||
zf.writestr(
|
||||
"semantic_classes.json",
|
||||
f"{semantic_prefix}semantic_classes.json",
|
||||
json.dumps({"classes": semantic_classes}, ensure_ascii=False, indent=2).encode("utf-8"),
|
||||
)
|
||||
return semantic_classes
|
||||
|
||||
|
||||
@router.get(
|
||||
"/{project_id}/coco",
|
||||
summary="Export annotations in COCO format",
|
||||
)
|
||||
def export_coco(
|
||||
project_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> StreamingResponse:
|
||||
"""Export all annotations for a project as a COCO-format JSON file."""
|
||||
project = _project_or_404(project_id, db, current_user)
|
||||
frames = _project_frames(project_id, db)
|
||||
annotations = _filtered_annotations(project_id, {frame.id for frame in frames}, db)
|
||||
templates = db.query(Template).all()
|
||||
coco = _build_coco(project, frames, annotations, templates)
|
||||
|
||||
data = json.dumps(coco, ensure_ascii=False, indent=2).encode("utf-8")
|
||||
filename = f"project_{project_id}_coco.json"
|
||||
|
||||
return StreamingResponse(
|
||||
io.BytesIO(data),
|
||||
media_type="application/json",
|
||||
headers={"Content-Disposition": f'attachment; filename="{filename}"'},
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/{project_id}/masks",
|
||||
summary="Export PNG masks as a ZIP archive",
|
||||
)
|
||||
def export_masks(
|
||||
project_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> StreamingResponse:
|
||||
"""Export individual masks plus z-index fused semantic masks inside a ZIP."""
|
||||
_project_or_404(project_id, db, current_user)
|
||||
frames = _project_frames(project_id, db)
|
||||
annotations = _filtered_annotations(project_id, {frame.id for frame in frames}, db)
|
||||
|
||||
zip_buffer = io.BytesIO()
|
||||
with zipfile.ZipFile(zip_buffer, "w", zipfile.ZIP_DEFLATED) as zf:
|
||||
_write_mask_pngs(
|
||||
zf,
|
||||
frames,
|
||||
annotations,
|
||||
mask_type="both",
|
||||
semantic_prefix="",
|
||||
individual_prefix="",
|
||||
)
|
||||
|
||||
zip_buffer.seek(0)
|
||||
filename = f"project_{project_id}_masks.zip"
|
||||
@@ -284,3 +689,71 @@ def export_masks(project_id: int, db: Session = Depends(get_db)) -> StreamingRes
|
||||
media_type="application/zip",
|
||||
headers={"Content-Disposition": f'attachment; filename="{filename}"'},
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/{project_id}/results",
|
||||
summary="Export segmentation results as a ZIP archive",
|
||||
)
|
||||
def export_results(
|
||||
project_id: int,
|
||||
scope: str = Query("all", pattern="^(all|range|current)$"),
|
||||
mask_type: str = Query("both", pattern="^(separate|gt_label|pro_label|mix_label|both|all)$"),
|
||||
outputs: str | None = Query(None),
|
||||
mix_opacity: float = Query(0.3, ge=0.0, le=1.0),
|
||||
start_frame: int | None = Query(None, ge=1),
|
||||
end_frame: int | None = Query(None, ge=1),
|
||||
frame_id: int | None = Query(None, ge=1),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> StreamingResponse:
|
||||
"""Export JSON annotations plus selected PNG mask outputs inside one ZIP.
|
||||
|
||||
`scope=all` exports the whole video. `scope=range` uses 1-based frame
|
||||
numbers from the sorted project frame sequence. `scope=current` uses the
|
||||
concrete backend `frame_id`.
|
||||
"""
|
||||
project = _project_or_404(project_id, db, current_user)
|
||||
frames = _filter_frames(
|
||||
_project_frames(project_id, db),
|
||||
scope=scope,
|
||||
start_frame=start_frame,
|
||||
end_frame=end_frame,
|
||||
frame_id=frame_id,
|
||||
)
|
||||
annotations = _filtered_annotations(project_id, {frame.id for frame in frames}, db)
|
||||
templates = db.query(Template).all()
|
||||
coco = _build_coco(project, frames, annotations, templates)
|
||||
class_values, class_mapping = _build_gt_class_mapping(annotations)
|
||||
selected_outputs = _parse_result_outputs(mask_type, outputs)
|
||||
|
||||
zip_buffer = io.BytesIO()
|
||||
with zipfile.ZipFile(zip_buffer, "w", zipfile.ZIP_DEFLATED) as zf:
|
||||
zf.writestr(
|
||||
"annotations_coco.json",
|
||||
json.dumps(coco, ensure_ascii=False, indent=2).encode("utf-8"),
|
||||
)
|
||||
zf.writestr(
|
||||
"maskid_GT像素值_类别映射.json",
|
||||
json.dumps({"classes": class_mapping}, ensure_ascii=False, indent=2).encode("utf-8"),
|
||||
)
|
||||
original_images = _write_original_frames(zf, project, frames)
|
||||
_write_result_mask_outputs(
|
||||
zf,
|
||||
project,
|
||||
frames,
|
||||
annotations,
|
||||
outputs=selected_outputs,
|
||||
class_values=class_values,
|
||||
class_mapping=class_mapping,
|
||||
original_images=original_images,
|
||||
mix_opacity=mix_opacity,
|
||||
)
|
||||
|
||||
zip_buffer.seek(0)
|
||||
filename = _segmentation_results_filename(project, frames)
|
||||
return StreamingResponse(
|
||||
zip_buffer,
|
||||
media_type="application/zip",
|
||||
headers={"Content-Disposition": _download_content_disposition(filename)},
|
||||
)
|
||||
|
||||
@@ -9,8 +9,9 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from database import get_db
|
||||
from minio_client import upload_file, get_presigned_url
|
||||
from models import ProcessingTask, Project
|
||||
from models import ProcessingTask, Project, User
|
||||
from progress_events import publish_task_progress_event
|
||||
from routers.auth import require_editor
|
||||
from schemas import ProcessingTaskOut
|
||||
from statuses import PROJECT_STATUS_PARSING, PROJECT_STATUS_PENDING, TASK_STATUS_QUEUED
|
||||
from worker_tasks import parse_project_media
|
||||
@@ -34,6 +35,7 @@ async def upload_media(
|
||||
file: UploadFile = File(...),
|
||||
project_id: Optional[int] = Form(None),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_editor),
|
||||
) -> dict:
|
||||
"""Accept a video, image, or DICOM file and store it in MinIO.
|
||||
|
||||
@@ -62,13 +64,15 @@ async def upload_media(
|
||||
file_url = get_presigned_url(object_name, expires=3600)
|
||||
|
||||
if project_id:
|
||||
project = db.query(Project).filter(Project.id == project_id).first()
|
||||
if project:
|
||||
project.video_path = object_name
|
||||
db.commit()
|
||||
logger.info("Linked upload to project_id=%s", project_id)
|
||||
else:
|
||||
logger.warning("Project id=%s not found for upload linkage", project_id)
|
||||
project = db.query(Project).filter(
|
||||
Project.id == project_id,
|
||||
Project.owner_user_id == current_user.id,
|
||||
).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
project.video_path = object_name
|
||||
db.commit()
|
||||
logger.info("Linked upload to project_id=%s", project_id)
|
||||
else:
|
||||
# Auto-create a project named after the file
|
||||
project = Project(
|
||||
@@ -77,6 +81,7 @@ async def upload_media(
|
||||
status=PROJECT_STATUS_PENDING,
|
||||
video_path=object_name,
|
||||
source_type="video",
|
||||
owner_user_id=current_user.id,
|
||||
)
|
||||
db.add(project)
|
||||
db.commit()
|
||||
@@ -109,6 +114,7 @@ async def upload_dicom_batch(
|
||||
files: List[UploadFile] = File(...),
|
||||
project_id: Optional[int] = Form(None),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_editor),
|
||||
) -> dict:
|
||||
"""Upload multiple .dcm files for a DICOM series.
|
||||
|
||||
@@ -121,7 +127,10 @@ async def upload_dicom_batch(
|
||||
uploaded = []
|
||||
|
||||
if project_id:
|
||||
project = db.query(Project).filter(Project.id == project_id).first()
|
||||
project = db.query(Project).filter(
|
||||
Project.id == project_id,
|
||||
Project.owner_user_id == current_user.id,
|
||||
).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
else:
|
||||
@@ -132,6 +141,7 @@ async def upload_dicom_batch(
|
||||
description=f"DICOM series with {len(files)} files",
|
||||
status=PROJECT_STATUS_PENDING,
|
||||
source_type="dicom",
|
||||
owner_user_id=current_user.id,
|
||||
)
|
||||
db.add(project)
|
||||
db.commit()
|
||||
@@ -173,13 +183,17 @@ def parse_media(
|
||||
max_frames: Optional[int] = Query(None, gt=0),
|
||||
target_width: int = Query(640, ge=64, le=4096),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_editor),
|
||||
) -> ProcessingTask:
|
||||
"""Create a background task for media frame extraction.
|
||||
|
||||
The Celery worker performs the heavy FFmpeg/OpenCV/pydicom work and
|
||||
updates the persisted task record as it progresses.
|
||||
"""
|
||||
project = db.query(Project).filter(Project.id == project_id).first()
|
||||
project = db.query(Project).filter(
|
||||
Project.id == project_id,
|
||||
Project.owner_user_id == current_user.id,
|
||||
).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
|
||||
|
||||
@@ -7,7 +7,8 @@ from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from database import get_db
|
||||
from models import Project, Frame
|
||||
from models import Project, Frame, User
|
||||
from routers.auth import get_current_user, require_editor
|
||||
from schemas import ProjectCreate, ProjectOut, ProjectUpdate, FrameCreate, FrameOut
|
||||
from minio_client import get_presigned_url
|
||||
|
||||
@@ -24,9 +25,13 @@ router = APIRouter(prefix="/api/projects", tags=["Projects"])
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
summary="Create a new project",
|
||||
)
|
||||
def create_project(payload: ProjectCreate, db: Session = Depends(get_db)) -> Project:
|
||||
def create_project(
|
||||
payload: ProjectCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_editor),
|
||||
) -> Project:
|
||||
"""Create a new segmentation project."""
|
||||
project = Project(**payload.model_dump())
|
||||
project = Project(**payload.model_dump(), owner_user_id=current_user.id)
|
||||
db.add(project)
|
||||
db.commit()
|
||||
db.refresh(project)
|
||||
@@ -39,9 +44,20 @@ def create_project(payload: ProjectCreate, db: Session = Depends(get_db)) -> Pro
|
||||
response_model=List[ProjectOut],
|
||||
summary="List all projects",
|
||||
)
|
||||
def list_projects(skip: int = 0, limit: int = 100, db: Session = Depends(get_db)) -> List[Project]:
|
||||
def list_projects(
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> List[Project]:
|
||||
"""Retrieve a paginated list of projects."""
|
||||
projects = db.query(Project).offset(skip).limit(limit).all()
|
||||
projects = (
|
||||
db.query(Project)
|
||||
.filter(Project.owner_user_id == current_user.id)
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
for p in projects:
|
||||
p.frame_count = len(p.frames)
|
||||
if p.thumbnail_url:
|
||||
@@ -54,9 +70,16 @@ def list_projects(skip: int = 0, limit: int = 100, db: Session = Depends(get_db)
|
||||
response_model=ProjectOut,
|
||||
summary="Get a single project",
|
||||
)
|
||||
def get_project(project_id: int, db: Session = Depends(get_db)) -> Project:
|
||||
def get_project(
|
||||
project_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> Project:
|
||||
"""Retrieve a project by its ID."""
|
||||
project = db.query(Project).filter(Project.id == project_id).first()
|
||||
project = db.query(Project).filter(
|
||||
Project.id == project_id,
|
||||
Project.owner_user_id == current_user.id,
|
||||
).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
project.frame_count = len(project.frames)
|
||||
@@ -74,9 +97,13 @@ def update_project(
|
||||
project_id: int,
|
||||
payload: ProjectUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_editor),
|
||||
) -> Project:
|
||||
"""Update project fields partially."""
|
||||
project = db.query(Project).filter(Project.id == project_id).first()
|
||||
project = db.query(Project).filter(
|
||||
Project.id == project_id,
|
||||
Project.owner_user_id == current_user.id,
|
||||
).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
|
||||
@@ -94,9 +121,16 @@ def update_project(
|
||||
status_code=status.HTTP_204_NO_CONTENT,
|
||||
summary="Delete a project",
|
||||
)
|
||||
def delete_project(project_id: int, db: Session = Depends(get_db)) -> None:
|
||||
def delete_project(
|
||||
project_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_editor),
|
||||
) -> None:
|
||||
"""Delete a project and all related frames and annotations."""
|
||||
project = db.query(Project).filter(Project.id == project_id).first()
|
||||
project = db.query(Project).filter(
|
||||
Project.id == project_id,
|
||||
Project.owner_user_id == current_user.id,
|
||||
).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
|
||||
@@ -118,9 +152,13 @@ def create_frame(
|
||||
project_id: int,
|
||||
payload: FrameCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_editor),
|
||||
) -> Frame:
|
||||
"""Register a new frame under a project."""
|
||||
project = db.query(Project).filter(Project.id == project_id).first()
|
||||
project = db.query(Project).filter(
|
||||
Project.id == project_id,
|
||||
Project.owner_user_id == current_user.id,
|
||||
).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
|
||||
@@ -141,9 +179,13 @@ def list_frames(
|
||||
skip: int = 0,
|
||||
limit: int = 1000,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> List[Frame]:
|
||||
"""Retrieve all frames belonging to a project."""
|
||||
project = db.query(Project).filter(Project.id == project_id).first()
|
||||
project = db.query(Project).filter(
|
||||
Project.id == project_id,
|
||||
Project.owner_user_id == current_user.id,
|
||||
).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
|
||||
@@ -165,11 +207,21 @@ def list_frames(
|
||||
response_model=FrameOut,
|
||||
summary="Get a single frame",
|
||||
)
|
||||
def get_frame(project_id: int, frame_id: int, db: Session = Depends(get_db)) -> Frame:
|
||||
def get_frame(
|
||||
project_id: int,
|
||||
frame_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> Frame:
|
||||
"""Retrieve a specific frame by ID."""
|
||||
frame = (
|
||||
db.query(Frame)
|
||||
.filter(Frame.project_id == project_id, Frame.id == frame_id)
|
||||
.join(Project, Project.id == Frame.project_id)
|
||||
.filter(
|
||||
Frame.project_id == project_id,
|
||||
Frame.id == frame_id,
|
||||
Project.owner_user_id == current_user.id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if not frame:
|
||||
|
||||
@@ -9,8 +9,9 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from celery_app import celery_app
|
||||
from database import get_db
|
||||
from models import ProcessingTask, Project
|
||||
from models import ProcessingTask, Project, User
|
||||
from progress_events import publish_task_progress_event
|
||||
from routers.auth import get_current_user, require_editor
|
||||
from schemas import ProcessingTaskOut
|
||||
from statuses import (
|
||||
PROJECT_STATUS_PARSING,
|
||||
@@ -31,8 +32,16 @@ def _now() -> datetime:
|
||||
return datetime.now(timezone.utc)
|
||||
|
||||
|
||||
def _get_task_or_404(task_id: int, db: Session) -> ProcessingTask:
|
||||
task = db.query(ProcessingTask).filter(ProcessingTask.id == task_id).first()
|
||||
def _get_task_or_404(task_id: int, db: Session, current_user: User) -> ProcessingTask:
|
||||
task = (
|
||||
db.query(ProcessingTask)
|
||||
.outerjoin(Project, Project.id == ProcessingTask.project_id)
|
||||
.filter(
|
||||
ProcessingTask.id == task_id,
|
||||
(ProcessingTask.project_id.is_(None)) | (Project.owner_user_id == current_user.id),
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if not task:
|
||||
raise HTTPException(status_code=404, detail="Task not found")
|
||||
return task
|
||||
@@ -48,9 +57,12 @@ def list_tasks(
|
||||
status: str | None = None,
|
||||
limit: int = 50,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> List[ProcessingTask]:
|
||||
"""Return recent background processing tasks."""
|
||||
query = db.query(ProcessingTask)
|
||||
query = db.query(ProcessingTask).outerjoin(Project, Project.id == ProcessingTask.project_id).filter(
|
||||
(ProcessingTask.project_id.is_(None)) | (Project.owner_user_id == current_user.id)
|
||||
)
|
||||
if project_id is not None:
|
||||
query = query.filter(ProcessingTask.project_id == project_id)
|
||||
if status is not None:
|
||||
@@ -59,15 +71,23 @@ def list_tasks(
|
||||
|
||||
|
||||
@router.get("/{task_id}", response_model=ProcessingTaskOut, summary="Get processing task")
|
||||
def get_task(task_id: int, db: Session = Depends(get_db)) -> ProcessingTask:
|
||||
def get_task(
|
||||
task_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> ProcessingTask:
|
||||
"""Return one background task by id."""
|
||||
return _get_task_or_404(task_id, db)
|
||||
return _get_task_or_404(task_id, db, current_user)
|
||||
|
||||
|
||||
@router.post("/{task_id}/cancel", response_model=ProcessingTaskOut, summary="Cancel processing task")
|
||||
def cancel_task(task_id: int, db: Session = Depends(get_db)) -> ProcessingTask:
|
||||
def cancel_task(
|
||||
task_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_editor),
|
||||
) -> ProcessingTask:
|
||||
"""Cancel a queued/running background task and revoke the Celery job when possible."""
|
||||
task = _get_task_or_404(task_id, db)
|
||||
task = _get_task_or_404(task_id, db, current_user)
|
||||
if task.status not in TASK_ACTIVE_STATUSES:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
@@ -95,9 +115,13 @@ def cancel_task(task_id: int, db: Session = Depends(get_db)) -> ProcessingTask:
|
||||
|
||||
|
||||
@router.post("/{task_id}/retry", response_model=ProcessingTaskOut, status_code=status.HTTP_202_ACCEPTED, summary="Retry processing task")
|
||||
def retry_task(task_id: int, db: Session = Depends(get_db)) -> ProcessingTask:
|
||||
def retry_task(
|
||||
task_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_editor),
|
||||
) -> ProcessingTask:
|
||||
"""Create a fresh queued task from a failed or cancelled task."""
|
||||
previous = _get_task_or_404(task_id, db)
|
||||
previous = _get_task_or_404(task_id, db, current_user)
|
||||
if previous.status not in {TASK_STATUS_FAILED, TASK_STATUS_CANCELLED}:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
@@ -106,7 +130,10 @@ def retry_task(task_id: int, db: Session = Depends(get_db)) -> ProcessingTask:
|
||||
if previous.project_id is None:
|
||||
raise HTTPException(status_code=400, detail="Task has no project_id")
|
||||
|
||||
project = db.query(Project).filter(Project.id == previous.project_id).first()
|
||||
project = db.query(Project).filter(
|
||||
Project.id == previous.project_id,
|
||||
Project.owner_user_id == current_user.id,
|
||||
).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
is_propagation_task = previous.task_type == "propagate_masks"
|
||||
|
||||
@@ -4,10 +4,12 @@ import logging
|
||||
from typing import List
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy import or_
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from database import get_db
|
||||
from models import Template
|
||||
from models import Template, User
|
||||
from routers.auth import get_current_user, require_editor
|
||||
from schemas import TemplateCreate, TemplateOut, TemplateUpdate
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -40,11 +42,15 @@ def _unpack_template(template: Template) -> Template:
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
summary="Create a new template",
|
||||
)
|
||||
def create_template(payload: TemplateCreate, db: Session = Depends(get_db)) -> Template:
|
||||
def create_template(
|
||||
payload: TemplateCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_editor),
|
||||
) -> Template:
|
||||
"""Create a new ontology template / segmentation class."""
|
||||
data = payload.model_dump()
|
||||
data = _pack_mapping_rules(data)
|
||||
template = Template(**data)
|
||||
template = Template(**data, owner_user_id=current_user.id)
|
||||
db.add(template)
|
||||
db.commit()
|
||||
db.refresh(template)
|
||||
@@ -62,9 +68,16 @@ def list_templates(
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> List[Template]:
|
||||
"""Retrieve all ontology templates."""
|
||||
templates = db.query(Template).offset(skip).limit(limit).all()
|
||||
templates = (
|
||||
db.query(Template)
|
||||
.filter(or_(Template.owner_user_id == current_user.id, Template.owner_user_id.is_(None)))
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
for t in templates:
|
||||
_unpack_template(t)
|
||||
return templates
|
||||
@@ -75,9 +88,16 @@ def list_templates(
|
||||
response_model=TemplateOut,
|
||||
summary="Get a single template",
|
||||
)
|
||||
def get_template(template_id: int, db: Session = Depends(get_db)) -> Template:
|
||||
def get_template(
|
||||
template_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> Template:
|
||||
"""Retrieve a template by its ID."""
|
||||
template = db.query(Template).filter(Template.id == template_id).first()
|
||||
template = db.query(Template).filter(
|
||||
Template.id == template_id,
|
||||
or_(Template.owner_user_id == current_user.id, Template.owner_user_id.is_(None)),
|
||||
).first()
|
||||
if not template:
|
||||
raise HTTPException(status_code=404, detail="Template not found")
|
||||
_unpack_template(template)
|
||||
@@ -93,9 +113,13 @@ def update_template(
|
||||
template_id: int,
|
||||
payload: TemplateUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_editor),
|
||||
) -> Template:
|
||||
"""Update template fields partially."""
|
||||
template = db.query(Template).filter(Template.id == template_id).first()
|
||||
template = db.query(Template).filter(
|
||||
Template.id == template_id,
|
||||
or_(Template.owner_user_id == current_user.id, Template.owner_user_id.is_(None)),
|
||||
).first()
|
||||
if not template:
|
||||
raise HTTPException(status_code=404, detail="Template not found")
|
||||
|
||||
@@ -118,9 +142,16 @@ def update_template(
|
||||
status_code=status.HTTP_204_NO_CONTENT,
|
||||
summary="Delete a template",
|
||||
)
|
||||
def delete_template(template_id: int, db: Session = Depends(get_db)) -> None:
|
||||
def delete_template(
|
||||
template_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_editor),
|
||||
) -> None:
|
||||
"""Delete a template. Associated annotations will have template_id set to NULL."""
|
||||
template = db.query(Template).filter(Template.id == template_id).first()
|
||||
template = db.query(Template).filter(
|
||||
Template.id == template_id,
|
||||
or_(Template.owner_user_id == current_user.id, Template.owner_user_id.is_(None)),
|
||||
).first()
|
||||
if not template:
|
||||
raise HTTPException(status_code=404, detail="Template not found")
|
||||
|
||||
|
||||
@@ -5,6 +5,55 @@ from typing import Optional, Any
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Auth / user schemas
|
||||
# ---------------------------------------------------------------------------
|
||||
class UserOut(BaseModel):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: int
|
||||
username: str
|
||||
role: str
|
||||
is_active: int
|
||||
|
||||
|
||||
class LoginResponse(BaseModel):
|
||||
token: str
|
||||
token_type: str = "bearer"
|
||||
username: str
|
||||
user: UserOut
|
||||
|
||||
|
||||
class AdminUserCreate(BaseModel):
|
||||
username: str
|
||||
password: str
|
||||
role: str = "annotator"
|
||||
is_active: bool = True
|
||||
|
||||
|
||||
class AdminUserUpdate(BaseModel):
|
||||
username: Optional[str] = None
|
||||
password: Optional[str] = None
|
||||
role: Optional[str] = None
|
||||
is_active: Optional[bool] = None
|
||||
|
||||
|
||||
class AuditLogOut(BaseModel):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: int
|
||||
actor_user_id: Optional[int] = None
|
||||
action: str
|
||||
target_type: Optional[str] = None
|
||||
target_id: Optional[str] = None
|
||||
detail: Optional[dict[str, Any]] = None
|
||||
created_at: datetime
|
||||
|
||||
|
||||
class DemoFactoryResetRequest(BaseModel):
|
||||
confirmation: str
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Project schemas
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -38,11 +87,19 @@ class ProjectOut(ProjectBase):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: int
|
||||
owner_user_id: Optional[int] = None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
frame_count: int = 0
|
||||
|
||||
|
||||
class DemoFactoryResetOut(BaseModel):
|
||||
admin_user: UserOut
|
||||
project: ProjectOut
|
||||
deleted_counts: dict[str, int]
|
||||
message: str
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Frame schemas
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -98,6 +155,7 @@ class TemplateOut(TemplateBase):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: int
|
||||
owner_user_id: Optional[int] = None
|
||||
created_at: datetime
|
||||
|
||||
|
||||
|
||||
@@ -102,8 +102,14 @@ def _normalize_smoothing_options(value: Any) -> dict[str, Any] | None:
|
||||
return {"strength": round(strength, 2), "method": method}
|
||||
|
||||
|
||||
def _chaikin_smooth_polygon(polygon: list[list[float]], iterations: int) -> list[list[float]]:
|
||||
def _smoothing_ratio(strength: float, curve: float = 1.65) -> float:
|
||||
normalized = max(0.0, min(float(strength or 0.0), 100.0)) / 100.0
|
||||
return normalized ** curve
|
||||
|
||||
|
||||
def _chaikin_smooth_polygon(polygon: list[list[float]], iterations: int, corner_cut: float = 0.25) -> list[list[float]]:
|
||||
points = _normalize_polygon(polygon)
|
||||
q = max(0.02, min(float(corner_cut), 0.25))
|
||||
for _ in range(max(0, iterations)):
|
||||
if len(points) < 3:
|
||||
break
|
||||
@@ -111,12 +117,12 @@ def _chaikin_smooth_polygon(polygon: list[list[float]], iterations: int) -> list
|
||||
for index, current in enumerate(points):
|
||||
following = points[(index + 1) % len(points)]
|
||||
next_points.append([
|
||||
_clamp01(0.75 * current[0] + 0.25 * following[0]),
|
||||
_clamp01(0.75 * current[1] + 0.25 * following[1]),
|
||||
_clamp01((1.0 - q) * current[0] + q * following[0]),
|
||||
_clamp01((1.0 - q) * current[1] + q * following[1]),
|
||||
])
|
||||
next_points.append([
|
||||
_clamp01(0.25 * current[0] + 0.75 * following[0]),
|
||||
_clamp01(0.25 * current[1] + 0.75 * following[1]),
|
||||
_clamp01(q * current[0] + (1.0 - q) * following[0]),
|
||||
_clamp01(q * current[1] + (1.0 - q) * following[1]),
|
||||
])
|
||||
points = next_points
|
||||
return points
|
||||
@@ -127,7 +133,7 @@ def _simplify_polygon(polygon: list[list[float]], strength: float) -> list[list[
|
||||
return polygon
|
||||
contour = np.array([[[point[0], point[1]]] for point in polygon], dtype=np.float32)
|
||||
arc_length = cv2.arcLength(contour, True)
|
||||
epsilon = arc_length * (0.001 + (strength / 100.0) * 0.006)
|
||||
epsilon = arc_length * (0.00015 + _smoothing_ratio(strength) * 0.00735)
|
||||
approx = cv2.approxPolyDP(contour, epsilon, True).reshape(-1, 2)
|
||||
if len(approx) < 3:
|
||||
return polygon
|
||||
@@ -140,9 +146,25 @@ def _smooth_polygon(polygon: list[list[float]], smoothing: dict[str, Any] | None
|
||||
strength = float(smoothing.get("strength") or 0.0)
|
||||
if strength <= 0:
|
||||
return _normalize_polygon(polygon)
|
||||
iterations = max(1, min(3, int(strength // 35) + 1))
|
||||
smoothed = _chaikin_smooth_polygon(polygon, iterations)
|
||||
simplified = _simplify_polygon(smoothed, strength)
|
||||
effective_strength = _smoothing_ratio(strength, curve=1.45) * 100.0
|
||||
if effective_strength >= 85:
|
||||
iterations = 4
|
||||
elif effective_strength >= 55:
|
||||
iterations = 3
|
||||
elif effective_strength >= 25:
|
||||
iterations = 2
|
||||
else:
|
||||
iterations = 1
|
||||
corner_cut = 0.03 + _smoothing_ratio(strength, curve=1.35) * 0.22
|
||||
normalized = _normalize_polygon(polygon)
|
||||
pre_simplified = _simplify_polygon(normalized, effective_strength * 0.25)
|
||||
smoothed = _chaikin_smooth_polygon(pre_simplified, iterations, corner_cut)
|
||||
simplified = _simplify_polygon(smoothed, effective_strength)
|
||||
if len(simplified) > len(normalized):
|
||||
for fallback_strength in (25.0, 35.0, 50.0, 70.0, 90.0, 100.0):
|
||||
simplified = _simplify_polygon(simplified, max(effective_strength, fallback_strength))
|
||||
if len(simplified) <= len(normalized):
|
||||
break
|
||||
return simplified if len(simplified) >= 3 else _normalize_polygon(polygon)
|
||||
|
||||
|
||||
|
||||
@@ -19,7 +19,7 @@ if str(BACKEND_DIR) not in sys.path:
|
||||
|
||||
from database import Base, get_db # noqa: E402
|
||||
from main import websocket_progress # noqa: E402
|
||||
from routers import ai, auth, dashboard, export, media, projects, tasks, templates # noqa: E402
|
||||
from routers import admin, ai, auth, dashboard, export, media, projects, tasks, templates # noqa: E402
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
@@ -32,6 +32,7 @@ def db_session() -> Iterator[Session]:
|
||||
TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
Base.metadata.create_all(bind=engine)
|
||||
session = TestingSessionLocal()
|
||||
auth.ensure_default_admin(session)
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
@@ -56,6 +57,7 @@ def app(db_session: Session) -> FastAPI:
|
||||
test_app.include_router(export.router)
|
||||
test_app.include_router(dashboard.router)
|
||||
test_app.include_router(tasks.router)
|
||||
test_app.include_router(admin.router)
|
||||
|
||||
@test_app.get("/health")
|
||||
def health_check() -> dict[str, str]:
|
||||
@@ -67,6 +69,10 @@ def app(db_session: Session) -> FastAPI:
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def client(app: FastAPI) -> Iterator[TestClient]:
|
||||
def client(app: FastAPI, db_session: Session) -> Iterator[TestClient]:
|
||||
with TestClient(app) as test_client:
|
||||
admin = auth.ensure_default_admin(db_session)
|
||||
test_client.headers.update({
|
||||
"Authorization": f"Bearer {auth.create_access_token(admin)}"
|
||||
})
|
||||
yield test_client
|
||||
|
||||
158
backend/tests/test_admin.py
Normal file
158
backend/tests/test_admin.py
Normal file
@@ -0,0 +1,158 @@
|
||||
from models import Annotation, AuditLog, Frame, Mask, ProcessingTask, Project, Template, User
|
||||
from routers.auth import create_access_token, hash_password
|
||||
from statuses import PROJECT_STATUS_PENDING
|
||||
|
||||
|
||||
def test_admin_user_management_and_audit_logs(client, db_session):
|
||||
created = client.post("/api/admin/users", json={
|
||||
"username": "doctor",
|
||||
"password": "secret123",
|
||||
"role": "annotator",
|
||||
"is_active": True,
|
||||
})
|
||||
assert created.status_code == 201
|
||||
user_id = created.json()["id"]
|
||||
|
||||
updated = client.patch(f"/api/admin/users/{user_id}", json={
|
||||
"role": "viewer",
|
||||
"password": "newsecret",
|
||||
"is_active": False,
|
||||
})
|
||||
assert updated.status_code == 200
|
||||
assert updated.json()["role"] == "viewer"
|
||||
assert updated.json()["is_active"] == 0
|
||||
|
||||
users = client.get("/api/admin/users")
|
||||
assert users.status_code == 200
|
||||
assert any(user["username"] == "doctor" for user in users.json())
|
||||
|
||||
deleted = client.delete(f"/api/admin/users/{user_id}")
|
||||
assert deleted.status_code == 204
|
||||
|
||||
logs = client.get("/api/admin/audit-logs")
|
||||
assert logs.status_code == 200
|
||||
actions = [log["action"] for log in logs.json()]
|
||||
assert "admin.user_created" in actions
|
||||
assert "admin.user_updated" in actions
|
||||
assert "admin.user_deleted" in actions
|
||||
|
||||
|
||||
def test_admin_routes_require_admin_role(client, db_session):
|
||||
user = User(username="viewer", password_hash=hash_password("secret123"), role="viewer", is_active=1)
|
||||
db_session.add(user)
|
||||
db_session.commit()
|
||||
db_session.refresh(user)
|
||||
original_auth = client.headers["Authorization"]
|
||||
client.headers.update({"Authorization": f"Bearer {create_access_token(user)}"})
|
||||
try:
|
||||
response = client.get("/api/admin/users")
|
||||
assert response.status_code == 403
|
||||
finally:
|
||||
client.headers.update({"Authorization": original_auth})
|
||||
|
||||
|
||||
def test_viewer_role_is_read_only_for_business_mutations(client, db_session):
|
||||
project = client.post("/api/projects", json={"name": "Readonly Check"}).json()
|
||||
user = User(username="readonly", password_hash=hash_password("secret123"), role="viewer", is_active=1)
|
||||
db_session.add(user)
|
||||
db_session.commit()
|
||||
db_session.refresh(user)
|
||||
original_auth = client.headers["Authorization"]
|
||||
client.headers.update({"Authorization": f"Bearer {create_access_token(user)}"})
|
||||
try:
|
||||
assert client.get("/api/projects").status_code == 200
|
||||
assert client.post("/api/projects", json={"name": "Nope"}).status_code == 403
|
||||
assert client.patch(f"/api/projects/{project['id']}", json={"name": "Nope"}).status_code == 403
|
||||
assert client.post("/api/ai/annotate", json={"project_id": project["id"]}).status_code == 403
|
||||
finally:
|
||||
client.headers.update({"Authorization": original_auth})
|
||||
|
||||
|
||||
def test_admin_cannot_delete_self_or_user_with_projects(client, db_session):
|
||||
me = client.get("/api/auth/me").json()
|
||||
assert client.delete(f"/api/admin/users/{me['id']}").status_code == 400
|
||||
|
||||
user = User(username="owner", password_hash=hash_password("secret123"), role="annotator", is_active=1)
|
||||
db_session.add(user)
|
||||
db_session.commit()
|
||||
db_session.refresh(user)
|
||||
db_session.add(Project(name="Owned", owner_user_id=user.id))
|
||||
db_session.commit()
|
||||
|
||||
response = client.delete(f"/api/admin/users/{user.id}")
|
||||
assert response.status_code == 409
|
||||
|
||||
|
||||
def test_demo_factory_reset_leaves_admin_and_unparsed_demo_video(client, db_session, monkeypatch, tmp_path):
|
||||
video_path = tmp_path / "Data_MyVideo_1.mp4"
|
||||
video_path.write_bytes(b"demo-video")
|
||||
monkeypatch.setattr("routers.admin.settings.demo_video_path", str(video_path))
|
||||
uploaded = []
|
||||
monkeypatch.setattr("routers.admin.upload_file", lambda object_name, data, content_type, length: uploaded.append({
|
||||
"object_name": object_name,
|
||||
"data": data,
|
||||
"content_type": content_type,
|
||||
"length": length,
|
||||
}))
|
||||
|
||||
extra_user = User(username="doctor", password_hash=hash_password("secret123"), role="annotator", is_active=1)
|
||||
db_session.add(extra_user)
|
||||
db_session.commit()
|
||||
db_session.refresh(extra_user)
|
||||
old_project = Project(name="Old", owner_user_id=extra_user.id, video_path="uploads/old.mp4")
|
||||
db_session.add(old_project)
|
||||
db_session.commit()
|
||||
db_session.refresh(old_project)
|
||||
frame = Frame(project_id=old_project.id, frame_index=0, image_url="frames/old.jpg")
|
||||
db_session.add(frame)
|
||||
task = ProcessingTask(task_type="parse_video", project_id=old_project.id)
|
||||
private_template = Template(
|
||||
name="Private",
|
||||
description="private",
|
||||
color="#fff",
|
||||
z_index=1,
|
||||
owner_user_id=extra_user.id,
|
||||
)
|
||||
db_session.add_all([task, private_template])
|
||||
db_session.commit()
|
||||
db_session.refresh(frame)
|
||||
annotation = Annotation(project_id=old_project.id, frame_id=frame.id, mask_data={"label": "old"})
|
||||
db_session.add(annotation)
|
||||
db_session.commit()
|
||||
db_session.refresh(annotation)
|
||||
db_session.add(Mask(annotation_id=annotation.id, mask_url="masks/old.png"))
|
||||
db_session.add(AuditLog(actor_user_id=extra_user.id, action="old.audit"))
|
||||
db_session.commit()
|
||||
|
||||
response = client.post("/api/admin/demo-factory-reset", json={"confirmation": "RESET_DEMO_FACTORY"})
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["message"] == "演示环境已恢复出厂设置"
|
||||
assert data["admin_user"]["username"] == "admin"
|
||||
assert data["project"]["name"] == "Data_MyVideo_1"
|
||||
assert data["project"]["status"] == PROJECT_STATUS_PENDING
|
||||
assert data["project"]["frame_count"] == 0
|
||||
assert data["project"]["video_path"] == f"uploads/{data['project']['id']}/Data_MyVideo_1.mp4"
|
||||
assert uploaded == [{
|
||||
"object_name": data["project"]["video_path"],
|
||||
"data": b"demo-video",
|
||||
"content_type": "video/mp4",
|
||||
"length": len(b"demo-video"),
|
||||
}]
|
||||
|
||||
assert [user.username for user in db_session.query(User).all()] == ["admin"]
|
||||
assert db_session.query(Project).count() == 1
|
||||
assert db_session.query(Frame).count() == 0
|
||||
assert db_session.query(Annotation).count() == 0
|
||||
assert db_session.query(Mask).count() == 0
|
||||
assert db_session.query(ProcessingTask).count() == 0
|
||||
assert db_session.query(Template).filter(Template.owner_user_id.is_not(None)).count() == 0
|
||||
assert db_session.query(AuditLog).count() == 1
|
||||
assert db_session.query(AuditLog).first().action == "admin.demo_factory_reset"
|
||||
|
||||
|
||||
def test_demo_factory_reset_requires_exact_confirmation(client):
|
||||
response = client.post("/api/admin/demo-factory-reset", json={"confirmation": "reset"})
|
||||
|
||||
assert response.status_code == 400
|
||||
@@ -223,6 +223,88 @@ def test_analyze_mask_returns_backend_geometry_properties(client):
|
||||
assert body["message"] == "已从后端重新提取几何拓扑锚点"
|
||||
|
||||
|
||||
def test_analyze_mask_reports_actual_polygon_anchor_count(client):
|
||||
_, frame, _ = _create_project_and_frame(client)
|
||||
polygon = [[0.1 + index * 0.005, 0.1 + (0.01 if index % 2 else 0)] for index in range(80)]
|
||||
|
||||
response = client.post("/api/ai/analyze-mask", json={
|
||||
"frame_id": frame["id"],
|
||||
"mask_data": {
|
||||
"polygons": [polygon],
|
||||
"label": "AI Mask",
|
||||
"color": "#06b6d4",
|
||||
},
|
||||
"points": [[0.2, 0.2]],
|
||||
})
|
||||
|
||||
assert response.status_code == 200
|
||||
body = response.json()
|
||||
assert body["topology_anchor_count"] == len(polygon)
|
||||
assert len(body["topology_anchors"]) <= 64
|
||||
|
||||
|
||||
def test_smooth_mask_simplifies_noisy_ai_polygon(client):
|
||||
_, frame, _ = _create_project_and_frame(client)
|
||||
polygon = []
|
||||
for index in range(20):
|
||||
polygon.append([0.1 + index * 0.02, 0.1 + (0.01 if index % 2 else 0)])
|
||||
for index in range(20):
|
||||
polygon.append([0.5 + (0.01 if index % 2 else 0), 0.1 + index * 0.02])
|
||||
for index in range(20):
|
||||
polygon.append([0.5 - index * 0.02, 0.5 + (0.01 if index % 2 else 0)])
|
||||
for index in range(20):
|
||||
polygon.append([0.1 + (0.01 if index % 2 else 0), 0.5 - index * 0.02])
|
||||
|
||||
response = client.post("/api/ai/smooth-mask", json={
|
||||
"frame_id": frame["id"],
|
||||
"mask_data": {
|
||||
"polygons": [polygon],
|
||||
"label": "AI Mask",
|
||||
"color": "#06b6d4",
|
||||
},
|
||||
"strength": 80,
|
||||
})
|
||||
|
||||
assert response.status_code == 200
|
||||
body = response.json()
|
||||
assert body["topology_anchor_count"] == len(body["polygons"][0])
|
||||
assert len(body["polygons"][0]) < len(polygon)
|
||||
|
||||
|
||||
def test_smooth_mask_uses_eased_strength_curve(client):
|
||||
_, frame, _ = _create_project_and_frame(client)
|
||||
polygon = []
|
||||
for index in range(20):
|
||||
polygon.append([0.1 + index * 0.02, 0.1 + (0.01 if index % 2 else 0)])
|
||||
for index in range(20):
|
||||
polygon.append([0.5 + (0.01 if index % 2 else 0), 0.1 + index * 0.02])
|
||||
for index in range(20):
|
||||
polygon.append([0.5 - index * 0.02, 0.5 + (0.01 if index % 2 else 0)])
|
||||
for index in range(20):
|
||||
polygon.append([0.1 + (0.01 if index % 2 else 0), 0.5 - index * 0.02])
|
||||
|
||||
def smoothed_count(strength: int) -> int:
|
||||
response = client.post("/api/ai/smooth-mask", json={
|
||||
"frame_id": frame["id"],
|
||||
"mask_data": {
|
||||
"polygons": [polygon],
|
||||
"label": "AI Mask",
|
||||
"color": "#06b6d4",
|
||||
},
|
||||
"strength": strength,
|
||||
})
|
||||
assert response.status_code == 200
|
||||
return len(response.json()["polygons"][0])
|
||||
|
||||
low_count = smoothed_count(20)
|
||||
mid_count = smoothed_count(70)
|
||||
high_count = smoothed_count(95)
|
||||
|
||||
assert low_count <= len(polygon)
|
||||
assert mid_count < low_count
|
||||
assert high_count < mid_count
|
||||
|
||||
|
||||
def test_smooth_mask_returns_backend_smoothed_geometry(client):
|
||||
_, frame, _ = _create_project_and_frame(client)
|
||||
|
||||
@@ -311,6 +393,7 @@ def test_propagate_saves_tracked_annotations(client, monkeypatch):
|
||||
"color": "#ff0000",
|
||||
"class_metadata": {"id": "c1", "name": "胆囊", "color": "#ff0000", "zIndex": 20},
|
||||
"template_id": None,
|
||||
"smoothing": {"strength": 45, "method": "chaikin"},
|
||||
},
|
||||
})
|
||||
|
||||
@@ -327,6 +410,9 @@ def test_propagate_saves_tracked_annotations(client, monkeypatch):
|
||||
assert saved["mask_data"]["source"] == "sam2.1_hiera_tiny_propagation"
|
||||
assert saved["mask_data"]["class"]["name"] == "胆囊"
|
||||
assert saved["mask_data"]["score"] == 0.8
|
||||
assert saved["mask_data"]["geometry_smoothing"] == {"strength": 45.0, "method": "chaikin"}
|
||||
assert saved["mask_data"]["polygons"][0] != [[0.15, 0.15], [0.25, 0.15], [0.25, 0.25]]
|
||||
assert len(saved["mask_data"]["polygons"][0]) > 3
|
||||
|
||||
listing = client.get(f"/api/ai/annotations?project_id={project['id']}")
|
||||
assert len(listing.json()) == 1
|
||||
@@ -490,8 +576,10 @@ def test_propagation_task_runner_saves_annotations_and_progress(client, db_sessi
|
||||
listing = client.get(f"/api/ai/annotations?project_id={project['id']}")
|
||||
assert listing.json()[0]["frame_id"] == frames[1]["id"]
|
||||
assert listing.json()[0]["mask_data"]["source"] == "sam2.1_hiera_tiny_propagation"
|
||||
stored_polygon = listing.json()[0]["mask_data"]["polygons"][0]
|
||||
assert listing.json()[0]["mask_data"]["geometry_smoothing"] == {"strength": 40.0, "method": "chaikin"}
|
||||
assert len(listing.json()[0]["mask_data"]["polygons"][0]) > 3
|
||||
assert stored_polygon != [[0.15, 0.15], [0.25, 0.15], [0.25, 0.25]]
|
||||
assert len(stored_polygon) > 3
|
||||
|
||||
|
||||
def test_propagation_task_runner_skips_unchanged_seed_and_replaces_changed_seed(client, db_session, monkeypatch):
|
||||
@@ -1084,3 +1172,156 @@ def test_import_gt_mask_splits_label_values(client):
|
||||
assert [item["mask_data"]["gt_label_value"] for item in body] == [1, 2]
|
||||
assert [item["mask_data"]["label"] for item in body] == ["GT Class 1", "GT Class 2"]
|
||||
assert all(len(item["points"]) == 1 for item in body)
|
||||
|
||||
|
||||
def test_import_gt_mask_preserves_low_value_gtlabel_png(client):
|
||||
project, frame, _ = _create_project_and_frame(client)
|
||||
template = client.post("/api/templates", json={
|
||||
"name": "GTLabel Template",
|
||||
"color": "#06b6d4",
|
||||
"z_index": 0,
|
||||
"classes": [
|
||||
{"id": "tumor", "name": "肿瘤", "color": "#ff0000", "zIndex": 10, "maskId": 1},
|
||||
],
|
||||
"rules": [],
|
||||
}).json()
|
||||
mask = np.zeros((360, 640), dtype=np.uint16)
|
||||
cv2.rectangle(mask, (40, 40), (140, 140), 1, thickness=-1)
|
||||
ok, encoded = cv2.imencode(".png", mask)
|
||||
assert ok
|
||||
|
||||
response = client.post(
|
||||
"/api/ai/import-gt-mask",
|
||||
data={
|
||||
"project_id": str(project["id"]),
|
||||
"frame_id": str(frame["id"]),
|
||||
"template_id": str(template["id"]),
|
||||
"unknown_color_policy": "discard",
|
||||
},
|
||||
files={"file": ("GT_label.png", encoded.tobytes(), "image/png")},
|
||||
)
|
||||
|
||||
assert response.status_code == 201
|
||||
body = response.json()
|
||||
assert len(body) == 1
|
||||
assert body[0]["mask_data"]["gt_label_value"] == 1
|
||||
assert body[0]["mask_data"]["class"]["name"] == "肿瘤"
|
||||
assert body[0]["mask_data"]["class"]["maskId"] == 1
|
||||
|
||||
|
||||
def test_import_gt_mask_rejects_rgb_color_masks(client):
|
||||
project, frame, _ = _create_project_and_frame(client)
|
||||
template = client.post("/api/templates", json={
|
||||
"name": "Color Template",
|
||||
"color": "#06b6d4",
|
||||
"z_index": 0,
|
||||
"classes": [
|
||||
{"id": "known", "name": "已知类别", "color": "#ff0000", "zIndex": 10, "maskId": 1},
|
||||
],
|
||||
"rules": [],
|
||||
}).json()
|
||||
mask = np.zeros((80, 120, 3), dtype=np.uint8)
|
||||
mask[10:40, 10:40] = [0, 0, 255] # BGR red -> #ff0000
|
||||
mask[40:70, 70:110] = [0, 255, 0] # BGR green -> unknown #00ff00
|
||||
ok, encoded = cv2.imencode(".png", mask)
|
||||
assert ok
|
||||
|
||||
response = client.post(
|
||||
"/api/ai/import-gt-mask",
|
||||
data={
|
||||
"project_id": str(project["id"]),
|
||||
"frame_id": str(frame["id"]),
|
||||
"template_id": str(template["id"]),
|
||||
"unknown_color_policy": "discard",
|
||||
},
|
||||
files={"file": ("color-mask.png", encoded.tobytes(), "image/png")},
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert "RGB 三通道完全相同" in response.json()["detail"]
|
||||
|
||||
|
||||
def test_import_gt_mask_reads_uint16_gt_label_and_maps_maskid_class(client):
|
||||
project, frame, _ = _create_project_and_frame(client)
|
||||
template = client.post("/api/templates", json={
|
||||
"name": "Label Template",
|
||||
"color": "#06b6d4",
|
||||
"z_index": 0,
|
||||
"classes": [{"id": "tumor", "name": "肿瘤", "color": "#ff0000", "zIndex": 10, "maskId": 1}],
|
||||
"rules": [],
|
||||
}).json()
|
||||
mask = np.zeros((360, 640), dtype=np.uint16)
|
||||
cv2.rectangle(mask, (20, 20), (120, 120), 1, thickness=-1)
|
||||
ok, encoded = cv2.imencode(".png", mask)
|
||||
assert ok
|
||||
|
||||
response = client.post(
|
||||
"/api/ai/import-gt-mask",
|
||||
data={
|
||||
"project_id": str(project["id"]),
|
||||
"frame_id": str(frame["id"]),
|
||||
"template_id": str(template["id"]),
|
||||
"unknown_color_policy": "discard",
|
||||
},
|
||||
files={"file": ("gt_label.png", encoded.tobytes(), "image/png")},
|
||||
)
|
||||
|
||||
assert response.status_code == 201
|
||||
body = response.json()
|
||||
assert len(body) == 1
|
||||
assert body[0]["mask_data"]["gt_label_value"] == 1
|
||||
assert body[0]["mask_data"]["label"] == "肿瘤"
|
||||
assert body[0]["mask_data"]["class"]["maskId"] == 1
|
||||
assert body[0]["mask_data"]["class"]["color"] == "#ff0000"
|
||||
|
||||
|
||||
def test_import_gt_mask_handles_unknown_maskid_policy_and_resizes_to_frame(client):
|
||||
project, frame, _ = _create_project_and_frame(client)
|
||||
template = client.post("/api/templates", json={
|
||||
"name": "Color Template",
|
||||
"color": "#06b6d4",
|
||||
"z_index": 0,
|
||||
"classes": [{"id": "known", "name": "已定义", "color": "#ff0000", "zIndex": 10, "maskId": 1}],
|
||||
"rules": [],
|
||||
}).json()
|
||||
mask = np.zeros((90, 160, 3), dtype=np.uint8)
|
||||
cv2.rectangle(mask, (5, 5), (40, 40), (1, 1, 1), thickness=-1)
|
||||
cv2.rectangle(mask, (80, 5), (120, 40), (2, 2, 2), thickness=-1)
|
||||
ok, encoded = cv2.imencode(".png", mask)
|
||||
assert ok
|
||||
|
||||
discard_response = client.post(
|
||||
"/api/ai/import-gt-mask",
|
||||
data={
|
||||
"project_id": str(project["id"]),
|
||||
"frame_id": str(frame["id"]),
|
||||
"template_id": str(template["id"]),
|
||||
"unknown_color_policy": "discard",
|
||||
},
|
||||
files={"file": ("colors.png", encoded.tobytes(), "image/png")},
|
||||
)
|
||||
|
||||
assert discard_response.status_code == 201
|
||||
assert [item["mask_data"]["label"] for item in discard_response.json()] == ["已定义"]
|
||||
assert discard_response.json()[0]["mask_data"]["gt_original_size"] == {"width": 160, "height": 90}
|
||||
assert discard_response.json()[0]["mask_data"]["gt_resized_to_frame"] is True
|
||||
assert discard_response.json()[0]["mask_data"]["image_size"] == {"width": 640, "height": 360}
|
||||
|
||||
undefined_response = client.post(
|
||||
"/api/ai/import-gt-mask",
|
||||
data={
|
||||
"project_id": str(project["id"]),
|
||||
"frame_id": str(frame["id"]),
|
||||
"template_id": str(template["id"]),
|
||||
"unknown_color_policy": "undefined",
|
||||
},
|
||||
files={"file": ("colors.png", encoded.tobytes(), "image/png")},
|
||||
)
|
||||
|
||||
assert undefined_response.status_code == 201
|
||||
labels = {item["mask_data"]["label"] for item in undefined_response.json()}
|
||||
assert labels == {"已定义", "未定义类别 2"}
|
||||
unknown = next(item for item in undefined_response.json() if item["mask_data"]["label"].startswith("未定义"))
|
||||
assert unknown["mask_data"]["gt_unknown_class"] is True
|
||||
assert unknown["mask_data"]["gt_label_value"] == 2
|
||||
assert unknown["mask_data"]["gt_resized_to_frame"] is True
|
||||
|
||||
@@ -2,10 +2,11 @@ def test_login_success(client):
|
||||
response = client.post("/api/auth/login", json={"username": "admin", "password": "123456"})
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {
|
||||
"token": "fake-jwt-token-for-admin",
|
||||
"username": "admin",
|
||||
}
|
||||
body = response.json()
|
||||
assert body["token"]
|
||||
assert body["token_type"] == "bearer"
|
||||
assert body["username"] == "admin"
|
||||
assert body["user"]["username"] == "admin"
|
||||
|
||||
|
||||
def test_login_rejects_invalid_credentials(client):
|
||||
@@ -13,3 +14,19 @@ def test_login_rejects_invalid_credentials(client):
|
||||
|
||||
assert response.status_code == 401
|
||||
assert response.json()["detail"] == "Invalid credentials"
|
||||
|
||||
|
||||
def test_me_returns_current_user(client):
|
||||
response = client.get("/api/auth/me")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["username"] == "admin"
|
||||
|
||||
|
||||
def test_business_routes_require_auth(app):
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
with TestClient(app) as unauthenticated:
|
||||
response = unauthenticated.get("/api/projects")
|
||||
|
||||
assert response.status_code == 401
|
||||
|
||||
@@ -1,19 +1,31 @@
|
||||
import zipfile
|
||||
import json
|
||||
from io import BytesIO
|
||||
from urllib.parse import unquote
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
|
||||
def _fake_image_bytes(width=100, height=50, color=(255, 255, 255)):
|
||||
image = np.full((height, width, 3), color, dtype=np.uint8)
|
||||
_, encoded = cv2.imencode(".jpg", image)
|
||||
return encoded.tobytes()
|
||||
|
||||
|
||||
def _seed_export_data(client):
|
||||
project = client.post("/api/projects", json={"name": "Export Project"}).json()
|
||||
project = client.post("/api/projects", json={
|
||||
"name": "Export Project",
|
||||
"video_path": "uploads/1/clip.mp4",
|
||||
}).json()
|
||||
frame = client.post(f"/api/projects/{project['id']}/frames", json={
|
||||
"project_id": project["id"],
|
||||
"frame_index": 0,
|
||||
"image_url": "frames/0.jpg",
|
||||
"width": 100,
|
||||
"height": 50,
|
||||
"timestamp_ms": 1250.0,
|
||||
"source_frame_number": 37,
|
||||
}).json()
|
||||
template = client.post("/api/templates", json={
|
||||
"name": "Category",
|
||||
@@ -113,6 +125,328 @@ def test_export_masks_uses_z_index_for_semantic_fusion(client):
|
||||
assert semantic[10, 10] == high_value
|
||||
|
||||
|
||||
def test_export_results_zip_contains_coco_original_images_and_selected_mask_outputs(client, monkeypatch):
|
||||
project, _, _, annotation = _seed_export_data(client)
|
||||
monkeypatch.setattr("routers.export.download_file", lambda object_name: _fake_image_bytes())
|
||||
|
||||
response = client.get(f"/api/export/{project['id']}/results?scope=all&mask_type=both")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.headers["content-type"].startswith("application/zip")
|
||||
with zipfile.ZipFile(BytesIO(response.content)) as archive:
|
||||
names = archive.namelist()
|
||||
frame_stem = "clip_0h00m01s250ms_frame000001"
|
||||
assert "annotations_coco.json" in names
|
||||
assert "maskid_GT像素值_类别映射.json" in names
|
||||
assert f"原始图片/{frame_stem}.jpg" in names
|
||||
assert f"分开Mask分割结果/{frame_stem}_分别导出/{frame_stem}_Category_maskid1.png" in names
|
||||
assert f"GT_label图/{frame_stem}.png" in names
|
||||
assert f"Pro_label彩色分割结果/{frame_stem}.png" in names
|
||||
assert f"Mix_label重叠覆盖彩色分割结果/{frame_stem}.png" in names
|
||||
coco = json.loads(archive.read("annotations_coco.json"))
|
||||
mapping = json.loads(archive.read("maskid_GT像素值_类别映射.json"))
|
||||
label_bytes = np.frombuffer(archive.read(f"GT_label图/{frame_stem}.png"), dtype=np.uint8)
|
||||
gt_label = cv2.imdecode(label_bytes, cv2.IMREAD_UNCHANGED)
|
||||
pro_label = cv2.imdecode(
|
||||
np.frombuffer(archive.read(f"Pro_label彩色分割结果/{frame_stem}.png"), dtype=np.uint8),
|
||||
cv2.IMREAD_COLOR,
|
||||
)
|
||||
mix_label = cv2.imdecode(
|
||||
np.frombuffer(archive.read(f"Mix_label重叠覆盖彩色分割结果/{frame_stem}.png"), dtype=np.uint8),
|
||||
cv2.IMREAD_COLOR,
|
||||
)
|
||||
|
||||
assert coco["images"][0]["frame_index"] == 0
|
||||
assert coco["annotations"][0]["image_id"] == annotation["frame_id"]
|
||||
assert mapping["classes"] == [{
|
||||
"gt_pixel_value": 1,
|
||||
"maskid": 1,
|
||||
"chineseName": "Category",
|
||||
"className": "Category",
|
||||
"categoryName": "Category",
|
||||
"rgb": [6, 182, 212],
|
||||
"color": "#06b6d4",
|
||||
"key": f"template:{annotation['template_id']}",
|
||||
"template_id": annotation["template_id"],
|
||||
}]
|
||||
assert gt_label[0, 0] == 0
|
||||
assert gt_label[20, 50] == 1
|
||||
assert pro_label[20, 50].tolist() == [212, 182, 6]
|
||||
assert pro_label[0, 0].tolist() == [0, 0, 0]
|
||||
assert mix_label[20, 50].tolist() != [255, 255, 255]
|
||||
|
||||
|
||||
def test_export_results_uses_internal_layer_order_for_gt_pro_and_mix_outputs(client, monkeypatch):
|
||||
monkeypatch.setattr("routers.export.download_file", lambda object_name: _fake_image_bytes(20, 20))
|
||||
project = client.post("/api/projects", json={
|
||||
"name": "Layered Export Project",
|
||||
"video_path": "uploads/2/layered.mp4",
|
||||
}).json()
|
||||
frame = client.post(f"/api/projects/{project['id']}/frames", json={
|
||||
"project_id": project["id"],
|
||||
"frame_index": 0,
|
||||
"image_url": "frames/layered.jpg",
|
||||
"width": 20,
|
||||
"height": 20,
|
||||
"timestamp_ms": 0,
|
||||
"source_frame_number": 0,
|
||||
}).json()
|
||||
client.post("/api/ai/annotate", json={
|
||||
"project_id": project["id"],
|
||||
"frame_id": frame["id"],
|
||||
"mask_data": {
|
||||
"polygons": [[[0.1, 0.1], [0.8, 0.1], [0.8, 0.8], [0.1, 0.8]]],
|
||||
"label": "Low",
|
||||
"color": "#00ff00",
|
||||
"class": {"id": "low", "name": "Low", "color": "#00ff00", "zIndex": 10},
|
||||
},
|
||||
})
|
||||
client.post("/api/ai/annotate", json={
|
||||
"project_id": project["id"],
|
||||
"frame_id": frame["id"],
|
||||
"mask_data": {
|
||||
"polygons": [[[0.4, 0.4], [0.9, 0.4], [0.9, 0.9], [0.4, 0.9]]],
|
||||
"label": "High",
|
||||
"color": "#ff0000",
|
||||
"class": {"id": "high", "name": "High", "color": "#ff0000", "zIndex": 20},
|
||||
},
|
||||
})
|
||||
|
||||
response = client.get(
|
||||
f"/api/export/{project['id']}/results?scope=all&outputs=gt_label,pro_label,mix_label&mix_opacity=0.5",
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
with zipfile.ZipFile(BytesIO(response.content)) as archive:
|
||||
mapping = json.loads(archive.read("maskid_GT像素值_类别映射.json"))
|
||||
high_value = next(item["maskid"] for item in mapping["classes"] if item["key"] == "class:high")
|
||||
stem = "layered_0h00m00s000ms_frame000001"
|
||||
gt_label = cv2.imdecode(
|
||||
np.frombuffer(archive.read(f"GT_label图/{stem}.png"), dtype=np.uint8),
|
||||
cv2.IMREAD_UNCHANGED,
|
||||
)
|
||||
pro_label = cv2.imdecode(
|
||||
np.frombuffer(archive.read(f"Pro_label彩色分割结果/{stem}.png"), dtype=np.uint8),
|
||||
cv2.IMREAD_COLOR,
|
||||
)
|
||||
mix_label = cv2.imdecode(
|
||||
np.frombuffer(archive.read(f"Mix_label重叠覆盖彩色分割结果/{stem}.png"), dtype=np.uint8),
|
||||
cv2.IMREAD_COLOR,
|
||||
)
|
||||
|
||||
assert gt_label[10, 10] == high_value
|
||||
assert pro_label[10, 10].tolist() == [0, 0, 255]
|
||||
assert mix_label[10, 10].tolist() == [127, 127, 255]
|
||||
|
||||
|
||||
def test_export_results_supports_range_and_current_scope(client, monkeypatch):
|
||||
monkeypatch.setattr("routers.export.download_file", lambda object_name: _fake_image_bytes(20, 20))
|
||||
project = client.post("/api/projects", json={
|
||||
"name": "Scoped Export Project",
|
||||
"video_path": "uploads/9/scope.mp4",
|
||||
"parse_fps": 2,
|
||||
}).json()
|
||||
template = client.post("/api/templates", json={
|
||||
"name": "Scoped Category",
|
||||
"color": "#06b6d4",
|
||||
"z_index": 0,
|
||||
"classes": [],
|
||||
"rules": [],
|
||||
}).json()
|
||||
frames = []
|
||||
annotations = []
|
||||
for idx in range(3):
|
||||
frame = client.post(f"/api/projects/{project['id']}/frames", json={
|
||||
"project_id": project["id"],
|
||||
"frame_index": idx,
|
||||
"image_url": f"frames/{idx}.jpg",
|
||||
"width": 20,
|
||||
"height": 20,
|
||||
"timestamp_ms": idx * 500.0,
|
||||
"source_frame_number": idx * 10,
|
||||
}).json()
|
||||
frames.append(frame)
|
||||
annotations.append(client.post("/api/ai/annotate", json={
|
||||
"project_id": project["id"],
|
||||
"frame_id": frame["id"],
|
||||
"template_id": template["id"],
|
||||
"mask_data": {"polygons": [[[0.1, 0.1], [0.8, 0.1], [0.8, 0.8], [0.1, 0.8]]]},
|
||||
}).json())
|
||||
|
||||
range_response = client.get(
|
||||
f"/api/export/{project['id']}/results?scope=range&start_frame=2&end_frame=3&mask_type=gt_label",
|
||||
)
|
||||
current_response = client.get(
|
||||
f"/api/export/{project['id']}/results?scope=current&frame_id={frames[1]['id']}&mask_type=separate",
|
||||
)
|
||||
|
||||
assert range_response.status_code == 200
|
||||
assert "Scoped_Export_Project_seg_T_0h00m00s500ms-0h00m01s000ms_P_2-3.zip" in unquote(
|
||||
range_response.headers["content-disposition"],
|
||||
)
|
||||
with zipfile.ZipFile(BytesIO(range_response.content)) as archive:
|
||||
names = archive.namelist()
|
||||
coco = json.loads(archive.read("annotations_coco.json"))
|
||||
assert "原始图片/scope_0h00m00s500ms_frame000002.jpg" in names
|
||||
assert "原始图片/scope_0h00m01s000ms_frame000003.jpg" in names
|
||||
assert "原始图片/scope_0h00m00s000ms_frame000001.jpg" not in names
|
||||
assert "GT_label图/scope_0h00m00s500ms_frame000002.png" in names
|
||||
assert "GT_label图/scope_0h00m01s000ms_frame000003.png" in names
|
||||
assert "GT_label图/scope_0h00m00s000ms_frame000001.png" not in names
|
||||
assert not any(name.startswith("分开Mask分割结果/") for name in names)
|
||||
assert not any(name.startswith("Pro_label彩色分割结果/") for name in names)
|
||||
assert not any(name.startswith("Mix_label重叠覆盖彩色分割结果/") for name in names)
|
||||
assert [image["frame_index"] for image in coco["images"]] == [1, 2]
|
||||
|
||||
assert current_response.status_code == 200
|
||||
with zipfile.ZipFile(BytesIO(current_response.content)) as archive:
|
||||
names = archive.namelist()
|
||||
coco = json.loads(archive.read("annotations_coco.json"))
|
||||
current_stem = "scope_0h00m00s500ms_frame000002"
|
||||
assert f"原始图片/{current_stem}.jpg" in names
|
||||
assert f"分开Mask分割结果/{current_stem}_分别导出/{current_stem}_Scoped_Category_maskid1.png" in names
|
||||
assert f"分开Mask分割结果/scope_0h00m00s000ms_frame000001_分别导出/scope_0h00m00s000ms_frame000001_Scoped_Category_maskid1.png" not in names
|
||||
assert not any(name.startswith("GT_label图/") for name in names)
|
||||
assert not any(name.startswith("Pro_label彩色分割结果/") for name in names)
|
||||
assert not any(name.startswith("Mix_label重叠覆盖彩色分割结果/") for name in names)
|
||||
assert [image["id"] for image in coco["images"]] == [frames[1]["id"]]
|
||||
|
||||
|
||||
def test_export_results_preserves_template_maskid_consistently_across_frames(client, monkeypatch):
|
||||
monkeypatch.setattr("routers.export.download_file", lambda object_name: _fake_image_bytes(20, 20))
|
||||
project = client.post("/api/projects", json={
|
||||
"name": "MaskId Export Project",
|
||||
"video_path": "uploads/8/maskid-demo.mp4",
|
||||
"parse_fps": 1,
|
||||
}).json()
|
||||
frames = []
|
||||
for idx in range(2):
|
||||
frames.append(client.post(f"/api/projects/{project['id']}/frames", json={
|
||||
"project_id": project["id"],
|
||||
"frame_index": idx,
|
||||
"image_url": f"frames/{idx}.jpg",
|
||||
"width": 20,
|
||||
"height": 20,
|
||||
"timestamp_ms": idx * 1000.0,
|
||||
"source_frame_number": idx,
|
||||
}).json())
|
||||
client.post("/api/ai/annotate", json={
|
||||
"project_id": project["id"],
|
||||
"frame_id": frames[-1]["id"],
|
||||
"mask_data": {
|
||||
"polygons": [[[0.1, 0.1], [0.8, 0.1], [0.8, 0.8], [0.1, 0.8]]],
|
||||
"label": "Tumor",
|
||||
"color": "#ff0000",
|
||||
"class": {"id": "tumor", "name": "Tumor", "color": "#ff0000", "maskId": 7, "zIndex": 30},
|
||||
},
|
||||
})
|
||||
|
||||
response = client.get(f"/api/export/{project['id']}/results?scope=all&mask_type=both")
|
||||
|
||||
assert response.status_code == 200
|
||||
with zipfile.ZipFile(BytesIO(response.content)) as archive:
|
||||
names = archive.namelist()
|
||||
mapping = json.loads(archive.read("maskid_GT像素值_类别映射.json"))
|
||||
first_stem = "maskid-demo_0h00m00s000ms_frame000001"
|
||||
second_stem = "maskid-demo_0h00m01s000ms_frame000002"
|
||||
assert f"分开Mask分割结果/{first_stem}_分别导出/{first_stem}_Tumor_maskid7.png" in names
|
||||
assert f"分开Mask分割结果/{second_stem}_分别导出/{second_stem}_Tumor_maskid7.png" in names
|
||||
first_label = cv2.imdecode(np.frombuffer(archive.read(f"GT_label图/{first_stem}.png"), dtype=np.uint8), cv2.IMREAD_UNCHANGED)
|
||||
second_label = cv2.imdecode(np.frombuffer(archive.read(f"GT_label图/{second_stem}.png"), dtype=np.uint8), cv2.IMREAD_UNCHANGED)
|
||||
|
||||
assert mapping["classes"] == [{
|
||||
"gt_pixel_value": 7,
|
||||
"maskid": 7,
|
||||
"chineseName": "Tumor",
|
||||
"className": "Tumor",
|
||||
"categoryName": "",
|
||||
"rgb": [255, 0, 0],
|
||||
"color": "#ff0000",
|
||||
"key": "class:tumor",
|
||||
"template_id": None,
|
||||
}]
|
||||
assert first_label[5, 5] == 7
|
||||
assert second_label[5, 5] == 7
|
||||
|
||||
|
||||
def test_exported_gtlabel_round_trips_through_gt_mask_import_with_template_maskid(client, monkeypatch):
|
||||
monkeypatch.setattr("routers.export.download_file", lambda object_name: _fake_image_bytes(20, 20))
|
||||
project = client.post("/api/projects", json={
|
||||
"name": "GT Roundtrip Project",
|
||||
"video_path": "uploads/8/roundtrip.mp4",
|
||||
}).json()
|
||||
template = client.post("/api/templates", json={
|
||||
"name": "Roundtrip Template",
|
||||
"color": "#06b6d4",
|
||||
"z_index": 0,
|
||||
"classes": [
|
||||
{"id": "tumor", "name": "Tumor", "color": "#ff0000", "zIndex": 30, "maskId": 7},
|
||||
],
|
||||
"rules": [],
|
||||
}).json()
|
||||
source_frame = client.post(f"/api/projects/{project['id']}/frames", json={
|
||||
"project_id": project["id"],
|
||||
"frame_index": 0,
|
||||
"image_url": "frames/source.jpg",
|
||||
"width": 20,
|
||||
"height": 20,
|
||||
"timestamp_ms": 0,
|
||||
}).json()
|
||||
target_frame = client.post(f"/api/projects/{project['id']}/frames", json={
|
||||
"project_id": project["id"],
|
||||
"frame_index": 1,
|
||||
"image_url": "frames/target.jpg",
|
||||
"width": 20,
|
||||
"height": 20,
|
||||
"timestamp_ms": 1000,
|
||||
}).json()
|
||||
client.post("/api/ai/annotate", json={
|
||||
"project_id": project["id"],
|
||||
"frame_id": source_frame["id"],
|
||||
"template_id": template["id"],
|
||||
"mask_data": {
|
||||
"polygons": [[[0.1, 0.1], [0.8, 0.1], [0.8, 0.8], [0.1, 0.8]]],
|
||||
"label": "Tumor",
|
||||
"color": "#ff0000",
|
||||
"class": {"id": "tumor", "name": "Tumor", "color": "#ff0000", "maskId": 7, "zIndex": 30},
|
||||
},
|
||||
})
|
||||
|
||||
export_response = client.get(
|
||||
f"/api/export/{project['id']}/results?scope=current&frame_id={source_frame['id']}&outputs=gt_label",
|
||||
)
|
||||
|
||||
assert export_response.status_code == 200
|
||||
with zipfile.ZipFile(BytesIO(export_response.content)) as archive:
|
||||
stem = "roundtrip_0h00m00s000ms_frame000001"
|
||||
exported_gt_label = archive.read(f"GT_label图/{stem}.png")
|
||||
gt_label = cv2.imdecode(np.frombuffer(exported_gt_label, dtype=np.uint8), cv2.IMREAD_UNCHANGED)
|
||||
mapping = json.loads(archive.read("maskid_GT像素值_类别映射.json"))
|
||||
|
||||
assert gt_label[5, 5] == 7
|
||||
assert mapping["classes"][0]["maskid"] == 7
|
||||
|
||||
import_response = client.post(
|
||||
"/api/ai/import-gt-mask",
|
||||
data={
|
||||
"project_id": str(project["id"]),
|
||||
"frame_id": str(target_frame["id"]),
|
||||
"template_id": str(template["id"]),
|
||||
"unknown_color_policy": "discard",
|
||||
},
|
||||
files={"file": ("exported_gt_label.png", exported_gt_label, "image/png")},
|
||||
)
|
||||
|
||||
assert import_response.status_code == 201
|
||||
imported = import_response.json()
|
||||
assert len(imported) == 1
|
||||
assert imported[0]["frame_id"] == target_frame["id"]
|
||||
assert imported[0]["mask_data"]["gt_label_value"] == 7
|
||||
assert imported[0]["mask_data"]["label"] == "Tumor"
|
||||
assert imported[0]["mask_data"]["class"]["maskId"] == 7
|
||||
|
||||
|
||||
def test_export_missing_project_returns_404(client):
|
||||
assert client.get("/api/export/999/coco").status_code == 404
|
||||
assert client.get("/api/export/999/masks").status_code == 404
|
||||
assert client.get("/api/export/999/results").status_code == 404
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from models import Annotation, Frame, Mask, ProcessingTask, Project
|
||||
from models import Annotation, Frame, Mask, ProcessingTask, Project, User
|
||||
from routers.auth import create_access_token, hash_password
|
||||
|
||||
|
||||
def test_project_crud_and_frames(client, monkeypatch):
|
||||
@@ -93,3 +94,33 @@ def test_project_and_frame_404s(client):
|
||||
}).status_code == 404
|
||||
assert client.get("/api/projects/999/frames").status_code == 404
|
||||
assert client.get("/api/projects/999/frames/1").status_code == 404
|
||||
|
||||
|
||||
def test_projects_are_scoped_to_authenticated_owner(client, db_session):
|
||||
owner_project = client.post("/api/projects", json={"name": "Owner Project"}).json()
|
||||
other_user = User(
|
||||
username="other",
|
||||
password_hash=hash_password("pass"),
|
||||
role="annotator",
|
||||
is_active=1,
|
||||
)
|
||||
db_session.add(other_user)
|
||||
db_session.commit()
|
||||
db_session.refresh(other_user)
|
||||
other_project = Project(name="Other Project", owner_user_id=other_user.id)
|
||||
db_session.add(other_project)
|
||||
db_session.commit()
|
||||
db_session.refresh(other_project)
|
||||
|
||||
listing = client.get("/api/projects")
|
||||
assert [project["id"] for project in listing.json()] == [owner_project["id"]]
|
||||
assert client.get(f"/api/projects/{other_project.id}").status_code == 404
|
||||
|
||||
original_auth = client.headers["Authorization"]
|
||||
client.headers.update({"Authorization": f"Bearer {create_access_token(other_user)}"})
|
||||
try:
|
||||
other_listing = client.get("/api/projects")
|
||||
assert [project["id"] for project in other_listing.json()] == [other_project.id]
|
||||
assert client.get(f"/api/projects/{owner_project['id']}").status_code == 404
|
||||
finally:
|
||||
client.headers.update({"Authorization": original_auth})
|
||||
|
||||
@@ -39,19 +39,19 @@ Word 方案描述的理想系统包含:
|
||||
| DICOM 批量导入 | 部分落地 | 上传和解析存在,项目级体验还需完善 |
|
||||
| WebSocket 进度 | 已落地 | 拆帧进度写入任务表后发布到 Redis `seg:progress`,FastAPI 广播到 `/ws/progress` |
|
||||
| SAM 推理 | 部分落地 | 当前产品入口启用 SAM 2.1 tiny/small/base+/large 和真实 GPU/SAM2.1 状态接口;SAM 2.1 已接 point/box/interactive 和 video predictor 片段传播。SAM 3 桥接源码保留,但前端入口和后端 registry 已禁用 |
|
||||
| 模板库 | 部分落地 | 分类、颜色、z-index 能存储和编辑;PNG mask 导出时会按 zIndex 做语义融合裁决,前端预览裁决尚未落地 |
|
||||
| 标注持久化 | 部分落地 | 后端有 `Annotation` 表,前端已接入新增、回显、分类更新、当前帧删除、手工绘制、GT mask 导入、seed point 编辑、polygon 顶点拖动/删除、边中点插点和多 polygon 子区域编辑;复杂洞结构编辑未落地 |
|
||||
| COCO / Mask 导出 | 已落地基础能力 | `backend/routers/export.py`;COCO JSON 和 PNG mask ZIP 前端按钮均已接入,ZIP 包含单标注 mask、语义融合 mask 和类别映射 |
|
||||
| 模板库 | 部分落地 | 分类、颜色、maskid 和拖拽排序能存储和编辑;右侧语义分类树也可拖拽调整内部覆盖顺序;PNG mask 导出时会按内部优先级做语义融合裁决,前端预览裁决尚未落地 |
|
||||
| 标注持久化 | 部分落地 | 后端有 `Annotation` 表,前端已接入新增、回显、分类更新、传播链前后帧同目标同步换类、当前帧删除、手工绘制、GT mask 导入、seed point 编辑、polygon 顶点拖动/删除、边中点插点和多 polygon 子区域编辑;复杂洞结构编辑未落地 |
|
||||
| COCO / Mask 导出 | 已落地基础能力 | `backend/routers/export.py`;COCO JSON、兼容 PNG mask ZIP 和统一分割结果 ZIP 均已接入;统一 ZIP 包含 maskid/GT 像素值映射、原始图片、按帧/类别合并的分开 mask、GT_label 黑白图、Pro_label 彩色图和 Mix_label 原图叠加图;GT_label 像素值使用类别真实 maskid,缺失 maskid 的旧标注才补下一个可用正整数 |
|
||||
|
||||
## 当前代码尚未落地的目标
|
||||
|
||||
- SAM 3:`sam3_engine.py`、`sam3_external_worker.py` 和 `setup_sam3_env.sh` 作为历史实现保留;由于当前系统不给文本提示,前端不再展示 SAM 3,后端 registry 也不暴露 `sam3`。官方没有 SAM 3 tiny/small 权重,当前可选最小真实 SAM 权重仍是 SAM 2.1 tiny。
|
||||
- Celery 异步任务队列:已注册 Celery app 和拆帧 worker task,`/api/media/parse` 会创建任务表记录并入队。
|
||||
- GT mask 导入:当前已支持二值/多类别 mask 导入,后端会按非零像素值拆分区域,生成 polygon 标注和距离变换 seed point;骨架提取、HDBSCAN 和模板自动映射尚未实现。
|
||||
- GT mask 导入:当前已支持二值 mask、灰度/16-bit GT_label 图和 RGB 三通道完全相同的 `[X,X,X]` maskid 图导入,后端会按 maskid 拆分区域,生成 polygon 标注和距离变换 seed point;超出现有类别的 maskid 可舍弃或导入为未定义类别;普通彩色类别图会被拒绝,尺寸不一致会自动最近邻拉伸到当前帧;骨架提取、HDBSCAN 和更复杂的模板自动映射尚未实现。
|
||||
- Mask 到点区域的拓扑降维:当前完成 distance transform seed point 和前端 seed point 拖拽编辑;骨架提取、HDBSCAN 等增强尚未实现。
|
||||
- 类别优先级融合:PNG mask 导出时已按 zIndex 生成语义融合 mask;前端裁决预览尚未实现。
|
||||
- 类别优先级融合:PNG mask 导出时已按内部优先级生成语义融合 mask;前端裁决预览尚未实现。
|
||||
- 撤销/重做:当前已有全局 mask 历史栈。
|
||||
- 结构化归档保存:工作区按钮已调用 `POST /api/ai/annotate` 保存当前未归档 mask,并通过 `PATCH /api/ai/annotations/{id}` 更新 dirty mask。
|
||||
- 保存状态按钮:工作区按钮按待保存数量显示“保存 X 个改动”或“已全部保存”,并调用 `POST /api/ai/annotate` 保存当前未归档 mask,通过 `PATCH /api/ai/annotations/{id}` 更新 dirty mask。
|
||||
|
||||
## 结论
|
||||
|
||||
|
||||
@@ -37,6 +37,7 @@
|
||||
| `workspace` | `VideoWorkspace` | 分割工作区 |
|
||||
| `ai` | `AISegmentation` | AI 智能分割页 |
|
||||
| `templates` | `TemplateRegistry` | 模板库 |
|
||||
| `admin` | `UserAdmin` | 管理员用户后台,仅 `role=admin` 可见 |
|
||||
|
||||
未登录时,`App.tsx` 直接渲染 `Login`。
|
||||
|
||||
@@ -44,14 +45,14 @@
|
||||
|
||||
全局状态在 `src/store/useStore.ts` 中,主要包括:
|
||||
|
||||
- 登录状态:`isAuthenticated`、`token`
|
||||
- 登录状态:`isAuthenticated`、`token`、`currentUser`
|
||||
- 项目:`projects`、`currentProject`
|
||||
- 工作区:`activeModule`、`activeTool`、`frames`、`currentFrameIndex`
|
||||
- 标注与 mask:`annotations`、`masks`
|
||||
- 模板:`templates`、`activeTemplateId`
|
||||
- UI:`isLoading`、`error`
|
||||
|
||||
当前状态管理是前端内存状态,没有持久化到 localStorage,除了登录 token。
|
||||
当前状态管理主要是前端内存状态;登录 token 会持久化到 `localStorage`,刷新后再通过 `/api/auth/me` 恢复当前用户。
|
||||
|
||||
## 数据流
|
||||
|
||||
@@ -59,8 +60,18 @@
|
||||
|
||||
1. `Login.tsx` 调用 `login()`。
|
||||
2. `src/lib/api.ts` 请求 `POST /api/auth/login`。
|
||||
3. FastAPI `backend/routers/auth.py` 校验 `admin / 123456`。
|
||||
4. 前端把返回 token 写入 localStorage。
|
||||
3. FastAPI `backend/routers/auth.py` 查询 `users` 表并校验密码哈希。
|
||||
4. 前端把返回 JWT 写入 localStorage,并把用户资料写入 store。
|
||||
5. 后续业务请求带 `Authorization: Bearer <token>`,后端按当前用户过滤项目资源。
|
||||
6. `admin/annotator` 可调用写入类业务接口,`viewer` 只能读取;`/api/admin/*` 仅允许 `admin`。
|
||||
|
||||
### 管理员用户管理
|
||||
|
||||
1. `Sidebar.tsx` 仅对 `currentUser.role === 'admin'` 显示“用户管理”。
|
||||
2. `UserAdmin.tsx` 调用 `GET/POST/PATCH/DELETE /api/admin/users` 完成用户新增、停用/启用、角色修改、改密码和删除无项目用户。
|
||||
3. `UserAdmin.tsx` 调用 `GET /api/admin/audit-logs` 展示登录成功/失败以及用户管理操作审计。
|
||||
4. `UserAdmin.tsx` 危险区“恢复演示出厂设置”需要浏览器确认和输入 `RESET_DEMO_FACTORY`,随后调用 `POST /api/admin/demo-factory-reset`。
|
||||
5. 后端 `backend/routers/admin.py` 会阻止管理员删除、停用或降级自己,并阻止删除仍拥有项目的用户;演示出厂重置会清空其它用户、项目帧、标注、任务和私有模板,重新创建一个尚未生成帧的 `Data_MyVideo_1` 视频项目。
|
||||
|
||||
### 项目与拆帧
|
||||
|
||||
@@ -100,6 +111,6 @@
|
||||
|
||||
- 前端 API/WS 地址虽然已支持环境变量和 hostname 推导,但部署时仍需要确认浏览器可访问 `:8000` 后端。
|
||||
- AI 当前启用 SAM 2.1 tiny/small/base+/large 点/框/interactive 路径;语义文本提示和 SAM 3 产品入口已禁用,`model=sam3` 会被后端拒绝。SAM 3 源码保留但不计入当前可用功能。
|
||||
- 工作区顶部“导出 JSON 标注集”“导出 PNG Mask ZIP”“导入 GT Mask”和“结构化归档保存”已接入导出、GT 多类别导入、seed point 回显/编辑、标注新增和 dirty 标注更新;清空当前帧遮罩会删除对应后端标注。手工绘制、polygon 顶点拖动/删除、区域合并/去除和撤销重做已经落到前端 mask 数据结构。
|
||||
- 工作区顶部“分割结果导出”和保存状态按钮、左侧工具栏“导入 GT Mask”已接入统一导出、GT 多类别导入、seed point 回显/编辑、标注新增和 dirty 标注更新;导入 GT Mask 支持二值 mask、低数值/16-bit GT_label 图和 RGB 三通道完全相同的 `[X,X,X]` maskid 图,未知 maskid 可由用户选择舍弃或导入为未定义类别,普通彩色类别图会被拒绝,尺寸不同会自动最近邻拉伸到当前帧。保存状态按钮会按待保存数量显示“保存 X 个改动”或“已全部保存”;统一导出可选择整体视频、特定范围帧或当前图片,并勾选分开 mask、GT_label 黑白图、Pro_label 彩色图和 Mix_label 原图叠加图;特定范围帧导出支持直接输入起止帧,也支持在播放进度条或视频处理进度条上点击/拖拽选择范围;Mix_label 支持默认 0.3 的透明度调节和首帧预览;后端统一导出 ZIP 固定包含 maskid/GT 像素值映射 JSON 与原始图片文件夹,GT_label 像素值使用类别真实 maskid,缺失 maskid 的旧标注才补下一个可用正整数,并按客户命名规则输出分开 Mask、GT_label、Pro_label 和 Mix_label 文件夹;清空当前帧遮罩会删除对应后端标注。手工绘制、polygon 顶点拖动/删除、区域合并/去除和撤销重做已经落到前端 mask 数据结构。
|
||||
- Dashboard 初始统计、队列和活动日志来自后端聚合接口;解析队列来自 `processing_tasks`,worker 进度通过 Redis `seg:progress` 转发到 WebSocket。任务取消、重试和失败详情已接入前后端。
|
||||
- 后端路由大多未做真实鉴权。
|
||||
- 后端已接入 Bearer JWT 鉴权、当前用户项目隔离和角色权限;写入类业务接口要求 `admin/annotator`,管理员用户后台要求 `admin`。当前审计覆盖登录和用户管理操作,全业务级审计仍可继续扩展。
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
| 元素 | 位置 | 状态 | 说明 |
|
||||
|------|------|------|------|
|
||||
| 登录拦截 | `App.tsx` | 真实可用 | 未登录显示 `Login`,登录后显示主界面 |
|
||||
| 模块切换 | `Sidebar.tsx` + `App.tsx` | 真实可用 | 切换 `dashboard/projects/workspace/ai/templates` |
|
||||
| 模块切换 | `Sidebar.tsx` + `App.tsx` | 真实可用 | 切换 `dashboard/projects/workspace/ai/templates`;“AI智能分割”入口使用 Bot + Sparkles 组合图标,强化 AI 语义 |
|
||||
| Logo | `Sidebar.tsx` | 真实可用 | 使用 `/logo.png`,文件存在于 `public/logo.png` |
|
||||
| GPU 状态圆标 | `Sidebar.tsx` | 真实可用 | 通过 `GET /api/ai/models/status` 显示 GPU/CPU 和当前模型可用性 |
|
||||
|
||||
@@ -20,16 +20,29 @@
|
||||
|
||||
| 元素 | 状态 | 说明 |
|
||||
|------|------|------|
|
||||
| 用户名/密码输入 | 真实可用 | 默认填入 `admin / 123456` |
|
||||
| 安全登录按钮 | 真实可用 | 调用 `POST /api/auth/login` |
|
||||
| 用户名/密码输入 | 真实可用 | 默认填入 `admin / 123456`,用户名使用 `autocomplete=username`,密码使用 `autocomplete=current-password` |
|
||||
| 安全登录按钮 | 真实可用 | 调用 `POST /api/auth/login`,后端校验 `users` 表密码哈希并返回签名 JWT |
|
||||
| 错误提示 | 真实可用 | 捕获后端错误并显示 |
|
||||
| 安全审计说明文字 | Mock / UI-only | UI 文案,没有真实审计功能 |
|
||||
| 登录态恢复 / 退出 | 真实可用 | 页面刷新后用 `/api/auth/me` 恢复当前用户;侧栏底部显示当前用户名并可退出登录 |
|
||||
| 安全审计说明文字 | 部分可用 | 登录和用户管理操作已有 `audit_logs` 记录;登录页“端到端加密”等安全文案仍是展示性说明,不代表已接入完整企业级安全审计 |
|
||||
|
||||
## 管理员用户后台
|
||||
|
||||
| 元素 | 状态 | 说明 |
|
||||
|------|------|------|
|
||||
| 侧栏“用户管理”入口 | 真实可用 | 仅当前用户 `role=admin` 时显示;非管理员无法看到入口,后端 `/api/admin/*` 也会返回 403 |
|
||||
| 用户列表 | 真实可用 | 调用 `GET /api/admin/users`,展示用户 id、用户名、角色、启停用状态和创建时间 |
|
||||
| 新增用户 | 真实可用 | 调用 `POST /api/admin/users`,支持设置用户名、初始密码和 `admin/annotator/viewer` 角色;后端校验用户名唯一和密码长度 |
|
||||
| 修改角色 / 启停用 / 改密码 | 真实可用 | 调用 `PATCH /api/admin/users/{id}`;后端禁止管理员把自己降级或停用,避免锁死后台 |
|
||||
| 删除用户 | 真实可用 | 调用 `DELETE /api/admin/users/{id}`;后端禁止删除自己,且用户名下仍有项目时返回 409,避免悬空项目数据 |
|
||||
| 审计日志 | 真实可用 | 调用 `GET /api/admin/audit-logs`,展示登录成功/失败、用户新增、修改和删除等管理操作 |
|
||||
| 恢复演示出厂设置 | 真实可用 | 管理员点击危险区按钮后先浏览器确认,再输入 `RESET_DEMO_FACTORY`;前端调用 `POST /api/admin/demo-factory-reset`,后端只保留默认 admin 与一个尚未生成帧的演示视频项目,并清空用户、项目帧、标注、任务和私有模板等演示数据 |
|
||||
|
||||
## Dashboard 系统概况
|
||||
|
||||
| 元素 | 状态 | 说明 |
|
||||
|------|------|------|
|
||||
| WebSocket 连接状态 | 真实可用 | 前端通过 `src/lib/config.ts` 推导或读取 `VITE_WS_PROGRESS_URL`,后端有 `/ws/progress` |
|
||||
| WebSocket 连接状态 | 真实可用 | 前端通过 `src/lib/config.ts` 推导或读取 `VITE_WS_PROGRESS_URL`,后端有 `/ws/progress`;Dashboard 卸载或切页导致的主动断开不会触发自动重连,也不会继续输出“Connection closed”噪音 |
|
||||
| 任务进度 | 真实可用 | 初始数据来自 `GET /api/dashboard/overview`,按 `processing_tasks` queued/running/success/failed/cancelled 任务生成;统计卡片中的处理中任务数只计算 queued/running |
|
||||
| 任务取消 | 真实可用 | queued/running 任务显示取消按钮,调用 `POST /api/tasks/{task_id}/cancel` |
|
||||
| 任务重试 | 真实可用 | failed/cancelled 任务显示重试按钮,调用 `POST /api/tasks/{task_id}/retry` 创建新任务 |
|
||||
@@ -65,12 +78,11 @@
|
||||
| 无帧项目提示 | 真实可用 | 如果 `video_path` 存在但无帧,只提示回到项目库生成帧,不自动创建拆帧任务 |
|
||||
| SAM 模型状态徽标 | 真实可用 | 工作区顶栏使用紧凑 GPU/CPU 状态徽标,避免和旁边的“传播权重”下拉重复显示 SAM 2.1 变体名称;悬停仍可查看模型状态说明 |
|
||||
| 已保存标注回显 | 真实可用 | 加载工作区帧后调用 `GET /api/ai/annotations` 并渲染已保存 mask;回显时保留当前项目帧里尚未保存的 AI/手工 draft mask,避免从 AI 页推送的候选被覆盖 |
|
||||
| “导出 JSON 标注集”按钮 | 真实可用 | 导出前会保存未归档 mask,然后调用 `exportCoco()` 下载 JSON |
|
||||
| “导出 PNG Mask ZIP”按钮 | 真实可用 | 导出前会保存未归档 mask,然后调用 `GET /api/export/{project_id}/masks` 下载 ZIP;后端同时包含单标注 mask、每帧语义融合 mask 和 `semantic_classes.json` |
|
||||
| “导入 GT Mask”按钮 | 真实可用 | 选择图片后调用 `POST /api/ai/import-gt-mask`,后端按非零像素值和连通域生成 polygon 标注与距离变换 seed point,再回显到工作区 |
|
||||
| 参考帧/起止帧/传播权重/自动传播 | 真实可用 | 当前打开帧即参考帧,前端会使用该帧全部 mask 作为 seed;工作区顶栏有独立“传播权重”下拉,可在传播前二次选择 SAM 2.1 tiny/small/base+/large 权重,不提供 SAM2/SAM3 家族切换,不影响 AI 智能分割页的单帧推理权重选择;传播权重下拉使用深色背景和青色文字,避免默认灰底白字不可读;如果用户尚未显式设置范围,点击“自动传播”会先进入时间轴范围选择模式,播放进度条和视频处理进度条都可点击/拖拽回填传播起始帧和传播结束帧,再点击“开始传播”提交;用户也可直接改数字框后点击按钮传播。提交后前端把传播权重 id、seed mask、seed 来源 id、未编辑传播结果的原始 seed 签名、mask 边缘平滑参数和前/后方向步骤提交到 `POST /api/ai/propagate/task`,后端先规范化/校验权重 id,再创建 `processing_tasks` 并由 Celery 执行对应 SAM 2.1 video predictor;worker 会在本次目标帧段内按 seed 来源和几何/语义/平滑参数签名做幂等判断,未改变且目标帧已有结果的 seed 直接跳过,已改变、目标帧只部分覆盖、换权重或平滑参数变化时会先删除本次目标帧段内同源旧自动传播标注再重新传播;中间帧人工新增/修改同一物体后重新传播时,后端会按语义和目标帧空间重叠清理旧传播结果,写入前清理不受旧结果 `propagation_direction` 限制,避免 backward 重传时与旧 forward mask 重叠;传播中顶栏显示任务进度、已处理帧次、删除旧区域数和已保存区域数,前端轮询 `GET /api/tasks/{task_id}` 并刷新已保存标注;任务可取消,若完成后 0 个新区域会明确提示没有生成新 mask 或已跳过未改变 mask |
|
||||
| 清空片段遮罩 | 真实可用 | 点击“清空片段遮罩”后会进入和自动传播一致的时间轴范围选择模式,用户可在播放进度条或视频处理进度条上点击/拖拽选择起止帧,再点“确认清空”;执行后删除该帧段内所有本地 draft mask,并对已保存 mask 调用 `DELETE /api/ai/annotations/{annotation_id}` 删除后端标注;不在范围内的 mask 和选区会保留 |
|
||||
| “结构化归档保存”按钮 | 真实可用 | 未保存 mask 写入 `POST /api/ai/annotate`;dirty mask 写入 `PATCH /api/ai/annotations/{id}`;保存成功后会重新拉取后端标注,并用 saved annotation 替换本次提交的 draft mask,避免仍显示未保存 |
|
||||
| “分割结果导出”按钮 | 真实可用 | 原“导出 JSON 标注集”和“导出 PNG Mask ZIP”已合并为一个入口;点击后可选择整体视频、特定范围帧或当前图片,默认导出范围为当前图片,并勾选导出分开二值 mask、GT_label 黑白图、Pro_label 彩色图和 Mix_label 原图叠加图;选择“特定范围帧”后会进入和自动传播、清空遮罩一致的时间轴范围选择模式,可在播放进度条或视频处理进度条上点击/拖拽选择导出起止帧,也可直接修改起止帧输入框;选择 Mix_label 时可调透明度,默认 0.3,并显示当前/待导出第一帧预览;提交前会保存未归档 mask,然后调用 `GET /api/export/{project_id}/results` 下载 ZIP;浏览器下载名和后端 `Content-Disposition` 均使用 `{项目库项目名}_seg_T_{起始时间戳}-{结束时间戳}_P_{起始项目帧序号}-{结束项目帧序号}.zip`;时间戳格式为 `0h00m00s000ms`,帧序号来自项目抽帧后的 1-based 顺序,不使用原视频帧号;包内固定包含 `annotations_coco.json`、`maskid_GT像素值_类别映射.json` 和 `原始图片/`;选择分开 mask 时包含按帧子目录组织且同类合并的 `分开Mask分割结果/`,选择 GT_label/Pro_label/Mix_label 时分别包含 `GT_label图/`、`Pro_label彩色分割结果/`、`Mix_label重叠覆盖彩色分割结果/`。GT_label 图中背景为 0,语义类别值使用类别真实 maskid,缺失 maskid 的旧标注才补下一个可用正整数 |
|
||||
| “导入 GT Mask”按钮 | 真实可用 | 入口已从工作区顶栏移动到左侧工具栏“重叠区域去除”之后,使用紫色图标底色;选择图片后先弹出导入结果预览和未知 maskid 策略选择,可舍弃未知类别或导入为未定义类别;随后调用 `POST /api/ai/import-gt-mask`,后端支持二值 mask、低数值/16-bit GT_label 图和 RGB 三通道完全相同的 `[X,X,X]` maskid 图,不符合灰度/maskid 图要求时返回错误;尺寸不同会自动最近邻拉伸到当前帧,再按类别/连通域生成 polygon 标注与距离变换 seed point,最后回显到工作区 |
|
||||
| 参考帧/起止帧/传播权重/自动传播 | 真实可用 | 当前打开帧即参考帧,前端会使用该帧全部 mask 作为 seed;工作区顶栏有独立“传播权重”下拉,可在传播前二次选择 SAM 2.1 tiny/small/base+/large 权重,不提供 SAM2/SAM3 家族切换,不影响 AI 智能分割页的单帧推理权重选择;传播权重下拉使用深色背景和青色文字,避免默认灰底白字不可读;如果用户尚未显式设置范围,点击“自动传播”会先进入时间轴范围选择模式,播放进度条和视频处理进度条都可点击/拖拽回填传播起始帧和传播结束帧,再点击“开始传播”提交;用户也可直接改数字框后点击按钮传播。提交后前端把传播权重 id、seed mask、seed 来源 id、未编辑传播结果的原始 seed 签名和前/后方向步骤提交到 `POST /api/ai/propagate/task`,后端先规范化/校验权重 id,再创建 `processing_tasks` 并由 Celery 执行对应 SAM 2.1 video predictor;worker 会在本次目标帧段内按 seed 来源和几何/语义签名做幂等判断,未改变且目标帧已有结果的 seed 直接跳过,已改变、目标帧只部分覆盖或换权重时会先删除本次目标帧段内同源旧自动传播标注再重新传播;历史或外部 seed 若仍带边缘平滑参数,后端仍按完整签名兼容处理;当前前端平滑应用会直接改写 polygon,因此传播以新几何参与签名;中间帧人工新增/修改同一物体后重新传播时,后端会按语义和目标帧空间重叠清理旧传播结果,写入前清理不受旧结果 `propagation_direction` 限制,避免 backward 重传时与旧 forward mask 重叠;传播中顶栏显示任务进度、已处理帧次、删除旧区域数和已保存区域数,前端轮询 `GET /api/tasks/{task_id}` 并刷新已保存标注;任务可取消,若完成后 0 个新区域会明确提示没有生成新 mask 或已跳过未改变 mask |
|
||||
| 清空片段遮罩 | 真实可用 | 点击“清空片段遮罩”后会进入和自动传播一致的时间轴范围选择模式,用户可在播放进度条或视频处理进度条上点击/拖拽选择起止帧;顶栏提供“清空全部”和“保留人工/AI”两种模式,默认清空全部以保持旧行为;“清空全部”会删除该帧段内所有本地 draft mask,并对已保存 mask 调用 `DELETE /api/ai/annotations/{annotation_id}`,若范围内存在人工绘制或 AI 智能分割生成的红色“人工/AI 标注帧”会先弹出确认;“保留人工/AI”只删除自动传播/推理 mask,不弹出人工帧确认,人工/AI 标注帧、范围外 mask 和未被清空的选区会保留;同时按清空范围裁剪当前会话的自动传播历史条,避免已清空片段仍显示最近传播进度 |
|
||||
| 保存状态按钮 | 真实可用 | 顶栏按钮按当前项目待保存数量显示为“保存 X 个改动”或“已全部保存”;未保存 mask 写入 `POST /api/ai/annotate`,dirty mask 写入 `PATCH /api/ai/annotations/{id}`;保存成功后会重新拉取后端标注,并用 saved annotation 替换本次提交的 draft mask,避免仍显示未保存 |
|
||||
|
||||
## CanvasArea 画布
|
||||
|
||||
@@ -83,14 +95,14 @@
|
||||
| 正向/反向选点 | 真实可用 | UI 能加点,并按当前帧 `frame.id` 调用 `/api/ai/predict`;结果需点击归档保存才持久化 |
|
||||
| 框选 | 真实可用 | UI 能画框,并把框坐标归一化后调用后端推理;结果需点击归档保存才持久化 |
|
||||
| AI 推理中提示 | 真实可用 | 请求期间会显示 |
|
||||
| 手工多边形/矩形/圆/点/线 | 真实可用 | 多边形点击取点后可按 Enter 完成,也可在三点后点击首节点闭合;矩形/圆/线拖拽生成 polygon;点工具生成小区域;绘制工具可在已有 mask 上继续落点;均写入 `Mask.segmentation`,可归档保存 |
|
||||
| 画布上下文提示 | 真实可用 | 切换到多边形、矩形、圆、线、点、正/反向选点、框选、区域合并/去除、调整多边形等隐性操作工具时,画布左上角显示当前工具的完成/取消/选择顺序提示;提示会在数秒后自动隐藏,避免长期遮挡待编辑图像,工具或操作状态变化时会重新出现 |
|
||||
| Mask 渲染 | 真实可用 | 前端会把推理、手工绘制、GT 导入和已保存标注转成 Konva `pathData` 渲染 |
|
||||
| Mask 透明度 | 真实可用 | 右侧语义分类树上方的“遮罩透明度”滑杆写入全局 `maskPreviewOpacity`,Canvas 使用该值调整所有工作区 mask 预览透明度,选中 mask 会在该基础上略微加亮 |
|
||||
| 手工多边形/矩形/圆/画笔/橡皮擦 | 真实可用 | 多边形点击取点后可按 Enter 完成,也可在三点后点击首节点闭合;矩形/圆拖拽生成 polygon;画笔按当前语义分类生成连续圆形笔触并在松开时 union 成 mask,若与选中 mask 连通则自动合并;橡皮擦从选中 mask 中扣除笔触区域;均写入 `Mask.segmentation`,可归档保存 |
|
||||
| 画布上下文提示 | 真实可用 | 切换到多边形、矩形、圆、画笔、橡皮擦、区域合并/去除、调整多边形等隐性操作工具时,画布左上角显示当前工具的完成/取消/选择顺序提示;提示会在数秒后自动隐藏,避免长期遮挡待编辑图像,工具或操作状态变化时会重新出现 |
|
||||
| Mask 渲染 | 真实可用 | 前端会把推理、手工绘制、GT 导入和已保存标注转成 Konva `pathData` 渲染;未选中特定 mask 时,当前帧 mask 会按右侧“语义分类树”拖拽得到的内部覆盖优先级从低到高渲染,使高优先级类别显示在上层;有选中 mask 时保留编辑态置顶行为,方便操作 |
|
||||
| Mask 透明度 | 真实可用 | 右侧语义分类树上方的“遮罩透明度”滑杆写入全局 `maskPreviewOpacity`,工作区 Canvas 和 AI 智能分割页都会使用该值调整 mask 预览透明度,选中 mask 会在该基础上略微加亮 |
|
||||
| 传播链跨帧选区跟随 | 真实可用 | 用户选中某个 mask 后切到同一自动传播结果覆盖的其他帧时,`CanvasArea` 会根据 `source_annotation_id`、`source_mask_id` 和 `propagation_seed_key` 查找目标帧对应传播 mask 并自动选中;找不到同链结果时才清空选区 |
|
||||
| Polygon 逐点编辑 / 删除 | 真实可用 | 点击 mask 后显示 polygon 顶点;按住顶点即可直接拖动并实时重算 `pathData/segmentation/bbox/area`,不需要先单击选中顶点,已保存 mask 标为 dirty;顶点拖拽结束不会触发 Stage 平移,Canvas 当前缩放和位置保持不变;选中顶点后 Delete/Backspace 可删点但保留至少三点;选中 mask 但未选中顶点时 Delete/Backspace 删除整个 mask,已保存 mask 会同步调用后端删除 |
|
||||
| GT seed point 回显/编辑 | 真实可用 | 已保存标注的 `points` 会显示为黄色 seed 点;拖动后标记为 dirty,归档保存会更新后端 |
|
||||
| 应用分类 | 真实可用 | Canvas 右下角按钮可将当前选择的模板分类应用到本帧 mask;右侧语义分类树点击分类时会优先改当前已选 mask,并把已选 mask 移到前端渲染最上层方便继续编辑;已保存 mask 会标为 dirty,归档保存时更新后端 |
|
||||
| 应用分类 | 真实可用 | Canvas 右下角按钮可将当前选择的模板分类应用到本帧 mask,并同步同一传播链前后帧的对应 mask;右侧语义分类树点击分类时会优先改当前已选 mask,并通过 `source_annotation_id`、`source_mask_id` 和 `propagation_seed_key` 同步更新同一传播链上的前后传播 mask,同时把已选 mask 移到前端渲染最上层方便继续编辑;已保存 mask 会标为 dirty,归档保存时更新后端 |
|
||||
| 清空遮罩 | 真实可用 | 工作区中会删除当前帧已保存标注并清空当前帧本地 mask |
|
||||
| 保存状态计数 | 真实可用 | 底部显示已保存、未保存、待更新数量 |
|
||||
| 当前图层信息 | 真实可用 | 根据当前选中 mask 显示真实标签/后端 annotation id;未保存 mask 显示“未保存”,未选中时显示“未选择” |
|
||||
@@ -99,13 +111,16 @@
|
||||
|
||||
| 元素 | 状态 | 说明 |
|
||||
|------|------|------|
|
||||
| 工具分组分隔线 | 真实可用 | 拖拽/选择到创建圆为绘制/基础编辑组,画笔/橡皮擦/区域合并/重叠区域去除为局部修补与布尔编辑组,导入 GT Mask 和 AI 智能分割为外部动作组;组间使用浅灰横线分隔,便于快速扫视 |
|
||||
| 拖拽/选择 | 真实可用 | 控制 Canvas 是否可拖拽 |
|
||||
| 调整多边形 | 真实可用 | 选中 polygon mask 后显示顶点和边中点;支持按住顶点直接拖动、点击边中点插点、双击边界按位置插点 |
|
||||
| 多边形/矩形/圆/点/线 | 真实可用 | 切换 activeTool 后由 `CanvasArea` 生成可保存的 polygon mask |
|
||||
| 多边形/矩形/圆/画笔/橡皮擦 | 真实可用 | 切换 activeTool 后由 `CanvasArea` 生成或编辑可保存的 polygon mask;画笔/橡皮擦在工具栏显示尺寸滑杆 |
|
||||
| 区域合并/去除 | 真实可用 | 选择工具后点击多个 mask,右下角显示已选数量和操作按钮;合并/去除模式会隐藏 polygon 编辑手柄,避免手柄抢占多选点击;布尔选择态中第一个选中的主区域用黄色实线轮廓,后续参与合并/扣除的区域用红色虚线轮廓,避免主区域和扣除区域看起来像随机阴影差异;使用 `polygon-clipping` 做 union / difference;合并会保留主 mask 并移除被合并 mask,去除会从主 mask 扣除后续选中 mask;内含扣除会保留 hole ring 并用 even-odd 规则渲染 |
|
||||
| 正向选点/反向选点/框选 | 部分可用 | 会影响 Canvas 交互,并能触发已对齐的 AI 推理接口;点击工作区内已有 SAM 提示点会优先删除该提示点并重新推理,不会冒泡成新增提示点或 mask 选择 |
|
||||
| 导入 GT Mask | 真实可用 | 位于“重叠区域去除”之后,点击后打开文件选择器,并在上传前选择未知类别处理策略;该入口不切换 activeTool |
|
||||
| 魔法棒 SAM 触发 | 部分可用 | 切到 AI 页面;不是直接执行推理 |
|
||||
| 撤销/重做 | 真实可用 | 绑定 Zustand `maskHistory/maskFuture`,支持工作区顶栏按钮、工具栏按钮、AI 页按钮和快捷键 `Ctrl/Cmd+Z`、`Ctrl/Cmd+Shift+Z`、`Ctrl/Cmd+Y`;输入框聚焦时不拦截快捷键 |
|
||||
| AI 正向选点/反向选点/框选 | 不在工作区工具栏显示 | 这些是 AI 智能分割页功能,工作区左侧工具栏不再提供正向选点、反向选点和边界框选按钮 |
|
||||
| AI 智能分割入口 | 真实可用 | 位于工作区工具栏底部,使用和侧栏一致的 Bot + Sparkles 组合图标;点击后切到 AI 智能分割页 |
|
||||
| 撤销/重做 | 真实可用 | 绑定 Zustand `maskHistory/maskFuture`,工作区只保留顶栏按钮和快捷键 `Ctrl/Cmd+Z`、`Ctrl/Cmd+Shift+Z`、`Ctrl/Cmd+Y`,AI 页保留自己的按钮;左侧工具栏不再重复放置撤销/重做;输入框聚焦时不拦截快捷键;工作区顶栏撤销图标使用琥珀色、重做图标使用蓝紫色,提高深色顶栏里的识别度 |
|
||||
| 紧凑/滚动布局 | 真实可用 | 工具按钮使用较紧凑的垂直间距;左侧高度不足时工具栏自身出现纵向滚动,不挤压画布;外层工具栏扩展到 56px,按钮列仍固定 48px,滚动条占用右侧外扩空间,不挤占图标位置;滚动条使用 `seg-scrollbar`,默认低对比融入深色工具区,hover/focus 时才增强为青色提示 |
|
||||
|
||||
## FrameTimeline 时间轴
|
||||
@@ -116,7 +131,7 @@
|
||||
| 点击缩略图跳帧 | 真实可用 | 调用 `setCurrentFrame(idx)`;非当前帧中,人工/AI 标注帧使用红色边框,自动传播/推理帧使用蓝色边框;同一帧同时有人工/AI 标注和自动传播结果时,红色标注边框优先保留,蓝色传播状态以内描边表达;当前帧仍用青色外框高亮优先,若当前帧同时是人工/AI 标注帧,则在青色外框内增加红色内描边,固定为外层当前帧、内层人工/AI 标注,避免状态颜色互相覆盖 |
|
||||
| 顶部 range 拖动 | 真实可用 | 改变当前帧 |
|
||||
| 具体时间显示 | 真实可用 | 根据项目 `parse_fps/original_fps` 显示当前时间和总时长,格式为 `mm:ss.cc` |
|
||||
| 播放进度条 / 视频处理进度条 | 真实可用 | 播放进度条位于上方,视频处理进度条位于下方;视频处理进度条普通状态下可点击跳转到对应帧;根据已保存标注回显的 `mask_data.source` / `propagated_from_frame_id` 识别自动传播生成的帧并显示蓝色区段,人工绘制或 AI 智能分割生成的帧显示红色竖线,红/蓝标识也可点击跳转到对应帧;每次自动传播成功处理帧后,工作区会在当前会话记录最近传播范围,并在视频处理进度条上叠加不同色系的深到浅渐变片段,辅助识别最近处理过的视频区间;未处理背景使用中性灰以和红/蓝/渐变标记区分;工作区进入自动传播或清空片段遮罩的范围选择模式时,两条进度条显示 amber 选区,并可点击/拖拽选择起止帧 |
|
||||
| 播放进度条 / 视频处理进度条 | 真实可用 | 播放进度条位于上方,视频处理进度条位于下方;当前帧位置用一条白色竖线贯穿两条进度条,避免和青色播放进度、红/蓝处理状态混淆;视频处理进度条普通状态下可点击跳转到对应帧;根据已保存标注回显的 `mask_data.source`、`propagated_from_frame_id`、`source_annotation_id`、`source_mask_id` 或 `propagation_seed_key` 识别自动传播生成的帧并显示蓝色区段,人工绘制或 AI 智能分割生成的帧显示红色竖线,红/蓝标识也可点击跳转到对应帧;每次自动传播成功处理帧后,工作区会在当前会话记录最近传播范围,并在视频处理进度条上叠加同一蓝色系的纯色片段,按距最新传播的时间顺序逐次变暗,且第 5 次及更早统一为阈值旧记录色,辅助识别第一次、第二次、第 N 次传播;清空片段遮罩会同步移除或裁剪与清空范围重叠的传播历史片段;未处理背景使用中性灰以和红/蓝/传播历史标记区分;工作区进入自动传播或清空片段遮罩的范围选择模式时,两条进度条显示 amber 选区,并额外用洋红色起始线和黄绿色结束线贯穿两条进度条,表示待处理起止帧,颜色避开附近的青色、红色、蓝色和 amber 元素 |
|
||||
| 播放/暂停 | 真实可用 | 当前代码按 `parse_fps/original_fps` 推进帧,最多 30fps |
|
||||
| 方向键切帧 | 真实可用 | 全局监听左右方向键切到上一帧/下一帧;焦点在 input、textarea、select 或 contentEditable 内时不会拦截 |
|
||||
|
||||
@@ -127,12 +142,12 @@
|
||||
| 模板选择 | 部分可用 | 读取全局 templates,可切换 activeTemplateId |
|
||||
| 面板滚动条 | 真实可用 | 右侧本体/语义分类面板内容过长时自身滚动;滚动条使用 `seg-scrollbar`,默认低对比融入深色侧栏,hover/focus 时才增强显示 |
|
||||
| 面板标题 | 已简化 | 原“本体论与属性分类管理树”固定说明栏已移除,右侧面板直接展示模板、透明度和语义分类树 |
|
||||
| 分类树展示 / 换标签 | 真实可用 | 显示当前模板 classes;点击分类会设为后续新 mask 的 activeClass,如果 Canvas 已选 mask,则同步更新已选 mask 的标签、颜色和 class 元数据,并把已选 mask 移到前端渲染最上层;当用户在 Canvas 点击已有 mask 时,本面板会按 mask 的 class id / 名称自动切换模板、设置 active class,并滚动/聚焦到对应分类按钮 |
|
||||
| 分类树展示 / 换标签 | 真实可用 | 显示当前模板 classes;点击分类会设为后续新 mask 的 activeClass,如果 Canvas 已选 mask,则同步更新已选 mask 及同一传播链前后帧对应 mask 的标签、颜色和 class 元数据,并把已选 mask 移到前端渲染最上层;当用户在 Canvas 点击已有 mask 时,本面板会按 mask 的 class id / 名称自动切换模板、设置 active class,并滚动/聚焦到对应分类按钮 |
|
||||
| 添加自定义分类 | 真实可用 | 需要先选择模板;新增分类通过 `PATCH /api/templates/{id}` 写入后端模板 `mapping_rules.classes`,并同步全局模板 store |
|
||||
| 目标实例属性标题 | 真实可用 | “特定目标实例属性追踪”下方显示当前选中 mask 的 `className/label`,不再跟随全局 active class,避免点过其他分类后标题固定成旧分类 |
|
||||
| 后端拓扑锚点数量 | 真实可用 | 选中 mask 后调用 `POST /api/ai/analyze-mask`,由后端根据 seed points 或 polygon 顶点采样返回锚点数量 |
|
||||
| 重新提取拓扑锚点按钮 | 真实可用 | 调用 `POST /api/ai/analyze-mask` 并带 `extract_skeleton=true`,刷新后端几何锚点统计 |
|
||||
| 边缘平滑强度 / 应用边缘平滑 | 真实可用 | 选中 mask 后可调整 0-100 的平滑强度并调用 `POST /api/ai/smooth-mask`;后端用 Chaikin polygon smoothing 返回新 polygon、bbox、面积和拓扑锚点,前端把当前 mask 标记为 dirty/draft,保存后 `geometry_smoothing` 写入标注 metadata;自动传播 seed 会携带同一参数,前/后传播结果保存前应用一致平滑 |
|
||||
| 当前选中区域计数 | 已移除 | 当前交互以单选 mask 为主,计数长期为 1,属于低价值信息,已从实例属性面板删除 |
|
||||
| 后端拓扑锚点数量 | 真实可用 | 选中 mask 后调用 `POST /api/ai/analyze-mask`,后端按 polygon 的真实顶点数量返回 `topology_anchor_count`;`topology_anchors` 列表只保留最多 64 个抽样点用于调试展示,避免把真实数量误压成十几个;前端会忽略被浏览器中止或已过期的分析请求,避免切换 mask、拖动平滑预览或卸载组件时出现误报 |
|
||||
| 边缘平滑强度 / 应用边缘平滑 | 真实可用 | 选中 mask 后调整 0-100 平滑强度会先即时更新滑杆数值,再在用户停止拖动约 220ms 后调用 `POST /api/ai/smooth-mask` 生成预览 polygon,避免拖动时连续请求导致卡顿;预览会临时替换当前 mask 显示但不标 dirty;点“应用边缘平滑”后会把平滑 polygon 作为新的实际 mask 几何写入当前 mask 和同传播链前/后对应 mask,整次应用进入同一个撤销/重做历史步骤,并把相关 mask 标记为 dirty/draft;传播链上的 mask 保存时会保留原传播 lineage metadata,不会因为平滑几何同步而在时间轴上变成人工/AI 红色标注帧;应用后平滑强度重置为 0,后续可继续用“调整多边形”编辑新的 polygon;后端平滑使用缓入强度曲线,低强度只做温和切角和轻量去噪,高强度才逐步增加 Chaikin 迭代、切角比例和简化阈值,避免 20% 前后已经过度平滑 |
|
||||
|
||||
## AISegmentation 独立 AI 页
|
||||
|
||||
@@ -145,9 +160,9 @@
|
||||
| SAM 3 入口 | 当前禁用 | 因当前系统不提供文本提示,前端不再显示 SAM 3 模型选择、文本输入或 SAM 3 框选入口;后端 `model=sam3` 返回不支持 |
|
||||
| 语义文本输入 | 当前禁用 | AI 页不再提供文本语义输入;后端收到 `semantic` prompt 会返回 400 |
|
||||
| 参数开关 | 真实可用 | UI 展示为“局部专注模式(自动裁剪无锚区域)”和“严格除杂模式(自动清理干涉点)”,只是为了让用户更容易理解,不重命名内部字段;`cropMode` 会随 `/api/ai/predict` 发送 `crop_to_prompt`,后端对点/框 prompt 裁剪推理区域并回映射 polygon;`autoDeleteBg` 会发送 `auto_filter_background` 和 `min_score`,后端过滤低分结果和覆盖负向点的结果 |
|
||||
| 遮罩清晰度 | 真实可用 | 调节 AI 页候选 mask 的预览透明度,只影响本页显示,不改变 mask 几何、分类或保存数据 |
|
||||
| AI 遮罩透明度 | 真实可用 | 调节共享的 `maskPreviewOpacity`,AI 页候选 mask 和右侧“遮罩透明度”滑杆联动,只影响预览显示,不改变 mask 几何、分类或保存数据 |
|
||||
| 执行高精度语义分割 | 真实可用 | 使用当前项目帧和所选 SAM 2.1 变体调用 `/api/ai/predict`;SAM 2.1 需要点/框提示且只采用最高分候选;AI 页只渲染本页最新候选,不显示工作区已有 mask,重复执行会替换上一次 AI 页候选而不是叠加;生成结果写入全局 masks 并自动选中,右侧分类树可立即换标签 |
|
||||
| 推送至工作区编辑 | 真实可用 | 切回工作区并把工具切到“调整多边形”,保留 AI 页选中的未保存 mask 和当前帧视角;工作区回显后端标注时不会覆盖这类 draft mask,也不会强制跳回第一帧 |
|
||||
| 推送至工作区编辑 | 真实可用 | 切回工作区并把工具切到“调整多边形”,保留 AI 页选中的未保存 mask 和当前帧视角;推送前会校验当前 AI 候选 mask 必须已有 `classId` 或 `className`,未选择语义分类时会用右上角 error toast 提示用户先点右侧语义分类树,不允许进入工作区;如果用户直接离开 AI 页,未分类 AI 候选会被清理,避免无语义 mask 进入工作区;工作区回显后端标注时不会覆盖这类 draft mask,也不会强制跳回第一帧 |
|
||||
| 撤销/重做 | 真实可用 | 绑定全局 mask 历史栈 |
|
||||
| 删除最近锚点 | 真实可用 | 删除 AI 页最近一次放置的正/反向提示点,不影响已生成候选 mask 或工作区 mask |
|
||||
| 删除选中候选 | 真实可用 | 删除 AI 页当前选中的本页候选 mask;不会删除工作区已有 mask,Delete/Backspace 也遵循同一范围 |
|
||||
@@ -164,13 +179,13 @@
|
||||
| 编辑模板 | 真实可用 | 调用 `PATCH /api/templates/{id}` |
|
||||
| 删除模板 | 真实可用 | 调用 `DELETE /api/templates/{id}` |
|
||||
| 添加/删除分类 | 真实可用 | 保存在模板 `mapping_rules.classes` |
|
||||
| 拖拽排序 | 真实可用 | 重算 zIndex,保存时写后端 |
|
||||
| 拖拽排序 | 真实可用 | 模板库和工作区右侧语义分类树都可拖拽调整内部覆盖优先级,保存时写后端;工作区拖拽会同步当前同类 mask 的 `classZIndex` 并标记待保存;界面只显示类别稳定 maskid,maskid 不作为排序规范 |
|
||||
| JSON 批量导入 | 部分可用 | 前端解析 JSON 并加入编辑态,保存后才落库 |
|
||||
| 载入腹腔镜 35 分类 | 真实可用 | 前端内置数据;后端也 seed 默认模板 |
|
||||
| mapping rules | 部分可用 | 可存 `rules`,但当前没有运行时映射执行引擎;适合后续用于导入外部标签、别名归一化或跨数据集类别映射 |
|
||||
|
||||
## 总体结论
|
||||
|
||||
当前前端真实可用的主链路是:登录、Dashboard 后端概览、项目列表、新建项目、上传视频/DICOM、显式生成帧、浏览帧、播放帧、工作区手工绘制、点/框 AI 推理、视频片段传播、GT mask 导入、标注保存/回显、COCO 导出、PNG mask ZIP 导出、模板 CRUD。
|
||||
当前前端真实可用的主链路是:JWT 登录、刷新恢复用户、退出登录、Dashboard 当前用户概览、当前用户项目列表、新建项目、上传视频/DICOM、显式生成帧、浏览帧、播放帧、工作区手工绘制、点/框 AI 推理、视频片段传播、GT mask 导入、标注保存/回显、COCO 导出、PNG mask ZIP 导出、模板 CRUD。
|
||||
|
||||
当前最主要的 Mock 或未打通链路是:真正的文本语义分割已因无文本提示入口而暂时禁用;复杂洞结构编辑、骨架/HDBSCAN 级别的 mask 降维增强、任务历史筛选、项目更多菜单和 mapping rules 运行时映射执行引擎仍未落地。登录页“安全审计说明文字”仍只是 UI 文案。
|
||||
当前最主要的 Mock 或未打通链路是:真正的文本语义分割已因无文本提示入口而暂时禁用;复杂洞结构编辑、骨架/HDBSCAN 级别的 mask 降维增强、任务历史筛选、项目更多菜单、全业务操作审计和 mapping rules 运行时映射执行引擎仍未落地。登录页“端到端加密”等安全文案仍只是 UI 文案;登录和用户管理操作审计已落库并可在管理员后台查看。
|
||||
|
||||
@@ -15,13 +15,14 @@ timeout: 30000
|
||||
Authorization: Bearer <token>
|
||||
```
|
||||
|
||||
当前后端多数接口没有鉴权依赖,所以这个 header 主要是前端侧行为。
|
||||
当前后端业务接口会校验该 header。缺失、过期或无效 token 返回 401;项目、帧、标注、任务、Dashboard 和导出会按当前用户拥有的项目过滤。
|
||||
|
||||
## 前端封装的 API
|
||||
|
||||
| 函数 | 方法与路径 | 状态 | 说明 |
|
||||
|------|------------|------|------|
|
||||
| `login(username, password)` | `POST /api/auth/login` | 对齐 | 后端返回 `{ token, username }`,前端只使用 token |
|
||||
| `login(username, password)` | `POST /api/auth/login` | 对齐 | 后端返回 `{ token, token_type, username, user }`,前端保存 token 和当前用户 |
|
||||
| `getCurrentUser()` | `GET /api/auth/me` | 对齐 | 用已有 Bearer token 恢复当前登录用户 |
|
||||
| `getProjects()` | `GET /api/projects` | 对齐 | 前端映射 `frame_count`、`thumbnail_url` 等字段 |
|
||||
| `createProject(payload)` | `POST /api/projects` | 对齐 | 支持 `name`、`description`、`parse_fps` |
|
||||
| `updateProject(id, payload)` | `PATCH /api/projects/{id}` | 对齐 | 后端是 `PATCH /api/projects/{id}` |
|
||||
@@ -46,10 +47,11 @@ Authorization: Bearer <token>
|
||||
| `saveAnnotation(payload)` | `POST /api/ai/annotate` | 对齐 | 工作区归档保存当前项目未保存 mask |
|
||||
| `updateAnnotation(annotationId, payload)` | `PATCH /api/ai/annotations/{annotation_id}` | 对齐 | 工作区归档保存 dirty mask |
|
||||
| `deleteAnnotation(annotationId)` | `DELETE /api/ai/annotations/{annotation_id}` | 对齐 | 工作区清空当前帧已保存标注 |
|
||||
| `importGtMask(file, projectId, frameId, templateId?)` | `POST /api/ai/import-gt-mask` | 对齐 | multipart 上传 GT mask,后端按非零像素值/连通域生成 polygon 标注和 seed point |
|
||||
| `importGtMask(file, projectId, frameId, templateId?, options?)` | `POST /api/ai/import-gt-mask` | 对齐 | multipart 上传 GT mask;支持 `unknown_color_policy=discard/undefined`;后端仅接受灰度 maskid 图或 RGB 三通道完全相同的 `[X,X,X]` maskid 图,0 为背景、X 为 maskid;按模板 `maskId` 匹配类别,未知 maskid 可舍弃或导入为未定义类别;尺寸不同会最近邻拉伸到当前帧,连通域会生成 polygon 标注和 seed point |
|
||||
| `getDashboardOverview()` | `GET /api/dashboard/overview` | 对齐 | Dashboard 初始统计、队列和活动日志 |
|
||||
| `exportCoco(projectId)` | `GET /api/export/{projectId}/coco` | 对齐 | 后端实际是 `GET /api/export/{project_id}/coco` |
|
||||
| `exportMasks(projectId)` | `GET /api/export/{projectId}/masks` | 对齐 | 下载单标注 mask、语义融合 mask 和类别映射 ZIP |
|
||||
| `exportSegmentationResults(projectId, options)` | `GET /api/export/{projectId}/results` | 对齐 | 新的统一导出入口;支持 `scope=all/range/current`、`outputs=separate,gt_label,pro_label,mix_label`、`mix_opacity`、`start_frame/end_frame` 和 `frame_id` 参数,返回包含 COCO JSON、maskid/GT 像素值映射、原始帧图片和所选 mask PNG 的 ZIP;`mask_type=separate/gt_label/pro_label/mix_label/both` 仍兼容 |
|
||||
|
||||
## 后端 FastAPI 接口
|
||||
|
||||
@@ -58,6 +60,10 @@ Authorization: Bearer <token>
|
||||
| 方法 | 路径 | 用途 |
|
||||
|------|------|------|
|
||||
| POST | `/api/auth/login` | 登录 |
|
||||
| GET | `/api/auth/me` | 当前用户 |
|
||||
| GET/POST/PATCH/DELETE | `/api/admin/users` | 管理员用户管理 |
|
||||
| GET | `/api/admin/audit-logs` | 管理员审计日志 |
|
||||
| POST | `/api/admin/demo-factory-reset` | 演示部署恢复出厂设置;请求体需 `confirmation=RESET_DEMO_FACTORY` |
|
||||
| POST | `/api/projects` | 创建项目 |
|
||||
| GET | `/api/projects` | 项目列表 |
|
||||
| GET | `/api/projects/{project_id}` | 项目详情 |
|
||||
@@ -92,6 +98,7 @@ Authorization: Bearer <token>
|
||||
| GET | `/api/dashboard/overview` | Dashboard 聚合快照 |
|
||||
| GET | `/api/export/{project_id}/coco` | 导出 COCO JSON |
|
||||
| GET | `/api/export/{project_id}/masks` | 导出 PNG mask ZIP |
|
||||
| GET | `/api/export/{project_id}/results` | 统一导出分割结果 ZIP,包含 `annotations_coco.json`、`maskid_GT像素值_类别映射.json`、`原始图片/` 和按参数选择的 `分开Mask分割结果/`、`GT_label图/`、`Pro_label彩色分割结果/`、`Mix_label重叠覆盖彩色分割结果/`;GT_label 背景为 0,类别值使用模板中的真实 maskid,缺失 maskid 的旧标注才补下一个可用正整数 |
|
||||
| GET | `/health` | 健康检查 |
|
||||
| WS | `/ws/progress` | WebSocket 进度通道,未出现在 OpenAPI paths 中 |
|
||||
|
||||
@@ -163,6 +170,7 @@ POST /api/media/parse?project_id=1&parse_fps=15&max_frames=120&target_width=960
|
||||
"name": "胆囊",
|
||||
"color": "#ffae00",
|
||||
"zIndex": 280,
|
||||
"maskId": 1,
|
||||
"category": "腹腔镜胆囊切除术"
|
||||
}
|
||||
],
|
||||
@@ -250,7 +258,7 @@ SAM 2 点提示和 auto fallback 当前只采用最高分候选 mask,避免同
|
||||
"bbox": [0.1, 0.1, 0.2, 0.2],
|
||||
"label": "胆囊",
|
||||
"color": "#ff0000",
|
||||
"class_metadata": {"id": "c1", "name": "胆囊", "color": "#ff0000", "zIndex": 20},
|
||||
"class_metadata": {"id": "c1", "name": "胆囊", "color": "#ff0000", "zIndex": 20, "maskId": 1},
|
||||
"template_id": 2
|
||||
}
|
||||
}
|
||||
@@ -293,7 +301,7 @@ SAM 2.1 变体使用对应 video predictor 的 mask seed 传播;`model=sam2`
|
||||
- `getProjectAnnotations()` 已接入 `GET /api/ai/annotations`。
|
||||
- `updateAnnotation()` 已接入 `PATCH /api/ai/annotations/{annotationId}`。
|
||||
- `deleteAnnotation()` 已接入 `DELETE /api/ai/annotations/{annotationId}`。
|
||||
- `importGtMask()` 已接入 `POST /api/ai/import-gt-mask`,导入后端生成的 polygon 标注、原始 `gt_label_value` 和 seed point。
|
||||
- `importGtMask()` 已接入 `POST /api/ai/import-gt-mask`,导入后端生成的 polygon 标注、原始 `gt_label_value`、原图尺寸/是否拉伸信息和 seed point。导入端使用 `cv2.IMREAD_UNCHANGED` 保留低数值/16-bit GT_label 图的像素值;灰度图和 RGB 三通道相等图均按模板 `maskId` 匹配类别,不再按彩色图 RGB 颜色匹配类别;超出现有类别时由 `unknown_color_policy` 决定舍弃或写为 `gt_unknown_class` 未定义类别。
|
||||
- `exportMasks()` 已接入 `GET /api/export/{projectId}/masks`。
|
||||
- `parseMedia()` 已改为创建 Celery 后台任务,并返回 `ProcessingTask`。
|
||||
- `queuePropagationTask()` 已接入 `/api/ai/propagate/task`,自动传播不再依赖长时间同步 HTTP 请求。
|
||||
@@ -302,8 +310,9 @@ SAM 2.1 变体使用对应 video predictor 的 mask seed 传播;`model=sam2`
|
||||
- `retryTask()` 已接入 `POST /api/tasks/{taskId}/retry`。
|
||||
- `getDashboardOverview()` 已从 `processing_tasks` 聚合解析队列。
|
||||
- Dashboard 任务列表已展示 queued/running/success/failed/cancelled 任务,并可通过 `getTask()` 查看失败详情;`summary.parsing_task_count` 仍只统计 queued/running。
|
||||
- 工作区导出按钮已调用 `exportCoco()` / `exportMasks()`,并会先保存未归档 mask。
|
||||
- 工作区“分割结果导出”已调用 `exportSegmentationResults()`,并会先保存未归档 mask;旧的 `exportCoco()` / `exportMasks()` 仍保留为兼容接口。
|
||||
- PNG mask ZIP 已包含每帧 `semantic_frame_*.png` 和 `semantic_classes.json`,重叠区域按 zIndex 裁决。
|
||||
- 统一导出 ZIP 下载文件名为 `{项目库项目名}_seg_T_{起始时间戳}-{结束时间戳}_P_{起始项目帧序号}-{结束项目帧序号}.zip`;项目名来自 `Project.name` 并会替换文件系统不安全字符,时间戳来自帧 `timestamp_ms` 并格式化为 `0h00m00s000ms`,帧号使用项目抽帧后的 1-based `frame_index + 1`,不使用原视频 `source_frame_number`。ZIP 内包含 `annotations_coco.json`、`maskid_GT像素值_类别映射.json` 和 `原始图片/`。原始图片按 `视频名称_时间戳_项目帧序号` 命名;选择分开 mask 时写入 `分开Mask分割结果/{视频名称_时间戳_项目帧序号}_分别导出/{视频名称_时间戳_项目帧序号}_{类别名称}_maskid{maskid}.png`,同一帧同一类别会合并为一张二值 mask;选择 GT_label 图时写入 `GT_label图/{视频名称_时间戳_项目帧序号}.png`;选择 Pro_label 彩色图时写入 `Pro_label彩色分割结果/{视频名称_时间戳_项目帧序号}.png`;选择 Mix_label 叠加图时写入 `Mix_label重叠覆盖彩色分割结果/{视频名称_时间戳_项目帧序号}.png`,透明度由 `mix_opacity` 控制,默认 0.3。导出时 maskid 与 GT_label 像素值相同;有模板 maskid 的类别保留真实 maskid,缺失 maskid 的旧标注补下一个可用正整数并写入映射 JSON,跨图一致;maskid 不参与覆盖排序,覆盖顺序仍使用内部拖拽排序字段。
|
||||
|
||||
## 仍需处理的接口问题
|
||||
|
||||
|
||||
@@ -26,7 +26,7 @@
|
||||
已完成:
|
||||
|
||||
1. 前端根据 `Mask.segmentation` 构造后端需要的 normalized `mask_data.polygons`。
|
||||
2. 用户点击“结构化归档保存”后,未保存 mask 调用 `POST /api/ai/annotate`,dirty mask 调用 `PATCH /api/ai/annotations/{annotation_id}`。
|
||||
2. 用户点击顶栏保存状态按钮后,未保存 mask 调用 `POST /api/ai/annotate`,dirty mask 调用 `PATCH /api/ai/annotations/{annotation_id}`;按钮文案会按待保存数量显示“保存 X 个改动”或“已全部保存”。
|
||||
3. 后端保存或更新 `project_id`、`frame_id`、`template_id`、`mask_data`、`bbox`;具体分类写入 `mask_data.class`。
|
||||
4. 工作区加载帧后调用 `GET /api/ai/annotations` 回显已保存标注。
|
||||
5. 工作区“清空遮罩”调用 `DELETE /api/ai/annotations/{annotation_id}` 删除当前帧已保存标注。
|
||||
@@ -37,15 +37,16 @@
|
||||
2. 逐点几何编辑器已支持拖动/删除顶点、边中点插入新点和多 polygon 子区域编辑;后续增强为复杂洞结构编辑。
|
||||
3. 区域合并/去除已支持基础 union/difference;后续增强为更明确的多选列表、操作预览和冲突确认。
|
||||
|
||||
## 阶段 3:接入导出按钮(已完成 COCO JSON 和 PNG Mask ZIP)
|
||||
## 阶段 3:接入导出按钮(已完成统一分割结果导出)
|
||||
|
||||
当前工作区“导出 JSON 标注集”和“导出 PNG Mask ZIP”都会先保存未归档 mask,再调用后端导出接口。
|
||||
当前工作区“分割结果导出”会先保存未归档 mask,再调用后端统一结果导出接口。旧 COCO JSON 和 PNG Mask ZIP 接口保留为兼容路径。
|
||||
|
||||
已完成:
|
||||
|
||||
1. COCO JSON 调用 `/api/export/{projectId}/coco`。
|
||||
2. PNG Mask ZIP 调用 `/api/export/{projectId}/masks`。
|
||||
3. ZIP 内保留单标注二值 `mask_*.png`,同时输出 `semantic_frame_*.png` 和 `semantic_classes.json`。
|
||||
3. 兼容 PNG Mask ZIP 仍保留单标注二值 `mask_*.png`,同时输出 `semantic_frame_*.png` 和 `semantic_classes.json`。
|
||||
4. 统一导出调用 `/api/export/{projectId}/results`,支持整体视频、特定范围帧、当前图片三种范围,以及分开 mask、GT_label 黑白图、Pro_label 彩色图和 Mix_label 原图叠加图;ZIP 固定包含 maskid/GT 像素值映射 JSON 和原始图片文件夹,各输出文件夹按客户指定的 `视频名称_0h00m00s000ms_项目帧序号` 规则命名;GT_label 图背景为 0,类别值优先使用模板中的真实 maskid,缺失 maskid 的旧标注才补下一个可用正整数。
|
||||
|
||||
剩余建议:
|
||||
|
||||
@@ -98,7 +99,7 @@ Word 方案中的完整版本包含距离变换、骨架提取和聚类。当前
|
||||
|
||||
已完成:
|
||||
|
||||
1. 工作区提供“导入 GT Mask”入口。
|
||||
1. 工作区左侧工具栏提供“导入 GT Mask”入口,位置在“重叠区域去除”之后。
|
||||
2. 前端调用 `POST /api/ai/import-gt-mask` multipart 接口。
|
||||
3. 后端按非零像素值拆分多类别 mask。
|
||||
4. 后端使用 OpenCV contour 提取每个类别下的连通域。
|
||||
@@ -117,10 +118,10 @@ Word 方案中的完整版本包含距离变换、骨架提取和聚类。当前
|
||||
|
||||
已完成:
|
||||
|
||||
1. 标注保存时记录 template class id / name / zIndex。
|
||||
2. 导出 mask 时按 zIndex 从低到高覆盖。
|
||||
1. 标注保存时记录 template class id / name / maskid,并保留内部覆盖优先级。
|
||||
2. 导出 mask 时按内部优先级从低到高覆盖。
|
||||
3. 同类语义值在融合图中共享同一个 class value。
|
||||
4. 跨类重叠由高 zIndex 覆盖低 zIndex。
|
||||
4. 跨类重叠由更高内部优先级覆盖更低内部优先级;maskid 不作为排序规范。
|
||||
|
||||
剩余建议:
|
||||
|
||||
|
||||
@@ -7,15 +7,21 @@
|
||||
## R1 登录与会话
|
||||
|
||||
- 系统提供登录页。
|
||||
- 默认开发凭证为 `admin / 123456`。
|
||||
- 登录成功后前端保存 token,并进入主应用。
|
||||
- 默认开发管理员为启动时种子化的 `admin / 123456`,密码以哈希形式存入 `users` 表。
|
||||
- 登录成功后前端保存签名 JWT,并进入主应用。
|
||||
- 页面刷新后前端会用已有 token 调用 `/api/auth/me` 恢复当前用户。
|
||||
- 登录失败时显示错误信息。
|
||||
- 当前 token 是开发用固定 token,不做真实 JWT 校验。
|
||||
- 业务接口必须校验 Bearer token;缺失或无效 token 返回 401。
|
||||
- 项目、帧、标注、任务、Dashboard 和导出必须按当前用户的项目隔离;用户不能读取、修改或删除其他用户项目资源。
|
||||
- 角色包括 `admin`、`annotator` 和 `viewer`;`admin/annotator` 可写入业务数据和触发 AI/传播,`viewer` 只能访问读接口,用户管理后台仅 `admin` 可用。
|
||||
- 管理员侧栏显示“用户管理”入口;管理员可以新增用户、修改角色、停用/启用、修改密码、删除无项目用户。
|
||||
- 系统记录登录成功/失败和用户管理操作到 `audit_logs`,管理员后台可查看最近审计日志。
|
||||
- 管理员后台提供“恢复演示出厂设置”危险操作;前端必须二次确认,后端也必须校验 `confirmation=RESET_DEMO_FACTORY`,执行后只保留默认 admin 账号、系统模板和一个尚未生成帧的演示视频项目,清空其它用户、项目、帧、标注、任务、用户模板和旧审计记录,并写入本次重置审计。
|
||||
|
||||
## R2 项目管理
|
||||
|
||||
- 前端展示项目库,并从 `GET /api/projects` 获取项目列表。
|
||||
- 用户可以新建项目,前端调用 `POST /api/projects`。
|
||||
- 用户可以新建项目,前端调用 `POST /api/projects`;后端把项目归属到当前登录用户。
|
||||
- 用户可以选择项目,进入工作区。
|
||||
- 用户可以导入视频文件,前端创建项目、上传文件并刷新项目列表;导入视频不自动拆帧。
|
||||
- 用户可以对已导入且尚未生成帧的视频项目点击“生成帧”,在弹窗中选择目标 FPS 后创建拆帧任务。
|
||||
@@ -48,28 +54,34 @@
|
||||
- 若项目有媒体但无帧,工作区只提示需要先在项目库生成帧,不再自动触发拆帧。
|
||||
- Canvas 显示当前帧图片。
|
||||
- Canvas 支持滚轮缩放、移动工具拖拽、鼠标坐标显示。
|
||||
- Canvas 未选中特定 mask 时,mask 显示顺序必须遵循右侧“语义分类树”拖拽得到的内部覆盖优先级:低优先级先渲染,高优先级后渲染并显示在上层;选中 mask 后可以为了编辑交互临时置顶。
|
||||
- 时间轴支持缩略图点击切帧、range 拖动切帧、视频处理进度条点击切帧、人工/AI 标注帧和自动传播帧标识点击切帧、键盘左右方向键切帧、播放/暂停顺序推进帧。
|
||||
- 清空片段遮罩进入范围选择后必须提供两种模式:`清空全部` 会清空范围内所有 mask,若包含人工绘制或 AI 智能分割生成的“人工/AI 标注帧”必须弹出“是否清除“人工/AI标注帧”?”确认;`保留人工/AI` 只清空范围内自动传播/推理 mask,人工/AI 标注帧必须保留且不弹出人工帧确认;用户取消确认时不能删除本地 mask、后端标注或传播历史条。
|
||||
- 用户在某帧选中 mask 后,如果切换到同一自动传播结果覆盖的其他帧,工作区应自动识别并选中目标帧中对应的传播 mask;匹配依据为传播结果回显到 mask metadata 的 seed 来源和传播链字段,而不是仅凭标签或颜色。
|
||||
- 播放帧率使用项目 `parse_fps` 或 `original_fps`,限制在 1 到 30 FPS。
|
||||
- 时间轴显示当前帧时间和总时长,时间基准使用项目 `parse_fps` 或 `original_fps`,格式为 `mm:ss.cc`。
|
||||
- 时间轴顶部播放进度条只表达当前播放位置;其下方的视频处理进度条表达处理状态:人工绘制或 AI 智能分割生成的帧显示红色竖线,自动传播生成的帧显示蓝色区段,最近自动传播处理过的片段叠加不同色系的横向渐变条,片段内部随时间从深到浅,帮助识别最近处理范围;未处理背景使用中性灰以和标记保持明显区分。底部帧可视化栏中,人工/AI 标注帧缩略图边框为红色,自动传播/推理帧缩略图边框为蓝色,当前帧仍用青色外框高亮优先;如果同一帧既有人工/AI 标注又有自动传播结果,红色人工/AI 标注框优先保留,自动传播状态只作为蓝色内描边或次级提示;如果当前帧同时是人工/AI 标注帧,则显示青色外框加红色内描边,外层选中框和内层标注框顺序不能交换。
|
||||
- 时间轴顶部播放进度条只表达当前播放位置;其下方的视频处理进度条表达处理状态:当前帧位置用白色竖线贯穿播放进度条和视频处理进度条;人工绘制或 AI 智能分割生成的帧显示红色竖线,自动传播生成的帧显示蓝色区段,最近自动传播处理过的片段叠加同一蓝色系纯色条,按距最新传播的时间顺序逐次变暗,且第 5 次及更早统一为阈值旧记录色,帮助识别第一次、第二次、第 N 次传播;清空片段遮罩后,与清空范围重叠的最近传播历史条必须同步移除或裁剪,不应继续显示已经清空的传播范围;未处理背景使用中性灰以和标记保持明显区分。进入自动传播或清空遮罩范围选择时,起始帧和结束帧必须额外显示两条贯穿两条进度条的高对比边界线,颜色避开青色播放进度、红色标注、蓝色传播、amber 选区和深色背景。底部帧可视化栏中,人工/AI 标注帧缩略图边框为红色,自动传播/推理帧缩略图边框为蓝色,当前帧仍用青色外框高亮优先;如果同一帧既有人工/AI 标注又有自动传播结果,红色人工/AI 标注框优先保留,自动传播状态只作为蓝色内描边或次级提示;如果当前帧同时是人工/AI 标注帧,则显示青色外框加红色内描边,外层选中框和内层标注框顺序不能交换。
|
||||
- 自动传播提交前支持独立选择传播权重,范围限定为 SAM 2.1 tiny/small/base+/large 四个权重变体;该选择只影响传播任务,不提供 SAM2/SAM3 家族切换,也不改变 AI 智能分割页的单帧推理权重。
|
||||
|
||||
## R5 工具栏
|
||||
|
||||
- 工具栏可以切换当前 active tool。
|
||||
- 正向点、反向点、框选工具会影响 Canvas 交互。
|
||||
- 工作区左侧工具栏不展示正向点、反向点、框选工具;这些入口只属于 AI 智能分割页。
|
||||
- 侧栏“AI智能分割”和工作区工具栏 AI 跳转入口必须使用带明确 AI 语义的图标,而不是普通魔法棒等泛化工具图标。
|
||||
- 魔法棒按钮切换到 AI 页面。
|
||||
- 多边形、矩形、圆、点、线工具会在 Canvas 上生成可保存的 polygon mask。
|
||||
- 多边形通过点击取点并按 Enter 完成,也支持三点后点击首节点闭合;矩形、圆、线通过拖拽生成;点工具生成小点区域。
|
||||
- 多边形、矩形、圆、画笔、橡皮擦工具会在 Canvas 上生成或编辑可保存的 polygon mask;左侧工具栏不再提供创建点和创建线段入口。
|
||||
- 多边形通过点击取点并按 Enter 完成,也支持三点后点击首节点闭合;矩形、圆通过拖拽生成;画笔和橡皮擦支持调整大小。
|
||||
- 画笔工具只在语义分类树有选中类别时可用,按住拖动时以圆形笔触采样,鼠标松开后一次性 union 成连续区域;如果笔触与当前选中 mask 连通,默认合并到该 mask,否则生成新的当前类别 mask。
|
||||
- 橡皮擦工具只在当前帧已选中 mask 时可用,按住拖动时以圆形笔触采样,鼠标松开后从选中 mask 中 difference 扣除;扣空时删除该 mask,已保存 mask 仍需同步后端删除。
|
||||
- 创建多边形、创建矩形、区域合并/去除、调整多边形等 Canvas 左上角上下文提示只作为短提示,切换工具或操作状态变化时显示,数秒后自动隐藏,避免长期遮挡待编辑图像;再次切换工具或操作状态变化会重新显示。
|
||||
- 绘制工具点击已有 mask 时应继续执行当前绘制动作,不应被 mask 选择逻辑吞掉。
|
||||
- 工具栏提供“调整多边形”工具,用户可以点击 mask 进入 polygon 顶点编辑态;按住顶点即可直接拖动并实时更新 mask 几何,不需要先单击选中顶点,已保存 mask 会标记为 dirty;顶点和 seed point 等子节点拖拽不能冒泡成画布拖拽,编辑结束后 Canvas 当前缩放和平移视口必须保持不变。
|
||||
- 工具栏在“重叠区域去除”之后提供“导入 GT Mask”入口;该入口使用区别于普通编辑工具的紫色底色,不切换 activeTool。
|
||||
- 顶点编辑态显示边中点插入手柄;点击边中点会在该边中间新增顶点。
|
||||
- “调整多边形”工具下双击 polygon 边界时,会在最接近的线段上按双击位置新增顶点。
|
||||
- 顶点编辑态下选中顶点后可用 Delete/Backspace 删除顶点,但不会让 polygon 少于三点。
|
||||
- 选中整个 mask 且未选中具体顶点时,Delete/Backspace 删除该 mask;已保存 mask 同步调用后端删除接口。
|
||||
- 撤销、重做绑定全局 `maskHistory/maskFuture`,支持工具栏按钮、AI 页按钮和 Canvas 快捷键。
|
||||
- 撤销、重做绑定全局 `maskHistory/maskFuture`,工作区支持顶栏按钮和 Canvas 快捷键,AI 页支持自己的按钮;左侧工具栏不重复放置撤销/重做入口。
|
||||
- 区域合并工具支持多选当前帧 mask,并使用 polygon union 生成合并后的主 mask。
|
||||
- 区域去除工具支持多选当前帧 mask,并从第一个选中的主 mask 中扣除后续选中 mask。
|
||||
- 区域合并/去除模式显示已选数量,并隐藏 polygon 编辑手柄以避免手柄抢占多选点击;第一个选中的主区域使用黄色实线轮廓,后续参与合并/扣除的区域使用红色虚线轮廓。
|
||||
@@ -93,11 +105,13 @@
|
||||
- SAM 2.1 不支持文本语义提示;当前 AI 页面不提供文本语义输入,必须使用点/框提示。
|
||||
- SAM 2.1 点提示和 auto fallback 默认只采用一个最高分候选 mask,避免多个候选 mask 作为同一结果重叠显示。
|
||||
- AI 页面只渲染本页最新生成的候选 mask;重复执行高精度分割会替换上一次 AI 页候选,工作区已有手工、保存、传播或 GT 导入 mask 不会自动进入 AI 画布,也不会被替换。
|
||||
- AI 页面提供“遮罩清晰度”滑杆,调节本页候选 mask 的预览透明度,不改变 mask 几何、分类或保存数据。
|
||||
- AI 页面提供“AI 遮罩透明度”滑杆,并与右侧“遮罩透明度”共享 `maskPreviewOpacity`;调节任一入口都会改变 AI 候选 mask 预览透明度,不改变 mask 几何、分类或保存数据。
|
||||
- AI 页面参数开关展示文案使用“局部专注模式(自动裁剪无锚区域)”和“严格除杂模式(自动清理干涉点)”;这是 UI 可读性文案,不改变 `cropMode`、`autoDeleteBg` 或后端 `options` 字段。
|
||||
- AI 页面生成的 SAM 2.1 mask 会写入全局 `masks`,自动同步到当前项目帧,并写入全局 `selectedMaskIds`;右侧语义分类树可以直接给新生成 mask 换标签。
|
||||
- AI 页“清空全体锚点”只清空本页提示点和本页生成的候选 mask,不删除工作区已有 mask。
|
||||
- AI 页面“推送至工作区编辑”会切回工作区并把工具切到“调整多边形”,保留当前选中的 AI mask 和当前帧视角,以便继续编辑轮廓和归档保存;如果 AI 操作发生在非第一帧,回到工作区后不得强制跳回第一帧。
|
||||
- AI 页面“推送至工作区编辑”必须先校验待推送 AI 候选 mask 已有语义分类;没有 `classId` 或 `className` 时用右上角 error toast 明确提示并停留在 AI 页,不允许进入工作区,确保工作区内 mask 都有语义。
|
||||
- 如果用户不通过推送按钮而是直接离开 AI 页面,仍未选择语义分类的 AI 候选 mask 必须从全局 `masks` 和 `selectedMaskIds` 中清理,避免无语义候选通过侧栏切换进入工作区。
|
||||
- AI 页面“推送至工作区编辑”在语义校验通过后会切回工作区并把工具切到“调整多边形”,保留当前选中的 AI mask 和当前帧视角,以便继续编辑轮廓和归档保存;如果 AI 操作发生在非第一帧,回到工作区后不得强制跳回第一帧。
|
||||
- 工作区加载后端已保存标注时,必须保留当前项目帧里尚未保存的 AI/手工 draft mask,避免 AI 页推送到工作区的候选 mask 被异步回显流程覆盖。
|
||||
- 语义文本提示 `semantic` 当前被后端禁用并返回 400。
|
||||
- SAM 3 源码和历史测试保留,但不属于当前产品可用功能;前端不再展示 SAM 3 入口,后端 registry 不暴露 `sam3`。
|
||||
@@ -108,13 +122,13 @@
|
||||
- 前端会把多个 seed 或双向范围拆成 `steps`,通过 `POST /api/ai/propagate/task` 创建 `propagate_masks` 后台任务,避免长 HTTP 请求卡在浏览器侧,同时避免并发抢占 GPU。
|
||||
- `POST /api/ai/propagate` 作为单 seed 同步兼容接口保留;`POST /api/ai/propagate/task` 是工作区自动传播使用的任务接口。两者当前支持四个 SAM 2.1 变体;兼容 `model=sam2` 并归一化为 tiny。SAM 2.1 使用官方 `SAM2VideoPredictor.add_new_mask()` 和 `propagate_in_video()`。
|
||||
- 自动传播任务写入 `processing_tasks`,前端轮询 `GET /api/tasks/{task_id}` 显示进度并刷新标注;Dashboard 也能看到该任务,任务可取消和重试。
|
||||
- 传播结果会写入后续帧 `annotations`,`mask_data.source` 标记为 `<model_id>_propagation`,并保留 label、color、class 元数据、seed 来源 id、seed 签名、传播方向和 `geometry_smoothing` 边缘平滑参数。
|
||||
- 自动传播任务必须避免重复叠加:同一目标帧段内,同一参考 seed、同一权重、同一方向、同一平滑参数且所有目标帧已有未变化结果时,worker 直接跳过;同一参考 seed 已变化、目标帧段只部分覆盖、用户改用其他 SAM 2.1 权重或修改平滑参数时,worker 先删除本次目标帧段内对应旧自动传播标注,再保存新传播结果;对早期只记录前端临时 `source_mask_id` 的旧传播结果,worker 会按传播方向和语义信息做兼容清理。用户在自动传播链中间帧人工新增或修改同一物体 mask 后重新向前/向后传播时,即使新 seed 缺少旧传播链 source id,也要按语义信息和目标帧空间重叠清理旧传播结果后再写入新结果;写入前清理不受旧结果 `propagation_direction` 限制,因此当前帧向前传播时也会替换原先由更早帧向后传播出来的旧 mask,避免同一物体新旧 mask 堆叠。未编辑的自动传播结果再次作为参考 seed 时,会继承原始 `propagation_seed_signature` 以避免重复传播;被编辑后的传播结果只保留 lineage,不继承旧签名,以便触发删除旧结果并重新传播。带 `geometry_smoothing` 的 seed 在 forward/backward 两个方向都会用同一参数平滑保存结果。
|
||||
- 传播结果会写入后续帧 `annotations`,`mask_data.source` 标记为 `<model_id>_propagation`,并保留 label、color、class 元数据、seed 来源 id、seed 签名和传播方向;如果历史或外部 seed 带 `geometry_smoothing` 平滑参数,worker 保存前仍必须对传播返回的 polygon 实际应用同一平滑几何,不能只更新拓扑锚点或 metadata。当前工作区平滑按钮应用后会直接改写实际 polygon 并清除平滑参数,后续传播以新几何本身参与签名。
|
||||
- 自动传播任务必须避免重复叠加:同一目标帧段内,同一参考 seed、同一权重、同一方向且所有目标帧已有未变化结果时,worker 直接跳过;同一参考 seed 已变化、目标帧段只部分覆盖或用户改用其他 SAM 2.1 权重时,worker 先删除本次目标帧段内对应旧自动传播标注,再保存新传播结果;对早期只记录前端临时 `source_mask_id` 的旧传播结果,worker 会按传播方向和语义信息做兼容清理。用户在自动传播链中间帧人工新增或修改同一物体 mask 后重新向前/向后传播时,即使新 seed 缺少旧传播链 source id,也要按语义信息和目标帧空间重叠清理旧传播结果后再写入新结果;写入前清理不受旧结果 `propagation_direction` 限制,因此当前帧向前传播时也会替换原先由更早帧向后传播出来的旧 mask,避免同一物体新旧 mask 堆叠。未编辑的自动传播结果再次作为参考 seed 时,会继承原始 `propagation_seed_signature` 以避免重复传播;被编辑后的传播结果只保留 lineage,不继承旧签名,以便触发删除旧结果并重新传播。历史带 `geometry_smoothing` 的 seed 在 forward/backward 两个方向都会用同一参数平滑保存结果。
|
||||
- AI 页面会对未放置点提示、后端错误和返回 0 个 mask 的情况显示明确反馈。
|
||||
- AI 参数支持 `crop_to_prompt`、`auto_filter_background` 和 `min_score`;点/框 prompt 可以裁剪局部区域推理并回映射结果,背景过滤会移除低分结果和包含负向点的 polygon。
|
||||
- 后端返回 `polygons` 和 `scores`。
|
||||
- 前端把后端 `polygons` 转成 Konva `pathData`、`segmentation`、`bbox`、`area`。
|
||||
- AI 推理结果先存放在前端 store 的 `masks` 中,点击“结构化归档保存”后持久化到后端标注表。
|
||||
- AI 推理结果先存放在前端 store 的 `masks` 中,顶栏保存状态按钮会按待保存数量显示“保存 X 个改动”或“已全部保存”;点击保存后持久化到后端标注表。
|
||||
|
||||
## R7 标注保存
|
||||
|
||||
@@ -123,12 +137,15 @@
|
||||
- 后端提供 `GET /api/ai/annotations` 查询项目标注,可选按 `frame_id` 过滤。
|
||||
- 后端提供 `PATCH /api/ai/annotations/{annotation_id}` 更新已保存标注的 `mask_data`、`points`、`bbox` 和 `template_id`。
|
||||
- 后端提供 `DELETE /api/ai/annotations/{annotation_id}` 删除已保存标注。
|
||||
- 当前前端“结构化归档保存”会保存当前项目未保存 mask,并会更新已标记为 dirty 的已保存 mask。
|
||||
- 当前前端保存状态按钮会保存当前项目未保存 mask,并会更新已标记为 dirty 的已保存 mask。
|
||||
- 保存成功后,前端会重新拉取后端标注,并用后端 saved annotation 替换本次提交的 draft mask;未提交的其他 draft mask 仍保留。
|
||||
- 工作区“清空遮罩”会删除当前帧已保存标注,并清空当前帧未保存 mask。
|
||||
- 工作区加载项目帧后会查询已保存标注并回显。
|
||||
- 工作区支持导入 GT mask 图片,前端调用 `POST /api/ai/import-gt-mask`。
|
||||
- 后端导入 GT mask 时按非零像素值拆分多类别区域,再按连通域生成 polygon 标注,并通过距离变换写入 seed point。
|
||||
- 导入 GT Mask 时,前端必须让用户选择未知 maskid 处理策略:舍弃未知类别,或导入为“未定义类别”等待后续重新命名。
|
||||
- 后端导入 GT mask 时必须支持二值 mask、灰度/16-bit `GT_label图`,以及 RGB 三通道完全相同的 `[X,X,X]` maskid 图;0 是背景,X 是 maskid。灰度/RGB 等通道图按当前模板 `maskId` 匹配类别,超出现有类别时按用户选择的策略处理;普通彩色 RGB 类别图不再视为合法 GT mask,必须返回图片不符合要求的明确错误。
|
||||
- 导入 GT mask 前端必须提供导入结果预览,显示检测到的 maskid、未知 maskid 和尺寸适配提示;如果 mask 图片尺寸与当前帧不同,后端导入前必须按当前帧长宽用最近邻插值拉伸,使 mask 可适配当前图片。
|
||||
- 后端导入 GT mask 时按非背景像素值或颜色拆分多类别区域,再按连通域生成 polygon 标注,并通过距离变换写入 seed point。
|
||||
- 前端会回显导入标注的 seed point;拖动 seed point 后,已保存标注会变为 dirty,归档保存时会更新后端 `points`。
|
||||
|
||||
## R8 模板库
|
||||
@@ -136,7 +153,7 @@
|
||||
- 前端展示模板列表,调用 `GET /api/templates`。
|
||||
- 用户可以新建、编辑、删除模板。
|
||||
- 模板分类存放在 `mapping_rules.classes`,规则存放在 `mapping_rules.rules`。
|
||||
- 前端支持添加/删除分类、拖拽排序后重算 `zIndex`、JSON 批量导入、加载腹腔镜默认分类。
|
||||
- 前端支持添加/删除分类、拖拽排序后更新内部覆盖优先级、JSON 批量导入、加载腹腔镜默认分类。界面不展示内部优先级数值,只展示每个类别稳定的 `maskid`。
|
||||
- 后端支持模板创建、列表、详情、局部更新和删除。
|
||||
|
||||
## R9 本体检查面板
|
||||
@@ -144,11 +161,12 @@
|
||||
- 工作区右侧可以选择模板。
|
||||
- 面板显示模板分类;新增自定义分类会写入当前激活模板的后端 `mapping_rules.classes`。
|
||||
- 用户可以选择具体分类;新 AI mask 会记录 `classId`、`className`、`classZIndex`,并在保存时写入 `mask_data.class`。
|
||||
- 如果 Canvas 当前已经选中一个或多个 mask,点击语义分类树会把这些 mask 的 `label`、`color` 和 class 元数据改为该分类;已保存 mask 会进入 `dirty` 状态,归档保存时更新后端。
|
||||
- 如果 Canvas 当前已经选中一个或多个 mask,点击语义分类树会把这些 mask 的 `label`、`color` 和 class 元数据改为该分类;如果这些 mask 属于自动传播链,还必须通过 `source_annotation_id`、`source_mask_id` 和 `propagation_seed_key` 同步更新同一传播链前后帧的对应 mask;已保存 mask 会进入 `dirty` 状态,归档保存时更新后端。
|
||||
- 添加自定义分类需要先选择模板,保存时调用 `PATCH /api/templates/{id}` 并同步全局模板 store。
|
||||
- “特定目标实例属性追踪”下方显示当前选中 mask 的 `className/label`,不显示全局 active class 的旧值。
|
||||
- 选中 mask 后,拓扑锚点和重新提取拓扑锚点按钮调用 `POST /api/ai/analyze-mask`,不再显示固定占位值;前端不再展示“后端模型置信度”条目。
|
||||
- 选中 mask 后,右侧实例属性面板提供“边缘平滑强度”和“应用边缘平滑”;应用时调用 `POST /api/ai/smooth-mask`,后端返回平滑后的 polygon、bbox、area 和拓扑锚点,前端将 mask 标记为 dirty/draft,用户仍需通过结构化归档保存落库。
|
||||
- 当前实例属性面板不展示“当前选中区域”计数;当前 mask 交互以单选为主,计数长期为 1,不作为有效业务信息展示。
|
||||
- 选中 mask 后,拓扑锚点调用 `POST /api/ai/analyze-mask` 自动读取,不再显示固定占位值;后端 `topology_anchor_count` 必须表示 polygon 的真实顶点数量,不能用抽样后的展示点数代替;前端必须静默忽略 abort/cancel 或过期的分析请求,避免快速切换 mask、拖动平滑预览或卸载组件时误显示“后端属性读取失败”;前端不再展示“后端模型置信度”条目,也不再提供“重新提取拓扑锚点”调试按钮。
|
||||
- 选中 mask 后,右侧实例属性面板提供“边缘平滑强度”和“应用边缘平滑”;调整滑杆时必须立即更新数值,但后端预览请求需要做短防抖,用户停止拖动约 220ms 后再调用 `POST /api/ai/smooth-mask` 并用返回 polygon 临时预览当前 mask 边缘,避免连续拖动时请求过密造成卡顿;预览阶段不标 dirty;点击“应用边缘平滑”后确认当前预览结果,前端必须把平滑 polygon 作为新的实际 mask 几何写入当前 mask,并同步写入同一传播链前后对应 mask;整次平滑应用必须作为一个撤销/重做历史步骤,撤销/重做要同时作用于当前 mask 和传播链对应 mask;应用后相关 mask 标记为 dirty/draft,平滑强度重置为 0,用户仍可继续用 polygon 编辑工具编辑平滑后的新多边形,并通过顶栏保存状态按钮落库。后端平滑必须对 AI/SAM 密集轮廓执行去噪简化、Chaikin 平滑和二次简化,使结果 polygon 的密集边缘点实际减少;强度映射必须低段温和、高段继续递进,避免 20% 左右已经过度平滑且后续档位无明显变化。
|
||||
|
||||
## R10 Dashboard 与 WebSocket
|
||||
|
||||
@@ -168,11 +186,21 @@
|
||||
|
||||
- 后端支持 `GET /api/export/{project_id}/coco` 导出 COCO JSON。
|
||||
- 后端支持 `GET /api/export/{project_id}/masks` 导出 PNG mask ZIP。
|
||||
- 后端支持 `GET /api/export/{project_id}/results` 统一导出分割结果 ZIP,参数支持整体视频、特定范围帧和当前图片三种范围,并支持分开二值 mask、GT_label 黑白图、Pro_label 彩色图和 Mix_label 原图叠加图;Mix_label 透明度默认 0.3。
|
||||
- 统一导出 ZIP 必须固定包含 `maskid_GT像素值_类别映射.json`,记录当前导出中每个类别的 `maskid`、GT_label 像素值、中文名、类别名、RGB 值、颜色和类别 key;GT_label 背景值固定为 0,语义类别值使用类别真实 maskid,缺失 maskid 的旧标注才补下一个可用正整数,且同一类别跨图片保持一致。
|
||||
- 统一导出 ZIP 必须固定包含 `原始图片/` 文件夹,导出范围内每帧的原始图片命名为 `视频名称_时间戳_项目帧序号` 加原图片扩展名;视频名称来自项目视频文件名,时间戳来自帧 `timestamp_ms` 并格式化为 `0h00m00s000ms`,帧号使用项目抽帧后的 1-based `frame_index + 1`,不使用原视频帧号。
|
||||
- 选择“分开 Mask”时,统一导出 ZIP 必须包含 `分开Mask分割结果/`;每帧建立 `{视频名称_时间戳_项目帧序号}_分别导出` 子文件夹,同一帧同一类别的所有 annotation 合并为一张二值 PNG,文件名包含 `视频名称_时间戳_项目帧序号_{类别名称}_maskid{maskid}`。
|
||||
- 选择“GT_label 黑白图”时,统一导出 ZIP 必须包含 `GT_label图/`;每帧输出一张融合后的 GT_label PNG,文件名为 `视频名称_时间戳_项目帧序号`,重叠区域仍按内部拖拽排序从低到高覆盖;maskid 不构成排序规范。
|
||||
- 选择“Pro_label 彩色图”时,统一导出 ZIP 必须包含 `Pro_label彩色分割结果/`;每帧输出一张按类别 RGB 上色的 PNG,背景为 `[0,0,0]`。
|
||||
- 选择“Mix_label 叠加图”时,统一导出 ZIP 必须包含 `Mix_label重叠覆盖彩色分割结果/`;每帧输出一张彩色 label 叠加原始图片的 PNG,透明度可选且默认为 0.3。
|
||||
- GT_label、Pro_label 和 Mix_label 的重叠区域覆盖顺序必须和右侧“语义分类树”的内部覆盖优先级一致,低优先级先写入,高优先级后写入。
|
||||
- 分割结果导出 ZIP 文件名必须使用 `{项目库项目名}_seg_T_{起始时间戳}-{结束时间戳}_P_{起始项目帧序号}-{结束项目帧序号}.zip`;项目名来自项目库中的 `Project.name`,时间戳来自导出范围首尾帧 `timestamp_ms` 并格式化为 `0h00m00s000ms`,帧号使用项目抽帧后的 1-based `frame_index + 1`。
|
||||
- 当前前端 `exportCoco()` API 封装已对齐后端路径。
|
||||
- 当前前端 `exportMasks()` API 封装已对齐后端路径。
|
||||
- 工作区“导出 JSON 标注集”按钮已绑定下载事件;导出前会先保存当前未归档 mask。
|
||||
- 工作区“导出 PNG Mask ZIP”按钮已绑定下载事件;导出前会先保存当前未归档 mask。
|
||||
- 当前前端 `exportSegmentationResults()` API 封装已对齐统一导出路径。
|
||||
- 工作区“分割结果导出”按钮已替代原 JSON/PNG 两个按钮;点击后在下拉栏内选择导出范围、勾选导出内容,并在选择 Mix_label 时调节遮罩透明度和查看当前/待导出第一帧预览;导出范围默认选中“当前图片”,导出前会先保存当前未归档 mask。选择“特定范围帧”时,用户既可以直接修改起止帧输入框,也可以像自动传播、清空遮罩一样在播放进度条或视频处理进度条上点击/拖拽选择导出范围。
|
||||
- PNG mask ZIP 包含单标注二值 mask、按 zIndex 融合后的每帧语义 mask 和 `semantic_classes.json`。
|
||||
- 统一导出的 GT_label 图背景值固定为 0,所有语义类别值优先保留模板类别真实 maskid,缺失 maskid 的旧标注才按下一个可用正整数补值。
|
||||
|
||||
## R12 配置
|
||||
|
||||
|
||||
@@ -29,9 +29,9 @@
|
||||
| 项目库 | `src/components/ProjectLibrary.tsx` | 项目列表、新建、删除、导入视频/DICOM、显式生成帧 |
|
||||
| 工作区 | `src/components/VideoWorkspace.tsx` | 加载帧和模板,组织工具栏、Canvas、本体面板、时间轴 |
|
||||
| Canvas | `src/components/CanvasArea.tsx` | 显示帧、缩放平移、点/框提示、渲染 mask |
|
||||
| 工具栏 | `src/components/ToolsPalette.tsx` | 切换工具、跳转 AI 页面、触发 mask 撤销/重做;紧凑垂直布局,高度不足时自身滚动;外层宽 56px,按钮列固定 48px,滚动条使用右侧外扩空间和低对比 `seg-scrollbar` |
|
||||
| 工作区顶栏 | `src/components/VideoWorkspace.tsx` | 保存/导出/传播/按起止帧批量清空遮罩/导入 GT、显式撤销/重做按钮和工作区快捷键 |
|
||||
| 时间轴 | `src/components/FrameTimeline.tsx` | 帧导航、播放进度、视频处理进度条、自动传播历史片段、自动传播范围选择、左右方向键切帧、播放和当前/总时长显示 |
|
||||
| 工具栏 | `src/components/ToolsPalette.tsx` | 切换工作区编辑工具、在“重叠区域去除”后触发 GT Mask 导入、跳转 AI 页面;AI 跳转入口复用 Bot + Sparkles 组合图标以明确表达 AI 智能分割;不再放置 AI 正/反点和框选工具,也不重复放置撤销/重做;拖拽/选择到创建圆、画笔/橡皮擦/区域合并/重叠区域去除、导入 GT Mask/AI 智能分割三类工具之间用浅灰横线分隔;紧凑垂直布局,高度不足时自身滚动;外层宽 56px,按钮列固定 48px,滚动条使用右侧外扩空间和低对比 `seg-scrollbar` |
|
||||
| 工作区顶栏 | `src/components/VideoWorkspace.tsx` | 保存状态按钮(“保存 X 个改动”/“已全部保存”)、导出/传播/按起止帧批量清空遮罩、显式撤销/重做按钮和工作区快捷键 |
|
||||
| 时间轴 | `src/components/FrameTimeline.tsx` | 帧导航、播放进度、视频处理进度条、自动传播历史片段、自动传播/清空遮罩/导出范围选择、左右方向键切帧、播放和当前/总时长显示 |
|
||||
| 本体面板 | `src/components/OntologyInspector.tsx` | 模板选择、工作区 mask 透明度、分类树、后端自定义分类、mask 后端属性分析;内容过长时自身滚动,滚动条使用低对比 `seg-scrollbar` |
|
||||
| AI 页面 | `src/components/AISegmentation.tsx` | 独立 AI 推理视图,使用当前项目帧 |
|
||||
| 模板库 | `src/components/TemplateRegistry.tsx` | 模板 CRUD、分类编辑、导入、排序 |
|
||||
@@ -44,9 +44,10 @@
|
||||
| 应用入口 | `backend/main.py` | FastAPI app、CORS、路由注册、健康检查、WebSocket |
|
||||
| 配置 | `backend/config.py` | Pydantic settings |
|
||||
| 数据库 | `backend/database.py` | SQLAlchemy engine、session、Base |
|
||||
| 模型 | `backend/models.py` | Project、Frame、Template、Annotation、Mask、ProcessingTask |
|
||||
| 模型 | `backend/models.py` | User、Project、Frame、Template、Annotation、Mask、AuditLog、ProcessingTask |
|
||||
| Schema | `backend/schemas.py` | Pydantic 请求/响应模型 |
|
||||
| Auth | `backend/routers/auth.py` | 开发登录 |
|
||||
| Auth | `backend/routers/auth.py` | 用户表、密码哈希、JWT 登录和 `/api/auth/me` |
|
||||
| Admin | `backend/routers/admin.py` | 管理员用户 CRUD、角色/密码/启停用和审计日志 |
|
||||
| Projects | `backend/routers/projects.py` | 项目与帧 CRUD |
|
||||
| Templates | `backend/routers/templates.py` | 模板 CRUD 和 mapping_rules 打包/解包 |
|
||||
| Media | `backend/routers/media.py` | 上传媒体和拆帧 |
|
||||
@@ -77,7 +78,19 @@
|
||||
|
||||
1. `Login` 收集用户名和密码。
|
||||
2. `login()` 调用 `POST /api/auth/login`。
|
||||
3. 成功后 store 写入 token,App 渲染主界面。
|
||||
3. 后端用 `users` 表中的密码哈希校验用户,成功后返回签名 JWT 和用户资料。
|
||||
4. 前端把 token 写入 `localStorage` 和 Zustand;刷新页面时 `useStore` 会从 `localStorage` 恢复 token。
|
||||
5. `App` 在已登录状态调用 `/api/auth/me` 恢复当前用户,再拉取当前用户项目列表。
|
||||
|
||||
### 用户隔离
|
||||
|
||||
1. `Project.owner_user_id` 指向 `users.id`;启动时默认 admin 用户会被创建,历史 `owner_user_id IS NULL` 的项目会迁移归属到 admin。
|
||||
2. 项目、帧、媒体上传/拆帧、AI 标注、传播任务、任务列表、Dashboard 和导出接口都通过当前 JWT 用户过滤项目资源。
|
||||
3. `Template.owner_user_id` 支持用户模板;`owner_user_id IS NULL` 的模板视为系统模板,可作为默认分类体系对用户可见。
|
||||
4. 角色分为 `admin`、`annotator`、`viewer`:`admin/annotator` 可调用写入类业务接口,`viewer` 只能调用读接口;`/api/admin/*` 仅允许 `admin`。
|
||||
5. `UserAdmin.tsx` 仅在当前用户角色为 `admin` 时从 `Sidebar` 展示,调用 `/api/admin/users` 完成新增、角色修改、停用/启用、密码修改和删除无项目用户,调用 `/api/admin/audit-logs` 展示登录和管理操作审计;危险区“恢复演示出厂设置”先用浏览器确认,再要求输入 `RESET_DEMO_FACTORY`,随后调用 `/api/admin/demo-factory-reset`。
|
||||
6. `POST /api/admin/demo-factory-reset` 仅允许 `admin`,会重置默认 admin 密码/角色/启用状态,删除其它用户、项目、帧、标注、mask、任务、用户模板和旧审计,重新创建 `Data_MyVideo_1` 项目并上传 `settings.demo_video_path` 指向的视频作为未生成帧的源视频项目;系统模板保留以保证重置后仍可标注。
|
||||
7. 缺失、过期或伪造的 Bearer token 会在业务路由返回 401,权限不足返回 403,其他用户项目资源对当前用户表现为 404。
|
||||
|
||||
### 项目导入与生成帧
|
||||
|
||||
@@ -109,14 +122,15 @@
|
||||
5. `VideoWorkspace` 加载项目帧时会优先按当前选中 mask 的 `frameId` 和当前打开帧 id 恢复 `currentFrameIndex`;只有没有可恢复帧时才回到第一帧,避免 AI 页在非第一帧推送回工作区时视角被重置。
|
||||
6. `CanvasArea` 会把全局 `selectedMaskIds` 中仍存在于当前帧的 id 同步回本地选区,避免帧初始化时的临时清空覆盖 AI 页推送过来的选中态;如果切换到另一帧时原 id 不存在,但目标帧存在同一自动传播链的结果,前端会用 `source_annotation_id`、`source_mask_id` 和 `propagation_seed_key` 匹配对应传播 mask 并自动选中。
|
||||
7. `CanvasArea` 根据容器和帧尺寸按 86% 适配比例计算初始 scale/position,使底图默认居中且尽量大,但保留画布边距;滚轮缩放和拖拽平移仍由用户后续控制。
|
||||
8. `FrameTimeline` 顶部播放进度条显示当前播放位置;其下方视频处理进度条根据 `Mask.metadata.source` / `propagated_from_frame_id` 计算自动传播帧并显示蓝色区段,对人工绘制或 AI 智能分割等非传播 mask 帧显示红色竖线。普通状态下,视频处理进度条可点击跳转到对应帧,红色人工/AI 标注帧和蓝色自动传播帧标识本身也可点击跳转。处理条未处理背景使用中性灰,和红色/蓝色标记保持明显区分。`VideoWorkspace` 会记录当前会话最近 8 次成功处理过的自动传播范围,并通过 `propagationHistory` 传给 `FrameTimeline`;时间轴会把这些片段叠加为不同色系的横向渐变条,片段内按视频时间从深到浅,较早片段降低透明度。底部缩略图导航轴对非当前帧使用红色边框标识人工/AI 标注帧,使用蓝色边框标识自动传播/推理帧;如果同一帧同时存在人工/AI 标注和自动传播结果,红色人工/AI 标注边框优先保留,自动传播状态只作为蓝色内描边。当前帧使用青色外框高亮优先,若当前帧同时是人工/AI 标注帧,则以青色外框加红色内描边同时表达两个状态,外层当前帧框和内层人工/AI 框的顺序固定。工作区进入自动传播或清空片段遮罩范围选择模式时,播放进度条和视频处理进度条显示 amber 覆盖层,并可点击/拖拽设置处理起止帧。
|
||||
9. 当前帧传入 `CanvasArea`。
|
||||
10. 工作区顶栏短状态文本会在空闲状态下自动消失;保存、导出、导入 GT 和传播任务运行中仍保留进度状态,无帧项目提示也会保留。
|
||||
11. 左侧工具栏和右侧本体/语义分类面板使用 `seg-scrollbar` 定制纵向滚动条;默认滚动条 thumb 低透明度融入深色背景,hover/focus 时增强为青色提示,避免系统默认滚动条在工具区中过于突兀。左侧工具栏额外保留右侧滚动条槽位,按钮列仍按原 48px 布局,避免滚动条和图标抢空间。
|
||||
8. `CanvasArea` 未选中特定 mask 时,会按 `classZIndex` 从低到高渲染当前帧 mask;该值来自右侧“语义分类树”的拖拽排序,因此高优先级类别会后渲染并覆盖低优先级类别。有选中 mask 时,编辑态可保留选中区域置顶,方便拖点、换类和布尔操作。
|
||||
9. `FrameTimeline` 顶部播放进度条显示当前播放位置;其下方视频处理进度条根据 `Mask.metadata.source` / `propagated_from_frame_id` 计算自动传播帧并显示蓝色区段,对人工绘制或 AI 智能分割等非传播 mask 帧显示红色竖线。当前帧另用白色竖线贯穿播放进度条和视频处理进度条,和青色播放进度、红色标注、蓝色传播状态区分。普通状态下,视频处理进度条可点击跳转到对应帧,红色人工/AI 标注帧和蓝色自动传播帧标识本身也可点击跳转。处理条未处理背景使用中性灰,和红色/蓝色标记保持明显区分。`VideoWorkspace` 会记录当前会话最近 8 次成功处理过的自动传播范围,并通过 `propagationHistory` 传给 `FrameTimeline`;时间轴会把这些片段叠加为同一蓝色系的纯色条,按距最新传播的时间顺序逐次变暗,且第 5 次及更早统一为阈值旧记录色,不再在单个片段内部使用渐变。清空片段遮罩时,`VideoWorkspace` 会按清空范围移除或裁剪本地传播历史片段,避免已清空的处理范围仍显示最近传播条。底部缩略图导航轴对非当前帧使用红色边框标识人工/AI 标注帧,使用蓝色边框标识自动传播/推理帧;如果同一帧同时存在人工/AI 标注和自动传播结果,红色人工/AI 标注边框优先保留,自动传播状态只作为蓝色内描边。当前帧使用青色外框高亮优先,若当前帧同时是人工/AI 标注帧,则以青色外框加红色内描边同时表达两个状态,外层当前帧框和内层人工/AI 框的顺序固定。工作区进入自动传播、清空片段遮罩或特定范围帧导出选择模式时,播放进度条和视频处理进度条显示 amber 覆盖层,并额外用洋红色起始线和黄绿色结束线贯穿两条进度条,表达待处理或待导出范围边界,可点击/拖拽设置起止帧。
|
||||
10. 当前帧传入 `CanvasArea`。
|
||||
11. 工作区顶栏短状态文本会在空闲状态下自动消失;保存、导出、导入 GT 和传播任务运行中仍保留进度状态,无帧项目提示也会保留。
|
||||
12. 左侧工具栏和右侧本体/语义分类面板使用 `seg-scrollbar` 定制纵向滚动条;默认滚动条 thumb 低透明度融入深色背景,hover/focus 时增强为青色提示,避免系统默认滚动条在工具区中过于突兀。左侧工具栏额外保留右侧滚动条槽位,按钮列仍按原 48px 布局,避免滚动条和图标抢空间。
|
||||
12. 右侧面板不再显示“本体论与属性分类管理树”固定说明栏,直接展示实际可操作内容。
|
||||
13. 右侧“遮罩透明度”滑杆写入 Zustand `maskPreviewOpacity`,`CanvasArea` 用该值计算 mask group opacity;选中 mask 在基础透明度上加亮,方便保留选中反馈。
|
||||
13. 右侧“遮罩透明度”滑杆写入 Zustand `maskPreviewOpacity`,`CanvasArea` 和 `AISegmentation` 都用该值计算 mask group opacity;选中 mask 在基础透明度上加亮或按基础透明度显示,方便保留选中反馈。
|
||||
14. Canvas 点击 mask 后,全局 `selectedMaskIds` 会同步到 `OntologyInspector`;本体面板按选中 mask 的 `classId`、`className/label` 和颜色匹配模板分类,自动设置 active class,并把分类按钮滚动/聚焦到可见区域。
|
||||
15. 工作区顶栏“清空片段遮罩”和“自动传播”共用时间轴范围选择交互;第一次点击“清空片段遮罩”会进入范围选择模式,按钮变为“确认清空”,用户可在播放进度条或视频处理进度条上点击/拖拽选择起止帧;确认执行时对范围内已保存 mask 调用 `DELETE /api/ai/annotations/{id}`,同时移除范围内本地 draft mask 和被清空的选区,范围外 mask 保持不变。
|
||||
15. 工作区顶栏“清空片段遮罩”和“自动传播”共用时间轴范围选择交互;第一次点击“清空片段遮罩”会进入范围选择模式,按钮变为“确认清空”,用户可在播放进度条或视频处理进度条上点击/拖拽选择起止帧;进入清空模式后顶栏显示“清空全部 / 保留人工/AI”两段式模式选择,默认“清空全部”。“清空全部”会对范围内已保存 mask 调用 `DELETE /api/ai/annotations/{id}`,同时移除范围内本地 draft mask、被清空的选区和与清空范围重叠的本地传播历史条;若范围内存在非自动传播来源的 mask,也就是时间轴红色“人工/AI 标注帧”,执行前会弹出“是否清除“人工/AI标注帧”?”确认,取消则不删除任何 mask。“保留人工/AI”只删除范围内自动传播/推理 mask,不删除人工绘制或 AI 智能分割生成的红色标注帧,不弹出人工帧确认;范围外 mask 和传播历史片段保持不变。
|
||||
|
||||
### AI 点/框推理
|
||||
|
||||
@@ -136,12 +150,12 @@
|
||||
14. AI 页面提示点由本地 `points` 状态维护;点击已渲染提示点会按 index 删除对应点,“删除最近锚点”会删除数组最后一个点,不改动候选 mask 列表。
|
||||
15. AI 页面候选 mask 删除只接受当前 `aiMaskIds` 范围内的已选 id;“删除选中候选”和 Delete/Backspace 都复用该范围过滤,避免删除工作区已有 mask。
|
||||
16. AI 页面参数开关文案只做展示增强:“局部专注模式(自动裁剪无锚区域)”仍控制 `cropMode/crop_to_prompt`,“严格除杂模式(自动清理干涉点)”仍控制 `autoDeleteBg/auto_filter_background/min_score`。
|
||||
17. AI 页面“遮罩清晰度”滑杆只调节候选 mask 的 Konva preview opacity,不写入 `Mask.segmentation`、分类元数据或后端 payload。
|
||||
17. AI 页面“AI 遮罩透明度”滑杆复用 Zustand `maskPreviewOpacity`,和右侧“遮罩透明度”联动,只调节候选 mask 的 Konva preview opacity,不写入 `Mask.segmentation`、分类元数据或后端 payload。
|
||||
18. AI 画布左上角根据正向点、反向点、边界框选和视口控制显示上下文提示,说明点击/拖拽、删除提示点和执行推理的操作方式。
|
||||
19. AI 画布根据容器和当前帧尺寸按 86% 适配比例计算初始 scale/position,使底图默认居中且尽量大,但保留画布边距。
|
||||
20. Canvas 按当前帧过滤并渲染 mask。
|
||||
21. 新 mask 会带上当前选择的模板分类元数据,包括 `classId`、`className`、`classZIndex`、`metadata.source=ai_segmentation` 和保存状态 `draft`。
|
||||
20. 用户点击“结构化归档保存”后,前端将像素 `segmentation` 转成 normalized `mask_data.polygons`;未保存 mask 调用 `POST /api/ai/annotate`,dirty mask 调用 `PATCH /api/ai/annotations/{annotation_id}`;保存成功后本次提交的 draft mask id 会从本地保留列表中排除,并由后端 saved annotation 回显替换。
|
||||
20. 顶栏保存状态按钮按当前项目待保存数量显示为“保存 X 个改动”或“已全部保存”;用户点击保存后,前端将像素 `segmentation` 转成 normalized `mask_data.polygons`;未保存 mask 调用 `POST /api/ai/annotate`,dirty mask 调用 `PATCH /api/ai/annotations/{annotation_id}`;保存成功后本次提交的 draft mask id 会从本地保留列表中排除,并由后端 saved annotation 回显替换。
|
||||
21. 工作区加载项目帧后通过 `GET /api/ai/annotations` 取回已保存标注并转成前端 mask。
|
||||
22. 工作区“清空遮罩”删除当前帧已保存标注,并清除当前帧本地 mask。
|
||||
|
||||
@@ -151,26 +165,27 @@
|
||||
2. 用户可以直接修改传播起始帧/结束帧数字框,并可通过工作区顶栏“传播权重”下拉独立选择本次传播使用的 SAM 2.1 tiny/small/base+/large 权重;该入口不提供 SAM2/SAM3 家族切换,默认跟随全局 AI 权重,用户手动选择后不再被 AI 页权重切换覆盖。
|
||||
3. `VideoWorkspace` 以当前参考帧为 seed,将起止帧拆成 `backward` 和/或 `forward` 两段;只包含当前帧时不传播。
|
||||
4. `VideoWorkspace` 在提交传播前会先调用现有归档保存链路保存当前项目中的 draft/dirty mask,并重新读取 store 中的回显结果;参考帧 seed 因此优先携带稳定的后端 `source_annotation_id`,避免用前端临时 mask id 生成传播结果后,二次传播无法找到旧结果。
|
||||
5. `VideoWorkspace` 用 `buildAnnotationPayload()` 把每个 seed mask 转成 normalized polygon、bbox、label、color、class 元数据、`geometry_smoothing`、`source_mask_id` 和可用时的 `source_annotation_id`;如果 seed mask 是未编辑的自动传播结果,会沿用其原始 `source_annotation_id/source_mask_id/propagation_seed_signature`,让后端把它识别为原传播链的同一个 seed;如果该传播结果被编辑并保存,更新 payload 只保留 lineage,不保留旧签名,使后端按“已修改”路径清理旧结果并重传。
|
||||
5. `VideoWorkspace` 用 `buildAnnotationPayload()` 把每个 seed mask 转成 normalized polygon、bbox、label、color、class 元数据、`source_mask_id` 和可用时的 `source_annotation_id`;如果 seed mask 是未编辑的自动传播结果,会沿用其原始 `source_annotation_id/source_mask_id/propagation_seed_signature`,让后端把它识别为原传播链的同一个 seed;如果该传播结果被编辑并保存,更新 payload 只保留 lineage,不保留旧签名,使后端按“已修改”路径清理旧结果并重传。对历史或外部写入的 `geometry_smoothing` metadata,payload 仍可透传给后端兼容处理;当前前端平滑应用会直接改写 polygon 几何并移除该参数。
|
||||
6. 前端把传播权重 id、每个 seed、每个方向组装成 `steps`,一次调用 `POST /api/ai/propagate/task`,`include_source=false`、`save_annotations=true`;接口先规范化/校验 `model` 字段中的权重 id,再创建 `processing_tasks.task_type=propagate_masks` 并投递 Celery,避免长 HTTP 请求阻塞前端等待。
|
||||
7. `VideoWorkspace` 记录返回的 `task_id`,轮询 `GET /api/tasks/{task_id}` 显示任务 message、步骤进度、已处理帧次和已保存区域数;任务运行期间提供取消传播按钮,调用通用 `POST /api/tasks/{task_id}/cancel`。
|
||||
8. Celery worker 逐 step 顺序执行传播,避免多个视频 tracker 并发抢占 GPU;每个 step 开始/完成都会写入 `processing_tasks.progress/result/message` 并发布 Redis `seg:progress`,Dashboard 可同步显示。每个 step 开始前,worker 会在本次目标帧段内用 seed 来源 id、传播方向和包含 `geometry_smoothing` 的 seed 签名查找旧传播标注:同权重、签名相同且目标帧都已有结果时跳过该 seed;签名不同、目标帧只部分覆盖、本次使用了其他 SAM 2.1 权重或平滑参数变化则先删除本次目标帧段内对应方向的旧自动传播标注,再执行新的 video predictor 传播。对旧版本只记录前端临时 `source_mask_id` 的传播标注,worker 会按 label/color/class 做兼容匹配,确保可被后续稳定 `source_annotation_id` 的传播替换;对中间帧人工新增的替代 seed,若缺少旧 source id,worker 仍会用语义信息识别候选旧传播结果,并在写入目标帧新 polygon 前用目标帧 bbox 重叠做二次确认和清理。写入前这层清理不限制旧结果方向,确保 backward 传播可覆盖早先 forward 传播留下的同物体旧 mask。
|
||||
8. Celery worker 逐 step 顺序执行传播,避免多个视频 tracker 并发抢占 GPU;每个 step 开始/完成都会写入 `processing_tasks.progress/result/message` 并发布 Redis `seg:progress`,Dashboard 可同步显示。每个 step 开始前,worker 会在本次目标帧段内用 seed 来源 id、传播方向和 seed 签名查找旧传播标注:同权重、签名相同且目标帧都已有结果时跳过该 seed;签名不同、目标帧只部分覆盖或本次使用了其他 SAM 2.1 权重则先删除本次目标帧段内对应方向的旧自动传播标注,再执行新的 video predictor 传播;若历史 seed 签名中包含 `geometry_smoothing`,仍按完整签名参与兼容去重。对旧版本只记录前端临时 `source_mask_id` 的传播标注,worker 会按 label/color/class 做兼容匹配,确保可被后续稳定 `source_annotation_id` 的传播替换;对中间帧人工新增的替代 seed,若缺少旧 source id,worker 仍会用语义信息识别候选旧传播结果,并在写入目标帧新 polygon 前用目标帧 bbox 重叠做二次确认和清理。写入前这层清理不限制旧结果方向,确保 backward 传播可覆盖早先 forward 传播留下的同物体旧 mask。
|
||||
9. 后端按项目帧序列截取片段,下载对应帧到临时目录,并写成 `000000.jpg` 这类纯数字文件名;这是 `SAM2VideoPredictor` 对视频帧排序的要求,和项目库中持久化的 `frame_%06d.jpg` 对象名无关。
|
||||
10. `model` 为任一 SAM 2.1 权重变体时,`sam2_engine` 使用对应 checkpoint/config 加载 `SAM2VideoPredictor.add_new_mask()` 注入 seed mask,再用 `propagate_in_video()` 传播;`model=sam2` 会在入队时规范化为 tiny,任务 payload/result 会保留规范化后的权重 id;单个 SAM2 video predictor 调用内部暂不提供逐帧流式进度。
|
||||
11. `model=sam3` 当前不支持;SAM 3 video tracker 代码保留但没有接入产品路径。
|
||||
12. 后端把传播返回的 normalized polygon 保存为后续帧 `Annotation`,跳过源帧;如果 seed 带 `geometry_smoothing`,保存前会用同一 Chaikin 平滑参数处理 forward/backward 两个方向的结果。`mask_data.source` 记录权重传播来源,同时写入 `propagation_seed_key`、`propagation_seed_signature`、`propagation_direction`、`source_annotation_id`、`source_mask_id` 和 `geometry_smoothing` 供后续幂等传播判断。
|
||||
12. 后端把传播返回的 normalized polygon 保存为后续帧 `Annotation`,跳过源帧;如果历史或外部 seed 带 `geometry_smoothing`,保存前仍会用同一平滑参数处理 forward/backward 两个方向的结果:强度先经过缓入曲线映射,低强度使用较小 Chaikin 切角比例和简化阈值,高强度再逐步增加迭代、切角和简化力度;随后按强度对 SAM 密集轮廓做 `approxPolyDP` 去噪简化,再做 Chaikin 平滑,最后二次简化并以平滑后的 polygon 计算 bbox 后落库。当前工作区“应用边缘平滑”会在前端把同传播链对应 mask 直接改写为新的 polygon 并移除 `geometry_smoothing` 参数,因此后续传播通常按新几何本身参与 seed 签名。`mask_data.source` 记录权重传播来源,同时写入 `propagation_seed_key`、`propagation_seed_signature`、`propagation_direction`、`source_annotation_id` 和 `source_mask_id` 供后续幂等传播判断;历史 `geometry_smoothing` 仅在存在时保留用于兼容判断。
|
||||
13. 前端轮询到已创建区域后刷新 `GET /api/ai/annotations` 并回显新标注;任务结束后如果后端返回 0 个新区域,工作区会明确提示没有生成新的 mask,若是未改变 seed 被跳过则提示未改变 mask 已跳过。处理过帧次大于 0 的成功任务会追加一条本地传播历史片段,用于视频处理进度条显示最近传播范围;`annotationToMask()` 会保留传播来源 metadata,供时间轴视频处理进度条显示蓝色传播区段。
|
||||
|
||||
### 手工绘制与历史栈
|
||||
|
||||
1. 用户在 `ToolsPalette` 选择多边形、矩形、圆、点或线工具。
|
||||
1. 用户在 `ToolsPalette` 选择多边形、矩形、圆、画笔或橡皮擦工具;创建点和创建线段入口不在工作区左侧工具栏中提供。
|
||||
2. `CanvasArea` 将交互坐标转换成像素 polygon。
|
||||
3. 多边形工具逐次记录节点,三点后点击首节点或按 Enter 时生成闭合 polygon。
|
||||
4. Canvas 左上角根据当前工具和操作阶段显示上下文短提示;多边形提示会随已放置点数切换,明确 Enter 完成、Esc 取消和点击首节点闭合。提示会在工具或操作状态变化时出现,并在数秒后自动隐藏,避免长期遮挡底图。
|
||||
5. mask path 只在 `move`、`edit_polygon`、`area_merge` 和 `area_remove` 工具下拦截点击;绘制和 AI prompt 工具点击已有 mask 时继续冒泡给 Stage。
|
||||
6. 新 mask 写入 `pathData`、像素 `segmentation`、`bbox`、`area` 和当前模板分类元数据。
|
||||
7. `addMask()`、`setMasks()`、`updateMask()`、`clearMasks()` 会维护 `maskHistory/maskFuture`。
|
||||
8. 工具栏按钮、工作区顶栏按钮和 AI 页按钮调用 `undoMasks()` / `redoMasks()`;工作区由 `VideoWorkspace` 统一处理 `Ctrl/Cmd+Z`、`Ctrl/Cmd+Shift+Z` 和 `Ctrl/Cmd+Y`,并在输入框、下拉框和可编辑文本聚焦时跳过快捷键,避免影响帧范围输入。
|
||||
5. mask path 只在 `move`、`edit_polygon`、`area_merge` 和 `area_remove` 工具下拦截点击;绘制、画笔、橡皮擦和 AI prompt 工具点击已有 mask 时继续冒泡给 Stage。
|
||||
6. 画笔/橡皮擦尺寸保存在 Zustand 中;拖动期间只保留采样后的圆形笔触预览,鼠标松开后再用 `polygon-clipping` 执行 union/difference,避免拖动中反复重算复杂 polygon。
|
||||
7. 新 mask 写入 `pathData`、像素 `segmentation`、`bbox`、`area` 和当前模板分类元数据。
|
||||
8. `addMask()`、`setMasks()`、`updateMask()`、`clearMasks()` 会维护 `maskHistory/maskFuture`。
|
||||
9. 工作区撤销/重做只保留顶栏按钮和快捷键入口,AI 页保留自己的撤销/重做按钮;工作区由 `VideoWorkspace` 统一处理 `Ctrl/Cmd+Z`、`Ctrl/Cmd+Shift+Z` 和 `Ctrl/Cmd+Y`,并在输入框、下拉框和可编辑文本聚焦时跳过快捷键,避免影响帧范围输入。
|
||||
|
||||
### Polygon 逐点编辑
|
||||
|
||||
@@ -196,7 +211,7 @@
|
||||
|
||||
### GT Mask 导入
|
||||
|
||||
1. 工作区“导入 GT Mask”选择图片文件。
|
||||
1. 工作区左侧工具栏“导入 GT Mask”选择图片文件;入口位于“重叠区域去除”之后。
|
||||
2. 前端 `importGtMask()` 以 multipart form-data 调用 `POST /api/ai/import-gt-mask`,携带 `project_id` 和 `frame_id`。
|
||||
3. 后端验证项目、帧、模板后使用 OpenCV 读取灰度 mask。
|
||||
4. 后端按非零像素值拆分多类别标签。
|
||||
@@ -214,13 +229,16 @@
|
||||
5. 返回时再解包给前端。
|
||||
6. `CanvasArea` 把当前选中的 mask id 同步到全局 `selectedMaskIds`;切换工具、切换帧或卸载 Canvas 时会清空选择。
|
||||
7. `AISegmentation` 生成 mask 后会写入全局 `masks` 并把生成的 mask id 写入 `selectedMaskIds`;点击 AI 页预览 mask 也会更新 `selectedMaskIds`。
|
||||
8. AI 页“推送至工作区编辑”会切换到工作区并把 `activeTool` 设为 `edit_polygon`;`CanvasArea` 初始读取全局 `selectedMaskIds`,让 AI 页选中的 mask 在工作区继续保持选中。
|
||||
9. 工作区帧/标注异步加载完成后,`hydrateSavedAnnotations()` 会合并本地未保存 draft mask 和后端已保存 mask,不会用后端回显结果直接覆盖整个 `masks` store。
|
||||
10. `OntologyInspector` 可以选择具体分类;选择结果进入全局 store,供 `CanvasArea` 和 `AISegmentation` 新建/更新 mask 时使用。
|
||||
11. 如果 `selectedMaskIds` 中存在当前 store 的 mask,点击分类时会立即更新这些 mask 的 `templateId`、`classId`、`className`、`classZIndex`、`label` 和 `color`。
|
||||
12. 同一次点击会把这些已选 mask 移动到前端 `masks` 数组末尾;`CanvasArea` 按数组顺序渲染,后渲染的 Path 显示在最上层,方便用户继续编辑刚换标签的区域。该显示置顶不改变模板 `zIndex` 或后端导出语义覆盖规则。
|
||||
13. 已保存 mask 被重新分类后进入 `dirty` 且 `saved=false`,继续复用工作区归档保存的 PATCH 链路。
|
||||
14. 模板保存、删除和 JSON 导入失败使用 `TransientNotice` 非阻塞提示,默认数秒后自动消失。
|
||||
8. AI 页“推送至工作区编辑”会先检查待推送 AI 候选 mask 是否具备 `classId` 或 `className`;缺少语义分类时清空普通推理反馈,并通过 `TransientNotice` 右上角 error toast 提示用户先点右侧语义分类树,不切换模块、不修改工具状态。
|
||||
9. `AISegmentation` 卸载时会清理仍缺少 `classId/className` 的本页 AI 候选,并同步移除对应 `selectedMaskIds`,避免用户绕过推送按钮从侧栏切到工作区时带入无语义 mask。
|
||||
10. AI 页语义校验通过后会切换到工作区并把 `activeTool` 设为 `edit_polygon`;`CanvasArea` 初始读取全局 `selectedMaskIds`,让 AI 页选中的 mask 在工作区继续保持选中。
|
||||
11. 工作区帧/标注异步加载完成后,`hydrateSavedAnnotations()` 会合并本地未保存 draft mask 和后端已保存 mask,不会用后端回显结果直接覆盖整个 `masks` store。
|
||||
12. `OntologyInspector` 可以选择具体分类;选择结果进入全局 store,供 `CanvasArea` 和 `AISegmentation` 新建/更新 mask 时使用。
|
||||
13. 如果 `selectedMaskIds` 中存在当前 store 的 mask,点击分类时会立即更新这些 mask 的 `templateId`、`classId`、`className`、`classZIndex`、`label` 和 `color`。
|
||||
14. 对属于自动传播链的 mask,分类更新会复用 `source_annotation_id`、`source_mask_id` 和 `propagation_seed_key` 查找同一目标实例在前后帧中的传播结果,并同步更新这些传播 mask 的分类元数据,避免同一物体跨帧语义不一致。
|
||||
15. 同一次点击会把这些已选 mask 移动到前端 `masks` 数组末尾;`CanvasArea` 按数组顺序渲染,后渲染的 Path 显示在最上层,方便用户继续编辑刚换标签的区域。该显示置顶不改变模板 `zIndex` 或后端导出语义覆盖规则。
|
||||
16. 已保存 mask 被重新分类后进入 `dirty` 且 `saved=false`,同传播链被同步更新的已保存 mask 也进入 `dirty`,继续复用工作区归档保存的 PATCH 链路。
|
||||
16. 模板保存、删除和 JSON 导入失败使用 `TransientNotice` 非阻塞提示,默认数秒后自动消失。
|
||||
|
||||
### 导出
|
||||
|
||||
@@ -228,7 +246,9 @@
|
||||
2. PNG mask 导出会把 normalized polygon 渲染为单标注二值 mask。
|
||||
3. PNG mask 导出还会按 `mask_data.class.zIndex` 或模板 `z_index` 从低到高覆盖,生成每帧语义融合 mask。
|
||||
4. ZIP 内写入 `semantic_classes.json`,记录语义值到类别、颜色和 zIndex 的映射。
|
||||
5. 前端“导出 JSON 标注集”和“导出 PNG Mask ZIP”按钮都会在导出前保存待归档标注,然后下载对应文件。
|
||||
5. 前端使用“分割结果导出”统一入口替代原 JSON/PNG 两个按钮;点击后在下拉栏选择整体视频、特定范围帧或当前图片,默认选中当前图片,并勾选分开二值 mask、GT_label 黑白图、Pro_label 彩色图和 Mix_label 原图叠加图。选择“特定范围帧”时,导出起止帧输入框和 `FrameTimeline` 的范围拖拽选择共用同一组导出范围状态;选择 Mix_label 时显示透明度滑杆,默认 0.3,并用当前/待导出第一帧做遮罩预览。提交前会保存待归档标注,然后下载统一 ZIP。下载文件名使用 `{项目库项目名}_seg_T_{起始时间戳}-{结束时间戳}_P_{起始项目帧序号}-{结束项目帧序号}.zip`;项目名来自 `currentProject.name`,起止帧按当前导出范围取首尾帧,时间戳格式为 `0h00m00s000ms`,帧号使用项目抽帧后的 1-based 顺序,项目名中的文件系统不安全字符会替换为 `_`。
|
||||
6. 统一导出 ZIP 固定包含 `annotations_coco.json`、`maskid_GT像素值_类别映射.json` 和 `原始图片/`;原始图片文件名使用 `视频名称_时间戳_项目帧序号`。导出会保留类别真实 maskid,GT_label 像素值与 maskid 相同并跨图一致;缺失 maskid 的旧标注才补下一个可用正整数并写入映射 JSON。选择分开 mask 时包含 `分开Mask分割结果/`,每帧建立 `{视频名称_时间戳_项目帧序号}_分别导出` 子文件夹,并按“同一帧同一类别合并一张图”的方式输出 `{视频名称_时间戳_项目帧序号}_{类别名称}_maskid{maskid}.png`。选择 GT_label 图时包含 `GT_label图/{视频名称_时间戳_项目帧序号}.png`;选择 Pro_label 图时包含 `Pro_label彩色分割结果/{视频名称_时间戳_项目帧序号}.png`;选择 Mix_label 图时包含 `Mix_label重叠覆盖彩色分割结果/{视频名称_时间戳_项目帧序号}.png`。GT_label、Pro_label 和 Mix_label 的重叠区域按内部拖拽排序从低到高覆盖,和未选中状态下的画布显示顺序一致;maskid 不参与排序。后端直接下载接口的 `Content-Disposition` 使用同一 ZIP 命名规则,并用 `filename*` 支持中文项目名。
|
||||
7. 右侧 `OntologyInspector` 的语义分类树支持拖拽调整内部覆盖顺序;拖拽后保存到模板并同步当前工作区同类 mask 的 `classZIndex`,但保留类别 maskid 不变。
|
||||
|
||||
## 接口契约
|
||||
|
||||
@@ -237,13 +257,14 @@
|
||||
- `updateProject()` 使用 `PATCH /api/projects/{id}`。
|
||||
- `exportCoco()` 使用 `GET /api/export/{projectId}/coco`。
|
||||
- `exportMasks()` 使用 `GET /api/export/{projectId}/masks`。
|
||||
- `exportSegmentationResults()` 使用 `GET /api/export/{projectId}/results`,通过 query 参数选择范围和 mask 类型。
|
||||
- `cancelTask()` 使用 `POST /api/tasks/{taskId}/cancel`。
|
||||
- `retryTask()` 使用 `POST /api/tasks/{taskId}/retry`。
|
||||
- `predictMask()` 使用 `POST /api/ai/predict`,请求体为 `image_id`、`prompt_type`、`prompt_data`、`model`。
|
||||
- `propagateMasks()` 使用 `POST /api/ai/propagate`,请求体为 `project_id`、`frame_id`、`model`、`seed`、`direction`、`max_frames`,作为单 seed 同步兼容接口保留。
|
||||
- `queuePropagationTask()` 使用 `POST /api/ai/propagate/task`,请求体为 `project_id`、`frame_id`、`model`、`steps`、`include_source`、`save_annotations`,返回 `ProcessingTask`。
|
||||
- `saveAnnotation()` 使用 `POST /api/ai/annotate`。
|
||||
- `importGtMask()` 使用 `POST /api/ai/import-gt-mask` multipart form-data。
|
||||
- `importGtMask()` 使用 `POST /api/ai/import-gt-mask` multipart form-data,并传入 `unknown_color_policy=discard|undefined`。前端上传前弹出导入结果预览和未知 maskid 策略选择;后端使用 `cv2.IMREAD_UNCHANGED` 保留低数值/16-bit GT_label 像素值。合法 GT mask 限定为灰度图或 RGB 三通道完全相同的 `[X,X,X]` maskid 图,0 为背景、X 为 maskid;灰度/RGB 等通道图按模板 `maskId` 匹配类别,普通彩色 RGB 类别图不再按颜色匹配并会返回格式错误;未知类别按策略舍弃或保存为 `gt_unknown_class` 未定义类别。若 GT mask 尺寸和当前帧不同,后端用最近邻插值拉伸到当前帧尺寸后再生成 polygon。
|
||||
- `getProjectAnnotations()` 使用 `GET /api/ai/annotations`。
|
||||
- `updateAnnotation()` 使用 `PATCH /api/ai/annotations/{annotationId}`。
|
||||
- `deleteAnnotation()` 使用 `DELETE /api/ai/annotations/{annotationId}`。
|
||||
@@ -276,5 +297,5 @@
|
||||
- 已保存标注支持通过“应用分类”、polygon 顶点拖动/删除、边中点插入、多 polygon 子区域编辑和区域合并/去除进入 dirty 状态并归档更新;选中整块 mask 可用 Delete/Backspace 删除并同步后端;复杂洞结构编辑尚未实现。
|
||||
- SAM 3 文本语义分割已从当前产品路径中禁用;相关源码保留,恢复时需要重新接入前端入口、registry、状态接口和测试。
|
||||
- 自定义分类通过 `PATCH /api/templates/{id}` 写入当前激活模板的 `mapping_rules.classes`。
|
||||
- 选中 mask 后,本体面板的“特定目标实例属性追踪”标题值来自当前 mask 的 `className/label`,不使用全局 active class;面板调用 `POST /api/ai/analyze-mask` 显示拓扑锚点数量等属性,“重新提取拓扑锚点”会带 `extract_skeleton=true` 重新请求后端分析;“边缘平滑强度/应用边缘平滑”调用 `POST /api/ai/smooth-mask`,由后端按 Chaikin smoothing 返回新 polygon 并把 `geometry_smoothing` 写回 mask metadata。前端不再展示“后端模型置信度”。
|
||||
- 选中 mask 后,本体面板的“特定目标实例属性追踪”标题值来自当前 mask 的 `className/label`,不使用全局 active class;面板不再展示长期为 1 的“当前选中区域”计数;面板调用 `POST /api/ai/analyze-mask` 自动显示拓扑锚点数量等属性,`topology_anchor_count` 是真实 polygon 顶点数量,`topology_anchors` 只保留最多 64 个抽样点用于调试展示;`OntologyInspector` 会为分析请求维护递增序号,旧请求返回时不再回写状态,并静默忽略 Axios abort/cancel 错误,避免快速切换、平滑预览或组件卸载时把正常中止误报成失败;不再提供“重新提取拓扑锚点”调试按钮;“边缘平滑强度”滑杆会即时更新数值,但 `POST /api/ai/smooth-mask` 预览请求经过约 220ms 防抖后才发送,返回 polygon 作为临时预览写入当前 mask 显示,预览不改变保存状态;点击“应用边缘平滑”后,前端把平滑 polygon 作为新的实际几何写入当前 mask,并按传播 lineage 同步写入传播链前后对应 mask,相关 mask 标记为 dirty/draft,整次操作通过一次 `setMasks()` 进入撤销/重做历史;应用后不保留 `geometry_smoothing` 参数,平滑强度重置为 0。前端不再展示“后端模型置信度”。
|
||||
- GT mask 导入已完成多类别像素值拆分、contour、distance transform seed point 和前端 seed point 拖拽编辑;骨架提取、HDBSCAN 聚类和模板自动映射尚未实现。
|
||||
|
||||
@@ -14,17 +14,17 @@
|
||||
|
||||
| 需求 | 测试文件 | 覆盖点 |
|
||||
|------|----------|--------|
|
||||
| R1 登录与会话 | `src/components/Login.test.tsx`, `backend/tests/test_auth.py` | 成功登录、失败提示、后端 401 |
|
||||
| R2 项目管理 | `src/lib/api.test.ts`, `src/components/ProjectLibrary.test.tsx`, `backend/tests/test_projects.py` | 前端字段映射、PATCH 更新、项目卡片删除、DELETE 契约、后端 CRUD、删除级联、帧列表 |
|
||||
| R1 登录与会话 | `src/components/Login.test.tsx`, `src/components/Sidebar.test.tsx`, `src/components/UserAdmin.test.tsx`, `src/store/useStore.test.ts`, `backend/tests/test_auth.py`, `backend/tests/test_admin.py` | 成功登录、JWT/token 写入、当前用户写入、刷新恢复基础状态、失败提示、登录输入 autocomplete、后端 401、`/api/auth/me`、管理员入口、用户 CRUD、角色权限、审计日志、viewer 读写权限边界、演示出厂设置二次确认和重置结果 |
|
||||
| R2 项目管理 | `src/lib/api.test.ts`, `src/components/ProjectLibrary.test.tsx`, `backend/tests/test_projects.py` | 前端字段映射、PATCH 更新、项目卡片删除、DELETE 契约、后端 CRUD、删除级联、帧列表、项目按当前 JWT 用户隔离 |
|
||||
| R3 媒体上传与拆帧 | `src/components/ProjectLibrary.test.tsx`, `src/components/TransientNotice.test.tsx`, `backend/tests/test_media.py`, `backend/tests/test_tasks.py` | 视频导入不自动拆帧、显式生成帧 FPS 选择、项目卡片显示目标 parse_fps 而非原视频 FPS、扩展名校验、自动建项目、关联项目、创建异步任务、非阻塞自动消失操作提示、标准帧序列参数、帧时间戳/源帧号、任务序列元数据、worker 注册帧、取消任务、重试任务、取消后 worker 停止 |
|
||||
| R4 工作区与帧浏览 | `src/components/VideoWorkspace.test.tsx`, `src/components/FrameTimeline.test.tsx` | 加载帧、无帧项目不自动解析并提示生成帧、工作区短状态自动消失、工作区/AI 画布底图默认居中且保留边距、工作区 mask 透明度、回显已保存标注时保留本地未保存 draft mask、选中 mask 后跨帧自动跟随同一传播链结果、清空片段遮罩进入时间轴范围选择并按选区批量清空、传播权重下拉深色可读配色、缩略图/range/视频处理进度条、视频处理进度条点击跳帧、人工/AI 标注帧红色竖线和标识点击跳帧、自动传播帧蓝色区段和标识点击跳帧、最近自动传播历史片段不同色系渐变显示、缩略图红/蓝边框、人工/AI 标注帧叠加传播状态时红框优先保留并显示蓝色内描边、当前人工/AI 标注帧青色外框加红色内描边、普通状态不显示传播范围黄色选区、播放进度条和视频处理进度条选择传播/清空范围、当前帧由播放进度条末端和缩略图青色高亮表达/左右方向键切帧、播放、按项目 FPS 显示当前/总时长 |
|
||||
| R5 工具栏 | `src/components/ToolsPalette.test.tsx`, `src/components/CanvasArea.test.tsx`, `src/components/VideoWorkspace.test.tsx`, `src/store/useStore.test.ts` | 工具切换、工具栏紧凑垂直布局和高度不足时滚动、工具栏低对比滚动条、工具栏外扩滚动条槽位不挤占按钮列、调整多边形工具、AI 跳转、矩形/圆/线/点/多边形手工 mask 绘制、点工具在已有 mask 上落点、多边形 Enter/首节点闭合、上下文提示提示 Enter/Esc/首节点闭合且数秒后自动隐藏、polygon 顶点直接拖动/删除、顶点拖拽结束不改变 Canvas 视口、边中点插点、双击边界按位置插点、整块 mask 删除、区域合并/去除、布尔选择主区域/扣除区域视觉区分和选择顺序提示、内含去除 hole 渲染、合并模式隐藏编辑手柄、工作区 SAM 提示点点击删除且不冒泡新增点、工作区顶栏撤销/重做按钮、撤销/重做快捷键和输入框快捷键跳过、撤销/重做历史栈 |
|
||||
| R6 AI 推理 | `src/lib/api.test.ts`, `src/components/CanvasArea.test.tsx`, `src/components/AISegmentation.test.tsx`, `src/components/VideoWorkspace.test.tsx`, `src/components/ModelStatusBadge.test.tsx`, `backend/tests/test_ai.py`, `backend/tests/test_sam2_engine.py` | SAM 2.1 变体选择、点/框/interactive 契约、semantic 禁用、SAM 3 入口隐藏和后端拒绝、SAM 2.1 最高分候选去重、SAM 2.1 框选后正负点细化同一候选 mask、AI 页框选发送 box prompt、AI 页框选后加点发送 interactive prompt、AI 页提示工具上下文提示、AI 页重复执行替换旧候选、SAM 2.1 反向点启用背景过滤且空结果移除旧候选、AI 页不渲染工作区已有 mask、AI 页可在候选 mask 上继续添加正/反点、AI 页可单点删除提示点并删除最近锚点、AI 页可删除选中候选且不删除工作区 mask、AI 页清空只移除本页候选、AI 页参数开关可读性文案且 options 字段不变、AI 页遮罩清晰度只改预览 opacity、AI 页生成 mask 自动选中并可通过分类树换标签、AI 页推送到工作区编辑保留选择和当前帧、SAM 2.1 视频以当前参考帧全部 mask 和起止帧范围自动传播、传播前自动保存 draft/dirty seed mask、传播前独立选择 SAM 2.1 tiny/small/base+/large 权重、自动传播创建 Celery 任务、传播入队权重 id 规范化/拒绝不支持 id、传播 seed 来源 id/签名/边缘平滑 metadata、未编辑传播结果作为 seed 时继承原始签名并跳过重复传播、已编辑传播结果保留 lineage 但重算签名并清理旧结果、中间帧人工新增替代 seed 时清理下游同物体旧传播结果、中间帧 backward 传播清理旧 forward 结果、换权重或换边缘平滑参数传播先清理旧结果、旧临时 seed id 传播结果兼容清理、传播中轮询任务进度、传播任务取消/重试、传播来源 metadata 回显、空提示/空结果反馈、GPU/SAM2.1 状态、AI 参数 options、局部裁剪推理、背景过滤、状态徽标、坐标归一化、正负点 labels、polygons 转 path、后端 fake registry |
|
||||
| R7 标注保存 | `src/components/VideoWorkspace.test.tsx`, `src/components/CanvasArea.test.tsx`, `src/lib/api.test.ts`, `backend/tests/test_ai.py` | 保存标注、保存后用后端 saved annotation 替换已提交 draft、加载回显、更新 dirty 标注、清空删除已保存标注、GT mask 多类别导入、seed point 回显/归一化、项目不存在、帧不存在 |
|
||||
| R4 工作区与帧浏览 | `src/components/VideoWorkspace.test.tsx`, `src/components/FrameTimeline.test.tsx` | 加载帧、无帧项目不自动解析并提示生成帧、工作区短状态自动消失、工作区/AI 画布底图默认居中且保留边距、工作区 mask 透明度、回显已保存标注时保留本地未保存 draft mask、选中 mask 后跨帧自动跟随同一传播链结果、清空片段遮罩进入时间轴范围选择并按选区批量清空、清空全部模式、保留人工/AI 模式只清传播 mask、清空人工/AI 标注帧前二次确认、取消确认不删除、仅自动传播帧不确认、清空后裁剪/移除重叠传播历史条、传播权重下拉深色可读配色、缩略图/range/视频处理进度条、视频处理进度条点击跳帧、人工/AI 标注帧红色竖线和标识点击跳帧、自动传播帧通过 source/lineage metadata 识别为蓝色区段和标识点击跳帧、最近自动传播历史片段同一蓝色系按新旧递进纯色显示,旧记录第 5 次后统一阈值色、当前帧白色贯穿线、传播/清空范围洋红/黄绿色边界贯穿线、缩略图红/蓝边框、人工/AI 标注帧叠加传播状态时红框优先保留并显示蓝色内描边、当前人工/AI 标注帧青色外框加红色内描边、普通状态不显示传播范围黄色选区、播放进度条和视频处理进度条选择传播/清空范围、左右方向键切帧、播放、按项目 FPS 显示当前/总时长 |
|
||||
| R5 工具栏 | `src/components/ToolsPalette.test.tsx`, `src/components/CanvasArea.test.tsx`, `src/components/VideoWorkspace.test.tsx`, `src/store/useStore.test.ts` | 工具切换、工具栏紧凑垂直布局和高度不足时滚动、工具栏低对比滚动条、工具栏外扩滚动条槽位不挤占按钮列、调整多边形工具、AI 跳转、GT Mask 导入位于重叠区域去除之后且使用紫色底色、GT Mask 未知类别导入策略选择、工作区工具栏不展示 AI 正/反点和框选、左侧工具栏不重复撤销/重做、左侧工具栏不展示创建点/创建线段、矩形/圆/多边形手工 mask 绘制、画笔/橡皮擦尺寸控制、画笔新建当前类别 mask、画笔与选中 mask 连通时自动合并、橡皮擦从选中 mask 扣除、未选中 mask 时画布按语义分类树内部优先级渲染、多边形 Enter/首节点闭合、上下文提示提示 Enter/Esc/首节点闭合且数秒后自动隐藏、polygon 顶点直接拖动/删除、顶点拖拽结束不改变 Canvas 视口、边中点插点、双击边界按位置插点、整块 mask 删除、区域合并/去除、布尔选择主区域/扣除区域视觉区分和选择顺序提示、内含去除 hole 渲染、合并模式隐藏编辑手柄、工作区顶栏撤销/重做按钮、顶栏撤销/重做图标强调色、撤销/重做快捷键和输入框快捷键跳过、撤销/重做历史栈 |
|
||||
| R6 AI 推理 | `src/lib/api.test.ts`, `src/components/CanvasArea.test.tsx`, `src/components/AISegmentation.test.tsx`, `src/components/VideoWorkspace.test.tsx`, `src/components/ModelStatusBadge.test.tsx`, `backend/tests/test_ai.py`, `backend/tests/test_sam2_engine.py` | SAM 2.1 变体选择、点/框/interactive 契约、semantic 禁用、SAM 3 入口隐藏和后端拒绝、SAM 2.1 最高分候选去重、SAM 2.1 框选后正负点细化同一候选 mask、AI 页框选发送 box prompt、AI 页框选后加点发送 interactive prompt、AI 页提示工具上下文提示、AI 页重复执行替换旧候选、SAM 2.1 反向点启用背景过滤且空结果移除旧候选、AI 页不渲染工作区已有 mask、AI 页可在候选 mask 上继续添加正/反点、AI 页可单点删除提示点并删除最近锚点、AI 页可删除选中候选且不删除工作区 mask、AI 页清空只移除本页候选、AI 页参数开关可读性文案且 options 字段不变、AI 页/右侧共享遮罩透明度只改预览 opacity、AI 页生成 mask 自动选中并可通过分类树换标签、AI 页无语义候选禁止推送到工作区并用 error toast 提示、离开 AI 页时清理未分类候选、AI 页推送到工作区编辑保留选择和当前帧、SAM 2.1 视频以当前参考帧全部 mask 和起止帧范围自动传播、传播前自动保存 draft/dirty seed mask、传播前独立选择 SAM 2.1 tiny/small/base+/large 权重、自动传播创建 Celery 任务、传播入队权重 id 规范化/拒绝不支持 id、传播 seed 来源 id/签名和历史平滑 metadata 兼容、历史平滑 seed 保存前对 forward/backward polygon 实际应用边缘平滑并减少密集轮廓点、边缘平滑强度缓入递进曲线、未编辑传播结果作为 seed 时继承原始签名并跳过重复传播、已编辑传播结果保留 lineage 但重算签名并清理旧结果、中间帧人工新增替代 seed 时清理下游同物体旧传播结果、中间帧 backward 传播清理旧 forward 结果、换权重传播先清理旧结果、旧临时 seed id 传播结果兼容清理、传播中轮询任务进度、传播任务取消/重试、传播来源 metadata 回显、空提示/空结果反馈、GPU/SAM2.1 状态、AI 参数 options、局部裁剪推理、背景过滤、状态徽标、坐标归一化、正负点 labels、polygons 转 path、后端 fake registry |
|
||||
| R7 标注保存 | `src/components/VideoWorkspace.test.tsx`, `src/components/CanvasArea.test.tsx`, `src/lib/api.test.ts`, `backend/tests/test_ai.py` | 保存状态按钮“保存 X 个改动/已全部保存”、保存标注、保存后用后端 saved annotation 替换已提交 draft、加载回显、更新 dirty 标注、清空删除已保存标注、GT mask 多类别导入、seed point 回显/归一化、项目不存在、帧不存在 |
|
||||
| R8 模板库 | `src/components/TemplateRegistry.test.tsx`, `src/components/TransientNotice.test.tsx`, `src/lib/api.test.ts`, `backend/tests/test_templates.py` | 前端模板加载/新建/编辑/删除、JSON 分类导入、JSON/保存错误非阻塞提示、mapping_rules 解包/打包、后端模板 CRUD |
|
||||
| R9 本体检查面板 | `src/components/OntologyInspector.test.tsx`, `src/components/CanvasArea.test.tsx`, `src/store/useStore.test.ts`, `backend/tests/test_ai.py` | 模板选择、面板标题简化、面板低对比滚动条、工作区遮罩透明度滑杆、分类展示、具体分类选择、Canvas 选区同步、点击 Canvas mask 后自动聚焦对应语义分类、点击分类给已选 mask 换标签并移动到前端渲染最上层、自定义分类 PATCH 后端模板、目标实例标题显示当前 mask label、隐藏后端模型置信度、选中 mask 后端拓扑属性分析、重新提取拓扑锚点、边缘平滑强度调用后端并将 mask 标记为 dirty |
|
||||
| R10 Dashboard 与 WebSocket | `src/lib/api.test.ts`, `src/lib/websocket.test.ts`, `src/components/Dashboard.test.tsx`, `backend/tests/test_dashboard.py`, `backend/tests/test_main.py`, `backend/tests/test_progress_events.py`, `backend/tests/test_tasks.py` | 后端概览接口、任务表驱动进度区、最近完成任务保留显示、任务取消/重试/详情、cancelled 事件、Redis 进度事件 payload/发布、地址推导、消息订阅、连接状态回调、队列更新、heartbeat |
|
||||
| R11 导出 | `src/components/VideoWorkspace.test.tsx`, `src/lib/api.test.ts`, `backend/tests/test_export.py` | COCO/PNG 按钮下载、导出前自动保存、导出路径、JSON 结构、mask ZIP、zIndex 语义融合 |
|
||||
| R9 本体检查面板 | `src/components/OntologyInspector.test.tsx`, `src/components/CanvasArea.test.tsx`, `src/store/useStore.test.ts`, `backend/tests/test_ai.py` | 模板选择、面板标题简化、面板低对比滚动条、工作区遮罩透明度滑杆、分类展示、具体分类选择、Canvas 选区同步、点击 Canvas mask 后自动聚焦对应语义分类、点击分类给已选 mask 换标签并移动到前端渲染最上层、分类变更同步同一传播链前后帧对应 mask、自定义分类 PATCH 后端模板、目标实例标题显示当前 mask label、隐藏当前选中区域计数、隐藏后端模型置信度、选中 mask 后端拓扑属性分析、拓扑锚点数量按真实 polygon 顶点数显示、分析请求 abort/cancel 静默忽略且旧请求不覆盖新状态、边缘平滑强度防抖预览不标 dirty、应用边缘平滑后将 mask 标记为 dirty、平滑作为实际几何编辑、平滑同步传播链对应 mask、平滑保存时保留传播 lineage 而不把传播帧变成人工/AI 标注帧、平滑撤销/重做、平滑应用后强度归零 |
|
||||
| R10 Dashboard 与 WebSocket | `src/lib/api.test.ts`, `src/lib/websocket.test.ts`, `src/components/Dashboard.test.tsx`, `backend/tests/test_dashboard.py`, `backend/tests/test_main.py`, `backend/tests/test_progress_events.py`, `backend/tests/test_tasks.py` | 后端概览接口、任务表驱动进度区、最近完成任务保留显示、任务取消/重试/详情、cancelled 事件、Redis 进度事件 payload/发布、地址推导、消息订阅、连接状态回调、队列更新、heartbeat、主动断开不重连 |
|
||||
| R11 导出 | `src/components/VideoWorkspace.test.tsx`, `src/lib/api.test.ts`, `backend/tests/test_export.py` | 统一分割结果导出下拉、导出前自动保存、整体/范围/当前帧范围参数、特定范围帧可通过播放进度条/视频处理进度条拖拽选择、下载 ZIP 按项目名/`0h00m00s000ms` 起止时间戳/起止项目帧序号命名、导出内容 outputs 参数、Mix_label 透明度参数和预览、兼容 COCO/PNG 路径、JSON 结构、maskid/GT 像素值映射 JSON、原始图片文件夹、按帧/按类别合并的分开 Mask 文件夹、GT_label 黑白图文件夹、Pro_label 彩色图文件夹、Mix_label 原图叠加图文件夹、GT/Pro/Mix 按内部优先级覆盖且和语义分类树顺序一致、GT_label 背景 0、保留类别真实 maskid、导出 GT_label 再导入保持类别一致 |
|
||||
| R12 配置 | `src/lib/config.test.ts` | env 优先、hostname 推导、WS 推导 |
|
||||
| R13 文档与测试 | `doc/09-test-plan.md` | 测试覆盖矩阵 |
|
||||
|
||||
@@ -32,24 +32,24 @@
|
||||
|
||||
| 需求 | 功能点 | 对应测试 | 当前状态 |
|
||||
|------|--------|----------|----------|
|
||||
| R1 | 登录页、默认开发凭证、token 写入、失败提示、后端 401 | `Login.test.tsx`, `test_auth.py` | 已覆盖 |
|
||||
| R2 | 项目列表/创建/选择、视频导入、DICOM 导入、后端项目和帧 CRUD | `ProjectLibrary.test.tsx`, `api.test.ts`, `test_projects.py` | 已覆盖 |
|
||||
| R1 | 登录页、默认开发管理员、JWT 写入、当前用户写入、刷新恢复基础状态、失败提示、后端 401、`/api/auth/me`、管理员用户管理、角色权限、审计日志、演示出厂设置二次确认、重置后只保留 admin 和未生成帧演示视频项目 | `Login.test.tsx`, `Sidebar.test.tsx`, `UserAdmin.test.tsx`, `useStore.test.ts`, `test_auth.py`, `test_admin.py` | 已覆盖 |
|
||||
| R2 | 项目列表/创建/选择、项目按用户隔离、视频导入、DICOM 导入、后端项目和帧 CRUD | `ProjectLibrary.test.tsx`, `api.test.ts`, `test_projects.py` | 已覆盖 |
|
||||
| R3 | 文件类型校验、自动/指定项目上传、视频导入与生成帧分离、显式 FPS 生成帧、项目卡片 FPS 徽标显示 `parse_fps`、视频/DICOM 拆帧任务、非阻塞自动消失操作提示、`parse_fps/max_frames/target_width`、标准帧序列 metadata、任务查询、取消、重试、worker 取消停止 | `ProjectLibrary.test.tsx`, `TransientNotice.test.tsx`, `api.test.ts`, `test_media.py`, `test_tasks.py` | 已覆盖 |
|
||||
| R4 | 工作区加载帧、无帧项目不自动解析、工作区短状态自动消失、后端标注回显保留本地未保存 draft mask、Canvas/AI 底图居中适配且保留边距、工作区 mask 透明度、选中 mask 后跨帧自动跟随同一传播链结果、清空片段遮罩进入时间轴范围选择并按选区批量清空、传播权重下拉深色可读配色、缩略图/range/视频处理进度条、视频处理进度条点击跳帧、人工/AI 标注帧红色竖线和标识点击跳帧、自动传播帧蓝色区段和标识点击跳帧、最近自动传播历史片段不同色系渐变显示、缩略图红/蓝边框、人工/AI 标注帧叠加传播状态时红框优先保留并显示蓝色内描边、当前人工/AI 标注帧青色外框加红色内描边、普通状态不显示传播范围黄色选区、播放进度条/视频处理进度条拖拽选择传播/清空范围、Canvas/AI 画布拖拽平移回写 position state、当前帧由播放进度条末端和缩略图青色高亮表达/左右方向键切帧、播放、按 FPS 显示时间 | `VideoWorkspace.test.tsx`, `FrameTimeline.test.tsx`, `CanvasArea.test.tsx`, `AISegmentation.test.tsx` | 已覆盖 |
|
||||
| R5 | 工具切换、工具栏紧凑滚动布局、低对比滚动条、外扩滚动条槽位、调整多边形入口、AI 跳转、矩形/圆/线/点/多边形绘制、已有 mask 上继续绘制、多边形和布尔工具上下文提示、Canvas 上下文提示数秒后自动隐藏 | `ToolsPalette.test.tsx`, `CanvasArea.test.tsx` | 已覆盖 |
|
||||
| R5 | 顶点直接拖动编辑、顶点拖拽结束不改变 Canvas 视口、边中点插点、双击边界按位置插点、顶点删除、整块删除、工作区 SAM 提示点删除优先级、工作区顶栏撤销/重做按钮、撤销/重做快捷键、区域合并、区域去除、布尔选择主区域黄色实线/扣除区域红色虚线、布尔选择顺序提示、hole even-odd 渲染 | `CanvasArea.test.tsx`, `VideoWorkspace.test.tsx`, `useStore.test.ts` | 已覆盖 |
|
||||
| R6 | SAM 2.1 变体选择、点/框/interactive、semantic 禁用、SAM 3 入口隐藏和后端拒绝、SAM 2.1 最高分候选去重、AI 页框选/框选后加点、AI 页提示工具上下文提示、AI 页重复执行替换旧候选、AI 页不渲染工作区已有 mask、AI 页可在候选 mask 上继续添加正/反点、AI 页可删除提示点、AI 页可删除选中候选、AI 页清空只移除本页候选、AI 页遮罩清晰度只改预览 opacity、AI 页生成 mask 自动选中并可换标签、AI 页推送到工作区编辑保留选择和当前帧、SAM 2.1 视频按参考帧全部 mask 和范围自动传播、传播前自动保存 draft/dirty seed mask、传播前独立选择 SAM 2.1 tiny/small/base+/large 权重、自动传播 Celery 任务入队、传播入队权重 id 规范化/拒绝不支持 id、传播 seed 来源 id/签名/边缘平滑 metadata、未编辑传播结果作为 seed 时继承原始签名并跳过重复传播、已编辑传播结果保留 lineage 但重算签名并清理旧结果、中间帧人工新增替代 seed 时清理下游同物体旧传播结果、中间帧 backward 传播清理旧 forward 结果、换权重或平滑参数传播先清理旧结果、旧临时 seed id 传播结果兼容清理、前端任务轮询进度、传播任务 runner 保存标注和结果权重 id、传播任务重试、传播空结果提示、GPU/模型状态、参数 options、polygons 转 mask | `api.test.ts`, `CanvasArea.test.tsx`, `AISegmentation.test.tsx`, `VideoWorkspace.test.tsx`, `ModelStatusBadge.test.tsx`, `test_ai.py`, `test_tasks.py`, `test_sam2_engine.py` | 已覆盖 |
|
||||
| R7 | 保存、保存后替换已提交 draft、查询、更新、删除标注、工作区回显、清空已保存标注、GT mask 导入和 seed point 回写 | `VideoWorkspace.test.tsx`, `CanvasArea.test.tsx`, `api.test.ts`, `test_ai.py` | 已覆盖 |
|
||||
| R4 | 工作区加载帧、无帧项目不自动解析、工作区短状态自动消失、后端标注回显保留本地未保存 draft mask、Canvas/AI 底图居中适配且保留边距、工作区 mask 透明度、选中 mask 后跨帧自动跟随同一传播链结果、清空片段遮罩进入时间轴范围选择并按选区批量清空、清空全部模式、保留人工/AI 模式只清传播 mask、清空人工/AI 标注帧前二次确认、取消确认不删除、仅自动传播帧不确认、清空后裁剪/移除重叠传播历史条、传播权重下拉深色可读配色、缩略图/range/视频处理进度条、视频处理进度条点击跳帧、人工/AI 标注帧红色竖线和标识点击跳帧、自动传播帧蓝色区段和标识点击跳帧、最近自动传播历史片段同一蓝色系按新旧递进显示,旧记录第 5 次后统一阈值色、当前帧白色贯穿线、传播/清空范围洋红/黄绿色边界贯穿线、缩略图红/蓝边框、人工/AI 标注帧叠加传播状态时红框优先保留并显示蓝色内描边、当前人工/AI 标注帧青色外框加红色内描边、普通状态不显示传播范围黄色选区、播放进度条/视频处理进度条拖拽选择传播/清空范围、Canvas/AI 画布拖拽平移回写 position state、左右方向键切帧、播放、按 FPS 显示时间 | `VideoWorkspace.test.tsx`, `FrameTimeline.test.tsx`, `CanvasArea.test.tsx`, `AISegmentation.test.tsx` | 已覆盖 |
|
||||
| R5 | 工具切换、工具栏紧凑滚动布局、低对比滚动条、外扩滚动条槽位、调整多边形入口、GT Mask 导入入口位置和紫色底色、工作区工具栏隐藏 AI 正/反点和框选、左侧工具栏不重复撤销/重做、AI 跳转、矩形/圆/线/点/多边形绘制、已有 mask 上继续绘制、多边形和布尔工具上下文提示、Canvas 上下文提示数秒后自动隐藏 | `ToolsPalette.test.tsx`, `CanvasArea.test.tsx` | 已覆盖 |
|
||||
| R5 | 顶点直接拖动编辑、顶点拖拽结束不改变 Canvas 视口、边中点插点、双击边界按位置插点、顶点删除、整块删除、工作区顶栏撤销/重做按钮、顶栏撤销/重做图标强调色、撤销/重做快捷键、区域合并、区域去除、布尔选择主区域黄色实线/扣除区域红色虚线、布尔选择顺序提示、hole even-odd 渲染 | `CanvasArea.test.tsx`, `VideoWorkspace.test.tsx`, `useStore.test.ts` | 已覆盖 |
|
||||
| R6 | SAM 2.1 变体选择、点/框/interactive、semantic 禁用、SAM 3 入口隐藏和后端拒绝、SAM 2.1 最高分候选去重、AI 页框选/框选后加点、AI 页提示工具上下文提示、AI 页重复执行替换旧候选、AI 页不渲染工作区已有 mask、AI 页可在候选 mask 上继续添加正/反点、AI 页可删除提示点、AI 页可删除选中候选、AI 页清空只移除本页候选、AI 页/右侧共享遮罩透明度只改预览 opacity、AI 页生成 mask 自动选中并可换标签、AI 页无语义候选禁止推送到工作区并用 error toast 提示、离开 AI 页时清理未分类候选、AI 页推送到工作区编辑保留选择和当前帧、SAM 2.1 视频按参考帧全部 mask 和范围自动传播、传播前自动保存 draft/dirty seed mask、传播前独立选择 SAM 2.1 tiny/small/base+/large 权重、自动传播 Celery 任务入队、传播入队权重 id 规范化/拒绝不支持 id、传播 seed 来源 id/签名和历史平滑 metadata 兼容、历史平滑 seed 保存前对 forward/backward polygon 实际应用边缘平滑并减少密集轮廓点、边缘平滑强度缓入递进曲线、未编辑传播结果作为 seed 时继承原始签名并跳过重复传播、已编辑传播结果保留 lineage 但重算签名并清理旧结果、中间帧人工新增替代 seed 时清理下游同物体旧传播结果、中间帧 backward 传播清理旧 forward 结果、换权重传播先清理旧结果、旧临时 seed id 传播结果兼容清理、前端任务轮询进度、传播任务 runner 保存标注和结果权重 id、传播任务重试、传播空结果提示、GPU/模型状态、参数 options、polygons 转 mask | `api.test.ts`, `CanvasArea.test.tsx`, `AISegmentation.test.tsx`, `VideoWorkspace.test.tsx`, `ModelStatusBadge.test.tsx`, `test_ai.py`, `test_tasks.py`, `test_sam2_engine.py` | 已覆盖 |
|
||||
| R7 | 保存状态按钮“保存 X 个改动/已全部保存”、保存、保存后替换已提交 draft、查询、更新、删除标注、工作区回显、清空已保存标注、GT mask 导入和 seed point 回写、低数值/16-bit GT_label 图导入、RGB 等通道 maskid 图导入、导入预览、未知 maskid 导入策略、非法彩色 GT mask 拒绝、尺寸不一致自动最近邻拉伸 | `VideoWorkspace.test.tsx`, `CanvasArea.test.tsx`, `api.test.ts`, `test_ai.py` | 已覆盖 |
|
||||
| R8 | 模板加载、新建、编辑、删除、JSON 分类导入、JSON/保存错误非阻塞提示、mapping_rules 映射、后端 CRUD | `TemplateRegistry.test.tsx`, `TransientNotice.test.tsx`, `api.test.ts`, `test_templates.py` | 已覆盖 |
|
||||
| R9 | 模板选择、面板标题简化、工作区遮罩透明度滑杆、分类展示、分类选择、点击 mask 自动聚焦对应分类、已选 mask 换标签并置顶显示、自定义分类写入后端模板、目标实例标题显示当前 mask label、隐藏后端模型置信度、后端拓扑属性分析、边缘平滑强度应用、占位状态 | `OntologyInspector.test.tsx`, `CanvasArea.test.tsx`, `useStore.test.ts`, `test_ai.py` | 已覆盖 |
|
||||
| R9 | 模板选择、面板标题简化、工作区遮罩透明度滑杆、分类展示、分类选择、分类树拖拽调整内部覆盖顺序且不改变 maskid、拖拽后同步同类 mask 层级并标记待保存、点击 mask 自动聚焦对应分类、已选 mask 换标签并置顶显示、分类变更同步同一传播链前后帧对应 mask、自定义分类写入后端模板、目标实例标题显示当前 mask label、隐藏当前选中区域计数、隐藏后端模型置信度、后端拓扑属性分析、拓扑锚点真实顶点计数、分析请求 abort/cancel 静默忽略且旧请求不覆盖新状态、边缘平滑强度防抖预览、边缘平滑应用后确认 dirty、平滑作为实际几何编辑、平滑同步传播链对应 mask、平滑撤销/重做、平滑应用后强度归零、占位状态 | `OntologyInspector.test.tsx`, `CanvasArea.test.tsx`, `useStore.test.ts`, `test_ai.py` | 已覆盖 |
|
||||
| R10 | Dashboard 概览、任务进度区、最近完成任务保留显示、活动日志、WebSocket progress/complete/error/status/cancelled、取消/重试/详情、连接状态回调、heartbeat | `Dashboard.test.tsx`, `websocket.test.ts`, `test_dashboard.py`, `test_main.py`, `test_progress_events.py`, `test_tasks.py` | 已覆盖 |
|
||||
| R11 | COCO/PNG ZIP 导出、导出前保存、路径和 JSON/ZIP 结构、zIndex 融合 | `VideoWorkspace.test.tsx`, `api.test.ts`, `test_export.py` | 已覆盖 |
|
||||
| R11 | 统一“分割结果导出”下拉、整体视频/特定范围帧/当前图片导出、特定范围帧时间轴拖拽选择、ZIP 文件名 `{项目库项目名}_seg_T_{起始时间戳}-{结束时间戳}_P_{起始项目帧序号}-{结束项目帧序号}.zip`、时间戳 `0h00m00s000ms` 格式、项目帧序号使用抽帧后 1-based 顺序、分开 Mask/GT_label/Pro_label/Mix_label outputs、Mix_label 透明度、导出前保存、兼容 COCO/PNG ZIP 路径、JSON/ZIP 结构、maskid/GT 像素值映射、原始图片导出、分开 Mask 按帧子目录与同类合并命名、GT_label/Pro_label/Mix_label 命名、GT/Pro/Mix 内部优先级融合且和语义分类树顺序一致、GT_label 背景 0、保留类别真实 maskid、导出的 GT_label 可按同一模板导回 | `VideoWorkspace.test.tsx`, `api.test.ts`, `test_export.py` | 已覆盖 |
|
||||
| R12 | API/WS 地址 env 优先和 hostname 推导 | `config.test.ts` | 已覆盖 |
|
||||
| R13 | 文档测试矩阵与功能点追踪 | `doc/09-test-plan.md` | 已覆盖 |
|
||||
|
||||
## 本轮补齐记录
|
||||
|
||||
- R5:补充 `CanvasArea.test.tsx` 中圆形和线段手工绘制测试,明确验证 metadata、segmentation、bbox/area 和草稿状态。
|
||||
- R5:补充 `CanvasArea.test.tsx` 中圆形、画笔新建、画笔连通合并和橡皮擦扣除测试,明确验证 metadata、segmentation、bbox/area、选中状态和草稿状态;补充 `ToolsPalette.test.tsx` 中画笔/橡皮擦尺寸控制测试,并验证创建点、创建线段入口不再显示。
|
||||
- R6:补充 `AISegmentation.test.tsx` 中 SAM 2.1 变体选择测试,验证前端不展示 SAM 3 入口、选择 small 后请求携带对应模型,且未放置点提示时不发起推理。
|
||||
- R6:补充 SAM 2 纯文本提示拦截、SAM 2 多候选只保留最高分、SAM 2 engine 单候选请求测试,避免多个重叠候选 mask 被同时叠加。
|
||||
- R6:补充 Canvas 工作区 SAM 2 反向点背景过滤测试,覆盖请求 options 和过滤为空时清除旧候选 mask。
|
||||
@@ -59,7 +59,9 @@
|
||||
- R6:补充 `propagateMasks()` 同步兼容接口和 `queuePropagationTask()` 任务接口测试,验证当前参考帧全部 mask 会按范围组装为后台传播 steps。
|
||||
- R6:补充 `VideoWorkspace` 自动传播进度测试,验证传播任务运行中显示进度,后端返回 0 个新区域时给出明确反馈。
|
||||
- R4/R6:补充时间轴传播范围选择测试,验证点击“自动传播”后可在播放进度条或视频处理进度条上拖拽回填起止帧,再提交后台传播任务。
|
||||
- R4/R6:补充视频处理进度条传播历史测试,验证多次自动传播后会按不同色系渐变片段显示最近处理范围。
|
||||
- R4/R6:补充视频处理进度条传播历史测试,验证多次自动传播后会按同一蓝色系显示最近处理范围,最新最亮、旧记录逐次变暗且第 5 次后统一阈值色,单个片段不使用渐变。
|
||||
- R4:补充清空片段遮罩后移除重叠传播历史条测试,避免已清空视频范围继续显示最近传播进度。
|
||||
- R4:补充清空片段遮罩模式测试,覆盖“清空全部”确认删除、“保留人工/AI”只清传播 mask、取消不删除、仅自动传播帧不弹确认。
|
||||
- R6/R10:补充 `queuePropagationTask()`、`POST /api/ai/propagate/task`、传播 Celery runner 和传播任务重试测试,验证工作区自动传播不再依赖长 HTTP 请求,并验证传给 `SAM2VideoPredictor` 的临时帧文件名是纯数字序列。
|
||||
- R6:补充传播去重回归测试,验证前端传播前会先保存 draft seed mask 并用稳定 `source_annotation_id` 入队;后端在 seed 来源由前端临时 id 迁移到后端 annotation id、用户换用其他 SAM 2.1 权重、未编辑传播结果再次作为 seed、已编辑传播结果重新作为 seed、中间帧人工新增替代 seed 时,会分别跳过或清理旧传播标注再保存新结果。
|
||||
- R6:`backend/tests/test_sam3_engine.py` 已标记跳过,仅作为历史保留实现的参考测试,不计入当前产品功能覆盖。
|
||||
@@ -67,6 +69,8 @@
|
||||
- R3:补充 worker 注册标准帧序列测试,验证帧 `timestamp_ms`、`source_frame_number` 和 `result.frame_sequence` 元数据。
|
||||
- R8:补充 `TemplateRegistry.test.tsx` 中模板编辑、删除测试,验证前端调用真实 API 封装并更新全局 store。
|
||||
- R9:补充 Canvas 选中 mask id 全局同步、本体树点击分类给已选 mask 换标签并移到渲染最上层的测试,验证已保存 mask 会进入 dirty 状态。
|
||||
- R9:补充边缘平滑滑杆防抖测试,验证连续拖动只触发最后一次后端预览请求,降低拖动卡顿。
|
||||
- R9:补充边缘平滑应用到传播链并可撤销/重做的测试,验证平滑后成为新的实际 polygon、强度归零且不再只保存平滑参数。
|
||||
|
||||
## 运行命令
|
||||
|
||||
|
||||
@@ -187,6 +187,11 @@ sam_model_config=configs/sam2.1/sam2.1_hiera_t.yaml
|
||||
|
||||
app_env=development
|
||||
cors_origins=["http://localhost:3000","http://127.0.0.1:3000"]
|
||||
jwt_secret_key=change-this-to-a-long-random-production-secret
|
||||
access_token_expire_minutes=1440
|
||||
default_admin_username=admin
|
||||
default_admin_password=123456
|
||||
demo_video_path=/home/wkmgc/Desktop/Seg_Server/Data_MyVideo_1.mp4
|
||||
EOF
|
||||
```
|
||||
|
||||
@@ -306,6 +311,10 @@ curl http://localhost:9000/minio/health/live
|
||||
admin / 123456
|
||||
```
|
||||
|
||||
首次启动会自动创建默认管理员,密码以哈希形式写入 `users` 表;登录返回签名 JWT,业务接口会校验 `Authorization: Bearer <token>`。生产环境必须修改 `jwt_secret_key` 和默认管理员密码。
|
||||
|
||||
默认管理员登录后会看到“用户管理”后台,可新增用户、停用/启用用户、修改角色、重置密码、删除无项目用户并查看登录与用户管理审计日志。角色分为 `admin`、`annotator`、`viewer`:`admin/annotator` 可以执行写入类业务操作,`viewer` 只读。演示部署可在该后台使用“恢复演示出厂设置”,二次确认后只保留默认 admin 和一个尚未生成帧的演示视频项目;该视频来自 `demo_video_path`。
|
||||
|
||||
---
|
||||
|
||||
## 10. 一键启动脚本
|
||||
@@ -471,7 +480,7 @@ python -m pytest backend/tests
|
||||
4. 在项目库点击“生成帧”,选择 FPS。
|
||||
5. Dashboard 中应看到任务进度;Celery 日志应显示拆帧任务。
|
||||
6. 进入分割工作区,能看到帧、时间轴和画布。
|
||||
7. 手工画一个多边形 mask,点击“结构化归档保存”。
|
||||
7. 手工画一个多边形 mask,确认顶栏保存状态按钮显示“保存 1 个改动”,点击保存。
|
||||
8. 刷新工作区后,已保存标注应回显。
|
||||
9. AI 智能分割中选择可用 SAM 2.1 模型,放置点或框,执行分割。
|
||||
10. 导出 JSON 或 PNG Mask ZIP。
|
||||
|
||||
22
src/App.tsx
22
src/App.tsx
@@ -1,6 +1,6 @@
|
||||
import React, { useEffect } from 'react';
|
||||
import { useStore } from './store/useStore';
|
||||
import { getProjects } from './lib/api';
|
||||
import { getCurrentUser, getProjects } from './lib/api';
|
||||
import { Sidebar } from './components/Sidebar';
|
||||
import { Dashboard } from './components/Dashboard';
|
||||
import { ProjectLibrary } from './components/ProjectLibrary';
|
||||
@@ -8,8 +8,9 @@ import { VideoWorkspace } from './components/VideoWorkspace';
|
||||
import { TemplateRegistry } from './components/TemplateRegistry';
|
||||
import { AISegmentation } from './components/AISegmentation';
|
||||
import { Login } from './components/Login';
|
||||
import { UserAdmin } from './components/UserAdmin';
|
||||
|
||||
export type ActiveModule = 'dashboard' | 'projects' | 'ai' | 'workspace' | 'templates';
|
||||
export type ActiveModule = 'dashboard' | 'projects' | 'ai' | 'workspace' | 'templates' | 'admin';
|
||||
|
||||
export default function App() {
|
||||
const isAuthenticated = useStore((state) => state.isAuthenticated);
|
||||
@@ -17,17 +18,27 @@ export default function App() {
|
||||
const setActiveModule = useStore((state) => state.setActiveModule);
|
||||
const setProjects = useStore((state) => state.setProjects);
|
||||
const setError = useStore((state) => state.setError);
|
||||
const setCurrentUser = useStore((state) => state.setCurrentUser);
|
||||
const logout = useStore((state) => state.logout);
|
||||
const currentUser = useStore((state) => state.currentUser);
|
||||
|
||||
useEffect(() => {
|
||||
if (isAuthenticated) {
|
||||
getProjects()
|
||||
.then((data) => setProjects(data))
|
||||
Promise.all([getCurrentUser(), getProjects()])
|
||||
.then(([user, projects]) => {
|
||||
setCurrentUser(user);
|
||||
setProjects(projects);
|
||||
})
|
||||
.catch((err) => {
|
||||
console.error('Failed to fetch projects:', err);
|
||||
if (err?.response?.status === 401) {
|
||||
logout();
|
||||
return;
|
||||
}
|
||||
setError('获取项目列表失败');
|
||||
});
|
||||
}
|
||||
}, [isAuthenticated, setProjects, setError]);
|
||||
}, [isAuthenticated, logout, setCurrentUser, setProjects, setError]);
|
||||
|
||||
if (!isAuthenticated) {
|
||||
return <Login />;
|
||||
@@ -42,6 +53,7 @@ export default function App() {
|
||||
{activeModule === 'ai' && <AISegmentation onSendToWorkspace={() => setActiveModule('workspace')} />}
|
||||
{activeModule === 'workspace' && <VideoWorkspace onNavigateToAI={() => setActiveModule('ai')} />}
|
||||
{activeModule === 'templates' && <TemplateRegistry />}
|
||||
{activeModule === 'admin' && currentUser?.role === 'admin' && <UserAdmin />}
|
||||
</main>
|
||||
</div>
|
||||
);
|
||||
|
||||
@@ -356,13 +356,41 @@ describe('AISegmentation', () => {
|
||||
await waitFor(() => expect(screen.getByTestId('konva-path')).toBeInTheDocument());
|
||||
|
||||
const maskGroup = () => screen.getAllByTestId('konva-group').find((group) => group.getAttribute('data-opacity'));
|
||||
expect(maskGroup()).toHaveAttribute('data-opacity', '0.72');
|
||||
fireEvent.change(screen.getByLabelText('遮罩清晰度'), { target: { value: '35' } });
|
||||
expect(maskGroup()).toHaveAttribute('data-opacity', '0.5');
|
||||
fireEvent.change(screen.getByLabelText('AI 遮罩透明度'), { target: { value: '35' } });
|
||||
|
||||
expect(maskGroup()).toHaveAttribute('data-opacity', '0.35');
|
||||
expect(useStore.getState().maskPreviewOpacity).toBe(35);
|
||||
expect(useStore.getState().masks[0].segmentation).toEqual([[10, 10, 40, 10, 40, 40]]);
|
||||
});
|
||||
|
||||
it('updates AI candidate opacity when the shared ontology opacity slider changes', async () => {
|
||||
apiMock.predictMask.mockResolvedValueOnce({
|
||||
masks: [
|
||||
{
|
||||
id: 'sam2-mask',
|
||||
pathData: 'M 10 10 L 40 10 L 40 40 Z',
|
||||
label: 'AI Mask',
|
||||
color: '#06b6d4',
|
||||
segmentation: [[10, 10, 40, 10, 40, 40]],
|
||||
bbox: [10, 10, 30, 30],
|
||||
area: 900,
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
render(<AISegmentation onSendToWorkspace={vi.fn()} />);
|
||||
fireEvent.click(screen.getByText('正向选点'));
|
||||
fireEvent.click(screen.getByTestId('konva-stage'));
|
||||
fireEvent.click(await screen.findByText('执行高精度语义分割'));
|
||||
await waitFor(() => expect(screen.getByTestId('konva-path')).toBeInTheDocument());
|
||||
|
||||
const maskGroup = () => screen.getAllByTestId('konva-group').find((group) => group.getAttribute('data-opacity'));
|
||||
fireEvent.change(screen.getByLabelText('遮罩透明度'), { target: { value: '80' } });
|
||||
|
||||
expect(maskGroup()).toHaveAttribute('data-opacity', '0.8');
|
||||
});
|
||||
|
||||
it('lets positive and negative prompt points be added on top of an AI mask', async () => {
|
||||
apiMock.predictMask
|
||||
.mockResolvedValueOnce({
|
||||
@@ -558,6 +586,11 @@ describe('AISegmentation', () => {
|
||||
|
||||
it('keeps the generated SAM2 mask selected when sending it to the workspace editor', async () => {
|
||||
const onSendToWorkspace = vi.fn();
|
||||
useStore.setState({
|
||||
activeTemplateId: 'template-1',
|
||||
activeClass: { id: 'class-1', name: '胆囊', color: '#ff0000', zIndex: 30 },
|
||||
activeClassId: 'class-1',
|
||||
});
|
||||
apiMock.predictMask.mockResolvedValueOnce({
|
||||
masks: [
|
||||
{
|
||||
@@ -585,4 +618,94 @@ describe('AISegmentation', () => {
|
||||
expect(onSendToWorkspace).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('blocks sending an AI candidate to the workspace until a semantic class is selected', async () => {
|
||||
const onSendToWorkspace = vi.fn();
|
||||
apiMock.predictMask.mockResolvedValueOnce({
|
||||
masks: [
|
||||
{
|
||||
id: 'sam2-mask',
|
||||
pathData: 'M 10 10 L 40 10 L 40 40 Z',
|
||||
label: 'AI Mask',
|
||||
color: '#06b6d4',
|
||||
segmentation: [[10, 10, 40, 10, 40, 40]],
|
||||
bbox: [10, 10, 30, 30],
|
||||
area: 900,
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
render(<AISegmentation onSendToWorkspace={onSendToWorkspace} />);
|
||||
fireEvent.click(screen.getByText('正向选点'));
|
||||
fireEvent.click(screen.getByTestId('konva-stage'));
|
||||
fireEvent.click(await screen.findByText('执行高精度语义分割'));
|
||||
await waitFor(() => expect(useStore.getState().selectedMaskIds).toEqual(['sam2-mask']));
|
||||
|
||||
fireEvent.click(screen.getByText('推送至工作区编辑'));
|
||||
|
||||
const toast = screen.getByRole('status');
|
||||
expect(toast).toHaveTextContent('请先在右侧语义分类树为 AI 候选区域选择语义分类,再推送至工作区。');
|
||||
expect(toast.className).toContain('bg-red-950');
|
||||
expect(useStore.getState().activeTool).toBe('point_pos');
|
||||
expect(onSendToWorkspace).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('removes unclassified AI candidates when leaving the AI page', async () => {
|
||||
apiMock.predictMask.mockResolvedValueOnce({
|
||||
masks: [
|
||||
{
|
||||
id: 'sam2-mask',
|
||||
pathData: 'M 10 10 L 40 10 L 40 40 Z',
|
||||
label: 'AI Mask',
|
||||
color: '#06b6d4',
|
||||
segmentation: [[10, 10, 40, 10, 40, 40]],
|
||||
bbox: [10, 10, 30, 30],
|
||||
area: 900,
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
const { unmount } = render(<AISegmentation onSendToWorkspace={vi.fn()} />);
|
||||
fireEvent.click(screen.getByText('正向选点'));
|
||||
fireEvent.click(screen.getByTestId('konva-stage'));
|
||||
fireEvent.click(await screen.findByText('执行高精度语义分割'));
|
||||
await waitFor(() => expect(useStore.getState().masks).toHaveLength(1));
|
||||
|
||||
unmount();
|
||||
|
||||
expect(useStore.getState().masks).toEqual([]);
|
||||
expect(useStore.getState().selectedMaskIds).toEqual([]);
|
||||
});
|
||||
|
||||
it('keeps classified AI candidates when leaving the AI page', async () => {
|
||||
useStore.setState({
|
||||
activeTemplateId: 'template-1',
|
||||
activeClass: { id: 'class-1', name: '胆囊', color: '#ff0000', zIndex: 30 },
|
||||
activeClassId: 'class-1',
|
||||
});
|
||||
apiMock.predictMask.mockResolvedValueOnce({
|
||||
masks: [
|
||||
{
|
||||
id: 'sam2-mask',
|
||||
pathData: 'M 10 10 L 40 10 L 40 40 Z',
|
||||
label: 'AI Mask',
|
||||
color: '#06b6d4',
|
||||
segmentation: [[10, 10, 40, 10, 40, 40]],
|
||||
bbox: [10, 10, 30, 30],
|
||||
area: 900,
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
const { unmount } = render(<AISegmentation onSendToWorkspace={vi.fn()} />);
|
||||
fireEvent.click(screen.getByText('正向选点'));
|
||||
fireEvent.click(screen.getByTestId('konva-stage'));
|
||||
fireEvent.click(await screen.findByText('执行高精度语义分割'));
|
||||
await waitFor(() => expect(useStore.getState().masks[0]?.classId).toBe('class-1'));
|
||||
|
||||
unmount();
|
||||
|
||||
expect(useStore.getState().masks).toHaveLength(1);
|
||||
expect(useStore.getState().selectedMaskIds).toEqual(['sam2-mask']);
|
||||
});
|
||||
|
||||
});
|
||||
|
||||
@@ -4,6 +4,7 @@ import { cn } from '../lib/utils';
|
||||
import { Stage, Layer, Image as KonvaImage, Circle, Path, Group, Rect } from 'react-konva';
|
||||
import useImage from 'use-image';
|
||||
import { OntologyInspector } from './OntologyInspector';
|
||||
import { TransientNotice, type NoticeState } from './TransientNotice';
|
||||
import { SAM2_MODEL_OPTIONS, useStore, type Mask } from '../store/useStore';
|
||||
import { getAiModelStatus, predictMask, type AiRuntimeStatus } from '../lib/api';
|
||||
|
||||
@@ -34,13 +35,15 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
|
||||
const setAiModel = useStore((state) => state.setAiModel);
|
||||
const activeTemplateId = useStore((state) => state.activeTemplateId);
|
||||
const activeClass = useStore((state) => state.activeClass);
|
||||
const maskPreviewOpacity = useStore((state) => state.maskPreviewOpacity);
|
||||
const setMaskPreviewOpacity = useStore((state) => state.setMaskPreviewOpacity);
|
||||
|
||||
const [modelStatus, setModelStatus] = useState<AiRuntimeStatus | null>(null);
|
||||
const [autoDeleteBg, setAutoDeleteBg] = useState(true);
|
||||
const [cropMode, setCropMode] = useState(false);
|
||||
const [maskOpacity, setMaskOpacity] = useState(72);
|
||||
const [isInferencing, setIsInferencing] = useState(false);
|
||||
const [inferenceMessage, setInferenceMessage] = useState('');
|
||||
const [notice, setNotice] = useState<NoticeState | null>(null);
|
||||
const [aiMaskIds, setAiMaskIds] = useState<string[]>([]);
|
||||
|
||||
// Canvas state
|
||||
@@ -59,11 +62,32 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
|
||||
const frameMasks = currentFrame
|
||||
? masks.filter((mask) => mask.frameId === currentFrame.id && aiMaskIdSet.has(mask.id))
|
||||
: masks.filter((mask) => aiMaskIdSet.has(mask.id));
|
||||
const selectedAiMasks = frameMasks.filter((mask) => selectedMaskIds.includes(mask.id));
|
||||
const aiMasksToSend = selectedAiMasks.length > 0 ? selectedAiMasks : frameMasks;
|
||||
const selectedModelStatus = modelStatus?.models.find((model) => model.id === aiModel);
|
||||
const modelCanInfer = selectedModelStatus?.available ?? true;
|
||||
|
||||
const effectiveTool = storeActiveTool;
|
||||
|
||||
useEffect(() => {
|
||||
return () => {
|
||||
if (aiMaskIds.length === 0) return;
|
||||
const state = useStore.getState();
|
||||
const aiIds = new Set(aiMaskIds);
|
||||
const unclassifiedAiIds = new Set(
|
||||
state.masks
|
||||
.filter((mask) => aiIds.has(mask.id) && !mask.classId && !mask.className)
|
||||
.map((mask) => mask.id),
|
||||
);
|
||||
if (unclassifiedAiIds.size === 0) return;
|
||||
|
||||
useStore.setState({
|
||||
masks: state.masks.filter((mask) => !unclassifiedAiIds.has(mask.id)),
|
||||
selectedMaskIds: state.selectedMaskIds.filter((id) => !unclassifiedAiIds.has(id)),
|
||||
});
|
||||
};
|
||||
}, [aiMaskIds]);
|
||||
|
||||
useEffect(() => {
|
||||
const handleResize = () => {
|
||||
if (!canvasContainerRef.current) return;
|
||||
@@ -266,6 +290,7 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
|
||||
classId: activeClass?.id,
|
||||
className: activeClass?.name,
|
||||
classZIndex: activeClass?.zIndex,
|
||||
classMaskId: activeClass?.maskId,
|
||||
saveStatus: 'draft',
|
||||
saved: false,
|
||||
pathData: m.pathData,
|
||||
@@ -329,6 +354,27 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
|
||||
deleteAiMasksById(selectedMaskIds);
|
||||
}, [deleteAiMasksById, selectedMaskIds]);
|
||||
|
||||
const handleSendToWorkspace = useCallback(() => {
|
||||
if (aiMasksToSend.length === 0) {
|
||||
setInferenceMessage('请先执行分割并选择一个 AI 候选区域。');
|
||||
return;
|
||||
}
|
||||
const hasMissingSemantic = aiMasksToSend.some((mask) => !mask.classId && !mask.className);
|
||||
if (hasMissingSemantic) {
|
||||
setInferenceMessage('');
|
||||
setNotice({
|
||||
id: Date.now(),
|
||||
message: '请先在右侧语义分类树为 AI 候选区域选择语义分类,再推送至工作区。',
|
||||
tone: 'error',
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
setInferenceMessage('');
|
||||
setActiveTool('edit_polygon');
|
||||
onSendToWorkspace();
|
||||
}, [aiMasksToSend, onSendToWorkspace, setActiveTool]);
|
||||
|
||||
const removePromptPoint = useCallback((pointIndex: number) => {
|
||||
setPoints((currentPoints) => currentPoints.filter((_, index) => index !== pointIndex));
|
||||
}, []);
|
||||
@@ -398,6 +444,7 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
|
||||
|
||||
return (
|
||||
<div className="w-full h-full flex bg-[#0a0a0a]">
|
||||
<TransientNotice notice={notice} onDismiss={() => setNotice(null)} />
|
||||
{/* Left AI Controller Panel */}
|
||||
<aside className="w-80 bg-[#0d0d0d] flex flex-col border-r border-white/5 shrink-0 z-10 overflow-hidden">
|
||||
<div className="h-16 border-b border-white/5 flex items-center px-6 shrink-0 justify-between">
|
||||
@@ -506,17 +553,18 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
|
||||
|
||||
<div className="space-y-2">
|
||||
<div className="flex items-center justify-between">
|
||||
<label htmlFor="ai-mask-opacity" className="text-[11px] text-gray-400 uppercase tracking-wider font-medium">遮罩清晰度</label>
|
||||
<span className="text-[10px] font-mono text-cyan-400">{maskOpacity}%</span>
|
||||
<label htmlFor="ai-mask-opacity" className="text-[11px] text-gray-400 uppercase tracking-wider font-medium">AI 遮罩透明度</label>
|
||||
<span className="text-[10px] font-mono text-cyan-400">{maskPreviewOpacity}%</span>
|
||||
</div>
|
||||
<input
|
||||
id="ai-mask-opacity"
|
||||
aria-label="AI 遮罩透明度"
|
||||
type="range"
|
||||
min="20"
|
||||
min="10"
|
||||
max="100"
|
||||
step="5"
|
||||
value={maskOpacity}
|
||||
onChange={(event) => setMaskOpacity(Number(event.target.value))}
|
||||
value={maskPreviewOpacity}
|
||||
onChange={(event) => setMaskPreviewOpacity(Number(event.target.value))}
|
||||
className="w-full accent-cyan-400"
|
||||
/>
|
||||
</div>
|
||||
@@ -544,10 +592,8 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
|
||||
</div>
|
||||
)}
|
||||
<button
|
||||
onClick={() => {
|
||||
setActiveTool('edit_polygon');
|
||||
onSendToWorkspace();
|
||||
}}
|
||||
onClick={handleSendToWorkspace}
|
||||
title="AI 候选区域必须先选择语义分类,才能推送到工作区"
|
||||
className="w-full py-3.5 rounded-lg flex items-center justify-center gap-2 transition-all font-medium tracking-wide text-xs uppercase bg-white/5 hover:bg-white/10 text-gray-300 border border-white/5 hover:border-white/10"
|
||||
>
|
||||
<SendToBack size={16} /> 推送至工作区编辑
|
||||
@@ -659,9 +705,10 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
|
||||
{/* AI Returned Masks */}
|
||||
{frameMasks.map((mask) => {
|
||||
const isSelected = selectedMaskIds.includes(mask.id);
|
||||
const baseOpacity = Math.min(Math.max(maskPreviewOpacity / 100, 0.1), 1);
|
||||
const previewOpacity = isSelected
|
||||
? maskOpacity / 100
|
||||
: Math.max(0.18, (maskOpacity / 100) * 0.62);
|
||||
? baseOpacity
|
||||
: Math.max(0.12, baseOpacity * 0.62);
|
||||
return (
|
||||
<Group key={mask.id} opacity={previewOpacity}>
|
||||
<Path
|
||||
|
||||
25
src/components/AiSegmentationIcon.tsx
Normal file
25
src/components/AiSegmentationIcon.tsx
Normal file
@@ -0,0 +1,25 @@
|
||||
import React from 'react';
|
||||
import { Bot, Sparkles } from 'lucide-react';
|
||||
|
||||
interface AiSegmentationIconProps {
|
||||
size?: number;
|
||||
strokeWidth?: number;
|
||||
}
|
||||
|
||||
export function AiSegmentationIcon({ size = 20, strokeWidth = 2 }: AiSegmentationIconProps) {
|
||||
const sparkleSize = Math.max(9, Math.round(size * 0.48));
|
||||
return (
|
||||
<span
|
||||
data-testid="ai-segmentation-icon"
|
||||
className="relative inline-flex items-center justify-center"
|
||||
style={{ width: size, height: size }}
|
||||
>
|
||||
<Bot size={size} strokeWidth={strokeWidth} />
|
||||
<Sparkles
|
||||
size={sparkleSize}
|
||||
strokeWidth={Math.max(strokeWidth, 2.2)}
|
||||
className="absolute -right-1 -top-1 text-cyan-300 drop-shadow-[0_0_4px_rgba(34,211,238,0.75)]"
|
||||
/>
|
||||
</span>
|
||||
);
|
||||
}
|
||||
@@ -934,69 +934,102 @@ describe('CanvasArea', () => {
|
||||
expect(useStore.getState().masks[0].segmentation?.[0]).toHaveLength(64);
|
||||
});
|
||||
|
||||
it('creates a manual line region from a drag gesture', () => {
|
||||
render(<CanvasArea activeTool="create_line" frame={frame} />);
|
||||
it('creates a brush mask when a semantic class is selected', () => {
|
||||
useStore.setState({
|
||||
activeTemplateId: '2',
|
||||
activeClass: { id: 'c1', name: '胆囊', color: '#ff0000', zIndex: 20, maskId: 1 },
|
||||
activeClassId: 'c1',
|
||||
});
|
||||
|
||||
render(<CanvasArea activeTool="brush" frame={frame} />);
|
||||
const stage = screen.getByTestId('konva-stage');
|
||||
fireEvent.mouseDown(stage, { clientX: 120, clientY: 80 });
|
||||
fireEvent.mouseMove(stage, { clientX: 260, clientY: 200 });
|
||||
fireEvent.mouseMove(stage, { clientX: 180, clientY: 120 });
|
||||
fireEvent.mouseUp(stage, { clientX: 260, clientY: 200 });
|
||||
|
||||
expect(useStore.getState().masks).toHaveLength(1);
|
||||
expect(useStore.getState().masks[0]).toEqual(expect.objectContaining({
|
||||
frameId: 'frame-1',
|
||||
label: '手工线段',
|
||||
color: '#06b6d4',
|
||||
label: '胆囊',
|
||||
color: '#ff0000',
|
||||
classId: 'c1',
|
||||
classMaskId: 1,
|
||||
saveStatus: 'draft',
|
||||
metadata: expect.objectContaining({
|
||||
source: 'manual',
|
||||
shape: '线段',
|
||||
shape: '画笔',
|
||||
}),
|
||||
}));
|
||||
expect(useStore.getState().masks[0].segmentation?.[0]).toHaveLength(8);
|
||||
expect(useStore.getState().masks[0].segmentation?.length).toBeGreaterThan(0);
|
||||
expect(useStore.getState().masks[0].area).toBeGreaterThan(1000);
|
||||
});
|
||||
|
||||
it('creates an editable point region on click', () => {
|
||||
render(<CanvasArea activeTool="create_point" frame={frame} />);
|
||||
fireEvent.click(screen.getByTestId('konva-stage'), { clientX: 120, clientY: 80 });
|
||||
|
||||
expect(useStore.getState().masks).toHaveLength(1);
|
||||
expect(useStore.getState().masks[0]).toEqual(expect.objectContaining({
|
||||
frameId: 'frame-1',
|
||||
label: '手工点区域',
|
||||
color: '#06b6d4',
|
||||
saveStatus: 'draft',
|
||||
points: [[120, 80]],
|
||||
bbox: expect.arrayContaining([115, 75]),
|
||||
metadata: expect.objectContaining({
|
||||
source: 'manual',
|
||||
shape: '点区域',
|
||||
}),
|
||||
}));
|
||||
});
|
||||
|
||||
it('creates a point region when clicking over an existing mask', () => {
|
||||
it('merges a connected brush stroke into the selected mask', () => {
|
||||
useStore.setState({
|
||||
activeTemplateId: '2',
|
||||
activeClass: { id: 'c1', name: '胆囊', color: '#ff0000', zIndex: 20 },
|
||||
activeClassId: 'c1',
|
||||
selectedMaskIds: ['m1'],
|
||||
masks: [
|
||||
{
|
||||
id: 'm1',
|
||||
frameId: 'frame-1',
|
||||
pathData: 'M 10 10 L 200 10 L 200 200 Z',
|
||||
label: 'Existing',
|
||||
color: '#06b6d4',
|
||||
segmentation: [[10, 10, 200, 10, 200, 200]],
|
||||
pathData: 'M 100 70 L 150 70 L 150 120 L 100 120 Z',
|
||||
label: '胆囊',
|
||||
color: '#ff0000',
|
||||
classId: 'c1',
|
||||
segmentation: [[100, 70, 150, 70, 150, 120, 100, 120]],
|
||||
area: 2500,
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
render(<CanvasArea activeTool="create_point" frame={frame} />);
|
||||
fireEvent.click(screen.getByTestId('konva-path'), { clientX: 120, clientY: 80 });
|
||||
render(<CanvasArea activeTool="brush" frame={frame} />);
|
||||
const stage = screen.getByTestId('konva-stage');
|
||||
fireEvent.mouseDown(stage, { clientX: 130, clientY: 90 });
|
||||
fireEvent.mouseMove(stage, { clientX: 170, clientY: 100 });
|
||||
fireEvent.mouseUp(stage, { clientX: 210, clientY: 110 });
|
||||
|
||||
expect(useStore.getState().masks).toHaveLength(2);
|
||||
expect(useStore.getState().masks[1]).toEqual(expect.objectContaining({
|
||||
metadata: expect.objectContaining({ shape: '点区域' }),
|
||||
points: [[120, 80]],
|
||||
expect(useStore.getState().masks).toHaveLength(1);
|
||||
expect(useStore.getState().masks[0]).toEqual(expect.objectContaining({
|
||||
id: 'm1',
|
||||
label: '胆囊',
|
||||
color: '#ff0000',
|
||||
saveStatus: 'draft',
|
||||
}));
|
||||
expect(useStore.getState().masks[0].area).toBeGreaterThan(2500);
|
||||
expect(useStore.getState().selectedMaskIds).toEqual(['m1']);
|
||||
});
|
||||
|
||||
it('erases from the selected mask with a sampled stroke', () => {
|
||||
useStore.setState({
|
||||
selectedMaskIds: ['m1'],
|
||||
masks: [
|
||||
{
|
||||
id: 'm1',
|
||||
frameId: 'frame-1',
|
||||
pathData: 'M 10 10 L 300 10 L 300 220 L 10 220 Z',
|
||||
label: 'Existing',
|
||||
color: '#06b6d4',
|
||||
segmentation: [[10, 10, 300, 10, 300, 220, 10, 220]],
|
||||
area: 60900,
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
render(<CanvasArea activeTool="eraser" frame={frame} />);
|
||||
const stage = screen.getByTestId('konva-stage');
|
||||
fireEvent.mouseDown(stage, { clientX: 120, clientY: 80 });
|
||||
fireEvent.mouseMove(stage, { clientX: 180, clientY: 120 });
|
||||
fireEvent.mouseUp(stage, { clientX: 260, clientY: 200 });
|
||||
|
||||
expect(useStore.getState().masks).toHaveLength(1);
|
||||
expect(useStore.getState().masks[0]).toEqual(expect.objectContaining({
|
||||
id: 'm1',
|
||||
saveStatus: 'draft',
|
||||
}));
|
||||
expect(useStore.getState().masks[0].area).toBeLessThan(60900);
|
||||
expect(useStore.getState().selectedMaskIds).toEqual(['m1']);
|
||||
});
|
||||
|
||||
it('finalizes a clicked polygon with Enter', () => {
|
||||
@@ -1082,10 +1115,10 @@ describe('CanvasArea', () => {
|
||||
vi.useRealTimers();
|
||||
});
|
||||
|
||||
it('applies the selected class to current-frame masks and marks saved masks dirty', () => {
|
||||
it('applies the selected class to current-frame masks and linked propagation masks', () => {
|
||||
useStore.setState({
|
||||
activeTemplateId: '2',
|
||||
activeClass: { id: 'c1', name: '胆囊', color: '#ff0000', zIndex: 20 },
|
||||
activeClass: { id: 'c1', name: '胆囊', color: '#ff0000', zIndex: 20, maskId: 1 },
|
||||
activeClassId: 'c1',
|
||||
masks: [
|
||||
{
|
||||
@@ -1098,6 +1131,28 @@ describe('CanvasArea', () => {
|
||||
saved: true,
|
||||
saveStatus: 'saved',
|
||||
},
|
||||
{
|
||||
id: 'm2',
|
||||
frameId: 'frame-2',
|
||||
annotationId: '100',
|
||||
pathData: 'M 1 1 Z',
|
||||
label: '旧传播标签',
|
||||
color: '#06b6d4',
|
||||
metadata: { source_annotation_id: 99, source_mask_id: 'annotation-99' },
|
||||
saved: true,
|
||||
saveStatus: 'saved',
|
||||
},
|
||||
{
|
||||
id: 'm3',
|
||||
frameId: 'frame-2',
|
||||
annotationId: '101',
|
||||
pathData: 'M 2 2 Z',
|
||||
label: '无关区域',
|
||||
color: '#ffffff',
|
||||
metadata: { source_annotation_id: 101 },
|
||||
saved: true,
|
||||
saveStatus: 'saved',
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
@@ -1109,11 +1164,56 @@ describe('CanvasArea', () => {
|
||||
classId: 'c1',
|
||||
className: '胆囊',
|
||||
classZIndex: 20,
|
||||
classMaskId: 1,
|
||||
label: '胆囊',
|
||||
color: '#ff0000',
|
||||
saveStatus: 'dirty',
|
||||
saved: false,
|
||||
}));
|
||||
expect(useStore.getState().masks[1]).toEqual(expect.objectContaining({
|
||||
classId: 'c1',
|
||||
className: '胆囊',
|
||||
classMaskId: 1,
|
||||
label: '胆囊',
|
||||
color: '#ff0000',
|
||||
saveStatus: 'dirty',
|
||||
saved: false,
|
||||
}));
|
||||
expect(useStore.getState().masks[2]).toEqual(expect.objectContaining({
|
||||
label: '无关区域',
|
||||
color: '#ffffff',
|
||||
saveStatus: 'saved',
|
||||
saved: true,
|
||||
}));
|
||||
});
|
||||
|
||||
it('renders unselected masks by semantic tree layer priority', () => {
|
||||
useStore.setState({
|
||||
selectedMaskIds: [],
|
||||
masks: [
|
||||
{
|
||||
id: 'high',
|
||||
frameId: 'frame-1',
|
||||
pathData: 'M 0 0 Z',
|
||||
label: '高优先级',
|
||||
color: '#ef4444',
|
||||
classZIndex: 30,
|
||||
},
|
||||
{
|
||||
id: 'low',
|
||||
frameId: 'frame-1',
|
||||
pathData: 'M 1 1 Z',
|
||||
label: '低优先级',
|
||||
color: '#22c55e',
|
||||
classZIndex: 10,
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
render(<CanvasArea activeTool="move" frame={frame} />);
|
||||
|
||||
const paths = screen.getAllByTestId('konva-path');
|
||||
expect(paths.map((path) => path.getAttribute('data-fill'))).toEqual(['#22c55e', '#ef4444']);
|
||||
});
|
||||
|
||||
it('delegates clear to the workspace handler so saved annotations can be deleted', () => {
|
||||
|
||||
@@ -18,14 +18,18 @@ type PromptPoint = CanvasPoint & { type: 'pos' | 'neg' };
|
||||
type PromptBox = { x1: number; y1: number; x2: number; y2: number };
|
||||
type ToolHint = { title: string; body: string };
|
||||
|
||||
const DRAG_MANUAL_TOOLS = new Set(['create_rectangle', 'create_circle', 'create_line']);
|
||||
const DRAG_MANUAL_TOOLS = new Set(['create_rectangle', 'create_circle']);
|
||||
const POLYGON_TOOL = 'create_polygon';
|
||||
const EDIT_POLYGON_TOOL = 'edit_polygon';
|
||||
const POINT_TOOL = 'create_point';
|
||||
const BRUSH_TOOL = 'brush';
|
||||
const ERASER_TOOL = 'eraser';
|
||||
const PAINT_TOOLS = new Set([BRUSH_TOOL, ERASER_TOOL]);
|
||||
const BOOLEAN_TOOLS = new Set(['area_merge', 'area_remove']);
|
||||
const POLYGON_CLOSE_RADIUS = 8;
|
||||
const DEFAULT_IMAGE_FIT_RATIO = 0.86;
|
||||
const TOOL_HINT_TTL_MS = 3600;
|
||||
const PAINT_STAMP_SEGMENTS = 16;
|
||||
const MAX_PAINT_STROKE_POINTS = 128;
|
||||
|
||||
function clamp(value: number, min: number, max: number): number {
|
||||
return Math.min(Math.max(value, min), max);
|
||||
@@ -97,6 +101,31 @@ function findLinkedMasksOnFrame(selectedIds: string[], allMasks: Mask[], targetF
|
||||
.map((mask) => mask.id);
|
||||
}
|
||||
|
||||
function findPropagationChainMaskIds(selectedIds: string[], allMasks: Mask[]): Set<string> {
|
||||
const selectedMasks = selectedIds
|
||||
.map((id) => allMasks.find((mask) => mask.id === id))
|
||||
.filter((mask): mask is Mask => Boolean(mask));
|
||||
const selectedTokens = new Set<string>();
|
||||
selectedMasks.forEach((mask) => {
|
||||
propagationLineageTokens(mask).forEach((token) => selectedTokens.add(token));
|
||||
});
|
||||
if (selectedTokens.size === 0) return new Set(selectedIds);
|
||||
|
||||
return new Set(
|
||||
allMasks
|
||||
.filter((mask) => {
|
||||
const candidateTokens = propagationLineageTokens(mask);
|
||||
return [...candidateTokens].some((token) => selectedTokens.has(token));
|
||||
})
|
||||
.map((mask) => mask.id),
|
||||
);
|
||||
}
|
||||
|
||||
function maskLayerPriority(mask: Mask): number {
|
||||
const parsed = Number(mask.classZIndex ?? mask.metadata?.classZIndex ?? 0);
|
||||
return Number.isFinite(parsed) ? parsed : 0;
|
||||
}
|
||||
|
||||
function polygonPath(points: CanvasPoint[]): string {
|
||||
if (points.length === 0) return '';
|
||||
return points
|
||||
@@ -165,6 +194,29 @@ function pointDistance(a: CanvasPoint, b: CanvasPoint): number {
|
||||
return Math.hypot(a.x - b.x, a.y - b.y);
|
||||
}
|
||||
|
||||
function extendStrokePoints(
|
||||
current: CanvasPoint[],
|
||||
nextPoint: CanvasPoint,
|
||||
spacing: number,
|
||||
maxPoints = MAX_PAINT_STROKE_POINTS,
|
||||
): CanvasPoint[] {
|
||||
const previous = current[current.length - 1];
|
||||
if (!previous) return [nextPoint];
|
||||
const distance = pointDistance(previous, nextPoint);
|
||||
if (distance < spacing) return current;
|
||||
const steps = Math.max(1, Math.floor(distance / spacing));
|
||||
const additions: CanvasPoint[] = [];
|
||||
for (let step = 1; step <= steps; step += 1) {
|
||||
if (current.length + additions.length >= maxPoints) break;
|
||||
const ratio = step / steps;
|
||||
additions.push({
|
||||
x: previous.x + (nextPoint.x - previous.x) * ratio,
|
||||
y: previous.y + (nextPoint.y - previous.y) * ratio,
|
||||
});
|
||||
}
|
||||
return [...current, ...additions];
|
||||
}
|
||||
|
||||
function distanceToSegmentSquared(point: CanvasPoint, start: CanvasPoint, end: CanvasPoint): number {
|
||||
const dx = end.x - start.x;
|
||||
const dy = end.y - start.y;
|
||||
@@ -218,6 +270,13 @@ function maskToMultiPolygon(mask: Mask): MultiPolygon | null {
|
||||
return polygons.length > 0 ? polygons : null;
|
||||
}
|
||||
|
||||
function polygonsToMultiPolygon(polygons: CanvasPoint[][]): MultiPolygon | null {
|
||||
const geometry = polygons
|
||||
.filter((points) => points.length >= 3)
|
||||
.map((points) => [closeRing(points)]);
|
||||
return geometry.length > 0 ? geometry : null;
|
||||
}
|
||||
|
||||
function openRingPoints(ring: Pair[]): CanvasPoint[] {
|
||||
const openRing = ring.length > 1
|
||||
&& ring[0][0] === ring[ring.length - 1][0]
|
||||
@@ -247,6 +306,27 @@ function multiPolygonHasHoles(geometry: MultiPolygon): boolean {
|
||||
return geometry.some((polygon) => polygon.length > 1);
|
||||
}
|
||||
|
||||
function maskWithSegmentation(
|
||||
mask: Mask,
|
||||
segmentation: number[][],
|
||||
options: { area?: number; hasHoles?: boolean } = {},
|
||||
): Mask {
|
||||
const bbox = segmentationBbox(segmentation);
|
||||
const metadata = { ...(mask.metadata || {}) };
|
||||
if (options.hasHoles === true) metadata.hasHoles = true;
|
||||
if (options.hasHoles === false) delete metadata.hasHoles;
|
||||
return {
|
||||
...mask,
|
||||
pathData: segmentationPath(segmentation),
|
||||
segmentation,
|
||||
bbox,
|
||||
area: options.area ?? segmentationArea(segmentation),
|
||||
metadata,
|
||||
saveStatus: mask.annotationId ? 'dirty' : 'draft',
|
||||
saved: mask.annotationId ? false : mask.saved,
|
||||
};
|
||||
}
|
||||
|
||||
function rectanglePoints(start: CanvasPoint, end: CanvasPoint): CanvasPoint[] {
|
||||
const x1 = Math.min(start.x, end.x);
|
||||
const y1 = Math.min(start.y, end.y);
|
||||
@@ -271,25 +351,26 @@ function circlePoints(start: CanvasPoint, end: CanvasPoint): CanvasPoint[] {
|
||||
});
|
||||
}
|
||||
|
||||
function pointRegion(point: CanvasPoint, radius = 5): CanvasPoint[] {
|
||||
return Array.from({ length: 12 }, (_, index) => {
|
||||
const angle = (Math.PI * 2 * index) / 12;
|
||||
return { x: point.x + Math.cos(angle) * radius, y: point.y + Math.sin(angle) * radius };
|
||||
function circleStampPoints(center: CanvasPoint, radius: number, segments = PAINT_STAMP_SEGMENTS): CanvasPoint[] {
|
||||
return Array.from({ length: segments }, (_, index) => {
|
||||
const angle = (Math.PI * 2 * index) / segments;
|
||||
return { x: center.x + Math.cos(angle) * radius, y: center.y + Math.sin(angle) * radius };
|
||||
});
|
||||
}
|
||||
|
||||
function lineRegion(start: CanvasPoint, end: CanvasPoint, halfWidth = 4): CanvasPoint[] {
|
||||
const dx = end.x - start.x;
|
||||
const dy = end.y - start.y;
|
||||
const length = Math.hypot(dx, dy) || 1;
|
||||
const nx = (-dy / length) * halfWidth;
|
||||
const ny = (dx / length) * halfWidth;
|
||||
return [
|
||||
{ x: start.x + nx, y: start.y + ny },
|
||||
{ x: end.x + nx, y: end.y + ny },
|
||||
{ x: end.x - nx, y: end.y - ny },
|
||||
{ x: start.x - nx, y: start.y - ny },
|
||||
];
|
||||
function paintStrokeToGeometry(strokePoints: CanvasPoint[], radius: number): MultiPolygon | null {
|
||||
const geometries = strokePoints
|
||||
.map((point) => polygonsToMultiPolygon([circleStampPoints(point, radius)]))
|
||||
.filter((geometry): geometry is MultiPolygon => Boolean(geometry));
|
||||
if (geometries.length === 0) return null;
|
||||
const [firstGeometry, ...restGeometries] = geometries;
|
||||
return restGeometries.length === 0
|
||||
? firstGeometry
|
||||
: polygonClipping.union(firstGeometry, ...restGeometries);
|
||||
}
|
||||
|
||||
function geometriesOverlap(first: MultiPolygon, second: MultiPolygon): boolean {
|
||||
return polygonClipping.intersection(first, second).length > 0;
|
||||
}
|
||||
|
||||
export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnotations }: CanvasAreaProps) {
|
||||
@@ -305,6 +386,7 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
|
||||
const [samCandidateMaskId, setSamCandidateMaskId] = useState<string | null>(null);
|
||||
const [manualStart, setManualStart] = useState<CanvasPoint | null>(null);
|
||||
const [manualCurrent, setManualCurrent] = useState<CanvasPoint | null>(null);
|
||||
const [paintStrokePoints, setPaintStrokePointsState] = useState<CanvasPoint[]>([]);
|
||||
const [polygonPoints, setPolygonPoints] = useState<CanvasPoint[]>([]);
|
||||
const [selectedMaskId, setSelectedMaskId] = useState<string | null>(() => useStore.getState().selectedMaskIds[0] || null);
|
||||
const [selectedMaskIds, setSelectedMaskIds] = useState<string[]>(() => useStore.getState().selectedMaskIds);
|
||||
@@ -315,6 +397,9 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
|
||||
const [inferenceMessage, setInferenceMessage] = useState('');
|
||||
const [isToolHintVisible, setIsToolHintVisible] = useState(false);
|
||||
const lastAutoFitKeyRef = useRef('');
|
||||
const paintStrokeRef = useRef<CanvasPoint[]>([]);
|
||||
const paintToolRef = useRef<string | null>(null);
|
||||
const lastPaintPointRef = useRef<CanvasPoint | null>(null);
|
||||
|
||||
const masks = useStore((state) => state.masks);
|
||||
const addMask = useStore((state) => state.addMask);
|
||||
@@ -323,6 +408,8 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
|
||||
const setMasks = useStore((state) => state.setMasks);
|
||||
const setGlobalSelectedMaskIds = useStore((state) => state.setSelectedMaskIds);
|
||||
const maskPreviewOpacity = useStore((state) => state.maskPreviewOpacity);
|
||||
const brushSize = useStore((state) => state.brushSize);
|
||||
const eraserSize = useStore((state) => state.eraserSize);
|
||||
const storeActiveTool = useStore((state) => state.activeTool);
|
||||
const aiModel = useStore((state) => state.aiModel);
|
||||
const activeTemplateId = useStore((state) => state.activeTemplateId);
|
||||
@@ -333,6 +420,16 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
|
||||
// Load the actual frame image
|
||||
const [image] = useImage(frame?.url || '');
|
||||
const frameMasks = masks.filter((mask) => mask.frameId === frame?.id);
|
||||
const displayFrameMasks = React.useMemo(() => {
|
||||
if (selectedMaskIds.length > 0) return frameMasks;
|
||||
return frameMasks
|
||||
.map((mask, index) => ({ mask, index }))
|
||||
.sort((a, b) => {
|
||||
const priorityDiff = maskLayerPriority(a.mask) - maskLayerPriority(b.mask);
|
||||
return priorityDiff === 0 ? a.index - b.index : priorityDiff;
|
||||
})
|
||||
.map((item) => item.mask);
|
||||
}, [frameMasks, selectedMaskIds.length]);
|
||||
const selectedMask = React.useMemo(
|
||||
() => frameMasks.find((mask) => mask.id === selectedMaskId) || null,
|
||||
[frameMasks, selectedMaskId],
|
||||
@@ -351,7 +448,14 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
|
||||
const draftMaskCount = frameMasks.filter((mask) => !mask.annotationId).length;
|
||||
const dirtyMaskCount = frameMasks.filter((mask) => mask.saveStatus === 'dirty').length;
|
||||
const isBooleanTool = BOOLEAN_TOOLS.has(effectiveTool);
|
||||
const isPaintTool = PAINT_TOOLS.has(effectiveTool);
|
||||
const isPolygonEditTool = effectiveTool === 'move' || effectiveTool === EDIT_POLYGON_TOOL;
|
||||
const activePaintSize = effectiveTool === ERASER_TOOL ? eraserSize : brushSize;
|
||||
const activePaintRadius = Math.max(2, activePaintSize / 2);
|
||||
const setPaintStrokePoints = useCallback((nextPoints: CanvasPoint[]) => {
|
||||
paintStrokeRef.current = nextPoints;
|
||||
setPaintStrokePointsState(nextPoints);
|
||||
}, []);
|
||||
const currentLayerLabel = selectedMask
|
||||
? `${selectedMask.className || selectedMask.label}${selectedMask.annotationId ? ` #${selectedMask.annotationId}` : ' (未保存)'}`
|
||||
: '未选择';
|
||||
@@ -381,11 +485,21 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
|
||||
if (effectiveTool === 'create_circle') {
|
||||
return { title: '创建圆形', body: '按住并拖拽确定外接范围,松开鼠标后生成椭圆 mask。' };
|
||||
}
|
||||
if (effectiveTool === 'create_line') {
|
||||
return { title: '创建线段', body: '按住并拖拽画出线段,松开后生成有宽度的线状 mask。' };
|
||||
if (effectiveTool === BRUSH_TOOL) {
|
||||
return {
|
||||
title: '画笔',
|
||||
body: activeClass
|
||||
? '按住并拖动画出连续区域;若与当前选中 mask 连通,会自动合并到该 mask。'
|
||||
: '先在右侧语义分类树选择类别,然后按住并拖动画出连续区域。',
|
||||
};
|
||||
}
|
||||
if (effectiveTool === POINT_TOOL) {
|
||||
return { title: '创建点区域', body: '点击画布创建一个小型点区域;也可以在已有 mask 上继续落点。' };
|
||||
if (effectiveTool === ERASER_TOOL) {
|
||||
return {
|
||||
title: '橡皮擦',
|
||||
body: selectedMask
|
||||
? '按住并拖动,从当前选中 mask 中扣除经过的区域。'
|
||||
: '先选择一个 mask,然后按住并拖动擦除区域。',
|
||||
};
|
||||
}
|
||||
if (effectiveTool === 'box_select') {
|
||||
return {
|
||||
@@ -426,7 +540,7 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
|
||||
};
|
||||
}
|
||||
return null;
|
||||
}, [booleanSelectedMasks.length, effectiveTool, frame, polygonPoints.length, samPromptBox, selectedMask]);
|
||||
}, [activeClass, booleanSelectedMasks.length, effectiveTool, frame, polygonPoints.length, samPromptBox, selectedMask]);
|
||||
|
||||
useEffect(() => {
|
||||
if (!toolHint) {
|
||||
@@ -479,14 +593,17 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
|
||||
useEffect(() => {
|
||||
setManualStart(null);
|
||||
setManualCurrent(null);
|
||||
setPaintStrokePoints([]);
|
||||
paintToolRef.current = null;
|
||||
lastPaintPointRef.current = null;
|
||||
setPolygonPoints([]);
|
||||
setSelectedVertexIndex(null);
|
||||
if (!isPolygonEditTool && !isBooleanTool) {
|
||||
if (!isPolygonEditTool && !isBooleanTool && !isPaintTool) {
|
||||
setSelectedMaskId(null);
|
||||
setSelectedMaskIds([]);
|
||||
setSelectedPolygonIndex(0);
|
||||
}
|
||||
}, [effectiveTool, isBooleanTool, isPolygonEditTool]);
|
||||
}, [effectiveTool, isBooleanTool, isPaintTool, isPolygonEditTool, setPaintStrokePoints]);
|
||||
|
||||
useEffect(() => {
|
||||
if (previousFrameIdRef.current === frame?.id) return;
|
||||
@@ -617,18 +734,13 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
|
||||
classId: activeClass?.id,
|
||||
className: activeClass?.name,
|
||||
classZIndex: activeClass?.zIndex,
|
||||
classMaskId: activeClass?.maskId,
|
||||
saveStatus: 'draft',
|
||||
saved: false,
|
||||
pathData: polygonPath(polygon),
|
||||
label,
|
||||
color,
|
||||
segmentation: polygonSegmentation(polygon),
|
||||
points: shape === '点区域'
|
||||
? [[
|
||||
polygon.reduce((sum, point) => sum + point.x, 0) / polygon.length,
|
||||
polygon.reduce((sum, point) => sum + point.y, 0) / polygon.length,
|
||||
]]
|
||||
: undefined,
|
||||
bbox: polygonBbox(polygon),
|
||||
area,
|
||||
metadata: { source: 'manual', shape },
|
||||
@@ -636,6 +748,38 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
|
||||
addMask(mask);
|
||||
}, [activeClass, activeTemplateId, addMask, frame?.id]);
|
||||
|
||||
const createManualMaskFromGeometry = useCallback((shape: string, geometry: MultiPolygon): Mask | null => {
|
||||
if (!frame?.id || !activeClass) return null;
|
||||
const segmentation = multiPolygonToSegmentation(geometry);
|
||||
if (segmentation.length === 0) return null;
|
||||
const area = multiPolygonArea(geometry);
|
||||
if (area <= 1) return null;
|
||||
const mask: Mask = {
|
||||
id: `manual-${frame.id}-${shape}-${Date.now()}`,
|
||||
frameId: frame.id,
|
||||
templateId: activeTemplateId || undefined,
|
||||
classId: activeClass.id,
|
||||
className: activeClass.name,
|
||||
classZIndex: activeClass.zIndex,
|
||||
classMaskId: activeClass.maskId,
|
||||
saveStatus: 'draft',
|
||||
saved: false,
|
||||
pathData: segmentationPath(segmentation),
|
||||
label: activeClass.name,
|
||||
color: activeClass.color,
|
||||
segmentation,
|
||||
bbox: segmentationBbox(segmentation),
|
||||
area,
|
||||
metadata: {
|
||||
source: 'manual',
|
||||
shape,
|
||||
...(multiPolygonHasHoles(geometry) ? { hasHoles: true } : {}),
|
||||
},
|
||||
};
|
||||
addMask(mask);
|
||||
return mask;
|
||||
}, [activeClass, activeTemplateId, addMask, frame?.id]);
|
||||
|
||||
const finishPolygon = useCallback(() => {
|
||||
if (polygonPoints.length < 3) return;
|
||||
createManualMask('多边形', polygonPoints);
|
||||
@@ -665,6 +809,20 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
|
||||
setManualCurrent({ x: pos.x, y: pos.y });
|
||||
}
|
||||
}
|
||||
|
||||
if (paintToolRef.current && PAINT_TOOLS.has(effectiveTool)) {
|
||||
const pos = stagePoint(e);
|
||||
const previous = lastPaintPointRef.current;
|
||||
if (!pos || !previous) return;
|
||||
const radius = Math.max(2, (paintToolRef.current === ERASER_TOOL ? eraserSize : brushSize) / 2);
|
||||
const minDistance = Math.max(3, radius * 0.55);
|
||||
if (pointDistance(previous, pos) < minDistance) return;
|
||||
const currentStroke = paintStrokeRef.current;
|
||||
if (currentStroke.length >= MAX_PAINT_STROKE_POINTS) return;
|
||||
const nextStroke = extendStrokePoints(currentStroke, pos, minDistance);
|
||||
lastPaintPointRef.current = nextStroke[nextStroke.length - 1] || pos;
|
||||
setPaintStrokePoints(nextStroke);
|
||||
}
|
||||
};
|
||||
|
||||
const runInference = useCallback(async (
|
||||
@@ -721,6 +879,7 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
|
||||
classId: activeClass?.id || existingCandidate?.classId,
|
||||
className: activeClass?.name || existingCandidate?.className,
|
||||
classZIndex: activeClass?.zIndex ?? existingCandidate?.classZIndex,
|
||||
classMaskId: activeClass?.maskId ?? existingCandidate?.classMaskId,
|
||||
saveStatus: existingCandidate?.annotationId ? 'dirty' as const : 'draft' as const,
|
||||
saved: false,
|
||||
pathData: m.pathData,
|
||||
@@ -768,14 +927,19 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
|
||||
|
||||
const handleApplyActiveClass = () => {
|
||||
if (!frame?.id || !activeClass) return;
|
||||
const seedIds = selectedMaskIds.length > 0
|
||||
? selectedMaskIds
|
||||
: frameMasks.map((mask) => mask.id);
|
||||
const targetIds = findPropagationChainMaskIds(seedIds, masks);
|
||||
setMasks(masks.map((mask) => {
|
||||
if (mask.frameId !== frame.id) return mask;
|
||||
if (!targetIds.has(mask.id)) return mask;
|
||||
return {
|
||||
...mask,
|
||||
templateId: activeTemplateId || mask.templateId,
|
||||
classId: activeClass.id,
|
||||
className: activeClass.name,
|
||||
classZIndex: activeClass.zIndex,
|
||||
classMaskId: activeClass.maskId,
|
||||
label: activeClass.name,
|
||||
color: activeClass.color,
|
||||
saveStatus: mask.annotationId ? 'dirty' : 'draft',
|
||||
@@ -815,7 +979,102 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
|
||||
setSelectedVertexIndex(null);
|
||||
}, [masks, onDeleteMaskAnnotations, samCandidateMaskId, setMasks]);
|
||||
|
||||
const applyPaintStroke = useCallback((tool: string | null, strokePoints: CanvasPoint[]) => {
|
||||
if (!frame?.id || strokePoints.length === 0) return;
|
||||
const radius = Math.max(2, (tool === ERASER_TOOL ? eraserSize : brushSize) / 2);
|
||||
const strokeGeometry = paintStrokeToGeometry(strokePoints, radius);
|
||||
if (!strokeGeometry) return;
|
||||
|
||||
if (tool === BRUSH_TOOL) {
|
||||
if (!activeClass) {
|
||||
setInferenceMessage('请先在右侧语义分类树选择类别,再使用画笔。');
|
||||
return;
|
||||
}
|
||||
|
||||
const targetGeometry = selectedMask ? maskToMultiPolygon(selectedMask) : null;
|
||||
const shouldMerge = Boolean(targetGeometry && geometriesOverlap(targetGeometry, strokeGeometry));
|
||||
if (selectedMask && targetGeometry && shouldMerge) {
|
||||
const resultGeometry = polygonClipping.union(targetGeometry, strokeGeometry);
|
||||
const resultSegmentation = multiPolygonToSegmentation(resultGeometry);
|
||||
if (resultSegmentation.length === 0) return;
|
||||
const nextMask = {
|
||||
...maskWithSegmentation(selectedMask, resultSegmentation, {
|
||||
area: multiPolygonArea(resultGeometry),
|
||||
hasHoles: multiPolygonHasHoles(resultGeometry),
|
||||
}),
|
||||
templateId: activeTemplateId || selectedMask.templateId,
|
||||
classId: activeClass.id,
|
||||
className: activeClass.name,
|
||||
classZIndex: activeClass.zIndex,
|
||||
classMaskId: activeClass.maskId,
|
||||
label: activeClass.name,
|
||||
color: activeClass.color,
|
||||
};
|
||||
setMasks(masks.map((mask) => (mask.id === selectedMask.id ? nextMask : mask)));
|
||||
setSelectedMaskId(selectedMask.id);
|
||||
setSelectedMaskIds([selectedMask.id]);
|
||||
setSelectedVertexIndex(null);
|
||||
return;
|
||||
}
|
||||
|
||||
const nextMask = createManualMaskFromGeometry('画笔', strokeGeometry);
|
||||
if (nextMask) {
|
||||
setSelectedMaskId(nextMask.id);
|
||||
setSelectedMaskIds([nextMask.id]);
|
||||
setSelectedPolygonIndex(0);
|
||||
setSelectedVertexIndex(null);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (tool === ERASER_TOOL) {
|
||||
if (!selectedMask) {
|
||||
setInferenceMessage('请先选择一个 mask,再使用橡皮擦。');
|
||||
return;
|
||||
}
|
||||
const targetGeometry = maskToMultiPolygon(selectedMask);
|
||||
if (!targetGeometry) return;
|
||||
const resultGeometry = polygonClipping.difference(targetGeometry, strokeGeometry);
|
||||
const resultSegmentation = multiPolygonToSegmentation(resultGeometry);
|
||||
if (resultSegmentation.length === 0) {
|
||||
deleteMasksById([selectedMask.id]);
|
||||
return;
|
||||
}
|
||||
const nextMask = maskWithSegmentation(selectedMask, resultSegmentation, {
|
||||
area: multiPolygonArea(resultGeometry),
|
||||
hasHoles: multiPolygonHasHoles(resultGeometry),
|
||||
});
|
||||
setMasks(masks.map((mask) => (mask.id === selectedMask.id ? nextMask : mask)));
|
||||
setSelectedMaskId(selectedMask.id);
|
||||
setSelectedMaskIds([selectedMask.id]);
|
||||
setSelectedVertexIndex(null);
|
||||
}
|
||||
}, [
|
||||
activeClass,
|
||||
activeTemplateId,
|
||||
brushSize,
|
||||
createManualMaskFromGeometry,
|
||||
deleteMasksById,
|
||||
eraserSize,
|
||||
frame?.id,
|
||||
masks,
|
||||
selectedMask,
|
||||
setMasks,
|
||||
]);
|
||||
|
||||
const handleStageMouseDown = (e: any) => {
|
||||
if (PAINT_TOOLS.has(effectiveTool)) {
|
||||
const canStart = effectiveTool === BRUSH_TOOL ? Boolean(activeClass) : Boolean(selectedMask);
|
||||
if (!canStart) return;
|
||||
const pos = stagePoint(e);
|
||||
if (pos) {
|
||||
paintToolRef.current = effectiveTool;
|
||||
lastPaintPointRef.current = pos;
|
||||
setPaintStrokePoints([pos]);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (DRAG_MANUAL_TOOLS.has(effectiveTool)) {
|
||||
const pos = stagePoint(e);
|
||||
if (pos) {
|
||||
@@ -836,11 +1095,28 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
|
||||
};
|
||||
|
||||
const handleStageMouseUp = (e: any) => {
|
||||
if (paintToolRef.current && PAINT_TOOLS.has(effectiveTool)) {
|
||||
const finalPoint = stagePoint(e);
|
||||
const currentStroke = paintStrokeRef.current;
|
||||
const spacing = Math.max(3, activePaintRadius * 0.55);
|
||||
const nextStroke = finalPoint
|
||||
&& currentStroke.length > 0
|
||||
&& pointDistance(currentStroke[currentStroke.length - 1], finalPoint) >= spacing
|
||||
&& currentStroke.length < MAX_PAINT_STROKE_POINTS
|
||||
? extendStrokePoints(currentStroke, finalPoint, spacing)
|
||||
: currentStroke;
|
||||
const tool = paintToolRef.current;
|
||||
setPaintStrokePoints([]);
|
||||
paintToolRef.current = null;
|
||||
lastPaintPointRef.current = null;
|
||||
applyPaintStroke(tool, nextStroke);
|
||||
return;
|
||||
}
|
||||
|
||||
if (DRAG_MANUAL_TOOLS.has(effectiveTool) && manualStart) {
|
||||
const end = stagePoint(e) || manualCurrent || manualStart;
|
||||
const width = Math.abs(end.x - manualStart.x);
|
||||
const height = Math.abs(end.y - manualStart.y);
|
||||
const distance = Math.hypot(width, height);
|
||||
|
||||
if (effectiveTool === 'create_rectangle' && width > 4 && height > 4) {
|
||||
createManualMask('矩形', rectanglePoints(manualStart, end));
|
||||
@@ -848,9 +1124,6 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
|
||||
if (effectiveTool === 'create_circle' && width > 4 && height > 4) {
|
||||
createManualMask('圆形', circlePoints(manualStart, end));
|
||||
}
|
||||
if (effectiveTool === 'create_line' && distance > 4) {
|
||||
createManualMask('线段', lineRegion(manualStart, end));
|
||||
}
|
||||
|
||||
setManualStart(null);
|
||||
setManualCurrent(null);
|
||||
@@ -880,14 +1153,7 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
|
||||
if (isPolygonEditTool) return;
|
||||
if (effectiveTool === 'box_select') return; // handled by mouseup
|
||||
if (DRAG_MANUAL_TOOLS.has(effectiveTool)) return;
|
||||
|
||||
if (effectiveTool === POINT_TOOL) {
|
||||
const pos = stagePoint(e);
|
||||
if (pos) {
|
||||
createManualMask('点区域', pointRegion(pos));
|
||||
}
|
||||
return;
|
||||
}
|
||||
if (PAINT_TOOLS.has(effectiveTool)) return;
|
||||
|
||||
if (effectiveTool === POLYGON_TOOL) {
|
||||
const pos = stagePoint(e);
|
||||
@@ -955,20 +1221,7 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
|
||||
segmentation: number[][],
|
||||
options: { area?: number; hasHoles?: boolean } = {},
|
||||
): Mask => {
|
||||
const bbox = segmentationBbox(segmentation);
|
||||
const metadata = { ...(mask.metadata || {}) };
|
||||
if (options.hasHoles === true) metadata.hasHoles = true;
|
||||
if (options.hasHoles === false) delete metadata.hasHoles;
|
||||
return {
|
||||
...mask,
|
||||
pathData: segmentationPath(segmentation),
|
||||
segmentation,
|
||||
bbox,
|
||||
area: options.area ?? segmentationArea(segmentation),
|
||||
metadata,
|
||||
saveStatus: mask.annotationId ? 'dirty' : 'draft',
|
||||
saved: mask.annotationId ? false : mask.saved,
|
||||
};
|
||||
return maskWithSegmentation(mask, segmentation, options);
|
||||
}, []);
|
||||
|
||||
useEffect(() => {
|
||||
@@ -1017,7 +1270,6 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
|
||||
if (manualStart && manualCurrent) {
|
||||
if (effectiveTool === 'create_rectangle') return polygonPath(rectanglePoints(manualStart, manualCurrent));
|
||||
if (effectiveTool === 'create_circle') return polygonPath(circlePoints(manualStart, manualCurrent));
|
||||
if (effectiveTool === 'create_line') return polygonPath(lineRegion(manualStart, manualCurrent));
|
||||
}
|
||||
if (effectiveTool === POLYGON_TOOL && polygonPoints.length > 0) {
|
||||
const previewPoints = [...polygonPoints, cursorPos];
|
||||
@@ -1217,7 +1469,7 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
|
||||
)}
|
||||
|
||||
{/* AI Returned Masks */}
|
||||
{frameMasks.map((mask) => {
|
||||
{displayFrameMasks.map((mask) => {
|
||||
const selectedIndex = selectedMaskIds.indexOf(mask.id);
|
||||
const isMaskSelected = selectedIndex >= 0;
|
||||
const isBooleanPrimary = isBooleanTool && selectedIndex === 0;
|
||||
@@ -1282,6 +1534,34 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
|
||||
/>
|
||||
)}
|
||||
|
||||
{paintStrokePoints.length > 0 && (
|
||||
<Group opacity={effectiveTool === ERASER_TOOL ? 0.28 : 0.22}>
|
||||
{paintStrokePoints.map((point, index) => (
|
||||
<Circle
|
||||
key={`paint-stroke-${index}`}
|
||||
x={point.x}
|
||||
y={point.y}
|
||||
radius={activePaintRadius}
|
||||
fill={effectiveTool === ERASER_TOOL ? '#ef4444' : activeClass?.color || '#22d3ee'}
|
||||
stroke={effectiveTool === ERASER_TOOL ? '#fecaca' : '#ffffff'}
|
||||
strokeWidth={1 / scale}
|
||||
/>
|
||||
))}
|
||||
</Group>
|
||||
)}
|
||||
|
||||
{isPaintTool && (effectiveTool === BRUSH_TOOL ? activeClass : selectedMask) && paintStrokePoints.length === 0 && (
|
||||
<Circle
|
||||
x={cursorPos.x}
|
||||
y={cursorPos.y}
|
||||
radius={activePaintRadius}
|
||||
fill="rgba(255,255,255,0.02)"
|
||||
stroke={effectiveTool === ERASER_TOOL ? '#f87171' : activeClass?.color || '#22d3ee'}
|
||||
strokeWidth={1.5 / scale}
|
||||
dash={[4 / scale, 4 / scale]}
|
||||
/>
|
||||
)}
|
||||
|
||||
{polygonPoints.map((point, index) => (
|
||||
<Circle
|
||||
key={`poly-point-${index}`}
|
||||
|
||||
@@ -188,6 +188,7 @@ export function Dashboard() {
|
||||
|
||||
return () => {
|
||||
mounted = false;
|
||||
clearTimeout(timer);
|
||||
unsubscribe();
|
||||
unsubscribeStatus();
|
||||
clearInterval(checkConnection);
|
||||
|
||||
@@ -78,7 +78,8 @@ describe('FrameTimeline', () => {
|
||||
|
||||
expect(screen.getByLabelText('视频处理进度条')).toBeInTheDocument();
|
||||
expect(screen.getByText('人工/AI 1 帧 · 自动传播 1 帧')).toBeInTheDocument();
|
||||
expect(screen.queryByTestId('current-frame-line')).not.toBeInTheDocument();
|
||||
expect(screen.getByTestId('current-frame-line')).toHaveStyle({ left: '50%' });
|
||||
expect(screen.getByTestId('current-frame-line').className).toContain('bg-white');
|
||||
expect(screen.getAllByTestId('propagated-frame-segment')).toHaveLength(1);
|
||||
expect(screen.getByTestId('propagated-frame-segment').className).toContain('bg-blue-500');
|
||||
expect(screen.getAllByTestId('annotated-frame-marker')).toHaveLength(1);
|
||||
@@ -86,31 +87,50 @@ describe('FrameTimeline', () => {
|
||||
expect(screen.queryByLabelText('跳转到已编辑帧 3')).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('renders recent propagation history segments with distinct gradient colors', () => {
|
||||
it('renders propagation history with newest bright and old records capped to one blue threshold', () => {
|
||||
useStore.setState({
|
||||
frames: [
|
||||
{ id: 'f1', projectId: 'p1', index: 0, url: '/1.jpg', width: 640, height: 360 },
|
||||
{ id: 'f2', projectId: 'p1', index: 1, url: '/2.jpg', width: 640, height: 360 },
|
||||
{ id: 'f3', projectId: 'p1', index: 2, url: '/3.jpg', width: 640, height: 360 },
|
||||
{ id: 'f4', projectId: 'p1', index: 3, url: '/4.jpg', width: 640, height: 360 },
|
||||
{ id: 'f5', projectId: 'p1', index: 4, url: '/5.jpg', width: 640, height: 360 },
|
||||
{ id: 'f6', projectId: 'p1', index: 5, url: '/6.jpg', width: 640, height: 360 },
|
||||
{ id: 'f7', projectId: 'p1', index: 6, url: '/7.jpg', width: 640, height: 360 },
|
||||
],
|
||||
});
|
||||
|
||||
render(
|
||||
<FrameTimeline
|
||||
propagationHistory={[
|
||||
{ id: 'history-1', startFrame: 1, endFrame: 2, colorIndex: 0, label: '第一次传播' },
|
||||
{ id: 'history-2', startFrame: 3, endFrame: 4, colorIndex: 1, label: '第二次传播' },
|
||||
{ id: 'history-1', startFrame: 1, endFrame: 1, colorIndex: 0, label: '第一次传播' },
|
||||
{ id: 'history-2', startFrame: 2, endFrame: 2, colorIndex: 1, label: '第二次传播' },
|
||||
{ id: 'history-3', startFrame: 3, endFrame: 3, colorIndex: 2, label: '第三次传播' },
|
||||
{ id: 'history-4', startFrame: 4, endFrame: 4, colorIndex: 3, label: '第四次传播' },
|
||||
{ id: 'history-5', startFrame: 5, endFrame: 5, colorIndex: 4, label: '第五次传播' },
|
||||
{ id: 'history-6', startFrame: 6, endFrame: 6, colorIndex: 5, label: '第六次传播' },
|
||||
{ id: 'history-7', startFrame: 7, endFrame: 7, colorIndex: 6, label: '第七次传播' },
|
||||
]}
|
||||
/>,
|
||||
);
|
||||
|
||||
const segments = screen.getAllByTestId('propagation-history-segment');
|
||||
expect(segments).toHaveLength(2);
|
||||
expect(segments).toHaveLength(7);
|
||||
expect(segments[0]).toHaveAttribute('title', '第一次传播');
|
||||
expect(segments[0]).toHaveStyle({ left: '0%', width: '50%' });
|
||||
expect(segments[0].getAttribute('style')).toContain('linear-gradient');
|
||||
expect(segments[1].getAttribute('style')).toContain('124, 58, 237');
|
||||
expect(segments[0]).toHaveStyle({ left: '0%' });
|
||||
expect(segments[0]).toHaveAttribute('data-recency-level', '4');
|
||||
expect(segments[1]).toHaveAttribute('data-recency-level', '4');
|
||||
expect(segments[2]).toHaveAttribute('data-recency-level', '4');
|
||||
expect(segments[3]).toHaveAttribute('data-recency-level', '3');
|
||||
expect(segments[4]).toHaveAttribute('data-recency-level', '2');
|
||||
expect(segments[5]).toHaveAttribute('data-recency-level', '1');
|
||||
expect(segments[6]).toHaveAttribute('data-recency-level', '0');
|
||||
const oldestStyle = segments[0].getAttribute('style') || '';
|
||||
const newestStyle = segments[6].getAttribute('style') || '';
|
||||
expect(oldestStyle).not.toContain('linear-gradient');
|
||||
expect(newestStyle).not.toContain('linear-gradient');
|
||||
expect(segments[0].style.backgroundColor).toBe(segments[1].style.backgroundColor);
|
||||
expect(segments[6].style.backgroundColor).not.toBe(segments[0].style.backgroundColor);
|
||||
});
|
||||
|
||||
it('jumps from the processing progress bar and frame status markers', () => {
|
||||
@@ -180,6 +200,7 @@ describe('FrameTimeline', () => {
|
||||
{ id: 'f2', projectId: 'p1', index: 1, url: '/2.jpg', width: 640, height: 360 },
|
||||
{ id: 'f3', projectId: 'p1', index: 2, url: '/3.jpg', width: 640, height: 360 },
|
||||
{ id: 'f4', projectId: 'p1', index: 3, url: '/4.jpg', width: 640, height: 360 },
|
||||
{ id: 'f5', projectId: 'p1', index: 4, url: '/5.jpg', width: 640, height: 360 },
|
||||
],
|
||||
masks: [
|
||||
{ id: 'm1', frameId: 'f2', pathData: 'M 0 0 Z', label: 'Draft', color: '#ef4444' },
|
||||
@@ -200,6 +221,14 @@ describe('FrameTimeline', () => {
|
||||
color: '#3b82f6',
|
||||
metadata: { source: 'sam2.1_hiera_tiny_propagation' },
|
||||
},
|
||||
{
|
||||
id: 'm5',
|
||||
frameId: 'f5',
|
||||
pathData: 'M 3 3 Z',
|
||||
label: 'Tracked after smoothing',
|
||||
color: '#3b82f6',
|
||||
metadata: { source_annotation_id: 7, source_mask_id: 'annotation-7' },
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
@@ -211,6 +240,8 @@ describe('FrameTimeline', () => {
|
||||
const manuallyAdjustedPropagatedTile = screen.getByAltText('frame-3').closest('div');
|
||||
expect(manuallyAdjustedPropagatedTile?.className).toContain('border-red-500');
|
||||
expect(manuallyAdjustedPropagatedTile?.className).toContain('inset_0_0_0_2px_rgba(59,130,246,0.85)');
|
||||
expect(screen.getByAltText('frame-4').closest('div')?.className).toContain('border-blue-500');
|
||||
expect(screen.getByAltText('frame-4').closest('div')?.className).not.toContain('border-red-500');
|
||||
});
|
||||
|
||||
it('keeps the current frame blue border while showing an inner red ring for annotated frames', () => {
|
||||
@@ -278,6 +309,12 @@ describe('FrameTimeline', () => {
|
||||
|
||||
expect(onPropagationRangeChange).toHaveBeenLastCalledWith(2, 4);
|
||||
expect(screen.getAllByTestId('propagation-range-overlay')).toHaveLength(2);
|
||||
const boundaryLines = screen.getAllByTestId('range-boundary-line');
|
||||
expect(boundaryLines).toHaveLength(2);
|
||||
expect(boundaryLines[0]).toHaveStyle({ left: '25%' });
|
||||
expect(boundaryLines[0].className).toContain('bg-fuchsia-400');
|
||||
expect(boundaryLines[1]).toHaveStyle({ left: '75%' });
|
||||
expect(boundaryLines[1].className).toContain('bg-lime-300');
|
||||
});
|
||||
|
||||
it('changes frames with left and right arrow keys without leaving bounds', () => {
|
||||
|
||||
@@ -49,7 +49,11 @@ export function FrameTimeline({
|
||||
const totalSeconds = totalFrames > 0 ? Math.max(totalFrames - 1, 0) / timeBaseFps : 0;
|
||||
const isPropagatedMask = (mask: (typeof masks)[number]) => {
|
||||
const source = typeof mask.metadata?.source === 'string' ? mask.metadata.source : '';
|
||||
return source.includes('_propagation') || mask.metadata?.propagated_from_frame_id !== undefined;
|
||||
return source.includes('_propagation')
|
||||
|| mask.metadata?.propagated_from_frame_id !== undefined
|
||||
|| mask.metadata?.source_annotation_id !== undefined
|
||||
|| mask.metadata?.source_mask_id !== undefined
|
||||
|| mask.metadata?.propagation_seed_key !== undefined;
|
||||
};
|
||||
const propagatedFrameMarkers = useMemo(() => {
|
||||
const frameIds = new Set(frames.map((frame) => frame.id));
|
||||
@@ -105,18 +109,29 @@ export function FrameTimeline({
|
||||
const rangeWidth = visibleSelectedRange && totalFrames > 0
|
||||
? ((visibleSelectedRange.endFrame - visibleSelectedRange.startFrame + 1) / totalFrames) * 100
|
||||
: 0;
|
||||
const propagationHistoryColors = [
|
||||
{ dark: 'rgba(8, 145, 178, 0.68)', light: 'rgba(103, 232, 249, 0.9)', glow: 'rgba(34, 211, 238, 0.38)' },
|
||||
{ dark: 'rgba(124, 58, 237, 0.66)', light: 'rgba(196, 181, 253, 0.9)', glow: 'rgba(167, 139, 250, 0.34)' },
|
||||
{ dark: 'rgba(5, 150, 105, 0.66)', light: 'rgba(110, 231, 183, 0.9)', glow: 'rgba(52, 211, 153, 0.34)' },
|
||||
{ dark: 'rgba(217, 119, 6, 0.66)', light: 'rgba(253, 186, 116, 0.9)', glow: 'rgba(251, 146, 60, 0.34)' },
|
||||
{ dark: 'rgba(219, 39, 119, 0.66)', light: 'rgba(251, 113, 133, 0.9)', glow: 'rgba(244, 114, 182, 0.34)' },
|
||||
];
|
||||
const frameLineLeft = (frame: number) => {
|
||||
if (totalFrames <= 1) return 0;
|
||||
return ((clampFrame(frame) - 1) / (totalFrames - 1)) * 100;
|
||||
};
|
||||
const currentFrameLineLeft = totalFrames > 0 ? frameLineLeft(currentFrame) : 0;
|
||||
const rangeStartLineLeft = visibleSelectedRange ? frameLineLeft(visibleSelectedRange.startFrame) : 0;
|
||||
const rangeEndLineLeft = visibleSelectedRange ? frameLineLeft(visibleSelectedRange.endFrame) : 0;
|
||||
const propagationHistoryColor = (ageFromNewest: number) => {
|
||||
const step = Math.min(Math.max(ageFromNewest, 0), 4);
|
||||
const lightness = 58 - step * 7;
|
||||
const alpha = 0.88 - step * 0.085;
|
||||
return {
|
||||
fill: `hsla(212, 88%, ${lightness}%, ${Math.max(alpha, 0.52)})`,
|
||||
glow: `hsla(212, 88%, ${Math.min(lightness + 10, 76)}%, ${0.38 - step * 0.045})`,
|
||||
border: `hsla(212, 90%, ${Math.min(lightness + 18, 84)}%, ${0.72 - step * 0.045})`,
|
||||
};
|
||||
};
|
||||
const visiblePropagationHistory = useMemo(() => (
|
||||
propagationHistory
|
||||
.map((segment, order) => {
|
||||
const range = normalizeRange(segment.startFrame, segment.endFrame);
|
||||
return { ...segment, ...range, order };
|
||||
const ageFromNewest = Math.min(Math.max(propagationHistory.length - 1 - order, 0), 4);
|
||||
return { ...segment, ...range, order, ageFromNewest };
|
||||
})
|
||||
.filter((segment) => totalFrames > 0 && segment.endFrame >= 1 && segment.startFrame <= totalFrames)
|
||||
), [propagationHistory, totalFrames]);
|
||||
@@ -282,6 +297,32 @@ export function FrameTimeline({
|
||||
{formatTime(currentSeconds)}
|
||||
</div>
|
||||
</div>
|
||||
{totalFrames > 0 && (
|
||||
<div
|
||||
data-testid="current-frame-line"
|
||||
aria-hidden="true"
|
||||
className="pointer-events-none absolute top-[18px] bottom-[8px] z-[60] w-[2px] -translate-x-1/2 rounded-full bg-white shadow-[0_0_10px_rgba(255,255,255,0.85)]"
|
||||
style={{ left: `${currentFrameLineLeft}%` }}
|
||||
/>
|
||||
)}
|
||||
{visibleSelectedRange && (
|
||||
<>
|
||||
<div
|
||||
data-testid="range-boundary-line"
|
||||
aria-hidden="true"
|
||||
title={`范围开始帧 ${visibleSelectedRange.startFrame}`}
|
||||
className="pointer-events-none absolute top-[16px] bottom-[7px] z-[65] w-[2px] -translate-x-1/2 rounded-full bg-fuchsia-400 shadow-[0_0_12px_rgba(244,114,182,0.9)]"
|
||||
style={{ left: `${rangeStartLineLeft}%` }}
|
||||
/>
|
||||
<div
|
||||
data-testid="range-boundary-line"
|
||||
aria-hidden="true"
|
||||
title={`范围结束帧 ${visibleSelectedRange.endFrame}`}
|
||||
className="pointer-events-none absolute top-[16px] bottom-[7px] z-[65] w-[2px] -translate-x-1/2 rounded-full bg-lime-300 shadow-[0_0_12px_rgba(190,242,100,0.9)]"
|
||||
style={{ left: `${rangeEndLineLeft}%` }}
|
||||
/>
|
||||
</>
|
||||
)}
|
||||
<div
|
||||
className={cn(
|
||||
"mt-2 h-2.5 w-full relative bg-zinc-700/80 border-y border-white/10 shadow-inner",
|
||||
@@ -321,21 +362,21 @@ export function FrameTimeline({
|
||||
);
|
||||
})}
|
||||
{visiblePropagationHistory.map((segment) => {
|
||||
const color = propagationHistoryColors[segment.colorIndex % propagationHistoryColors.length];
|
||||
const color = propagationHistoryColor(segment.ageFromNewest);
|
||||
const left = totalFrames > 0 ? ((segment.startFrame - 1) / totalFrames) * 100 : 0;
|
||||
const width = totalFrames > 0 ? ((segment.endFrame - segment.startFrame + 1) / totalFrames) * 100 : 0;
|
||||
const opacity = Math.max(0.48, 0.92 - (visiblePropagationHistory.length - 1 - segment.order) * 0.12);
|
||||
return (
|
||||
<div
|
||||
key={segment.id}
|
||||
data-testid="propagation-history-segment"
|
||||
data-recency-level={segment.ageFromNewest}
|
||||
title={segment.label || `自动传播记录:第 ${segment.startFrame}-${segment.endFrame} 帧`}
|
||||
className="pointer-events-none absolute inset-y-0 z-[15] rounded-[2px] border-x border-white/25"
|
||||
className="pointer-events-none absolute inset-y-0 z-[15] rounded-[2px] border-x"
|
||||
style={{
|
||||
left: `${left}%`,
|
||||
width: `${width}%`,
|
||||
opacity,
|
||||
background: `linear-gradient(to right, ${color.dark}, ${color.light})`,
|
||||
backgroundColor: color.fill,
|
||||
borderColor: color.border,
|
||||
boxShadow: `0 0 10px ${color.glow}`,
|
||||
}}
|
||||
/>
|
||||
|
||||
@@ -19,14 +19,19 @@ describe('Login', () => {
|
||||
});
|
||||
|
||||
it('logs in with the development credentials and stores the token', async () => {
|
||||
apiMock.login.mockResolvedValueOnce({ token: 'fake-jwt-token-for-admin' });
|
||||
apiMock.login.mockResolvedValueOnce({
|
||||
token: 'jwt-token',
|
||||
username: 'admin',
|
||||
user: { id: 1, username: 'admin', role: 'admin' },
|
||||
});
|
||||
|
||||
render(<Login />);
|
||||
fireEvent.click(screen.getByRole('button', { name: '安全登录' }));
|
||||
|
||||
await waitFor(() => expect(apiMock.login).toHaveBeenCalledWith('admin', '123456'));
|
||||
expect(useStore.getState().isAuthenticated).toBe(true);
|
||||
expect(localStorage.getItem('token')).toBe('fake-jwt-token-for-admin');
|
||||
expect(useStore.getState().currentUser?.username).toBe('admin');
|
||||
expect(localStorage.getItem('token')).toBe('jwt-token');
|
||||
});
|
||||
|
||||
it('shows backend login errors', async () => {
|
||||
@@ -39,4 +44,11 @@ describe('Login', () => {
|
||||
expect(await screen.findByText('Invalid credentials')).toBeInTheDocument();
|
||||
expect(useStore.getState().isAuthenticated).toBe(false);
|
||||
});
|
||||
|
||||
it('marks login fields with browser autocomplete hints', () => {
|
||||
render(<Login />);
|
||||
|
||||
expect(screen.getByDisplayValue('admin')).toHaveAttribute('autocomplete', 'username');
|
||||
expect(screen.getByDisplayValue('123456')).toHaveAttribute('autocomplete', 'current-password');
|
||||
});
|
||||
});
|
||||
|
||||
@@ -18,7 +18,7 @@ export function Login() {
|
||||
|
||||
try {
|
||||
const data = await loginApi(username, password);
|
||||
storeLogin(data.token);
|
||||
storeLogin(data.token, data.user);
|
||||
} catch (err: any) {
|
||||
const msg = err?.response?.data?.detail || err?.response?.data?.error || '登录失败,请检查网络或凭证';
|
||||
setError(msg);
|
||||
@@ -47,6 +47,7 @@ export function Login() {
|
||||
type="text"
|
||||
value={username}
|
||||
onChange={(e) => setUsername(e.target.value)}
|
||||
autoComplete="username"
|
||||
className="w-full bg-[#1a1a1a] border border-white/10 rounded-lg px-4 py-3 text-sm focus:outline-none focus:border-cyan-500/50 focus:ring-1 focus:ring-cyan-500/50 transition-all font-mono"
|
||||
placeholder="输入账号"
|
||||
/>
|
||||
@@ -58,6 +59,7 @@ export function Login() {
|
||||
type="password"
|
||||
value={password}
|
||||
onChange={(e) => setPassword(e.target.value)}
|
||||
autoComplete="current-password"
|
||||
className="w-full bg-[#1a1a1a] border border-white/10 rounded-lg px-4 py-3 text-sm focus:outline-none focus:border-cyan-500/50 focus:ring-1 focus:ring-cyan-500/50 transition-all font-mono"
|
||||
placeholder="输入密码"
|
||||
/>
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { fireEvent, render, screen, waitFor, within } from '@testing-library/react';
|
||||
import { act, fireEvent, render, screen, waitFor, within } from '@testing-library/react';
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest';
|
||||
import { resetStore } from '../test/storeTestUtils';
|
||||
import { useStore } from '../store/useStore';
|
||||
@@ -67,6 +67,9 @@ describe('OntologyInspector', () => {
|
||||
expect(useStore.getState().activeTemplateId).toBe('t1');
|
||||
expect(screen.getByText('胆囊')).toBeInTheDocument();
|
||||
expect(screen.getByText('肝脏')).toBeInTheDocument();
|
||||
expect(screen.getByText('maskid:1')).toBeInTheDocument();
|
||||
expect(screen.getByText('maskid:2')).toBeInTheDocument();
|
||||
expect(screen.queryByText(/z:/)).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('adjusts workspace mask opacity from above the semantic tree', () => {
|
||||
@@ -151,13 +154,86 @@ describe('OntologyInspector', () => {
|
||||
classId: 'c2',
|
||||
className: '肝脏',
|
||||
classZIndex: 10,
|
||||
classMaskId: 2,
|
||||
label: '肝脏',
|
||||
color: '#00ff00',
|
||||
saveStatus: 'dirty',
|
||||
saved: false,
|
||||
}));
|
||||
expect(screen.getByText('当前选中区域:')).toBeInTheDocument();
|
||||
expect(screen.getByText('1')).toBeInTheDocument();
|
||||
expect(screen.queryByText('当前选中区域:')).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('applies class changes to the same propagation chain across frames', () => {
|
||||
useStore.setState({
|
||||
selectedMaskIds: ['annotation-10'],
|
||||
masks: [
|
||||
{
|
||||
id: 'annotation-10',
|
||||
annotationId: '10',
|
||||
frameId: 'frame-1',
|
||||
pathData: 'M 0 0 Z',
|
||||
label: '旧标签',
|
||||
color: '#06b6d4',
|
||||
saveStatus: 'saved',
|
||||
saved: true,
|
||||
},
|
||||
{
|
||||
id: 'annotation-11',
|
||||
annotationId: '11',
|
||||
frameId: 'frame-2',
|
||||
pathData: 'M 1 1 Z',
|
||||
label: '旧传播标签',
|
||||
color: '#06b6d4',
|
||||
metadata: {
|
||||
source_annotation_id: 10,
|
||||
source_mask_id: 'annotation-10',
|
||||
propagation_seed_key: 'annotation:10',
|
||||
},
|
||||
saveStatus: 'saved',
|
||||
saved: true,
|
||||
},
|
||||
{
|
||||
id: 'annotation-99',
|
||||
annotationId: '99',
|
||||
frameId: 'frame-3',
|
||||
pathData: 'M 2 2 Z',
|
||||
label: '无关区域',
|
||||
color: '#ffffff',
|
||||
metadata: { source_annotation_id: 99 },
|
||||
saveStatus: 'saved',
|
||||
saved: true,
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
render(<OntologyInspector />);
|
||||
fireEvent.click(screen.getByText('肝脏'));
|
||||
|
||||
const updated = useStore.getState().masks;
|
||||
expect(updated.find((mask) => mask.id === 'annotation-10')).toEqual(expect.objectContaining({
|
||||
classId: 'c2',
|
||||
className: '肝脏',
|
||||
classMaskId: 2,
|
||||
label: '肝脏',
|
||||
color: '#00ff00',
|
||||
saveStatus: 'dirty',
|
||||
saved: false,
|
||||
}));
|
||||
expect(updated.find((mask) => mask.id === 'annotation-11')).toEqual(expect.objectContaining({
|
||||
classId: 'c2',
|
||||
className: '肝脏',
|
||||
classMaskId: 2,
|
||||
label: '肝脏',
|
||||
color: '#00ff00',
|
||||
saveStatus: 'dirty',
|
||||
saved: false,
|
||||
}));
|
||||
expect(updated.find((mask) => mask.id === 'annotation-99')).toEqual(expect.objectContaining({
|
||||
label: '无关区域',
|
||||
color: '#ffffff',
|
||||
saveStatus: 'saved',
|
||||
saved: true,
|
||||
}));
|
||||
});
|
||||
|
||||
it('persists custom classes to the active backend template', async () => {
|
||||
@@ -187,6 +263,59 @@ describe('OntologyInspector', () => {
|
||||
expect(useStore.getState().templates[0].classes).toHaveLength(3);
|
||||
});
|
||||
|
||||
it('persists dragged semantic class order as layer priority without changing maskid', async () => {
|
||||
apiMock.updateTemplate.mockResolvedValueOnce({
|
||||
id: 't1',
|
||||
name: '腹腔镜模板',
|
||||
classes: [
|
||||
{ id: 'c2', name: '肝脏', color: '#00ff00', zIndex: 20, maskId: 2, category: '器官' },
|
||||
{ id: 'c1', name: '胆囊', color: '#ff0000', zIndex: 10, maskId: 1, category: '器官' },
|
||||
],
|
||||
rules: [],
|
||||
});
|
||||
useStore.setState({
|
||||
masks: [{
|
||||
id: 'm-liver',
|
||||
annotationId: '42',
|
||||
frameId: 'frame-1',
|
||||
classId: 'c2',
|
||||
className: '肝脏',
|
||||
classZIndex: 10,
|
||||
pathData: 'M 0 0 Z',
|
||||
label: '肝脏',
|
||||
color: '#00ff00',
|
||||
saveStatus: 'saved',
|
||||
saved: true,
|
||||
}],
|
||||
});
|
||||
|
||||
render(<OntologyInspector />);
|
||||
const liverButton = screen.getByRole('button', { name: /肝脏/ });
|
||||
const gallbladderButton = screen.getByRole('button', { name: /胆囊/ });
|
||||
const dataTransfer = {
|
||||
effectAllowed: '',
|
||||
dropEffect: '',
|
||||
setData: vi.fn(),
|
||||
getData: vi.fn(() => 'c2'),
|
||||
};
|
||||
|
||||
fireEvent.dragStart(liverButton, { dataTransfer });
|
||||
fireEvent.dragOver(gallbladderButton, { dataTransfer });
|
||||
fireEvent.drop(gallbladderButton, { dataTransfer });
|
||||
|
||||
await waitFor(() => expect(apiMock.updateTemplate).toHaveBeenCalledWith('t1', expect.objectContaining({
|
||||
classes: [
|
||||
expect.objectContaining({ id: 'c2', zIndex: 20, maskId: 2 }),
|
||||
expect.objectContaining({ id: 'c1', zIndex: 10, maskId: 1 }),
|
||||
],
|
||||
})));
|
||||
expect(useStore.getState().masks[0]).toEqual(expect.objectContaining({
|
||||
classZIndex: 20,
|
||||
saveStatus: 'dirty',
|
||||
saved: false,
|
||||
}));
|
||||
});
|
||||
|
||||
it('loads selected mask properties from the backend analyzer', async () => {
|
||||
useStore.setState({
|
||||
frames: [{ id: 'frame-1', projectId: 'p1', index: 0, url: '/1.jpg', width: 100, height: 100 }],
|
||||
@@ -214,15 +343,40 @@ describe('OntologyInspector', () => {
|
||||
expect(screen.queryByText('后端模型置信度')).not.toBeInTheDocument();
|
||||
expect(screen.queryByText('0.8200')).not.toBeInTheDocument();
|
||||
expect(screen.getByText('4 节点')).toBeInTheDocument();
|
||||
fireEvent.click(screen.getByRole('button', { name: '重新提取拓扑锚点' }));
|
||||
expect(screen.queryByRole('button', { name: '重新提取拓扑锚点' })).not.toBeInTheDocument();
|
||||
expect(apiMock.analyzeMask).toHaveBeenLastCalledWith(
|
||||
expect.objectContaining({ id: 'm1' }),
|
||||
expect.objectContaining({ id: 'frame-1' }),
|
||||
{ extractSkeleton: true },
|
||||
);
|
||||
});
|
||||
|
||||
it('applies backend edge smoothing to the selected mask and marks it dirty', async () => {
|
||||
it('ignores aborted mask analysis requests without showing an error', async () => {
|
||||
const consoleError = vi.spyOn(console, 'error').mockImplementation(() => {});
|
||||
apiMock.analyzeMask.mockRejectedValueOnce({ code: 'ECONNABORTED', message: 'Request aborted' });
|
||||
useStore.setState({
|
||||
frames: [{ id: 'frame-1', projectId: 'p1', index: 0, url: '/1.jpg', width: 100, height: 100 }],
|
||||
selectedMaskIds: ['m1'],
|
||||
masks: [
|
||||
{
|
||||
id: 'm1',
|
||||
frameId: 'frame-1',
|
||||
pathData: 'M 0 0 Z',
|
||||
label: '胆囊',
|
||||
color: '#ff0000',
|
||||
segmentation: [[10, 10, 20, 10, 20, 20]],
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
render(<OntologyInspector />);
|
||||
|
||||
await waitFor(() => expect(apiMock.analyzeMask).toHaveBeenCalled());
|
||||
await waitFor(() => expect(screen.queryByText('后端属性读取失败')).not.toBeInTheDocument());
|
||||
expect(consoleError).not.toHaveBeenCalled();
|
||||
consoleError.mockRestore();
|
||||
});
|
||||
|
||||
it('previews backend edge smoothing while moving the slider without marking the mask dirty', async () => {
|
||||
useStore.setState({
|
||||
frames: [{ id: 'frame-1', projectId: 'p1', index: 0, url: '/1.jpg', width: 100, height: 100 }],
|
||||
selectedMaskIds: ['m1'],
|
||||
@@ -244,13 +398,100 @@ describe('OntologyInspector', () => {
|
||||
render(<OntologyInspector />);
|
||||
|
||||
fireEvent.change(screen.getByLabelText('边缘平滑强度'), { target: { value: '35' } });
|
||||
fireEvent.click(screen.getByRole('button', { name: '应用边缘平滑' }));
|
||||
|
||||
await waitFor(() => expect(apiMock.smoothMaskGeometry).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ id: 'm1' }),
|
||||
expect.objectContaining({ id: 'frame-1' }),
|
||||
35,
|
||||
));
|
||||
await waitFor(() => expect(useStore.getState().masks[0]).toEqual(expect.objectContaining({
|
||||
pathData: 'M 12 12 L 28 12 L 28 28 L 12 28 Z',
|
||||
segmentation: [[12, 12, 28, 12, 28, 28, 12, 28]],
|
||||
bbox: [12, 12, 16, 16],
|
||||
area: 256,
|
||||
saveStatus: 'saved',
|
||||
saved: true,
|
||||
metadata: { geometry_smoothing_preview: { strength: 35, method: 'chaikin' } },
|
||||
})));
|
||||
expect(screen.getByText('已应用边缘平滑强度 35,预览中,点击应用后写入当前 mask。')).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('debounces backend edge smoothing preview while dragging the slider', async () => {
|
||||
vi.useFakeTimers();
|
||||
try {
|
||||
useStore.setState({
|
||||
frames: [{ id: 'frame-1', projectId: 'p1', index: 0, url: '/1.jpg', width: 100, height: 100 }],
|
||||
selectedMaskIds: ['m1'],
|
||||
masks: [
|
||||
{
|
||||
id: 'm1',
|
||||
annotationId: '10',
|
||||
frameId: 'frame-1',
|
||||
pathData: 'M 10 10 L 30 10 L 30 30 Z',
|
||||
label: '胆囊',
|
||||
color: '#ff0000',
|
||||
segmentation: [[10, 10, 30, 10, 30, 30]],
|
||||
saveStatus: 'saved',
|
||||
saved: true,
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
render(<OntologyInspector />);
|
||||
|
||||
fireEvent.change(screen.getByLabelText('边缘平滑强度'), { target: { value: '15' } });
|
||||
fireEvent.change(screen.getByLabelText('边缘平滑强度'), { target: { value: '25' } });
|
||||
fireEvent.change(screen.getByLabelText('边缘平滑强度'), { target: { value: '35' } });
|
||||
|
||||
expect(screen.getByText('正在等待停止拖动后生成边缘平滑预览...')).toBeInTheDocument();
|
||||
expect(apiMock.smoothMaskGeometry).not.toHaveBeenCalled();
|
||||
|
||||
act(() => {
|
||||
vi.advanceTimersByTime(219);
|
||||
});
|
||||
expect(apiMock.smoothMaskGeometry).not.toHaveBeenCalled();
|
||||
|
||||
await act(async () => {
|
||||
vi.advanceTimersByTime(1);
|
||||
await Promise.resolve();
|
||||
});
|
||||
|
||||
expect(apiMock.smoothMaskGeometry).toHaveBeenCalledTimes(1);
|
||||
expect(apiMock.smoothMaskGeometry).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ id: 'm1' }),
|
||||
expect.objectContaining({ id: 'frame-1' }),
|
||||
35,
|
||||
);
|
||||
} finally {
|
||||
vi.useRealTimers();
|
||||
}
|
||||
});
|
||||
|
||||
it('applies a previewed edge smoothing result to the selected mask and marks it dirty', async () => {
|
||||
useStore.setState({
|
||||
frames: [{ id: 'frame-1', projectId: 'p1', index: 0, url: '/1.jpg', width: 100, height: 100 }],
|
||||
selectedMaskIds: ['m1'],
|
||||
masks: [
|
||||
{
|
||||
id: 'm1',
|
||||
annotationId: '10',
|
||||
frameId: 'frame-1',
|
||||
pathData: 'M 10 10 L 30 10 L 30 30 Z',
|
||||
label: '胆囊',
|
||||
color: '#ff0000',
|
||||
segmentation: [[10, 10, 30, 10, 30, 30]],
|
||||
saveStatus: 'saved',
|
||||
saved: true,
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
render(<OntologyInspector />);
|
||||
|
||||
fireEvent.change(screen.getByLabelText('边缘平滑强度'), { target: { value: '35' } });
|
||||
await waitFor(() => expect(screen.getByRole('button', { name: '应用边缘平滑' })).not.toBeDisabled());
|
||||
fireEvent.click(screen.getByRole('button', { name: '应用边缘平滑' }));
|
||||
|
||||
await waitFor(() => expect(useStore.getState().masks[0]).toEqual(expect.objectContaining({
|
||||
pathData: 'M 12 12 L 28 12 L 28 28 L 12 28 Z',
|
||||
segmentation: [[12, 12, 28, 12, 28, 28, 12, 28]],
|
||||
@@ -258,8 +499,87 @@ describe('OntologyInspector', () => {
|
||||
area: 256,
|
||||
saveStatus: 'dirty',
|
||||
saved: false,
|
||||
metadata: { geometry_smoothing: { strength: 35, method: 'chaikin' } },
|
||||
})));
|
||||
expect(screen.getByText('已应用边缘平滑强度 35,请保存后生效')).toBeInTheDocument();
|
||||
expect(useStore.getState().masks[0].metadata?.geometry_smoothing).toBeUndefined();
|
||||
expect(apiMock.smoothMaskGeometry).toHaveBeenCalledTimes(1);
|
||||
expect(screen.getByText('0%')).toBeInTheDocument();
|
||||
expect(screen.getByText('已应用边缘平滑强度 35,已变为新的 mask,强度已重置为 0,请保存后生效')).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('applies smoothing to linked propagation masks as one undoable geometry edit', async () => {
|
||||
useStore.setState({
|
||||
frames: [
|
||||
{ id: 'frame-0', projectId: 'p1', index: 0, url: '/0.jpg', width: 100, height: 100 },
|
||||
{ id: 'frame-1', projectId: 'p1', index: 1, url: '/1.jpg', width: 100, height: 100 },
|
||||
{ id: 'frame-2', projectId: 'p1', index: 2, url: '/2.jpg', width: 100, height: 100 },
|
||||
],
|
||||
selectedMaskIds: ['seed-mask'],
|
||||
masks: [
|
||||
{
|
||||
id: 'seed-mask',
|
||||
annotationId: '10',
|
||||
frameId: 'frame-1',
|
||||
pathData: 'M 10 10 L 30 10 L 30 30 Z',
|
||||
label: '胆囊',
|
||||
color: '#ff0000',
|
||||
segmentation: [[10, 10, 30, 10, 30, 30]],
|
||||
saveStatus: 'saved',
|
||||
saved: true,
|
||||
},
|
||||
{
|
||||
id: 'prop-backward',
|
||||
annotationId: '11',
|
||||
frameId: 'frame-0',
|
||||
pathData: 'M 11 11 L 31 11 L 31 31 Z',
|
||||
label: '胆囊',
|
||||
color: '#ff0000',
|
||||
segmentation: [[11, 11, 31, 11, 31, 31]],
|
||||
saveStatus: 'saved',
|
||||
saved: true,
|
||||
metadata: { source_annotation_id: 10, source_mask_id: 'annotation-10', propagated_from_frame_id: 10 },
|
||||
},
|
||||
{
|
||||
id: 'prop-forward',
|
||||
annotationId: '12',
|
||||
frameId: 'frame-2',
|
||||
pathData: 'M 12 12 L 32 12 L 32 32 Z',
|
||||
label: '胆囊',
|
||||
color: '#ff0000',
|
||||
segmentation: [[12, 12, 32, 12, 32, 32]],
|
||||
saveStatus: 'saved',
|
||||
saved: true,
|
||||
metadata: { source_annotation_id: 10, source_mask_id: 'annotation-10', propagated_from_frame_id: 10 },
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
render(<OntologyInspector />);
|
||||
|
||||
fireEvent.change(screen.getByLabelText('边缘平滑强度'), { target: { value: '35' } });
|
||||
await waitFor(() => expect(screen.getByRole('button', { name: '应用边缘平滑' })).not.toBeDisabled());
|
||||
fireEvent.click(screen.getByRole('button', { name: '应用边缘平滑' }));
|
||||
|
||||
await waitFor(() => expect(apiMock.smoothMaskGeometry).toHaveBeenCalledTimes(3));
|
||||
await waitFor(() => expect(useStore.getState().masks).toEqual([
|
||||
expect.objectContaining({ id: 'seed-mask', pathData: 'M 12 12 L 28 12 L 28 28 L 12 28 Z', saveStatus: 'dirty', saved: false }),
|
||||
expect.objectContaining({ id: 'prop-backward', pathData: 'M 12 12 L 28 12 L 28 28 L 12 28 Z', saveStatus: 'dirty', saved: false }),
|
||||
expect.objectContaining({ id: 'prop-forward', pathData: 'M 12 12 L 28 12 L 28 28 L 12 28 Z', saveStatus: 'dirty', saved: false }),
|
||||
]));
|
||||
expect(useStore.getState().masks.every((mask) => !mask.metadata?.geometry_smoothing)).toBe(true);
|
||||
expect(screen.getByText('已应用边缘平滑强度 35,已同步应用到传播链 3 个对应 mask,强度已重置为 0,请保存后生效')).toBeInTheDocument();
|
||||
|
||||
act(() => {
|
||||
useStore.getState().undoMasks();
|
||||
});
|
||||
expect(useStore.getState().masks.map((mask) => mask.pathData)).toEqual([
|
||||
'M 10 10 L 30 10 L 30 30 Z',
|
||||
'M 11 11 L 31 11 L 31 31 Z',
|
||||
'M 12 12 L 32 12 L 32 32 Z',
|
||||
]);
|
||||
|
||||
act(() => {
|
||||
useStore.getState().redoMasks();
|
||||
});
|
||||
expect(useStore.getState().masks.every((mask) => mask.pathData === 'M 12 12 L 28 12 L 28 28 L 12 28 Z')).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -1,10 +1,63 @@
|
||||
import React, { useEffect, useMemo, useRef, useState } from 'react';
|
||||
import { ChevronDown, Tag, Eye, Plus, X, Loader2 } from 'lucide-react';
|
||||
import { ChevronDown, Tag, Eye, Plus, X, Loader2, GripVertical } from 'lucide-react';
|
||||
import { useStore } from '../store/useStore';
|
||||
import type { TemplateClass } from '../store/useStore';
|
||||
import type { Mask, TemplateClass } from '../store/useStore';
|
||||
import { cn } from '../lib/utils';
|
||||
import { getActiveTemplate } from '../lib/templateSelection';
|
||||
import { analyzeMask, smoothMaskGeometry, updateTemplate, type MaskAnalysisResult } from '../lib/api';
|
||||
import { analyzeMask, smoothMaskGeometry, updateTemplate, type MaskAnalysisResult, type SmoothMaskGeometryResult } from '../lib/api';
|
||||
import { nextClassMaskId, normalizeClassMaskIds } from '../lib/maskIds';
|
||||
|
||||
const SMOOTHING_PREVIEW_DEBOUNCE_MS = 220;
|
||||
|
||||
const isRequestAbortError = (err: unknown) => {
|
||||
const error = err as { code?: string; message?: string; name?: string } | null;
|
||||
const message = error?.message || '';
|
||||
return error?.code === 'ERR_CANCELED'
|
||||
|| error?.code === 'ECONNABORTED'
|
||||
|| error?.name === 'AbortError'
|
||||
|| /request aborted|aborted|cancell?ed/i.test(message);
|
||||
};
|
||||
|
||||
function metadataNumber(value: unknown): number | null {
|
||||
const parsed = Number(value);
|
||||
return Number.isFinite(parsed) && parsed > 0 ? parsed : null;
|
||||
}
|
||||
|
||||
function propagationSourceMaskTokens(value: unknown): string[] {
|
||||
if (typeof value !== 'string' || value.length === 0) return [];
|
||||
const tokens = [`mask:${value}`];
|
||||
const annotationMatch = value.match(/^annotation-(\d+)$/);
|
||||
if (annotationMatch) {
|
||||
tokens.push(`annotation:${annotationMatch[1]}`);
|
||||
}
|
||||
return tokens;
|
||||
}
|
||||
|
||||
function propagationLineageTokens(mask: { id: string; annotationId?: string; metadata?: Record<string, unknown> }): Set<string> {
|
||||
const metadata = mask.metadata || {};
|
||||
const tokens = new Set<string>([`mask:${mask.id}`]);
|
||||
if (mask.annotationId) {
|
||||
tokens.add(`annotation:${mask.annotationId}`);
|
||||
}
|
||||
const sourceAnnotationId = metadataNumber(metadata.source_annotation_id);
|
||||
if (sourceAnnotationId !== null) {
|
||||
tokens.add(`annotation:${sourceAnnotationId}`);
|
||||
}
|
||||
propagationSourceMaskTokens(metadata.source_mask_id).forEach((token) => tokens.add(token));
|
||||
if (typeof metadata.propagation_seed_key === 'string' && metadata.propagation_seed_key.length > 0) {
|
||||
tokens.add(`seed-key:${metadata.propagation_seed_key}`);
|
||||
}
|
||||
return tokens;
|
||||
}
|
||||
|
||||
function findPropagationChainMaskIds(selectedMask: Pick<Mask, 'id' | 'annotationId' | 'metadata'>, masks: Mask[]): Set<string> {
|
||||
const selectedTokens = propagationLineageTokens(selectedMask);
|
||||
return new Set(
|
||||
masks
|
||||
.filter((mask) => Array.from(selectedTokens).some((token) => propagationLineageTokens(mask).has(token)))
|
||||
.map((mask) => mask.id),
|
||||
);
|
||||
}
|
||||
|
||||
export function OntologyInspector() {
|
||||
const templates = useStore((state) => state.templates);
|
||||
@@ -27,20 +80,38 @@ export function OntologyInspector() {
|
||||
const [newClassColor, setNewClassColor] = useState('#06b6d4');
|
||||
const [isSavingClass, setIsSavingClass] = useState(false);
|
||||
const [classSaveMessage, setClassSaveMessage] = useState('');
|
||||
const [dragClassId, setDragClassId] = useState<string | null>(null);
|
||||
const [maskAnalysis, setMaskAnalysis] = useState<MaskAnalysisResult | null>(null);
|
||||
const [isAnalyzingMask, setIsAnalyzingMask] = useState(false);
|
||||
const [analysisMessage, setAnalysisMessage] = useState('');
|
||||
const [smoothingStrength, setSmoothingStrength] = useState(0);
|
||||
const [isPreviewingSmoothing, setIsPreviewingSmoothing] = useState(false);
|
||||
const [isSmoothingMask, setIsSmoothingMask] = useState(false);
|
||||
|
||||
const activeTemplate = getActiveTemplate(templates, activeTemplateId);
|
||||
const templateClasses = activeTemplate?.classes || [];
|
||||
const templateClasses = normalizeClassMaskIds(activeTemplate?.classes || []);
|
||||
const allClasses = [...templateClasses].sort((a, b) => b.zIndex - a.zIndex);
|
||||
const selectedMask = masks.find((mask) => selectedMaskIds.includes(mask.id)) || null;
|
||||
const selectedMaskLabel = selectedMask?.className || selectedMask?.label || '未选择';
|
||||
const currentFrame = frames[currentFrameIndex] || null;
|
||||
const classButtonRefs = useRef(new Map<string, HTMLButtonElement>());
|
||||
const skipNextAutoAnalysisRef = useRef(false);
|
||||
const analysisRequestIdRef = useRef(0);
|
||||
const smoothingPreviewRef = useRef<{
|
||||
maskId: string;
|
||||
baseMask: NonNullable<typeof selectedMask>;
|
||||
strength: number;
|
||||
result: SmoothMaskGeometryResult | null;
|
||||
applied: boolean;
|
||||
requestId: number;
|
||||
} | null>(null);
|
||||
const smoothingRequestIdRef = useRef(0);
|
||||
const smoothingPreviewTimerRef = useRef<number | null>(null);
|
||||
|
||||
const clearSmoothingPreviewTimer = React.useCallback(() => {
|
||||
if (smoothingPreviewTimerRef.current === null) return;
|
||||
window.clearTimeout(smoothingPreviewTimerRef.current);
|
||||
smoothingPreviewTimerRef.current = null;
|
||||
}, []);
|
||||
|
||||
const selectedMaskClass = useMemo(() => {
|
||||
if (!selectedMask) return null;
|
||||
@@ -78,14 +149,21 @@ export function OntologyInspector() {
|
||||
if (!hasSelectedMasks) return;
|
||||
|
||||
const templateId = activeTemplate?.id || activeTemplateId || undefined;
|
||||
const targetIdSet = new Set<string>();
|
||||
masks
|
||||
.filter((mask) => selectedIdSet.has(mask.id))
|
||||
.forEach((mask) => {
|
||||
findPropagationChainMaskIds(mask, masks).forEach((maskId) => targetIdSet.add(maskId));
|
||||
});
|
||||
const updatedMasks = masks.map((mask) => {
|
||||
if (!selectedIdSet.has(mask.id)) return mask;
|
||||
if (!targetIdSet.has(mask.id)) return mask;
|
||||
return {
|
||||
...mask,
|
||||
templateId: templateId || mask.templateId,
|
||||
classId: templateClass.id,
|
||||
className: templateClass.name,
|
||||
classZIndex: templateClass.zIndex,
|
||||
classMaskId: templateClass.maskId,
|
||||
label: templateClass.name,
|
||||
color: templateClass.color,
|
||||
saveStatus: mask.annotationId ? 'dirty' as const : 'draft' as const,
|
||||
@@ -101,33 +179,63 @@ export function OntologyInspector() {
|
||||
]);
|
||||
};
|
||||
|
||||
const refreshMaskAnalysis = async (extractSkeleton = false) => {
|
||||
const refreshMaskAnalysis = async () => {
|
||||
const requestId = analysisRequestIdRef.current + 1;
|
||||
analysisRequestIdRef.current = requestId;
|
||||
if (!selectedMask || !currentFrame) {
|
||||
setMaskAnalysis(null);
|
||||
setAnalysisMessage(selectedMask ? '当前帧信息不可用,无法读取后端属性' : '请选择一个 mask 查看后端属性');
|
||||
return;
|
||||
}
|
||||
setIsAnalyzingMask(true);
|
||||
setAnalysisMessage('');
|
||||
try {
|
||||
const result = await analyzeMask(selectedMask, currentFrame, { extractSkeleton });
|
||||
const result = await analyzeMask(selectedMask, currentFrame);
|
||||
if (analysisRequestIdRef.current !== requestId) return;
|
||||
setMaskAnalysis(result);
|
||||
setAnalysisMessage(result.message);
|
||||
} catch (err) {
|
||||
if (analysisRequestIdRef.current !== requestId || isRequestAbortError(err)) return;
|
||||
console.error('Mask analysis failed:', err);
|
||||
setMaskAnalysis(null);
|
||||
setAnalysisMessage('后端属性读取失败');
|
||||
} finally {
|
||||
setIsAnalyzingMask(false);
|
||||
}
|
||||
};
|
||||
|
||||
const restoreSmoothingPreview = React.useCallback(() => {
|
||||
const preview = smoothingPreviewRef.current;
|
||||
if (!preview || preview.applied) {
|
||||
smoothingPreviewRef.current = null;
|
||||
return;
|
||||
}
|
||||
const state = useStore.getState();
|
||||
useStore.setState({
|
||||
masks: state.masks.map((mask) => (mask.id === preview.maskId ? preview.baseMask : mask)),
|
||||
selectedMaskIds: state.selectedMaskIds,
|
||||
});
|
||||
smoothingPreviewRef.current = null;
|
||||
}, []);
|
||||
|
||||
React.useEffect(() => {
|
||||
return () => {
|
||||
analysisRequestIdRef.current += 1;
|
||||
clearSmoothingPreviewTimer();
|
||||
restoreSmoothingPreview();
|
||||
};
|
||||
}, [clearSmoothingPreviewTimer, restoreSmoothingPreview]);
|
||||
|
||||
React.useEffect(() => {
|
||||
const preview = smoothingPreviewRef.current;
|
||||
if (preview && preview.maskId !== selectedMask?.id) {
|
||||
restoreSmoothingPreview();
|
||||
}
|
||||
}, [restoreSmoothingPreview, selectedMask?.id]);
|
||||
|
||||
React.useEffect(() => {
|
||||
if (skipNextAutoAnalysisRef.current) {
|
||||
skipNextAutoAnalysisRef.current = false;
|
||||
return;
|
||||
}
|
||||
void refreshMaskAnalysis(false);
|
||||
void refreshMaskAnalysis();
|
||||
// selectedMask is intentionally tracked by id and geometry fields to avoid
|
||||
// re-running analysis for unrelated store changes.
|
||||
}, [selectedMask?.id, selectedMask?.segmentation, selectedMask?.points, currentFrame?.id]);
|
||||
@@ -140,43 +248,202 @@ export function OntologyInspector() {
|
||||
setSmoothingStrength(Number.isFinite(strength) ? Math.min(Math.max(strength, 0), 100) : 0);
|
||||
}, [selectedMask?.id]);
|
||||
|
||||
const applySmoothingResultToMask = React.useCallback((
|
||||
mask: Mask,
|
||||
result: SmoothMaskGeometryResult,
|
||||
options: { commit: boolean },
|
||||
): Mask => {
|
||||
const metadata = { ...(mask.metadata || {}) };
|
||||
delete metadata.geometry_smoothing_preview;
|
||||
if (options.commit) {
|
||||
delete metadata.geometry_smoothing;
|
||||
} else {
|
||||
metadata.geometry_smoothing_preview = result.smoothing;
|
||||
}
|
||||
return {
|
||||
...mask,
|
||||
pathData: result.pathData,
|
||||
segmentation: result.segmentation,
|
||||
bbox: result.bbox,
|
||||
area: result.area,
|
||||
metadata,
|
||||
...(options.commit
|
||||
? {
|
||||
saveStatus: mask.annotationId ? 'dirty' as const : 'draft' as const,
|
||||
saved: mask.annotationId ? false : mask.saved,
|
||||
}
|
||||
: {}),
|
||||
};
|
||||
}, []);
|
||||
|
||||
const updateMaskWithSmoothingResult = React.useCallback((
|
||||
maskId: string,
|
||||
result: SmoothMaskGeometryResult,
|
||||
options: { commit: boolean },
|
||||
) => {
|
||||
const state = useStore.getState();
|
||||
const nextMasks = state.masks.map((mask) => (
|
||||
mask.id === maskId ? applySmoothingResultToMask(mask, result, options) : mask
|
||||
));
|
||||
if (options.commit) {
|
||||
setMasks(nextMasks);
|
||||
} else {
|
||||
useStore.setState({ masks: nextMasks });
|
||||
}
|
||||
}, [applySmoothingResultToMask, setMasks]);
|
||||
|
||||
const applySmoothingResultToAnalysis = React.useCallback((
|
||||
result: SmoothMaskGeometryResult,
|
||||
sourceMask: NonNullable<typeof selectedMask>,
|
||||
suffix: string,
|
||||
) => {
|
||||
setMaskAnalysis({
|
||||
confidence: null,
|
||||
confidence_source: 'manual_or_imported',
|
||||
topology_anchor_count: result.topology_anchor_count,
|
||||
topology_anchors: result.topology_anchors,
|
||||
area: result.area,
|
||||
bbox: result.bbox,
|
||||
source: sourceMask.metadata?.source as string | undefined,
|
||||
message: result.message,
|
||||
});
|
||||
setAnalysisMessage(`${result.message}${suffix}`);
|
||||
}, []);
|
||||
|
||||
const runSmoothingPreview = React.useCallback(async (nextStrength: number) => {
|
||||
if (!selectedMask || !currentFrame) return;
|
||||
|
||||
const existingPreview = smoothingPreviewRef.current?.maskId === selectedMask.id
|
||||
? smoothingPreviewRef.current
|
||||
: null;
|
||||
const baseMask = existingPreview?.baseMask || selectedMask;
|
||||
const requestId = smoothingRequestIdRef.current + 1;
|
||||
smoothingRequestIdRef.current = requestId;
|
||||
|
||||
if (nextStrength <= 0) {
|
||||
clearSmoothingPreviewTimer();
|
||||
smoothingPreviewRef.current = {
|
||||
maskId: selectedMask.id,
|
||||
baseMask,
|
||||
strength: 0,
|
||||
result: null,
|
||||
applied: false,
|
||||
requestId,
|
||||
};
|
||||
skipNextAutoAnalysisRef.current = true;
|
||||
useStore.setState({
|
||||
masks: useStore.getState().masks.map((mask) => (mask.id === selectedMask.id ? baseMask : mask)),
|
||||
});
|
||||
setAnalysisMessage('已预览恢复原始边缘,点击应用后写入当前 mask。');
|
||||
setIsPreviewingSmoothing(false);
|
||||
return;
|
||||
}
|
||||
|
||||
setAnalysisMessage('正在生成边缘平滑预览...');
|
||||
try {
|
||||
const result = await smoothMaskGeometry(baseMask, currentFrame, nextStrength);
|
||||
if (smoothingRequestIdRef.current !== requestId) return;
|
||||
smoothingPreviewRef.current = {
|
||||
maskId: selectedMask.id,
|
||||
baseMask,
|
||||
strength: nextStrength,
|
||||
result,
|
||||
applied: false,
|
||||
requestId,
|
||||
};
|
||||
skipNextAutoAnalysisRef.current = true;
|
||||
updateMaskWithSmoothingResult(selectedMask.id, result, { commit: false });
|
||||
applySmoothingResultToAnalysis(result, baseMask, ',预览中,点击应用后写入当前 mask。');
|
||||
} catch (err) {
|
||||
if (smoothingRequestIdRef.current !== requestId) return;
|
||||
console.error('Mask smoothing preview failed:', err);
|
||||
setAnalysisMessage('边缘平滑预览失败,请检查后端服务');
|
||||
} finally {
|
||||
if (smoothingRequestIdRef.current === requestId) {
|
||||
setIsPreviewingSmoothing(false);
|
||||
}
|
||||
}
|
||||
}, [applySmoothingResultToAnalysis, clearSmoothingPreviewTimer, currentFrame, selectedMask, updateMaskWithSmoothingResult]);
|
||||
|
||||
const previewSmoothing = React.useCallback((nextStrength: number) => {
|
||||
setSmoothingStrength(nextStrength);
|
||||
clearSmoothingPreviewTimer();
|
||||
if (!selectedMask || !currentFrame) return;
|
||||
if (nextStrength <= 0) {
|
||||
void runSmoothingPreview(nextStrength);
|
||||
return;
|
||||
}
|
||||
setIsPreviewingSmoothing(true);
|
||||
setAnalysisMessage('正在等待停止拖动后生成边缘平滑预览...');
|
||||
smoothingPreviewTimerRef.current = window.setTimeout(() => {
|
||||
smoothingPreviewTimerRef.current = null;
|
||||
void runSmoothingPreview(nextStrength);
|
||||
}, SMOOTHING_PREVIEW_DEBOUNCE_MS);
|
||||
}, [clearSmoothingPreviewTimer, currentFrame, runSmoothingPreview, selectedMask]);
|
||||
|
||||
const handleApplySmoothing = async () => {
|
||||
if (!selectedMask || !currentFrame) {
|
||||
setAnalysisMessage('请选择一个 mask 后再应用边缘平滑');
|
||||
return;
|
||||
}
|
||||
clearSmoothingPreviewTimer();
|
||||
smoothingRequestIdRef.current += 1;
|
||||
setIsSmoothingMask(true);
|
||||
setAnalysisMessage('');
|
||||
try {
|
||||
const result = await smoothMaskGeometry(selectedMask, currentFrame, smoothingStrength);
|
||||
skipNextAutoAnalysisRef.current = true;
|
||||
setMasks(masks.map((mask) => {
|
||||
if (mask.id !== selectedMask.id) return mask;
|
||||
return {
|
||||
...mask,
|
||||
pathData: result.pathData,
|
||||
segmentation: result.segmentation,
|
||||
bbox: result.bbox,
|
||||
area: result.area,
|
||||
metadata: {
|
||||
...(mask.metadata || {}),
|
||||
geometry_smoothing: result.smoothing,
|
||||
},
|
||||
saveStatus: mask.annotationId ? 'dirty' as const : 'draft' as const,
|
||||
saved: mask.annotationId ? false : mask.saved,
|
||||
};
|
||||
}));
|
||||
setMaskAnalysis({
|
||||
confidence: null,
|
||||
confidence_source: 'manual_or_imported',
|
||||
topology_anchor_count: result.topology_anchor_count,
|
||||
topology_anchors: result.topology_anchors,
|
||||
area: result.area,
|
||||
bbox: result.bbox,
|
||||
source: selectedMask.metadata?.source as string | undefined,
|
||||
message: result.message,
|
||||
const existingPreview = smoothingPreviewRef.current?.maskId === selectedMask.id
|
||||
&& smoothingPreviewRef.current.strength === smoothingStrength
|
||||
? smoothingPreviewRef.current
|
||||
: null;
|
||||
const baseMask = existingPreview?.baseMask || selectedMask;
|
||||
if (smoothingStrength <= 0) {
|
||||
smoothingPreviewRef.current = null;
|
||||
setSmoothingStrength(0);
|
||||
setAnalysisMessage('边缘平滑强度为 0,当前 mask 保持原始边缘。');
|
||||
return;
|
||||
}
|
||||
|
||||
const state = useStore.getState();
|
||||
const frameById = new Map(state.frames.map((frame) => [String(frame.id), frame]));
|
||||
const chainMaskIds = findPropagationChainMaskIds(baseMask, state.masks);
|
||||
chainMaskIds.add(selectedMask.id);
|
||||
const selectedResult = existingPreview?.result || await smoothMaskGeometry(baseMask, currentFrame, smoothingStrength);
|
||||
const resultEntries = new Map<string, SmoothMaskGeometryResult>();
|
||||
resultEntries.set(selectedMask.id, selectedResult);
|
||||
|
||||
await Promise.all(
|
||||
Array.from(chainMaskIds)
|
||||
.filter((maskId) => maskId !== selectedMask.id)
|
||||
.map(async (maskId) => {
|
||||
const mask = state.masks.find((item) => item.id === maskId);
|
||||
const frame = mask ? frameById.get(String(mask.frameId)) : null;
|
||||
if (!mask || !frame) return;
|
||||
resultEntries.set(maskId, await smoothMaskGeometry(mask, frame, smoothingStrength));
|
||||
}),
|
||||
);
|
||||
|
||||
const latestMasks = useStore.getState().masks;
|
||||
const historyBaseMasks = latestMasks.map((mask) => (mask.id === selectedMask.id ? baseMask : mask));
|
||||
useStore.setState({ masks: historyBaseMasks });
|
||||
const nextMasks = historyBaseMasks.map((mask) => {
|
||||
const result = resultEntries.get(mask.id);
|
||||
if (!result) return mask;
|
||||
return applySmoothingResultToMask(mask, result, { commit: true });
|
||||
});
|
||||
setAnalysisMessage(`${result.message},请保存后生效`);
|
||||
skipNextAutoAnalysisRef.current = true;
|
||||
setMasks(nextMasks);
|
||||
if (smoothingPreviewRef.current) {
|
||||
smoothingPreviewRef.current.applied = true;
|
||||
}
|
||||
smoothingPreviewRef.current = null;
|
||||
setSmoothingStrength(0);
|
||||
applySmoothingResultToAnalysis(
|
||||
selectedResult,
|
||||
baseMask,
|
||||
resultEntries.size > 1
|
||||
? `,已同步应用到传播链 ${resultEntries.size} 个对应 mask,强度已重置为 0,请保存后生效`
|
||||
: ',已变为新的 mask,强度已重置为 0,请保存后生效',
|
||||
);
|
||||
} catch (err) {
|
||||
console.error('Mask smoothing failed:', err);
|
||||
setAnalysisMessage('边缘平滑失败,请检查后端服务');
|
||||
@@ -197,6 +464,7 @@ export function OntologyInspector() {
|
||||
name: newClassName.trim(),
|
||||
color: newClassColor,
|
||||
zIndex: maxZ + 10,
|
||||
maskId: nextClassMaskId(templateClasses),
|
||||
category: '自定义',
|
||||
};
|
||||
setIsSavingClass(true);
|
||||
@@ -205,7 +473,7 @@ export function OntologyInspector() {
|
||||
const updated = await updateTemplate(activeTemplate.id, {
|
||||
name: activeTemplate.name,
|
||||
description: activeTemplate.description,
|
||||
classes: [...templateClasses, newClass],
|
||||
classes: normalizeClassMaskIds([...templateClasses, newClass]),
|
||||
rules: activeTemplate.rules || [],
|
||||
});
|
||||
updateTemplateStore(updated);
|
||||
@@ -222,6 +490,62 @@ export function OntologyInspector() {
|
||||
}
|
||||
};
|
||||
|
||||
const handleReorderClass = async (sourceClassId: string, targetClassId: string) => {
|
||||
if (!activeTemplate || sourceClassId === targetClassId) {
|
||||
setDragClassId(null);
|
||||
return;
|
||||
}
|
||||
const sourceIndex = allClasses.findIndex((item) => item.id === sourceClassId);
|
||||
const targetIndex = allClasses.findIndex((item) => item.id === targetClassId);
|
||||
if (sourceIndex < 0 || targetIndex < 0) {
|
||||
setDragClassId(null);
|
||||
return;
|
||||
}
|
||||
|
||||
const reordered = [...allClasses];
|
||||
const [source] = reordered.splice(sourceIndex, 1);
|
||||
reordered.splice(targetIndex, 0, source);
|
||||
const nextClasses = normalizeClassMaskIds(
|
||||
reordered.map((item, index) => ({
|
||||
...item,
|
||||
zIndex: (reordered.length - index) * 10,
|
||||
})),
|
||||
);
|
||||
|
||||
setIsSavingClass(true);
|
||||
setClassSaveMessage('正在保存分类覆盖顺序...');
|
||||
try {
|
||||
const updated = await updateTemplate(activeTemplate.id, {
|
||||
name: activeTemplate.name,
|
||||
description: activeTemplate.description,
|
||||
classes: nextClasses,
|
||||
rules: activeTemplate.rules || [],
|
||||
});
|
||||
updateTemplateStore(updated);
|
||||
setActiveTemplateId(updated.id);
|
||||
const zIndexByClassId = new Map(nextClasses.map((item) => [item.id, item.zIndex]));
|
||||
setMasks(useStore.getState().masks.map((mask) => (
|
||||
mask.classId && zIndexByClassId.has(mask.classId)
|
||||
? {
|
||||
...mask,
|
||||
classZIndex: zIndexByClassId.get(mask.classId),
|
||||
saveStatus: mask.annotationId ? 'dirty' as const : mask.saveStatus,
|
||||
saved: mask.annotationId ? false : mask.saved,
|
||||
}
|
||||
: mask
|
||||
)));
|
||||
const nextActiveClass = nextClasses.find((item) => item.id === activeClassId);
|
||||
if (nextActiveClass) setActiveClass(nextActiveClass);
|
||||
setClassSaveMessage('分类覆盖顺序已保存');
|
||||
} catch (err) {
|
||||
console.error('Reorder class failed:', err);
|
||||
setClassSaveMessage('分类覆盖顺序保存失败');
|
||||
} finally {
|
||||
setIsSavingClass(false);
|
||||
setDragClassId(null);
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="w-60 bg-[#0d0d0d] flex flex-col border-l border-white/5 shrink-0 z-10 overflow-hidden">
|
||||
<div className="flex-1 overflow-y-auto seg-scrollbar p-4 flex flex-col gap-6">
|
||||
@@ -275,13 +599,14 @@ export function OntologyInspector() {
|
||||
{/* Semantic Classification Tree */}
|
||||
<div>
|
||||
<h3 className="text-[10px] font-bold text-gray-500 uppercase tracking-widest mb-3 flex justify-between items-center">
|
||||
<span>语义分类树 (高度/Z-Index)</span>
|
||||
<span>语义分类树(拖拽调层级)</span>
|
||||
</h3>
|
||||
<div className="space-y-2">
|
||||
{allClasses.map(cls => (
|
||||
<div key={cls.id} className="flex flex-col gap-1">
|
||||
<button
|
||||
type="button"
|
||||
draggable={Boolean(activeTemplate) && !isSavingClass}
|
||||
ref={(node) => {
|
||||
if (node) {
|
||||
classButtonRefs.current.set(cls.id, node);
|
||||
@@ -290,18 +615,36 @@ export function OntologyInspector() {
|
||||
}
|
||||
}}
|
||||
onClick={() => handleSelectClass(cls)}
|
||||
onDragStart={(event) => {
|
||||
setDragClassId(cls.id);
|
||||
event.dataTransfer.effectAllowed = 'move';
|
||||
event.dataTransfer.setData('text/plain', cls.id);
|
||||
}}
|
||||
onDragOver={(event) => {
|
||||
if (!dragClassId || dragClassId === cls.id) return;
|
||||
event.preventDefault();
|
||||
event.dataTransfer.dropEffect = 'move';
|
||||
}}
|
||||
onDrop={(event) => {
|
||||
event.preventDefault();
|
||||
const sourceId = event.dataTransfer.getData('text/plain') || dragClassId;
|
||||
if (sourceId) void handleReorderClass(sourceId, cls.id);
|
||||
}}
|
||||
onDragEnd={() => setDragClassId(null)}
|
||||
aria-current={activeClassId === cls.id ? 'true' : undefined}
|
||||
className={cn(
|
||||
'flex items-center justify-between p-2 rounded bg-white/5 hover:bg-white/10 cursor-pointer group transition-colors text-left border',
|
||||
activeClassId === cls.id ? 'border-cyan-500/50 bg-cyan-500/10' : 'border-transparent',
|
||||
dragClassId === cls.id && 'opacity-50',
|
||||
)}
|
||||
>
|
||||
<div className="flex items-center gap-2">
|
||||
<GripVertical size={13} className="text-gray-600 group-hover:text-gray-400" aria-hidden="true" />
|
||||
<span className="w-2.5 h-2.5 rounded-sm" style={{ backgroundColor: cls.color }} />
|
||||
<span className="text-xs font-medium text-gray-200">{cls.name}</span>
|
||||
</div>
|
||||
<div className="flex items-center gap-3">
|
||||
<span className="text-[10px] text-gray-500 font-mono">z:{cls.zIndex}</span>
|
||||
<span className="text-[10px] text-gray-500 font-mono">maskid:{cls.maskId}</span>
|
||||
<Eye size={14} className="text-gray-500 group-hover:text-gray-300" />
|
||||
</div>
|
||||
</button>
|
||||
@@ -366,10 +709,6 @@ export function OntologyInspector() {
|
||||
</span>
|
||||
</div>
|
||||
<div className="space-y-3">
|
||||
<div className="flex items-center justify-between">
|
||||
<span className="text-[10px] text-gray-500 uppercase">当前选中区域:</span>
|
||||
<span className="text-xs font-mono text-gray-300">{selectedMaskIds.length}</span>
|
||||
</div>
|
||||
<div className="flex items-center justify-between">
|
||||
<span className="text-[10px] text-gray-500 uppercase">后端拓扑锚点:</span>
|
||||
<span className="text-xs font-mono text-gray-300">{maskAnalysis?.topology_anchor_count ?? 0} 节点</span>
|
||||
@@ -387,28 +726,21 @@ export function OntologyInspector() {
|
||||
max={100}
|
||||
step={5}
|
||||
value={smoothingStrength}
|
||||
onChange={(event) => setSmoothingStrength(Number(event.target.value))}
|
||||
onChange={(event) => void previewSmoothing(Number(event.target.value))}
|
||||
disabled={!selectedMask || isSmoothingMask}
|
||||
className="w-full accent-cyan-500 disabled:opacity-40"
|
||||
/>
|
||||
<button
|
||||
onClick={handleApplySmoothing}
|
||||
disabled={!selectedMask || !currentFrame || isSmoothingMask}
|
||||
disabled={!selectedMask || !currentFrame || isSmoothingMask || isPreviewingSmoothing}
|
||||
className="mt-2 w-full bg-cyan-500/10 hover:bg-cyan-500/20 border border-cyan-500/20 text-xs text-cyan-100 py-1.5 rounded transition-colors disabled:opacity-40 disabled:cursor-not-allowed"
|
||||
>
|
||||
{isSmoothingMask ? '平滑中...' : '应用边缘平滑'}
|
||||
{isSmoothingMask ? '平滑中...' : isPreviewingSmoothing ? '预览中...' : '应用边缘平滑'}
|
||||
</button>
|
||||
</div>
|
||||
{analysisMessage && (
|
||||
<div className="text-[10px] leading-relaxed text-gray-500">{analysisMessage}</div>
|
||||
)}
|
||||
<button
|
||||
onClick={() => refreshMaskAnalysis(true)}
|
||||
disabled={!selectedMask || isAnalyzingMask}
|
||||
className="w-full mt-2 bg-white/5 hover:bg-white/10 border border-white/10 text-xs text-gray-300 py-1.5 rounded transition-colors disabled:opacity-40 disabled:cursor-not-allowed"
|
||||
>
|
||||
{isAnalyzingMask ? '提取中...' : '重新提取拓扑锚点'}
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
41
src/components/Sidebar.test.tsx
Normal file
41
src/components/Sidebar.test.tsx
Normal file
@@ -0,0 +1,41 @@
|
||||
import { fireEvent, render, screen } from '@testing-library/react';
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest';
|
||||
import { resetStore } from '../test/storeTestUtils';
|
||||
import { useStore } from '../store/useStore';
|
||||
import { Sidebar } from './Sidebar';
|
||||
|
||||
vi.mock('./ModelStatusBadge', () => ({
|
||||
ModelStatusBadge: () => <div>模型状态</div>,
|
||||
}));
|
||||
|
||||
describe('Sidebar', () => {
|
||||
beforeEach(() => {
|
||||
resetStore();
|
||||
});
|
||||
|
||||
it('shows admin user management only for admin users', () => {
|
||||
const setActiveModule = vi.fn();
|
||||
useStore.setState({ currentUser: { id: 1, username: 'admin', role: 'admin' } });
|
||||
|
||||
render(<Sidebar activeModule="dashboard" setActiveModule={setActiveModule} />);
|
||||
|
||||
fireEvent.click(screen.getByTitle('用户管理'));
|
||||
expect(setActiveModule).toHaveBeenCalledWith('admin');
|
||||
});
|
||||
|
||||
it('hides admin user management for non-admin users', () => {
|
||||
useStore.setState({ currentUser: { id: 2, username: 'doctor', role: 'annotator' } });
|
||||
|
||||
render(<Sidebar activeModule="dashboard" setActiveModule={vi.fn()} />);
|
||||
|
||||
expect(screen.queryByTitle('用户管理')).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('uses an explicit AI-styled icon for AI segmentation', () => {
|
||||
useStore.setState({ currentUser: { id: 2, username: 'doctor', role: 'annotator' } });
|
||||
|
||||
render(<Sidebar activeModule="dashboard" setActiveModule={vi.fn()} />);
|
||||
|
||||
expect(screen.getByTitle('AI智能分割').querySelector('[data-testid="ai-segmentation-icon"]')).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
@@ -1,8 +1,10 @@
|
||||
import React from 'react';
|
||||
import { Home, FolderOpen, Edit3, LayoutTemplate, BrainCircuit } from 'lucide-react';
|
||||
import { Home, FolderOpen, Edit3, LayoutTemplate, LogOut, UserCircle, ShieldCheck } from 'lucide-react';
|
||||
import { cn } from '../lib/utils';
|
||||
import type { ActiveModule } from '../App';
|
||||
import { ModelStatusBadge } from './ModelStatusBadge';
|
||||
import { useStore } from '../store/useStore';
|
||||
import { AiSegmentationIcon } from './AiSegmentationIcon';
|
||||
|
||||
interface SidebarProps {
|
||||
activeModule: ActiveModule;
|
||||
@@ -10,12 +12,15 @@ interface SidebarProps {
|
||||
}
|
||||
|
||||
export function Sidebar({ activeModule, setActiveModule }: SidebarProps) {
|
||||
const currentUser = useStore((state) => state.currentUser);
|
||||
const logout = useStore((state) => state.logout);
|
||||
const navItems = [
|
||||
{ id: 'dashboard', icon: Home, label: '总体概况' },
|
||||
{ id: 'projects', icon: FolderOpen, label: '项目库' },
|
||||
{ id: 'workspace', icon: Edit3, label: '分割工作区' },
|
||||
{ id: 'ai', icon: BrainCircuit, label: 'AI智能分割' },
|
||||
{ id: 'ai', icon: AiSegmentationIcon, label: 'AI智能分割' },
|
||||
{ id: 'templates', icon: LayoutTemplate, label: '模板库' },
|
||||
...(currentUser?.role === 'admin' ? [{ id: 'admin', icon: ShieldCheck, label: '用户管理' }] : []),
|
||||
] as const;
|
||||
|
||||
return (
|
||||
@@ -49,6 +54,17 @@ export function Sidebar({ activeModule, setActiveModule }: SidebarProps) {
|
||||
</nav>
|
||||
<div className="mt-auto mb-4 flex flex-col gap-4">
|
||||
<ModelStatusBadge compact />
|
||||
<button
|
||||
type="button"
|
||||
title={currentUser ? `当前用户:${currentUser.username},点击退出` : '退出登录'}
|
||||
onClick={logout}
|
||||
className="group relative flex h-9 w-9 items-center justify-center rounded-lg border border-white/10 bg-white/5 text-gray-400 transition-colors hover:border-red-400/40 hover:bg-red-500/10 hover:text-red-200"
|
||||
>
|
||||
{currentUser ? <UserCircle size={20} /> : <LogOut size={20} />}
|
||||
<span className="absolute left-full ml-2 whitespace-nowrap rounded border border-[#333] bg-[#222] px-2 py-1 text-xs text-gray-200 opacity-0 shadow-xl transition-all group-hover:opacity-100">
|
||||
{currentUser ? `${currentUser.username} / 退出` : '退出登录'}
|
||||
</span>
|
||||
</button>
|
||||
</div>
|
||||
</aside>
|
||||
);
|
||||
|
||||
@@ -39,6 +39,8 @@ describe('TemplateRegistry', () => {
|
||||
|
||||
expect(await screen.findAllByText('腹腔镜胆囊切除术')).toHaveLength(2);
|
||||
expect(screen.getByText('胆囊')).toBeInTheDocument();
|
||||
expect(screen.getAllByText(/maskid: ?1/).length).toBeGreaterThan(0);
|
||||
expect(screen.queryByText(/Z-Level/)).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('creates a template and stores it globally', async () => {
|
||||
|
||||
@@ -3,6 +3,7 @@ import { Settings, Database, Trash2, Edit3, Plus, Loader2, X, GripVertical, Impo
|
||||
import { cn } from '../lib/utils';
|
||||
import { useStore } from '../store/useStore';
|
||||
import { getTemplates, createTemplate, updateTemplate, deleteTemplate } from '../lib/api';
|
||||
import { nextClassMaskId, normalizeClassMaskIds } from '../lib/maskIds';
|
||||
import type { Template, TemplateClass } from '../store/useStore';
|
||||
import { TransientNotice, type NoticeState, type NoticeTone } from './TransientNotice';
|
||||
|
||||
@@ -86,7 +87,7 @@ export function TemplateRegistry() {
|
||||
setSelectedTemplate(template);
|
||||
setEditName(template.name);
|
||||
setEditDesc(template.description || '');
|
||||
setEditClasses(template.classes ? [...template.classes] : []);
|
||||
setEditClasses(normalizeClassMaskIds(template.classes ? [...template.classes] : []));
|
||||
setShowModal(true);
|
||||
};
|
||||
|
||||
@@ -97,7 +98,7 @@ export function TemplateRegistry() {
|
||||
const basePayload = {
|
||||
name: editName.trim(),
|
||||
description: editDesc.trim() || undefined,
|
||||
classes: editClasses,
|
||||
classes: normalizeClassMaskIds(editClasses),
|
||||
rules: [],
|
||||
color: selectedTemplate ? (selectedTemplate as any).color || '#06b6d4' : '#06b6d4',
|
||||
z_index: selectedTemplate ? (selectedTemplate as any).z_index ?? 0 : 0,
|
||||
@@ -138,6 +139,7 @@ export function TemplateRegistry() {
|
||||
name: '新类别',
|
||||
color: generateColor(editClasses.length, Math.max(editClasses.length + 1, 8)),
|
||||
zIndex: editClasses.length > 0 ? Math.max(...editClasses.map((c) => c.zIndex)) + 10 : 10,
|
||||
maskId: nextClassMaskId(editClasses),
|
||||
category: '未分类',
|
||||
};
|
||||
setEditClasses([...editClasses, newClass]);
|
||||
@@ -179,6 +181,7 @@ export function TemplateRegistry() {
|
||||
return;
|
||||
}
|
||||
|
||||
const firstMaskId = nextClassMaskId(editClasses);
|
||||
const imported: TemplateClass[] = names.map((name: string, i: number) => {
|
||||
const rgb = colors[i] || [100, 100, 100];
|
||||
const hex = `#${rgb[0].toString(16).padStart(2, '0')}${rgb[1].toString(16).padStart(2, '0')}${rgb[2].toString(16).padStart(2, '0')}`;
|
||||
@@ -187,6 +190,7 @@ export function TemplateRegistry() {
|
||||
name,
|
||||
color: hex,
|
||||
zIndex: (names.length - i) * 10,
|
||||
maskId: firstMaskId + i,
|
||||
category: '批量导入',
|
||||
};
|
||||
});
|
||||
@@ -208,6 +212,7 @@ export function TemplateRegistry() {
|
||||
name,
|
||||
color: hex,
|
||||
zIndex: (LAPAROSCOPIC_NAMES.length - i) * 10,
|
||||
maskId: i + 1,
|
||||
category: '腹腔镜胆囊切除术',
|
||||
};
|
||||
});
|
||||
@@ -308,13 +313,13 @@ export function TemplateRegistry() {
|
||||
特定领域分类渲染级重叠裁决权重阵列 (Painter's Algorithm Weight)
|
||||
</h3>
|
||||
<div className="space-y-2">
|
||||
{(activeTemplate.classes || []).sort((a, b) => b.zIndex - a.zIndex).map((cls) => (
|
||||
{normalizeClassMaskIds(activeTemplate.classes || []).sort((a, b) => b.zIndex - a.zIndex).map((cls) => (
|
||||
<div key={cls.id} className="grid grid-cols-4 gap-4 p-3 bg-[#0d0d0d] border border-white/5 rounded items-center">
|
||||
<div className="col-span-1 flex items-center gap-2">
|
||||
<div className="w-3 h-3 rounded" style={{ backgroundColor: cls.color }}></div>
|
||||
<span className="font-medium text-sm text-gray-300">{cls.name}</span>
|
||||
</div>
|
||||
<div className="col-span-1 font-mono text-xs text-gray-500">优先级 Z-Level: {cls.zIndex}</div>
|
||||
<div className="col-span-1 font-mono text-xs text-gray-500">maskid: {cls.maskId}</div>
|
||||
<div className="col-span-2 flex justify-end">
|
||||
<span className="bg-white/5 text-gray-400 text-xs px-2 py-1 rounded border border-white/10">{cls.category || '未分类'}</span>
|
||||
</div>
|
||||
@@ -445,7 +450,7 @@ export function TemplateRegistry() {
|
||||
>
|
||||
{cls.name}
|
||||
</span>
|
||||
<span className="w-16 text-sm text-gray-500 font-mono text-right">z:{cls.zIndex}</span>
|
||||
<span className="w-24 text-sm text-gray-500 font-mono text-right">maskid:{cls.maskId}</span>
|
||||
</>
|
||||
)}
|
||||
<button onClick={() => removeClass(cls.id)} className="text-gray-500 hover:text-red-400 transition-colors">
|
||||
|
||||
@@ -1,35 +1,91 @@
|
||||
import { fireEvent, render, screen } from '@testing-library/react';
|
||||
import { describe, expect, it, vi } from 'vitest';
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest';
|
||||
import { useStore } from '../store/useStore';
|
||||
import { resetStore } from '../test/storeTestUtils';
|
||||
import { ToolsPalette } from './ToolsPalette';
|
||||
|
||||
describe('ToolsPalette', () => {
|
||||
it('switches tools and dispatches undo/redo actions when available', () => {
|
||||
beforeEach(() => {
|
||||
resetStore();
|
||||
});
|
||||
|
||||
it('switches workspace editing tools without showing AI prompt or duplicate undo tools', () => {
|
||||
const setActiveTool = vi.fn();
|
||||
const onUndo = vi.fn();
|
||||
const onRedo = vi.fn();
|
||||
|
||||
render(
|
||||
<ToolsPalette
|
||||
activeTool="move"
|
||||
setActiveTool={setActiveTool}
|
||||
onUndo={onUndo}
|
||||
onRedo={onRedo}
|
||||
canUndo
|
||||
canRedo
|
||||
/>,
|
||||
);
|
||||
|
||||
fireEvent.click(screen.getByTitle('创建多边形 (P)'));
|
||||
fireEvent.click(screen.getByTitle('调整多边形 (E)'));
|
||||
fireEvent.click(screen.getByTitle('正向选点 (SAM)'));
|
||||
fireEvent.click(screen.getByTitle('撤销操作 (Ctrl+Z)'));
|
||||
fireEvent.click(screen.getByTitle('重做操作 (Ctrl+Shift+Z)'));
|
||||
fireEvent.click(screen.getByTitle('画笔 (B)'));
|
||||
fireEvent.click(screen.getByTitle('橡皮擦 (X)'));
|
||||
|
||||
expect(setActiveTool).toHaveBeenNthCalledWith(1, 'create_polygon');
|
||||
expect(setActiveTool).toHaveBeenNthCalledWith(2, 'edit_polygon');
|
||||
expect(setActiveTool).toHaveBeenNthCalledWith(3, 'point_pos');
|
||||
expect(onUndo).toHaveBeenCalled();
|
||||
expect(onRedo).toHaveBeenCalled();
|
||||
expect(setActiveTool).toHaveBeenNthCalledWith(3, 'brush');
|
||||
expect(setActiveTool).toHaveBeenNthCalledWith(4, 'eraser');
|
||||
expect(screen.queryByTitle('正向选点 (SAM)')).not.toBeInTheDocument();
|
||||
expect(screen.queryByTitle('反向选点 (SAM)')).not.toBeInTheDocument();
|
||||
expect(screen.queryByTitle('边界框选 (SAM)')).not.toBeInTheDocument();
|
||||
expect(screen.queryByTitle('撤销操作 (Ctrl+Z)')).not.toBeInTheDocument();
|
||||
expect(screen.queryByTitle('重做操作 (Ctrl+Shift+Z)')).not.toBeInTheDocument();
|
||||
expect(screen.queryByTitle('创建点 (C)')).not.toBeInTheDocument();
|
||||
expect(screen.queryByTitle('创建线段 (L)')).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('shows size controls for brush and eraser tools', () => {
|
||||
const { rerender } = render(<ToolsPalette activeTool="brush" setActiveTool={vi.fn()} />);
|
||||
const brushSize = screen.getByLabelText('画笔大小');
|
||||
fireEvent.change(brushSize, { target: { value: '36' } });
|
||||
expect(useStore.getState().brushSize).toBe(36);
|
||||
|
||||
rerender(<ToolsPalette activeTool="eraser" setActiveTool={vi.fn()} />);
|
||||
const eraserSize = screen.getByLabelText('橡皮擦大小');
|
||||
fireEvent.change(eraserSize, { target: { value: '48' } });
|
||||
expect(useStore.getState().eraserSize).toBe(48);
|
||||
});
|
||||
|
||||
it('places GT mask import after overlap removal with a distinct violet style', () => {
|
||||
const onImportGtMask = vi.fn();
|
||||
render(
|
||||
<ToolsPalette
|
||||
activeTool="move"
|
||||
setActiveTool={vi.fn()}
|
||||
onImportGtMask={onImportGtMask}
|
||||
canImportGtMask
|
||||
/>,
|
||||
);
|
||||
|
||||
const overlapButton = screen.getByTitle('重叠区域去除 (-)');
|
||||
const importButton = screen.getByTitle('导入 GT Mask');
|
||||
fireEvent.click(importButton);
|
||||
|
||||
expect(onImportGtMask).toHaveBeenCalled();
|
||||
expect(importButton).toHaveClass('bg-violet-500/10');
|
||||
expect(overlapButton.compareDocumentPosition(importButton) & Node.DOCUMENT_POSITION_FOLLOWING).toBeTruthy();
|
||||
});
|
||||
|
||||
it('separates drawing, editing, and external action tool groups', () => {
|
||||
render(<ToolsPalette activeTool="move" setActiveTool={vi.fn()} canImportGtMask />);
|
||||
|
||||
const separators = screen.getAllByTestId('tool-group-separator');
|
||||
const circleButton = screen.getByTitle('创建圆 (O)');
|
||||
const brushButton = screen.getByTitle('画笔 (B)');
|
||||
const removeButton = screen.getByTitle('重叠区域去除 (-)');
|
||||
const importButton = screen.getByTitle('导入 GT Mask');
|
||||
|
||||
expect(separators).toHaveLength(2);
|
||||
expect(circleButton.compareDocumentPosition(separators[0]) & Node.DOCUMENT_POSITION_FOLLOWING).toBeTruthy();
|
||||
expect(separators[0].compareDocumentPosition(brushButton) & Node.DOCUMENT_POSITION_FOLLOWING).toBeTruthy();
|
||||
expect(removeButton.compareDocumentPosition(separators[1]) & Node.DOCUMENT_POSITION_FOLLOWING).toBeTruthy();
|
||||
expect(separators[1].compareDocumentPosition(importButton) & Node.DOCUMENT_POSITION_FOLLOWING).toBeTruthy();
|
||||
separators.forEach((separator) => {
|
||||
expect(separator).toHaveClass('bg-white/15');
|
||||
});
|
||||
});
|
||||
|
||||
it('switches to SAM trigger and calls the AI navigation hook', () => {
|
||||
@@ -37,7 +93,9 @@ describe('ToolsPalette', () => {
|
||||
const onTriggerAI = vi.fn();
|
||||
|
||||
render(<ToolsPalette activeTool="move" setActiveTool={setActiveTool} onTriggerAI={onTriggerAI} />);
|
||||
fireEvent.click(screen.getByTitle('打开 AI 智能分割'));
|
||||
const aiButton = screen.getByTitle('打开 AI 智能分割');
|
||||
expect(aiButton.querySelector('[data-testid="ai-segmentation-icon"]')).toBeInTheDocument();
|
||||
fireEvent.click(aiButton);
|
||||
|
||||
expect(setActiveTool).toHaveBeenCalledWith('sam_trigger');
|
||||
expect(onTriggerAI).toHaveBeenCalled();
|
||||
|
||||
@@ -1,44 +1,59 @@
|
||||
import React from 'react';
|
||||
import { MousePointer2, PencilLine, Hexagon, Square, Circle, Minus, Combine, Scissors, Wand2, Undo, Redo, Crosshair, PlusCircle, MinusCircle, SquareDashed } from 'lucide-react';
|
||||
import { MousePointer2, PencilLine, Hexagon, Square, Circle, Brush, Eraser, Combine, Scissors, FileUp } from 'lucide-react';
|
||||
import { cn } from '../lib/utils';
|
||||
import { AiSegmentationIcon } from './AiSegmentationIcon';
|
||||
import { useStore } from '../store/useStore';
|
||||
|
||||
interface ToolsPaletteProps {
|
||||
activeTool: string;
|
||||
setActiveTool: (tool: string) => void;
|
||||
onTriggerAI?: () => void;
|
||||
onUndo?: () => void;
|
||||
onRedo?: () => void;
|
||||
canUndo?: boolean;
|
||||
canRedo?: boolean;
|
||||
onImportGtMask?: () => void;
|
||||
canImportGtMask?: boolean;
|
||||
isImportingGtMask?: boolean;
|
||||
}
|
||||
|
||||
export function ToolsPalette({
|
||||
activeTool,
|
||||
setActiveTool,
|
||||
onTriggerAI,
|
||||
onUndo,
|
||||
onRedo,
|
||||
canUndo = false,
|
||||
canRedo = false,
|
||||
onImportGtMask,
|
||||
canImportGtMask = false,
|
||||
isImportingGtMask = false,
|
||||
}: ToolsPaletteProps) {
|
||||
const brushSize = useStore((state) => state.brushSize);
|
||||
const eraserSize = useStore((state) => state.eraserSize);
|
||||
const setBrushSize = useStore((state) => state.setBrushSize);
|
||||
const setEraserSize = useStore((state) => state.setEraserSize);
|
||||
const sizeControl = activeTool === 'brush'
|
||||
? {
|
||||
label: '画笔大小',
|
||||
value: brushSize,
|
||||
min: 4,
|
||||
max: 96,
|
||||
onChange: setBrushSize,
|
||||
}
|
||||
: activeTool === 'eraser'
|
||||
? {
|
||||
label: '橡皮擦大小',
|
||||
value: eraserSize,
|
||||
min: 4,
|
||||
max: 128,
|
||||
onChange: setEraserSize,
|
||||
}
|
||||
: null;
|
||||
const tools = [
|
||||
{ id: 'move', icon: MousePointer2, label: '拖拽 / 选择 (V)' },
|
||||
{ id: 'edit_polygon', icon: PencilLine, label: '调整多边形 (E)' },
|
||||
{ id: 'create_polygon', icon: Hexagon, label: '创建多边形 (P)' },
|
||||
{ id: 'create_rectangle', icon: Square, label: '创建矩形 (R)' },
|
||||
{ id: 'create_circle', icon: Circle, label: '创建圆 (O)' },
|
||||
{ id: 'create_point', icon: Crosshair, label: '创建点 (C)' },
|
||||
{ id: 'create_line', icon: Minus, label: '创建线段 (L)' },
|
||||
{ id: 'brush', icon: Brush, label: '画笔 (B)' },
|
||||
{ id: 'eraser', icon: Eraser, label: '橡皮擦 (X)' },
|
||||
{ id: 'area_merge', icon: Combine, label: '区域合并 (+)' },
|
||||
{ id: 'area_remove', icon: Scissors, label: '重叠区域去除 (-)' },
|
||||
];
|
||||
|
||||
const aiTools = [
|
||||
{ id: 'point_pos', icon: PlusCircle, label: '正向选点 (SAM)', color: 'text-green-400', bg: 'bg-green-500/10', border: 'border-green-500/30' },
|
||||
{ id: 'point_neg', icon: MinusCircle, label: '反向选点 (SAM)', color: 'text-red-400', bg: 'bg-red-500/10', border: 'border-red-500/30' },
|
||||
{ id: 'box_select', icon: SquareDashed, label: '边界框选 (SAM)', color: 'text-blue-400', bg: 'bg-blue-500/10', border: 'border-blue-500/30' },
|
||||
];
|
||||
|
||||
return (
|
||||
<div className="h-full w-14 bg-[#0d0d0d] border-r border-white/5 flex flex-col items-start py-2 shrink-0 z-10 overflow-y-auto overflow-x-hidden overscroll-contain seg-scrollbar">
|
||||
<div className="flex flex-col gap-1.5 w-12 shrink-0 px-1.5">
|
||||
@@ -46,45 +61,52 @@ export function ToolsPalette({
|
||||
const Icon = tool.icon;
|
||||
const isActive = activeTool === tool.id;
|
||||
return (
|
||||
<button
|
||||
key={tool.id}
|
||||
onClick={() => setActiveTool(tool.id)}
|
||||
title={tool.label}
|
||||
className={cn(
|
||||
"w-9 h-9 rounded-md flex items-center justify-center transition-all p-1.5",
|
||||
isActive
|
||||
? (tool.id.includes('remove') ? "bg-red-500/10 text-red-500"
|
||||
: tool.id.includes('merge') ? "bg-green-500/10 text-green-500"
|
||||
: "bg-white/10 text-white")
|
||||
: "text-gray-500 hover:bg-white/5 hover:text-white"
|
||||
<React.Fragment key={tool.id}>
|
||||
<button
|
||||
onClick={() => setActiveTool(tool.id)}
|
||||
title={tool.label}
|
||||
className={cn(
|
||||
"w-9 h-9 rounded-md flex items-center justify-center transition-all p-1.5",
|
||||
isActive
|
||||
? (tool.id.includes('remove') ? "bg-red-500/10 text-red-500"
|
||||
: tool.id.includes('merge') ? "bg-green-500/10 text-green-500"
|
||||
: "bg-white/10 text-white")
|
||||
: "text-gray-500 hover:bg-white/5 hover:text-white"
|
||||
)}
|
||||
>
|
||||
<Icon size={16} strokeWidth={isActive ? 2.5 : 2} />
|
||||
</button>
|
||||
{tool.id === 'eraser' && sizeControl && (
|
||||
<div className="w-9 rounded-md border border-white/10 bg-white/[0.03] px-1 py-2 text-center">
|
||||
<label htmlFor={`${activeTool}-size`} className="sr-only">{sizeControl.label}</label>
|
||||
<input
|
||||
id={`${activeTool}-size`}
|
||||
aria-label={sizeControl.label}
|
||||
type="range"
|
||||
min={sizeControl.min}
|
||||
max={sizeControl.max}
|
||||
value={sizeControl.value}
|
||||
onChange={(event) => sizeControl.onChange(Number(event.target.value))}
|
||||
className="h-20 w-7 accent-cyan-400 [writing-mode:vertical-rl]"
|
||||
/>
|
||||
<div className="mt-1 text-[10px] leading-none text-gray-400">{sizeControl.value}</div>
|
||||
</div>
|
||||
)}
|
||||
>
|
||||
<Icon size={16} strokeWidth={isActive ? 2.5 : 2} />
|
||||
</button>
|
||||
{(tool.id === 'create_circle' || tool.id === 'area_remove') && (
|
||||
<div data-testid="tool-group-separator" className="my-1 h-px w-9 bg-white/15" />
|
||||
)}
|
||||
</React.Fragment>
|
||||
)
|
||||
})}
|
||||
|
||||
<div className="w-full h-px bg-white/10 my-0.5" />
|
||||
|
||||
{aiTools.map(tool => {
|
||||
const Icon = tool.icon;
|
||||
const isActive = activeTool === tool.id;
|
||||
return (
|
||||
<button
|
||||
key={tool.id}
|
||||
onClick={() => setActiveTool(tool.id)}
|
||||
title={tool.label}
|
||||
className={cn(
|
||||
"w-9 h-9 rounded-md flex items-center justify-center transition-all p-1.5 border",
|
||||
isActive
|
||||
? `${tool.bg} ${tool.color} ${tool.border} shadow-[0_0_10px_rgba(255,255,255,0.05)]`
|
||||
: "text-gray-500 hover:bg-white/5 hover:text-white border-transparent"
|
||||
)}
|
||||
>
|
||||
<Icon size={16} strokeWidth={isActive ? 2.5 : 2} />
|
||||
</button>
|
||||
)
|
||||
})}
|
||||
<button
|
||||
onClick={onImportGtMask}
|
||||
disabled={!canImportGtMask || isImportingGtMask}
|
||||
title={isImportingGtMask ? '正在导入 GT Mask' : '导入 GT Mask'}
|
||||
className="w-9 h-9 rounded-md flex items-center justify-center transition-all p-1.5 border border-violet-500/30 bg-violet-500/10 text-violet-200 hover:bg-violet-500/20 hover:text-white disabled:opacity-35 disabled:hover:bg-violet-500/10 disabled:hover:text-violet-200 disabled:cursor-not-allowed"
|
||||
>
|
||||
<FileUp size={16} strokeWidth={2.2} />
|
||||
</button>
|
||||
|
||||
<button
|
||||
onClick={() => {
|
||||
@@ -99,26 +121,7 @@ export function ToolsPalette({
|
||||
: "text-gray-500 hover:bg-white/5"
|
||||
)}
|
||||
>
|
||||
<Wand2 size={16} strokeWidth={2} />
|
||||
</button>
|
||||
|
||||
<div className="w-full h-px bg-white/10 my-0.5" />
|
||||
|
||||
<button
|
||||
onClick={onUndo}
|
||||
disabled={!canUndo}
|
||||
className="w-9 h-9 rounded text-gray-500 hover:bg-white/5 hover:text-white flex items-center justify-center transition-colors disabled:opacity-30 disabled:hover:bg-transparent disabled:hover:text-gray-500 disabled:cursor-not-allowed"
|
||||
title="撤销操作 (Ctrl+Z)"
|
||||
>
|
||||
<Undo size={16} />
|
||||
</button>
|
||||
<button
|
||||
onClick={onRedo}
|
||||
disabled={!canRedo}
|
||||
className="w-9 h-9 rounded text-gray-500 hover:bg-white/5 hover:text-white flex items-center justify-center transition-colors disabled:opacity-30 disabled:hover:bg-transparent disabled:hover:text-gray-500 disabled:cursor-not-allowed"
|
||||
title="重做操作 (Ctrl+Shift+Z)"
|
||||
>
|
||||
<Redo size={16} />
|
||||
<AiSegmentationIcon size={17} strokeWidth={2.1} />
|
||||
</button>
|
||||
|
||||
</div>
|
||||
|
||||
154
src/components/UserAdmin.test.tsx
Normal file
154
src/components/UserAdmin.test.tsx
Normal file
@@ -0,0 +1,154 @@
|
||||
import { fireEvent, render, screen, waitFor } from '@testing-library/react';
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest';
|
||||
import { resetStore } from '../test/storeTestUtils';
|
||||
import { useStore } from '../store/useStore';
|
||||
import { UserAdmin } from './UserAdmin';
|
||||
|
||||
const apiMock = vi.hoisted(() => ({
|
||||
getAdminUsers: vi.fn(),
|
||||
getAuditLogs: vi.fn(),
|
||||
createAdminUser: vi.fn(),
|
||||
updateAdminUser: vi.fn(),
|
||||
deleteAdminUser: vi.fn(),
|
||||
resetDemoFactory: vi.fn(),
|
||||
}));
|
||||
|
||||
vi.mock('../lib/api', () => ({
|
||||
getAdminUsers: apiMock.getAdminUsers,
|
||||
getAuditLogs: apiMock.getAuditLogs,
|
||||
createAdminUser: apiMock.createAdminUser,
|
||||
updateAdminUser: apiMock.updateAdminUser,
|
||||
deleteAdminUser: apiMock.deleteAdminUser,
|
||||
resetDemoFactory: apiMock.resetDemoFactory,
|
||||
}));
|
||||
|
||||
describe('UserAdmin', () => {
|
||||
beforeEach(() => {
|
||||
resetStore();
|
||||
vi.clearAllMocks();
|
||||
useStore.setState({ currentUser: { id: 1, username: 'admin', role: 'admin' } });
|
||||
apiMock.getAdminUsers.mockResolvedValue([
|
||||
{ id: 1, username: 'admin', role: 'admin', is_active: 1 },
|
||||
{ id: 2, username: 'doctor', role: 'annotator', is_active: 1 },
|
||||
]);
|
||||
apiMock.getAuditLogs.mockResolvedValue([
|
||||
{
|
||||
id: 1,
|
||||
actor_user_id: 1,
|
||||
action: 'admin.user_created',
|
||||
target_type: 'user',
|
||||
target_id: '2',
|
||||
detail: { username: 'doctor' },
|
||||
created_at: '2026-05-02T00:00:00Z',
|
||||
},
|
||||
]);
|
||||
});
|
||||
|
||||
it('loads users and audit logs', async () => {
|
||||
render(<UserAdmin />);
|
||||
|
||||
expect(await screen.findByText('doctor')).toBeInTheDocument();
|
||||
expect(screen.getByText('admin.user_created')).toBeInTheDocument();
|
||||
expect(screen.getByText('当前管理员:admin')).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('creates a user with role and password', async () => {
|
||||
apiMock.createAdminUser.mockResolvedValueOnce({
|
||||
id: 3,
|
||||
username: 'nurse',
|
||||
role: 'viewer',
|
||||
is_active: 1,
|
||||
});
|
||||
|
||||
render(<UserAdmin />);
|
||||
await screen.findByText('doctor');
|
||||
fireEvent.change(screen.getByPlaceholderText('用户名'), { target: { value: 'nurse' } });
|
||||
fireEvent.change(screen.getByPlaceholderText('初始密码'), { target: { value: 'secret123' } });
|
||||
fireEvent.change(screen.getAllByDisplayValue('标注员')[0], { target: { value: 'viewer' } });
|
||||
fireEvent.click(screen.getByRole('button', { name: /新增用户/ }));
|
||||
|
||||
await waitFor(() => expect(apiMock.createAdminUser).toHaveBeenCalledWith({
|
||||
username: 'nurse',
|
||||
password: 'secret123',
|
||||
role: 'viewer',
|
||||
is_active: true,
|
||||
}));
|
||||
expect(await screen.findByText('用户已创建')).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('updates role, status and password, and deletes users', async () => {
|
||||
apiMock.updateAdminUser.mockResolvedValueOnce({ id: 2, username: 'doctor', role: 'viewer', is_active: 1 });
|
||||
apiMock.updateAdminUser.mockResolvedValueOnce({ id: 2, username: 'doctor', role: 'viewer', is_active: 0 });
|
||||
apiMock.updateAdminUser.mockResolvedValueOnce({ id: 2, username: 'doctor', role: 'viewer', is_active: 0 });
|
||||
apiMock.deleteAdminUser.mockResolvedValueOnce(undefined);
|
||||
vi.spyOn(window, 'prompt').mockReturnValueOnce('nextsecret');
|
||||
vi.spyOn(window, 'confirm').mockReturnValueOnce(true);
|
||||
|
||||
render(<UserAdmin />);
|
||||
await screen.findByText('doctor');
|
||||
|
||||
const roleSelects = screen.getAllByDisplayValue('标注员');
|
||||
fireEvent.change(roleSelects[1], { target: { value: 'viewer' } });
|
||||
await waitFor(() => expect(apiMock.updateAdminUser).toHaveBeenCalledWith(2, { role: 'viewer' }));
|
||||
|
||||
fireEvent.click(screen.getAllByRole('button', { name: '启用' })[1]);
|
||||
await waitFor(() => expect(apiMock.updateAdminUser).toHaveBeenCalledWith(2, { is_active: false }));
|
||||
|
||||
fireEvent.click(screen.getAllByTitle('修改密码')[1]);
|
||||
await waitFor(() => expect(apiMock.updateAdminUser).toHaveBeenCalledWith(2, { password: 'nextsecret' }));
|
||||
|
||||
fireEvent.click(screen.getAllByTitle('删除用户')[1]);
|
||||
await waitFor(() => expect(apiMock.deleteAdminUser).toHaveBeenCalledWith(2));
|
||||
});
|
||||
|
||||
it('requires two confirmations before resetting demo factory data', async () => {
|
||||
apiMock.resetDemoFactory.mockResolvedValueOnce({
|
||||
admin_user: { id: 1, username: 'admin', role: 'admin', is_active: 1 },
|
||||
project: {
|
||||
id: '8',
|
||||
name: 'Data_MyVideo_1',
|
||||
status: 'pending',
|
||||
frames: 0,
|
||||
fps: '30FPS',
|
||||
video_path: 'uploads/8/Data_MyVideo_1.mp4',
|
||||
},
|
||||
deleted_counts: { users: 1 },
|
||||
message: '演示环境已恢复出厂设置',
|
||||
});
|
||||
apiMock.getAuditLogs.mockResolvedValueOnce([
|
||||
{
|
||||
id: 2,
|
||||
actor_user_id: 1,
|
||||
action: 'admin.demo_factory_reset',
|
||||
target_type: 'project',
|
||||
target_id: '8',
|
||||
detail: {},
|
||||
created_at: '2026-05-02T00:00:00Z',
|
||||
},
|
||||
]);
|
||||
vi.spyOn(window, 'confirm').mockReturnValueOnce(true);
|
||||
vi.spyOn(window, 'prompt').mockReturnValueOnce('RESET_DEMO_FACTORY');
|
||||
|
||||
render(<UserAdmin />);
|
||||
await screen.findByText('doctor');
|
||||
fireEvent.click(screen.getByRole('button', { name: '恢复演示出厂设置' }));
|
||||
|
||||
await waitFor(() => expect(apiMock.resetDemoFactory).toHaveBeenCalledWith('RESET_DEMO_FACTORY'));
|
||||
expect(await screen.findByText('演示环境已恢复出厂设置')).toBeInTheDocument();
|
||||
expect(useStore.getState().projects).toEqual([expect.objectContaining({ name: 'Data_MyVideo_1' })]);
|
||||
expect(useStore.getState().frames).toEqual([]);
|
||||
expect(useStore.getState().masks).toEqual([]);
|
||||
});
|
||||
|
||||
it('does not reset demo data when confirmation text does not match', async () => {
|
||||
vi.spyOn(window, 'confirm').mockReturnValueOnce(true);
|
||||
vi.spyOn(window, 'prompt').mockReturnValueOnce('wrong');
|
||||
|
||||
render(<UserAdmin />);
|
||||
await screen.findByText('doctor');
|
||||
fireEvent.click(screen.getByRole('button', { name: '恢复演示出厂设置' }));
|
||||
|
||||
expect(apiMock.resetDemoFactory).not.toHaveBeenCalled();
|
||||
expect(await screen.findByText('确认文本不匹配,未执行恢复出厂设置')).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
352
src/components/UserAdmin.tsx
Normal file
352
src/components/UserAdmin.tsx
Normal file
@@ -0,0 +1,352 @@
|
||||
import React, { useEffect, useMemo, useState } from 'react';
|
||||
import { KeyRound, Loader2, Plus, ShieldCheck, Trash2, UserCog } from 'lucide-react';
|
||||
import {
|
||||
createAdminUser,
|
||||
deleteAdminUser,
|
||||
getAdminUsers,
|
||||
getAuditLogs,
|
||||
resetDemoFactory,
|
||||
updateAdminUser,
|
||||
type AdminUser,
|
||||
type AuditLog,
|
||||
} from '../lib/api';
|
||||
import { cn } from '../lib/utils';
|
||||
import { useStore } from '../store/useStore';
|
||||
import { TransientNotice, type NoticeState, type NoticeTone } from './TransientNotice';
|
||||
|
||||
const roleLabels: Record<string, string> = {
|
||||
admin: '管理员',
|
||||
annotator: '标注员',
|
||||
viewer: '观察员',
|
||||
};
|
||||
|
||||
function formatTime(value: string): string {
|
||||
const date = new Date(value);
|
||||
if (Number.isNaN(date.getTime())) return value;
|
||||
return date.toLocaleString('zh-CN', {
|
||||
month: '2-digit',
|
||||
day: '2-digit',
|
||||
hour: '2-digit',
|
||||
minute: '2-digit',
|
||||
});
|
||||
}
|
||||
|
||||
export function UserAdmin() {
|
||||
const currentUser = useStore((state) => state.currentUser);
|
||||
const setProjects = useStore((state) => state.setProjects);
|
||||
const setCurrentProject = useStore((state) => state.setCurrentProject);
|
||||
const setFrames = useStore((state) => state.setFrames);
|
||||
const setCurrentFrame = useStore((state) => state.setCurrentFrame);
|
||||
const setMasks = useStore((state) => state.setMasks);
|
||||
const setSelectedMaskIds = useStore((state) => state.setSelectedMaskIds);
|
||||
const [users, setUsers] = useState<AdminUser[]>([]);
|
||||
const [auditLogs, setAuditLogs] = useState<AuditLog[]>([]);
|
||||
const [isLoading, setIsLoading] = useState(true);
|
||||
const [isSaving, setIsSaving] = useState(false);
|
||||
const [isResetting, setIsResetting] = useState(false);
|
||||
const [notice, setNotice] = useState<NoticeState | null>(null);
|
||||
const [newUsername, setNewUsername] = useState('');
|
||||
const [newPassword, setNewPassword] = useState('');
|
||||
const [newRole, setNewRole] = useState('annotator');
|
||||
|
||||
const activeCount = useMemo(() => users.filter((user) => user.is_active).length, [users]);
|
||||
const showNotice = (message: string, tone: NoticeTone = 'info') => {
|
||||
setNotice({ id: Date.now(), message, tone });
|
||||
};
|
||||
|
||||
const loadAdminData = async () => {
|
||||
setIsLoading(true);
|
||||
try {
|
||||
const [nextUsers, nextLogs] = await Promise.all([getAdminUsers(), getAuditLogs(100)]);
|
||||
setUsers(nextUsers);
|
||||
setAuditLogs(nextLogs);
|
||||
} catch (err) {
|
||||
console.error('Failed to load admin data:', err);
|
||||
showNotice('用户管理数据加载失败', 'error');
|
||||
} finally {
|
||||
setIsLoading(false);
|
||||
}
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
void loadAdminData();
|
||||
}, []);
|
||||
|
||||
const handleCreateUser = async (event: React.FormEvent) => {
|
||||
event.preventDefault();
|
||||
if (!newUsername.trim() || newPassword.length < 6) {
|
||||
showNotice('请输入用户名,并设置至少 6 位密码', 'error');
|
||||
return;
|
||||
}
|
||||
setIsSaving(true);
|
||||
try {
|
||||
const created = await createAdminUser({
|
||||
username: newUsername.trim(),
|
||||
password: newPassword,
|
||||
role: newRole,
|
||||
is_active: true,
|
||||
});
|
||||
setUsers((prev) => [...prev, created]);
|
||||
setNewUsername('');
|
||||
setNewPassword('');
|
||||
setNewRole('annotator');
|
||||
showNotice('用户已创建', 'success');
|
||||
setAuditLogs(await getAuditLogs(100));
|
||||
} catch (err: any) {
|
||||
showNotice(err?.response?.data?.detail || '创建用户失败', 'error');
|
||||
} finally {
|
||||
setIsSaving(false);
|
||||
}
|
||||
};
|
||||
|
||||
const handlePatchUser = async (user: AdminUser, patch: Parameters<typeof updateAdminUser>[1]) => {
|
||||
setIsSaving(true);
|
||||
try {
|
||||
const updated = await updateAdminUser(user.id, patch);
|
||||
setUsers((prev) => prev.map((item) => (item.id === user.id ? updated : item)));
|
||||
showNotice('用户已更新', 'success');
|
||||
setAuditLogs(await getAuditLogs(100));
|
||||
} catch (err: any) {
|
||||
showNotice(err?.response?.data?.detail || '更新用户失败', 'error');
|
||||
} finally {
|
||||
setIsSaving(false);
|
||||
}
|
||||
};
|
||||
|
||||
const handleChangePassword = async (user: AdminUser) => {
|
||||
const password = window.prompt(`为 ${user.username} 设置新密码(至少 6 位)`);
|
||||
if (password === null) return;
|
||||
await handlePatchUser(user, { password });
|
||||
};
|
||||
|
||||
const handleDeleteUser = async (user: AdminUser) => {
|
||||
if (!window.confirm(`确定删除用户 ${user.username} 吗?已有项目的用户建议先停用。`)) return;
|
||||
setIsSaving(true);
|
||||
try {
|
||||
await deleteAdminUser(user.id);
|
||||
setUsers((prev) => prev.filter((item) => item.id !== user.id));
|
||||
showNotice('用户已删除', 'success');
|
||||
setAuditLogs(await getAuditLogs(100));
|
||||
} catch (err: any) {
|
||||
showNotice(err?.response?.data?.detail || '删除用户失败', 'error');
|
||||
} finally {
|
||||
setIsSaving(false);
|
||||
}
|
||||
};
|
||||
|
||||
const handleFactoryReset = async () => {
|
||||
const firstConfirmed = window.confirm(
|
||||
'恢复演示出厂设置会删除除默认 admin 外的所有用户、项目帧、标注、任务和私有模板,只保留一个未生成帧的演示视频项目。确定继续吗?',
|
||||
);
|
||||
if (!firstConfirmed) return;
|
||||
const typed = window.prompt('请输入 RESET_DEMO_FACTORY 以确认恢复演示出厂设置');
|
||||
if (typed === null) return;
|
||||
if (typed !== 'RESET_DEMO_FACTORY') {
|
||||
showNotice('确认文本不匹配,未执行恢复出厂设置', 'error');
|
||||
return;
|
||||
}
|
||||
setIsResetting(true);
|
||||
try {
|
||||
const result = await resetDemoFactory(typed);
|
||||
setUsers([result.admin_user]);
|
||||
setProjects([result.project]);
|
||||
setCurrentProject(null);
|
||||
setFrames([]);
|
||||
setCurrentFrame(0);
|
||||
setMasks([]);
|
||||
setSelectedMaskIds([]);
|
||||
setAuditLogs(await getAuditLogs(100));
|
||||
showNotice(result.message || '演示环境已恢复出厂设置', 'success');
|
||||
} catch (err: any) {
|
||||
showNotice(err?.response?.data?.detail || '恢复演示出厂设置失败', 'error');
|
||||
} finally {
|
||||
setIsResetting(false);
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="flex h-full flex-col overflow-hidden bg-[#0a0a0a] text-gray-200">
|
||||
<TransientNotice notice={notice} onDismiss={() => setNotice(null)} />
|
||||
<header className="border-b border-white/10 bg-[#0d0d0d] px-6 py-4">
|
||||
<div className="flex items-center justify-between gap-4">
|
||||
<div>
|
||||
<h1 className="text-xl font-semibold text-white">用户管理后台</h1>
|
||||
<p className="mt-1 text-xs text-gray-500">账号、角色、状态和安全审计</p>
|
||||
</div>
|
||||
<div className="flex items-center gap-3 text-xs text-gray-400">
|
||||
<span className="rounded border border-cyan-400/20 bg-cyan-400/10 px-3 py-1 text-cyan-100">
|
||||
当前管理员:{currentUser?.username || 'admin'}
|
||||
</span>
|
||||
<span className="rounded border border-white/10 bg-white/5 px-3 py-1">启用用户 {activeCount}</span>
|
||||
</div>
|
||||
</div>
|
||||
</header>
|
||||
|
||||
<main className="grid min-h-0 flex-1 grid-cols-[minmax(0,1.15fr)_minmax(360px,0.85fr)] gap-4 overflow-hidden p-4">
|
||||
<section className="flex min-h-0 flex-col overflow-hidden rounded-lg border border-white/10 bg-[#111]">
|
||||
<div className="flex items-center justify-between border-b border-white/10 px-4 py-3">
|
||||
<div className="flex items-center gap-2 text-sm font-medium text-white">
|
||||
<UserCog size={18} className="text-cyan-300" />
|
||||
用户与权限
|
||||
</div>
|
||||
{isLoading && <Loader2 size={16} className="animate-spin text-cyan-300" />}
|
||||
</div>
|
||||
|
||||
<form onSubmit={handleCreateUser} className="grid grid-cols-[1fr_1fr_150px_auto] gap-2 border-b border-white/10 p-4">
|
||||
<input
|
||||
value={newUsername}
|
||||
onChange={(event) => setNewUsername(event.target.value)}
|
||||
placeholder="用户名"
|
||||
autoComplete="off"
|
||||
className="rounded border border-white/10 bg-[#181818] px-3 py-2 text-sm text-white outline-none focus:border-cyan-400/50"
|
||||
/>
|
||||
<input
|
||||
value={newPassword}
|
||||
type="password"
|
||||
onChange={(event) => setNewPassword(event.target.value)}
|
||||
placeholder="初始密码"
|
||||
autoComplete="new-password"
|
||||
className="rounded border border-white/10 bg-[#181818] px-3 py-2 text-sm text-white outline-none focus:border-cyan-400/50"
|
||||
/>
|
||||
<select
|
||||
value={newRole}
|
||||
onChange={(event) => setNewRole(event.target.value)}
|
||||
className="rounded border border-white/10 bg-[#181818] px-3 py-2 text-sm text-white outline-none focus:border-cyan-400/50"
|
||||
>
|
||||
<option value="annotator">标注员</option>
|
||||
<option value="viewer">观察员</option>
|
||||
<option value="admin">管理员</option>
|
||||
</select>
|
||||
<button
|
||||
type="submit"
|
||||
disabled={isSaving}
|
||||
className="inline-flex items-center gap-2 rounded bg-cyan-500 px-4 py-2 text-sm font-semibold text-black transition-colors hover:bg-cyan-400 disabled:opacity-50"
|
||||
>
|
||||
<Plus size={16} />
|
||||
新增用户
|
||||
</button>
|
||||
</form>
|
||||
|
||||
<div className="min-h-0 flex-1 overflow-auto">
|
||||
<table className="w-full text-left text-sm">
|
||||
<thead className="sticky top-0 bg-[#151515] text-xs uppercase text-gray-500">
|
||||
<tr>
|
||||
<th className="px-4 py-3">用户</th>
|
||||
<th className="px-4 py-3">角色</th>
|
||||
<th className="px-4 py-3">状态</th>
|
||||
<th className="px-4 py-3 text-right">操作</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody className="divide-y divide-white/5">
|
||||
{users.map((user) => (
|
||||
<tr key={user.id} className="hover:bg-white/[0.03]">
|
||||
<td className="px-4 py-3">
|
||||
<div className="font-medium text-white">{user.username}</div>
|
||||
<div className="text-xs text-gray-500">ID {user.id}</div>
|
||||
</td>
|
||||
<td className="px-4 py-3">
|
||||
<select
|
||||
value={user.role}
|
||||
onChange={(event) => void handlePatchUser(user, { role: event.target.value })}
|
||||
disabled={isSaving}
|
||||
className="rounded border border-white/10 bg-[#181818] px-2 py-1 text-xs text-cyan-100"
|
||||
>
|
||||
<option value="admin">管理员</option>
|
||||
<option value="annotator">标注员</option>
|
||||
<option value="viewer">观察员</option>
|
||||
</select>
|
||||
</td>
|
||||
<td className="px-4 py-3">
|
||||
<button
|
||||
type="button"
|
||||
onClick={() => void handlePatchUser(user, { is_active: !user.is_active })}
|
||||
disabled={isSaving}
|
||||
className={cn(
|
||||
'rounded-full border px-3 py-1 text-xs',
|
||||
user.is_active
|
||||
? 'border-emerald-400/30 bg-emerald-400/10 text-emerald-200'
|
||||
: 'border-gray-500/30 bg-gray-500/10 text-gray-300',
|
||||
)}
|
||||
>
|
||||
{user.is_active ? '启用' : '停用'}
|
||||
</button>
|
||||
</td>
|
||||
<td className="px-4 py-3">
|
||||
<div className="flex justify-end gap-2">
|
||||
<button
|
||||
type="button"
|
||||
onClick={() => void handleChangePassword(user)}
|
||||
className="rounded border border-white/10 p-2 text-gray-300 hover:border-cyan-400/40 hover:text-cyan-200"
|
||||
title="修改密码"
|
||||
>
|
||||
<KeyRound size={15} />
|
||||
</button>
|
||||
<button
|
||||
type="button"
|
||||
onClick={() => void handleDeleteUser(user)}
|
||||
disabled={user.id === currentUser?.id}
|
||||
className="rounded border border-white/10 p-2 text-gray-300 hover:border-red-400/40 hover:text-red-200 disabled:cursor-not-allowed disabled:opacity-40"
|
||||
title="删除用户"
|
||||
>
|
||||
<Trash2 size={15} />
|
||||
</button>
|
||||
</div>
|
||||
</td>
|
||||
</tr>
|
||||
))}
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
</section>
|
||||
|
||||
<section className="flex min-h-0 flex-col overflow-hidden rounded-lg border border-white/10 bg-[#111]">
|
||||
<div className="flex items-center gap-2 border-b border-white/10 px-4 py-3 text-sm font-medium text-white">
|
||||
<ShieldCheck size={18} className="text-emerald-300" />
|
||||
审计日志
|
||||
</div>
|
||||
<div className="min-h-0 flex-1 overflow-auto p-3">
|
||||
<div className="space-y-2">
|
||||
{auditLogs.map((log) => (
|
||||
<div key={log.id} className="rounded border border-white/10 bg-black/20 p-3">
|
||||
<div className="flex items-center justify-between gap-2">
|
||||
<span className="text-xs font-medium text-cyan-100">{log.action}</span>
|
||||
<span className="text-[10px] text-gray-500">{formatTime(log.created_at)}</span>
|
||||
</div>
|
||||
<div className="mt-1 text-[11px] text-gray-400">
|
||||
actor #{log.actor_user_id ?? 'system'} {'->'} {log.target_type || 'target'} #{log.target_id || '-'}
|
||||
</div>
|
||||
{log.detail && Object.keys(log.detail).length > 0 && (
|
||||
<pre className="mt-2 max-h-24 overflow-auto rounded bg-black/30 p-2 text-[10px] leading-relaxed text-gray-500">
|
||||
{JSON.stringify(log.detail, null, 2)}
|
||||
</pre>
|
||||
)}
|
||||
</div>
|
||||
))}
|
||||
{!auditLogs.length && !isLoading && (
|
||||
<div className="py-10 text-center text-sm text-gray-500">暂无审计记录</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
<div className="border-t border-red-400/20 bg-red-950/10 p-4">
|
||||
<div className="flex items-start justify-between gap-3">
|
||||
<div>
|
||||
<div className="text-sm font-semibold text-red-100">演示环境出厂设置</div>
|
||||
<p className="mt-1 text-xs leading-relaxed text-red-200/70">
|
||||
清空演示过程产生的用户、项目帧、标注、任务和私有模板,只保留默认 admin 与一个尚未生成帧的演示视频项目。
|
||||
</p>
|
||||
</div>
|
||||
<button
|
||||
type="button"
|
||||
onClick={() => void handleFactoryReset()}
|
||||
disabled={isResetting || isSaving}
|
||||
className="shrink-0 rounded border border-red-400/40 bg-red-500/15 px-3 py-2 text-xs font-semibold text-red-100 transition-colors hover:bg-red-500/25 disabled:cursor-wait disabled:opacity-50"
|
||||
>
|
||||
{isResetting ? '恢复中...' : '恢复演示出厂设置'}
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</section>
|
||||
</main>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -18,6 +18,7 @@ const apiMock = vi.hoisted(() => ({
|
||||
deleteAnnotation: vi.fn(),
|
||||
exportCoco: vi.fn(),
|
||||
exportMasks: vi.fn(),
|
||||
exportSegmentationResults: vi.fn(),
|
||||
importGtMask: vi.fn(),
|
||||
annotationToMask: vi.fn(),
|
||||
buildAnnotationPayload: vi.fn(),
|
||||
@@ -39,6 +40,7 @@ vi.mock('../lib/api', () => ({
|
||||
deleteAnnotation: apiMock.deleteAnnotation,
|
||||
exportCoco: apiMock.exportCoco,
|
||||
exportMasks: apiMock.exportMasks,
|
||||
exportSegmentationResults: apiMock.exportSegmentationResults,
|
||||
importGtMask: apiMock.importGtMask,
|
||||
annotationToMask: apiMock.annotationToMask,
|
||||
buildAnnotationPayload: apiMock.buildAnnotationPayload,
|
||||
@@ -121,11 +123,16 @@ describe('VideoWorkspace', () => {
|
||||
});
|
||||
|
||||
render(<VideoWorkspace />);
|
||||
fireEvent.click(screen.getByRole('button', { name: '撤销操作' }));
|
||||
const undoButton = screen.getByRole('button', { name: '撤销操作' });
|
||||
const redoButton = screen.getByRole('button', { name: '重做操作' });
|
||||
expect(undoButton.querySelector('svg')).toHaveClass('text-amber-300');
|
||||
expect(redoButton.querySelector('svg')).toHaveClass('text-indigo-300');
|
||||
|
||||
fireEvent.click(undoButton);
|
||||
|
||||
expect(useStore.getState().masks).toEqual([]);
|
||||
|
||||
fireEvent.click(screen.getByRole('button', { name: '重做操作' }));
|
||||
fireEvent.click(redoButton);
|
||||
expect(useStore.getState().masks).toEqual([mask]);
|
||||
|
||||
fireEvent.keyDown(window, { key: 'z', ctrlKey: true });
|
||||
@@ -147,7 +154,7 @@ describe('VideoWorkspace', () => {
|
||||
await waitFor(() => expect(useStore.getState().frames).toHaveLength(1));
|
||||
|
||||
vi.useFakeTimers();
|
||||
fireEvent.click(screen.getByRole('button', { name: '结构化归档保存' }));
|
||||
fireEvent.click(screen.getByRole('button', { name: '已全部保存' }));
|
||||
expect(screen.getByText('没有待保存标注')).toBeInTheDocument();
|
||||
|
||||
act(() => {
|
||||
@@ -155,7 +162,7 @@ describe('VideoWorkspace', () => {
|
||||
});
|
||||
|
||||
expect(screen.queryByText('没有待保存标注')).not.toBeInTheDocument();
|
||||
expect(screen.getByRole('button', { name: '结构化归档保存' })).not.toBeDisabled();
|
||||
expect(screen.getByRole('button', { name: '已全部保存' })).not.toBeDisabled();
|
||||
vi.useRealTimers();
|
||||
});
|
||||
|
||||
@@ -305,7 +312,8 @@ describe('VideoWorkspace', () => {
|
||||
});
|
||||
});
|
||||
|
||||
fireEvent.click(screen.getByRole('button', { name: '结构化归档保存' }));
|
||||
expect(screen.getByRole('button', { name: '保存 1 个改动' })).toBeInTheDocument();
|
||||
fireEvent.click(screen.getByRole('button', { name: '保存 1 个改动' }));
|
||||
|
||||
await waitFor(() => expect(apiMock.saveAnnotation).toHaveBeenCalledWith({
|
||||
project_id: 1,
|
||||
@@ -322,6 +330,7 @@ describe('VideoWorkspace', () => {
|
||||
expect.objectContaining({ id: 'annotation-5', saved: true, saveStatus: 'saved' }),
|
||||
]));
|
||||
expect(useStore.getState().masks.some((mask) => mask.id === 'mask-1')).toBe(false);
|
||||
expect(screen.getByRole('button', { name: '已全部保存' })).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('updates dirty saved masks through the archive button', async () => {
|
||||
@@ -360,7 +369,8 @@ describe('VideoWorkspace', () => {
|
||||
});
|
||||
});
|
||||
|
||||
fireEvent.click(screen.getByRole('button', { name: '结构化归档保存' }));
|
||||
expect(screen.getByRole('button', { name: '保存 1 个改动' })).toBeInTheDocument();
|
||||
fireEvent.click(screen.getByRole('button', { name: '保存 1 个改动' }));
|
||||
|
||||
await waitFor(() => expect(apiMock.updateAnnotation).toHaveBeenCalledWith('99', {
|
||||
template_id: 2,
|
||||
@@ -415,6 +425,7 @@ describe('VideoWorkspace', () => {
|
||||
});
|
||||
|
||||
it('clears masks across the selected frame range', async () => {
|
||||
const confirmSpy = vi.spyOn(window, 'confirm').mockReturnValue(true);
|
||||
apiMock.getProjectFrames.mockResolvedValueOnce([
|
||||
{ id: 10, project_id: 1, frame_index: 0, image_url: '/frame-0.jpg', width: 640, height: 360 },
|
||||
{ id: 11, project_id: 1, frame_index: 1, image_url: '/frame-1.jpg', width: 640, height: 360 },
|
||||
@@ -436,7 +447,9 @@ describe('VideoWorkspace', () => {
|
||||
});
|
||||
|
||||
fireEvent.click(screen.getByRole('button', { name: '清空片段遮罩' }));
|
||||
expect(screen.getByText('请在播放进度条或视频处理进度条上点击/拖拽选择清空起止帧,再点击“确认清空”')).toBeInTheDocument();
|
||||
expect(screen.getByText('请选择清空模式,并在播放进度条或视频处理进度条上点击/拖拽选择清空起止帧,再点击“确认清空”')).toBeInTheDocument();
|
||||
expect(screen.getByRole('button', { name: '清空全部' })).toHaveAttribute('aria-pressed', 'true');
|
||||
expect(screen.getByRole('button', { name: '保留人工/AI' })).toBeInTheDocument();
|
||||
|
||||
const processingBar = screen.getByLabelText('视频处理进度条');
|
||||
vi.spyOn(processingBar, 'getBoundingClientRect').mockReturnValue({
|
||||
@@ -458,20 +471,129 @@ describe('VideoWorkspace', () => {
|
||||
|
||||
fireEvent.click(screen.getByRole('button', { name: '确认清空' }));
|
||||
|
||||
expect(confirmSpy).toHaveBeenCalledWith(expect.stringContaining('是否清除“人工/AI标注帧”?'));
|
||||
await waitFor(() => expect(apiMock.deleteAnnotation).toHaveBeenCalledWith('99'));
|
||||
expect(apiMock.deleteAnnotation).not.toHaveBeenCalledWith('100');
|
||||
expect(useStore.getState().masks.map((mask) => mask.id)).toEqual(['annotation-100']);
|
||||
expect(useStore.getState().selectedMaskIds).not.toContain('draft-1');
|
||||
expect(screen.getByText('已清空第 1-2 帧的 2 个遮罩,其中后端标注 1 个')).toBeInTheDocument();
|
||||
confirmSpy.mockRestore();
|
||||
});
|
||||
|
||||
it('auto-saves pending masks before exporting COCO', async () => {
|
||||
it('can clear only propagated masks while preserving manual or AI annotated frames', async () => {
|
||||
const confirmSpy = vi.spyOn(window, 'confirm').mockReturnValue(true);
|
||||
apiMock.getProjectFrames.mockResolvedValueOnce([
|
||||
{ id: 10, project_id: 1, frame_index: 0, image_url: '/frame-0.jpg', width: 640, height: 360 },
|
||||
{ id: 11, project_id: 1, frame_index: 1, image_url: '/frame-1.jpg', width: 640, height: 360 },
|
||||
]);
|
||||
apiMock.deleteAnnotation.mockResolvedValue(undefined);
|
||||
|
||||
render(<VideoWorkspace />);
|
||||
await waitFor(() => expect(useStore.getState().frames).toHaveLength(2));
|
||||
act(() => {
|
||||
useStore.setState({
|
||||
masks: [
|
||||
{ id: 'manual-1', annotationId: '98', frameId: '10', pathData: 'M 0 0 Z', label: 'Manual', color: '#ef4444', saved: true, saveStatus: 'saved' },
|
||||
{
|
||||
id: 'propagated-1',
|
||||
annotationId: '99',
|
||||
frameId: '11',
|
||||
pathData: 'M 1 1 Z',
|
||||
label: 'Tracked',
|
||||
color: '#3b82f6',
|
||||
saved: true,
|
||||
saveStatus: 'saved',
|
||||
metadata: { source_annotation_id: 7, source_mask_id: 'annotation-7' },
|
||||
},
|
||||
],
|
||||
selectedMaskIds: ['manual-1', 'propagated-1'],
|
||||
});
|
||||
});
|
||||
|
||||
fireEvent.click(screen.getByRole('button', { name: '清空片段遮罩' }));
|
||||
fireEvent.click(screen.getByRole('button', { name: '保留人工/AI' }));
|
||||
expect(screen.getByRole('button', { name: '保留人工/AI' })).toHaveAttribute('aria-pressed', 'true');
|
||||
fireEvent.click(screen.getByRole('button', { name: '确认清空' }));
|
||||
|
||||
expect(confirmSpy).not.toHaveBeenCalled();
|
||||
await waitFor(() => expect(apiMock.deleteAnnotation).toHaveBeenCalledWith('99'));
|
||||
expect(apiMock.deleteAnnotation).not.toHaveBeenCalledWith('98');
|
||||
expect(useStore.getState().masks.map((mask) => mask.id)).toEqual(['manual-1']);
|
||||
expect(useStore.getState().selectedMaskIds).toEqual(['manual-1']);
|
||||
expect(screen.getByText('已清空第 1-2 帧的 1 个自动传播遮罩,其中后端标注 1 个,人工/AI 标注帧已保留')).toBeInTheDocument();
|
||||
confirmSpy.mockRestore();
|
||||
});
|
||||
|
||||
it('cancels range clearing when manual or AI annotated frames are not confirmed', async () => {
|
||||
const confirmSpy = vi.spyOn(window, 'confirm').mockReturnValue(false);
|
||||
apiMock.getProjectFrames.mockResolvedValueOnce([
|
||||
{ id: 10, project_id: 1, frame_index: 0, image_url: '/frame-0.jpg', width: 640, height: 360 },
|
||||
{ id: 11, project_id: 1, frame_index: 1, image_url: '/frame-1.jpg', width: 640, height: 360 },
|
||||
]);
|
||||
|
||||
render(<VideoWorkspace />);
|
||||
await waitFor(() => expect(useStore.getState().frames).toHaveLength(2));
|
||||
act(() => {
|
||||
useStore.setState({
|
||||
masks: [
|
||||
{ id: 'annotation-99', annotationId: '99', frameId: '10', pathData: 'M 0 0 Z', label: 'Manual', color: '#06b6d4', saved: true, saveStatus: 'saved' },
|
||||
],
|
||||
});
|
||||
});
|
||||
|
||||
fireEvent.click(screen.getByRole('button', { name: '清空片段遮罩' }));
|
||||
fireEvent.click(screen.getByRole('button', { name: '确认清空' }));
|
||||
|
||||
expect(confirmSpy).toHaveBeenCalledWith(expect.stringContaining('是否清除“人工/AI标注帧”?'));
|
||||
expect(apiMock.deleteAnnotation).not.toHaveBeenCalled();
|
||||
expect(useStore.getState().masks.map((mask) => mask.id)).toEqual(['annotation-99']);
|
||||
expect(screen.getByText('已取消清空片段遮罩')).toBeInTheDocument();
|
||||
confirmSpy.mockRestore();
|
||||
});
|
||||
|
||||
it('does not ask for manual-frame confirmation when clearing propagated-only frames', async () => {
|
||||
const confirmSpy = vi.spyOn(window, 'confirm').mockReturnValue(true);
|
||||
apiMock.getProjectFrames.mockResolvedValueOnce([
|
||||
{ id: 10, project_id: 1, frame_index: 0, image_url: '/frame-0.jpg', width: 640, height: 360 },
|
||||
{ id: 11, project_id: 1, frame_index: 1, image_url: '/frame-1.jpg', width: 640, height: 360 },
|
||||
]);
|
||||
apiMock.deleteAnnotation.mockResolvedValue(undefined);
|
||||
|
||||
render(<VideoWorkspace />);
|
||||
await waitFor(() => expect(useStore.getState().frames).toHaveLength(2));
|
||||
act(() => {
|
||||
useStore.setState({
|
||||
masks: [
|
||||
{
|
||||
id: 'annotation-99',
|
||||
annotationId: '99',
|
||||
frameId: '10',
|
||||
pathData: 'M 0 0 Z',
|
||||
label: 'Propagated',
|
||||
color: '#06b6d4',
|
||||
saved: true,
|
||||
saveStatus: 'saved',
|
||||
metadata: { source: 'sam2_propagation', propagated_from_frame_id: 1 },
|
||||
},
|
||||
],
|
||||
});
|
||||
});
|
||||
|
||||
fireEvent.click(screen.getByRole('button', { name: '清空片段遮罩' }));
|
||||
fireEvent.click(screen.getByRole('button', { name: '确认清空' }));
|
||||
|
||||
expect(confirmSpy).not.toHaveBeenCalled();
|
||||
await waitFor(() => expect(apiMock.deleteAnnotation).toHaveBeenCalledWith('99'));
|
||||
confirmSpy.mockRestore();
|
||||
});
|
||||
|
||||
it('auto-saves pending masks before exporting segmentation results', async () => {
|
||||
apiMock.getProjectFrames.mockResolvedValueOnce([
|
||||
{ id: 10, project_id: 1, frame_index: 0, image_url: '/frame.jpg', width: 640, height: 360 },
|
||||
]);
|
||||
apiMock.buildAnnotationPayload.mockReturnValueOnce({ project_id: 1, frame_id: 10, mask_data: { polygons: [] } });
|
||||
apiMock.saveAnnotation.mockResolvedValueOnce({ id: 5 });
|
||||
apiMock.exportCoco.mockResolvedValueOnce(new Blob(['{}'], { type: 'application/json' }));
|
||||
apiMock.exportSegmentationResults.mockResolvedValueOnce(new Blob(['zip'], { type: 'application/zip' }));
|
||||
|
||||
render(<VideoWorkspace />);
|
||||
await waitFor(() => expect(useStore.getState().frames).toHaveLength(1));
|
||||
@@ -488,39 +610,167 @@ describe('VideoWorkspace', () => {
|
||||
});
|
||||
});
|
||||
|
||||
fireEvent.click(screen.getByRole('button', { name: '导出 JSON 标注集' }));
|
||||
fireEvent.click(screen.getByRole('button', { name: '分割结果导出' }));
|
||||
fireEvent.change(screen.getByLabelText('Mix_label 遮罩透明度'), { target: { value: '0.45' } });
|
||||
fireEvent.click(screen.getByRole('button', { name: '开始导出' }));
|
||||
|
||||
await waitFor(() => expect(apiMock.saveAnnotation).toHaveBeenCalled());
|
||||
expect(apiMock.exportCoco).toHaveBeenCalledWith('1');
|
||||
expect(apiMock.exportSegmentationResults).toHaveBeenCalledWith('1', {
|
||||
scope: 'current',
|
||||
outputs: ['separate', 'gt_label', 'pro_label', 'mix_label'],
|
||||
mixOpacity: 0.45,
|
||||
startFrame: undefined,
|
||||
endFrame: undefined,
|
||||
frameId: '10',
|
||||
});
|
||||
});
|
||||
|
||||
it('auto-saves pending masks before exporting PNG masks', async () => {
|
||||
it('exports a selected frame range with GT label masks', async () => {
|
||||
apiMock.getProjectFrames.mockResolvedValueOnce([
|
||||
{ id: 10, project_id: 1, frame_index: 0, image_url: '/frame.jpg', width: 640, height: 360 },
|
||||
{ id: 10, project_id: 1, frame_index: 0, image_url: '/frame-0.jpg', width: 640, height: 360, timestamp_ms: 0, source_frame_number: 0 },
|
||||
{ id: 11, project_id: 1, frame_index: 1, image_url: '/frame-1.jpg', width: 640, height: 360, timestamp_ms: 500, source_frame_number: 15 },
|
||||
{ id: 12, project_id: 1, frame_index: 2, image_url: '/frame-2.jpg', width: 640, height: 360, timestamp_ms: 1000, source_frame_number: 30 },
|
||||
]);
|
||||
apiMock.buildAnnotationPayload.mockReturnValueOnce({ project_id: 1, frame_id: 10, mask_data: { polygons: [] } });
|
||||
apiMock.saveAnnotation.mockResolvedValueOnce({ id: 5 });
|
||||
apiMock.exportMasks.mockResolvedValueOnce(new Blob(['zip'], { type: 'application/zip' }));
|
||||
apiMock.exportSegmentationResults.mockResolvedValueOnce(new Blob(['zip'], { type: 'application/zip' }));
|
||||
const downloads: string[] = [];
|
||||
const clickSpy = vi.spyOn(HTMLAnchorElement.prototype, 'click').mockImplementation(function mockClick(this: HTMLAnchorElement) {
|
||||
downloads.push(this.download);
|
||||
});
|
||||
useStore.setState({ currentProject: { id: '1', name: '病例 A/1', status: 'ready', video_path: 'uploads/demo.mp4' } });
|
||||
|
||||
render(<VideoWorkspace />);
|
||||
await waitFor(() => expect(useStore.getState().frames).toHaveLength(1));
|
||||
await waitFor(() => expect(useStore.getState().frames).toHaveLength(3));
|
||||
|
||||
fireEvent.click(screen.getByRole('button', { name: '分割结果导出' }));
|
||||
fireEvent.click(screen.getByRole('button', { name: '特定范围帧' }));
|
||||
fireEvent.change(screen.getByLabelText('导出起始帧'), { target: { value: '2' } });
|
||||
fireEvent.change(screen.getByLabelText('导出结束帧'), { target: { value: '3' } });
|
||||
fireEvent.click(screen.getByRole('button', { name: '分开 Mask' }));
|
||||
fireEvent.click(screen.getByRole('button', { name: 'Pro_label 彩色' }));
|
||||
fireEvent.click(screen.getByRole('button', { name: 'Mix_label 叠加' }));
|
||||
fireEvent.click(screen.getByRole('button', { name: '开始导出' }));
|
||||
|
||||
await waitFor(() => expect(apiMock.exportSegmentationResults).toHaveBeenCalledWith('1', {
|
||||
scope: 'range',
|
||||
outputs: ['gt_label'],
|
||||
mixOpacity: 0.3,
|
||||
startFrame: 2,
|
||||
endFrame: 3,
|
||||
frameId: undefined,
|
||||
}));
|
||||
expect(downloads[0]).toBe('病例_A_1_seg_T_0h00m00s500ms-0h00m01s000ms_P_2-3.zip');
|
||||
clickSpy.mockRestore();
|
||||
});
|
||||
|
||||
it('lets the timeline range picker update selected frame export bounds', async () => {
|
||||
apiMock.getProjectFrames.mockResolvedValueOnce([
|
||||
{ id: 10, project_id: 1, frame_index: 0, image_url: '/frame-0.jpg', width: 640, height: 360 },
|
||||
{ id: 11, project_id: 1, frame_index: 1, image_url: '/frame-1.jpg', width: 640, height: 360 },
|
||||
{ id: 12, project_id: 1, frame_index: 2, image_url: '/frame-2.jpg', width: 640, height: 360 },
|
||||
{ id: 13, project_id: 1, frame_index: 3, image_url: '/frame-3.jpg', width: 640, height: 360 },
|
||||
{ id: 14, project_id: 1, frame_index: 4, image_url: '/frame-4.jpg', width: 640, height: 360 },
|
||||
]);
|
||||
apiMock.exportSegmentationResults.mockResolvedValueOnce(new Blob(['zip'], { type: 'application/zip' }));
|
||||
|
||||
render(<VideoWorkspace />);
|
||||
await waitFor(() => expect(useStore.getState().frames).toHaveLength(5));
|
||||
|
||||
fireEvent.click(screen.getByRole('button', { name: '分割结果导出' }));
|
||||
fireEvent.click(screen.getByRole('button', { name: '特定范围帧' }));
|
||||
expect(screen.getByText('请在播放进度条或视频处理进度条上点击/拖拽选择导出起止帧,也可直接修改导出范围')).toBeInTheDocument();
|
||||
|
||||
const processingBar = screen.getByLabelText('视频处理进度条');
|
||||
vi.spyOn(processingBar, 'getBoundingClientRect').mockReturnValue({
|
||||
left: 0,
|
||||
right: 100,
|
||||
top: 0,
|
||||
bottom: 10,
|
||||
width: 100,
|
||||
height: 10,
|
||||
x: 0,
|
||||
y: 0,
|
||||
toJSON: () => ({}),
|
||||
});
|
||||
fireEvent.pointerDown(processingBar, { clientX: 25, pointerId: 1 });
|
||||
fireEvent.pointerMove(processingBar, { clientX: 100, pointerId: 1 });
|
||||
fireEvent.pointerUp(processingBar, { clientX: 100, pointerId: 1 });
|
||||
|
||||
expect(screen.getByLabelText('导出起始帧')).toHaveValue(2);
|
||||
expect(screen.getByLabelText('导出结束帧')).toHaveValue(5);
|
||||
|
||||
fireEvent.click(screen.getByRole('button', { name: '开始导出' }));
|
||||
|
||||
await waitFor(() => expect(apiMock.exportSegmentationResults).toHaveBeenCalledWith('1', {
|
||||
scope: 'range',
|
||||
outputs: ['separate', 'gt_label', 'pro_label', 'mix_label'],
|
||||
mixOpacity: 0.3,
|
||||
startFrame: 2,
|
||||
endFrame: 5,
|
||||
frameId: undefined,
|
||||
}));
|
||||
});
|
||||
|
||||
it('switches from export range selection to propagation range selection without starting propagation immediately', async () => {
|
||||
apiMock.getProjectFrames.mockResolvedValueOnce([
|
||||
{ id: 10, project_id: 1, frame_index: 0, image_url: '/frame-0.jpg', width: 640, height: 360 },
|
||||
{ id: 11, project_id: 1, frame_index: 1, image_url: '/frame-1.jpg', width: 640, height: 360 },
|
||||
{ id: 12, project_id: 1, frame_index: 2, image_url: '/frame-2.jpg', width: 640, height: 360 },
|
||||
]);
|
||||
|
||||
render(<VideoWorkspace />);
|
||||
await waitFor(() => expect(useStore.getState().frames).toHaveLength(3));
|
||||
act(() => {
|
||||
useStore.setState({
|
||||
masks: [{
|
||||
id: 'mask-1',
|
||||
id: 'annotation-8',
|
||||
annotationId: '8',
|
||||
frameId: '10',
|
||||
pathData: 'M 0 0 Z',
|
||||
label: 'AI Mask',
|
||||
color: '#06b6d4',
|
||||
segmentation: [[0, 0, 10, 0, 10, 10]],
|
||||
label: '胆囊',
|
||||
color: '#ff0000',
|
||||
segmentation: [[64, 36, 192, 36, 192, 108]],
|
||||
bbox: [64, 36, 128, 72],
|
||||
saveStatus: 'saved',
|
||||
}],
|
||||
});
|
||||
});
|
||||
|
||||
fireEvent.click(screen.getByRole('button', { name: '导出 PNG Mask ZIP' }));
|
||||
fireEvent.click(screen.getByRole('button', { name: '分割结果导出' }));
|
||||
fireEvent.click(screen.getByRole('button', { name: '特定范围帧' }));
|
||||
fireEvent.click(screen.getByRole('button', { name: '自动传播' }));
|
||||
|
||||
await waitFor(() => expect(apiMock.saveAnnotation).toHaveBeenCalled());
|
||||
expect(apiMock.exportMasks).toHaveBeenCalledWith('1');
|
||||
expect(screen.getByText('请在播放进度条或视频处理进度条上点击/拖拽选择传播起止帧,再点击“开始传播”')).toBeInTheDocument();
|
||||
expect(apiMock.queuePropagationTask).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('exports only the current frame when current image scope is selected', async () => {
|
||||
apiMock.getProjectFrames.mockResolvedValueOnce([
|
||||
{ id: 10, project_id: 1, frame_index: 0, image_url: '/frame-0.jpg', width: 640, height: 360 },
|
||||
{ id: 11, project_id: 1, frame_index: 1, image_url: '/frame-1.jpg', width: 640, height: 360 },
|
||||
]);
|
||||
apiMock.exportSegmentationResults.mockResolvedValueOnce(new Blob(['zip'], { type: 'application/zip' }));
|
||||
|
||||
render(<VideoWorkspace />);
|
||||
await waitFor(() => expect(useStore.getState().frames).toHaveLength(2));
|
||||
act(() => {
|
||||
useStore.setState({ currentFrameIndex: 1 });
|
||||
});
|
||||
|
||||
fireEvent.click(screen.getByRole('button', { name: '分割结果导出' }));
|
||||
fireEvent.click(screen.getByRole('button', { name: '当前图片' }));
|
||||
fireEvent.click(screen.getByRole('button', { name: 'GT_label 黑白' }));
|
||||
fireEvent.click(screen.getByRole('button', { name: 'Pro_label 彩色' }));
|
||||
fireEvent.click(screen.getByRole('button', { name: 'Mix_label 叠加' }));
|
||||
fireEvent.click(screen.getByRole('button', { name: '开始导出' }));
|
||||
|
||||
await waitFor(() => expect(apiMock.exportSegmentationResults).toHaveBeenCalledWith('1', {
|
||||
scope: 'current',
|
||||
outputs: ['separate'],
|
||||
mixOpacity: 0.3,
|
||||
startFrame: undefined,
|
||||
endFrame: undefined,
|
||||
frameId: '11',
|
||||
}));
|
||||
});
|
||||
|
||||
it('imports a GT mask for the current frame and hydrates saved annotations', async () => {
|
||||
@@ -547,13 +797,41 @@ describe('VideoWorkspace', () => {
|
||||
const fileInput = document.querySelector('input[type="file"]') as HTMLInputElement;
|
||||
const file = new File(['mask'], 'mask.png', { type: 'image/png' });
|
||||
fireEvent.change(fileInput, { target: { files: [file] } });
|
||||
expect(screen.getByText('导入结果预览')).toBeInTheDocument();
|
||||
await waitFor(() => expect(screen.getByRole('button', { name: '导入为未定义' })).not.toBeDisabled());
|
||||
fireEvent.click(screen.getByRole('button', { name: '导入为未定义' }));
|
||||
|
||||
await waitFor(() => expect(apiMock.importGtMask).toHaveBeenCalledWith(file, '1', '10'));
|
||||
await waitFor(() => expect(apiMock.importGtMask).toHaveBeenCalledWith(file, '1', '10', null, {
|
||||
unknownColorPolicy: 'undefined',
|
||||
}));
|
||||
await waitFor(() => expect(useStore.getState().masks).toEqual([
|
||||
expect.objectContaining({ id: 'annotation-88', label: 'GT Mask' }),
|
||||
]));
|
||||
});
|
||||
|
||||
it('lets users discard unknown GT mask classes before importing', async () => {
|
||||
apiMock.getProjectFrames.mockResolvedValueOnce([
|
||||
{ id: 10, project_id: 1, frame_index: 0, image_url: '/frame.jpg', width: 640, height: 360 },
|
||||
]);
|
||||
apiMock.importGtMask.mockResolvedValueOnce([]);
|
||||
apiMock.getProjectAnnotations.mockResolvedValueOnce([]).mockResolvedValueOnce([]);
|
||||
useStore.setState({ activeTemplateId: '2' });
|
||||
|
||||
render(<VideoWorkspace />);
|
||||
await waitFor(() => expect(useStore.getState().frames).toHaveLength(1));
|
||||
|
||||
const fileInput = document.querySelector('input[type="file"]') as HTMLInputElement;
|
||||
const file = new File(['mask'], 'color-mask.png', { type: 'image/png' });
|
||||
fireEvent.change(fileInput, { target: { files: [file] } });
|
||||
expect(screen.getByText('导入结果预览')).toBeInTheDocument();
|
||||
await waitFor(() => expect(screen.getByRole('button', { name: '舍弃未知类别' })).not.toBeDisabled());
|
||||
fireEvent.click(screen.getByRole('button', { name: '舍弃未知类别' }));
|
||||
|
||||
await waitFor(() => expect(apiMock.importGtMask).toHaveBeenCalledWith(file, '1', '10', '2', {
|
||||
unknownColorPolicy: 'discard',
|
||||
}));
|
||||
});
|
||||
|
||||
it('auto-propagates reference-frame masks through the configured frame range', async () => {
|
||||
apiMock.getProjectFrames.mockResolvedValueOnce([
|
||||
{ id: 10, project_id: 1, frame_index: 0, image_url: '/frame.jpg', width: 640, height: 360 },
|
||||
@@ -823,6 +1101,80 @@ describe('VideoWorkspace', () => {
|
||||
})));
|
||||
});
|
||||
|
||||
it('removes propagation history bars when clearing the same frame range', async () => {
|
||||
apiMock.getProjectFrames.mockResolvedValueOnce([
|
||||
{ id: 10, project_id: 1, frame_index: 0, image_url: '/frame-0.jpg', width: 640, height: 360 },
|
||||
{ id: 11, project_id: 1, frame_index: 1, image_url: '/frame-1.jpg', width: 640, height: 360 },
|
||||
{ id: 12, project_id: 1, frame_index: 2, image_url: '/frame-2.jpg', width: 640, height: 360 },
|
||||
{ id: 13, project_id: 1, frame_index: 3, image_url: '/frame-3.jpg', width: 640, height: 360 },
|
||||
{ id: 14, project_id: 1, frame_index: 4, image_url: '/frame-4.jpg', width: 640, height: 360 },
|
||||
]);
|
||||
apiMock.buildAnnotationPayload.mockReturnValue({
|
||||
project_id: 1,
|
||||
frame_id: 10,
|
||||
mask_data: {
|
||||
polygons: [[[0.1, 0.1], [0.3, 0.1], [0.3, 0.3]]],
|
||||
label: '胆囊',
|
||||
color: '#ff0000',
|
||||
},
|
||||
bbox: [0.1, 0.1, 0.2, 0.2],
|
||||
});
|
||||
apiMock.deleteAnnotation.mockResolvedValue(undefined);
|
||||
|
||||
render(<VideoWorkspace />);
|
||||
await waitFor(() => expect(useStore.getState().frames).toHaveLength(5));
|
||||
act(() => {
|
||||
useStore.setState({
|
||||
masks: [{
|
||||
id: 'annotation-8',
|
||||
annotationId: '8',
|
||||
frameId: '10',
|
||||
pathData: 'M 0 0 Z',
|
||||
label: '胆囊',
|
||||
color: '#ff0000',
|
||||
segmentation: [[64, 36, 192, 36, 192, 108]],
|
||||
bbox: [64, 36, 128, 72],
|
||||
}],
|
||||
});
|
||||
});
|
||||
|
||||
fireEvent.click(screen.getByRole('button', { name: '自动传播' }));
|
||||
const processingBar = screen.getByLabelText('视频处理进度条');
|
||||
vi.spyOn(processingBar, 'getBoundingClientRect').mockReturnValue({
|
||||
left: 0,
|
||||
right: 100,
|
||||
top: 0,
|
||||
bottom: 10,
|
||||
width: 100,
|
||||
height: 10,
|
||||
x: 0,
|
||||
y: 0,
|
||||
toJSON: () => ({}),
|
||||
});
|
||||
fireEvent.pointerDown(processingBar, { clientX: 25, pointerId: 1 });
|
||||
fireEvent.pointerMove(processingBar, { clientX: 100, pointerId: 1 });
|
||||
fireEvent.pointerUp(processingBar, { clientX: 100, pointerId: 1 });
|
||||
fireEvent.click(screen.getByRole('button', { name: '开始传播' }));
|
||||
|
||||
expect(await screen.findByTestId('propagation-history-segment')).toBeInTheDocument();
|
||||
|
||||
act(() => {
|
||||
useStore.setState({
|
||||
masks: [
|
||||
{ id: 'annotation-101', annotationId: '101', frameId: '11', pathData: 'M 1 1 Z', label: 'Propagated 1', color: '#ff0000', saved: true, saveStatus: 'saved', metadata: { source: 'sam2_propagation', propagated_from_frame_id: 10 } },
|
||||
{ id: 'annotation-102', annotationId: '102', frameId: '12', pathData: 'M 2 2 Z', label: 'Propagated 2', color: '#ff0000', saved: true, saveStatus: 'saved', metadata: { source: 'sam2_propagation', propagated_from_frame_id: 10 } },
|
||||
],
|
||||
});
|
||||
});
|
||||
|
||||
fireEvent.click(screen.getByRole('button', { name: '清空片段遮罩' }));
|
||||
fireEvent.click(screen.getByRole('button', { name: '确认清空' }));
|
||||
|
||||
await waitFor(() => expect(screen.queryByTestId('propagation-history-segment')).not.toBeInTheDocument());
|
||||
expect(apiMock.deleteAnnotation).toHaveBeenCalledWith('101');
|
||||
expect(apiMock.deleteAnnotation).toHaveBeenCalledWith('102');
|
||||
});
|
||||
|
||||
it('auto-propagates all reference-frame masks in both directions inside the selected range', async () => {
|
||||
apiMock.getProjectFrames.mockResolvedValueOnce([
|
||||
{ id: 10, project_id: 1, frame_index: 0, image_url: '/frame-0.jpg', width: 640, height: 360 },
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -84,6 +84,84 @@ describe('api client contracts', () => {
|
||||
expect(axiosMock.client.delete).toHaveBeenCalledWith('/api/projects/3');
|
||||
});
|
||||
|
||||
it('normalizes missing template class maskids without using priority as the public id', async () => {
|
||||
const { getTemplates } = await import('./api');
|
||||
axiosMock.client.get.mockResolvedValueOnce({
|
||||
data: [{
|
||||
id: 2,
|
||||
name: 'Template',
|
||||
mapping_rules: {
|
||||
classes: [
|
||||
{ id: 'c1', name: 'A', color: '#ff0000', zIndex: 100 },
|
||||
{ id: 'c2', name: 'B', color: '#00ff00', zIndex: 10, maskId: 7 },
|
||||
{ id: 'c3', name: 'C', color: '#0000ff', zIndex: 50 },
|
||||
],
|
||||
rules: [],
|
||||
},
|
||||
}],
|
||||
});
|
||||
|
||||
await expect(getTemplates()).resolves.toEqual([
|
||||
expect.objectContaining({
|
||||
classes: [
|
||||
expect.objectContaining({ id: 'c1', maskId: 1, zIndex: 100 }),
|
||||
expect.objectContaining({ id: 'c2', maskId: 7, zIndex: 10 }),
|
||||
expect.objectContaining({ id: 'c3', maskId: 2, zIndex: 50 }),
|
||||
],
|
||||
}),
|
||||
]);
|
||||
});
|
||||
|
||||
it('calls admin user management and audit endpoints', async () => {
|
||||
const {
|
||||
getAdminUsers,
|
||||
createAdminUser,
|
||||
updateAdminUser,
|
||||
deleteAdminUser,
|
||||
getAuditLogs,
|
||||
resetDemoFactory,
|
||||
} = await import('./api');
|
||||
axiosMock.client.get
|
||||
.mockResolvedValueOnce({ data: [{ id: 1, username: 'admin', role: 'admin', is_active: 1 }] })
|
||||
.mockResolvedValueOnce({ data: [{ id: 9, action: 'admin.user_created', created_at: 'now' }] });
|
||||
axiosMock.client.post.mockResolvedValueOnce({ data: { id: 2, username: 'doctor', role: 'annotator', is_active: 1 } });
|
||||
axiosMock.client.patch.mockResolvedValueOnce({ data: { id: 2, username: 'doctor', role: 'viewer', is_active: 1 } });
|
||||
axiosMock.client.delete.mockResolvedValueOnce({ data: null });
|
||||
|
||||
await expect(getAdminUsers()).resolves.toEqual([expect.objectContaining({ username: 'admin' })]);
|
||||
await createAdminUser({ username: 'doctor', password: 'secret123', role: 'annotator', is_active: true });
|
||||
await updateAdminUser(2, { role: 'viewer' });
|
||||
await deleteAdminUser(2);
|
||||
await expect(getAuditLogs(50)).resolves.toEqual([expect.objectContaining({ action: 'admin.user_created' })]);
|
||||
|
||||
expect(axiosMock.client.get).toHaveBeenNthCalledWith(1, '/api/admin/users');
|
||||
expect(axiosMock.client.post).toHaveBeenCalledWith('/api/admin/users', {
|
||||
username: 'doctor',
|
||||
password: 'secret123',
|
||||
role: 'annotator',
|
||||
is_active: true,
|
||||
});
|
||||
expect(axiosMock.client.patch).toHaveBeenCalledWith('/api/admin/users/2', { role: 'viewer' });
|
||||
expect(axiosMock.client.delete).toHaveBeenCalledWith('/api/admin/users/2');
|
||||
expect(axiosMock.client.get).toHaveBeenNthCalledWith(2, '/api/admin/audit-logs', { params: { limit: 50 } });
|
||||
|
||||
axiosMock.client.post.mockResolvedValueOnce({
|
||||
data: {
|
||||
admin_user: { id: 1, username: 'admin', role: 'admin', is_active: 1 },
|
||||
project: { id: 8, name: 'Data_MyVideo_1', status: 'pending', frame_count: 0, video_path: 'uploads/8/Data_MyVideo_1.mp4' },
|
||||
deleted_counts: { users: 1 },
|
||||
message: '演示环境已恢复出厂设置',
|
||||
},
|
||||
});
|
||||
await expect(resetDemoFactory('RESET_DEMO_FACTORY')).resolves.toEqual(expect.objectContaining({
|
||||
admin_user: expect.objectContaining({ username: 'admin' }),
|
||||
project: expect.objectContaining({ id: '8', name: 'Data_MyVideo_1', frames: 0 }),
|
||||
}));
|
||||
expect(axiosMock.client.post).toHaveBeenLastCalledWith('/api/admin/demo-factory-reset', {
|
||||
confirmation: 'RESET_DEMO_FACTORY',
|
||||
});
|
||||
});
|
||||
|
||||
it('normalizes legacy project status values returned by existing databases', async () => {
|
||||
const { getProjects } = await import('./api');
|
||||
axiosMock.client.get.mockResolvedValueOnce({
|
||||
@@ -123,6 +201,33 @@ describe('api client contracts', () => {
|
||||
});
|
||||
});
|
||||
|
||||
it('exports combined segmentation results with scope, outputs, and mix opacity params', async () => {
|
||||
const { exportSegmentationResults } = await import('./api');
|
||||
const blob = new Blob(['zip'], { type: 'application/zip' });
|
||||
axiosMock.client.get.mockResolvedValueOnce({ data: blob });
|
||||
|
||||
await expect(exportSegmentationResults('9', {
|
||||
scope: 'range',
|
||||
outputs: ['gt_label', 'pro_label', 'mix_label'],
|
||||
mixOpacity: 0.45,
|
||||
startFrame: 2,
|
||||
endFrame: 5,
|
||||
frameId: '12',
|
||||
})).resolves.toBe(blob);
|
||||
expect(axiosMock.client.get).toHaveBeenCalledWith('/api/export/9/results', {
|
||||
params: {
|
||||
scope: 'range',
|
||||
mask_type: undefined,
|
||||
outputs: 'gt_label,pro_label,mix_label',
|
||||
mix_opacity: 0.45,
|
||||
start_frame: 2,
|
||||
end_frame: 5,
|
||||
frame_id: 12,
|
||||
},
|
||||
responseType: 'blob',
|
||||
});
|
||||
});
|
||||
|
||||
it('loads dashboard overview from the backend summary endpoint', async () => {
|
||||
const { getDashboardOverview } = await import('./api');
|
||||
const overview = {
|
||||
@@ -319,7 +424,7 @@ describe('api client contracts', () => {
|
||||
const saved = [{ id: 1, project_id: 9, frame_id: 5, template_id: null, mask_data: null, points: null, bbox: null }];
|
||||
axiosMock.client.post.mockResolvedValueOnce({ data: saved });
|
||||
|
||||
await expect(importGtMask(file, '9', '5', '2')).resolves.toEqual(saved);
|
||||
await expect(importGtMask(file, '9', '5', '2', { unknownColorPolicy: 'discard' })).resolves.toEqual(saved);
|
||||
expect(axiosMock.client.post).toHaveBeenCalledWith(
|
||||
'/api/ai/import-gt-mask',
|
||||
expect.any(FormData),
|
||||
@@ -330,6 +435,7 @@ describe('api client contracts', () => {
|
||||
expect(form.get('project_id')).toBe('9');
|
||||
expect(form.get('frame_id')).toBe('5');
|
||||
expect(form.get('template_id')).toBe('2');
|
||||
expect(form.get('unknown_color_policy')).toBe('discard');
|
||||
});
|
||||
|
||||
it('builds annotation payloads from frontend masks and restores saved annotations to masks', async () => {
|
||||
@@ -344,6 +450,7 @@ describe('api client contracts', () => {
|
||||
classId: 'c1',
|
||||
className: '胆囊',
|
||||
classZIndex: 20,
|
||||
classMaskId: 7,
|
||||
segmentation: [[10, 10, 90, 10, 90, 40]],
|
||||
bbox: [10, 10, 80, 30],
|
||||
metadata: { geometry_smoothing: { strength: 35, method: 'chaikin' } },
|
||||
@@ -357,7 +464,7 @@ describe('api client contracts', () => {
|
||||
polygons: [[[0.1, 0.2], [0.9, 0.2], [0.9, 0.8]]],
|
||||
label: '胆囊',
|
||||
color: '#ff0000',
|
||||
class: { id: 'c1', name: '胆囊', color: '#ff0000', zIndex: 20 },
|
||||
class: { id: 'c1', name: '胆囊', color: '#ff0000', zIndex: 20, maskId: 7 },
|
||||
geometry_smoothing: { strength: 35, method: 'chaikin' },
|
||||
},
|
||||
bbox: [0.1, 0.2, 0.8, 0.6],
|
||||
@@ -372,7 +479,7 @@ describe('api client contracts', () => {
|
||||
polygons: [[[0.1, 0.2], [0.9, 0.2], [0.9, 0.8]]],
|
||||
label: '旧标签',
|
||||
color: '#06b6d4',
|
||||
class: { id: 'c1', name: '胆囊', color: '#ff0000', zIndex: 20 },
|
||||
class: { id: 'c1', name: '胆囊', color: '#ff0000', zIndex: 20, maskId: 7 },
|
||||
source: 'sam2.1_hiera_tiny_propagation',
|
||||
propagated_from_frame_id: 4,
|
||||
geometry_smoothing: { strength: 35, method: 'chaikin' },
|
||||
@@ -389,6 +496,7 @@ describe('api client contracts', () => {
|
||||
classId: 'c1',
|
||||
className: '胆囊',
|
||||
classZIndex: 20,
|
||||
classMaskId: 7,
|
||||
label: '胆囊',
|
||||
color: '#ff0000',
|
||||
saveStatus: 'saved',
|
||||
@@ -466,6 +574,45 @@ describe('api client contracts', () => {
|
||||
}));
|
||||
});
|
||||
|
||||
it('preserves propagation metadata when saving edited geometry without persisting preview-only smoothing fields', async () => {
|
||||
const { buildAnnotationPayload } = await import('./api');
|
||||
const frame = { id: '5', projectId: '9', index: 0, url: '/frame.jpg', width: 100, height: 50 };
|
||||
|
||||
expect(buildAnnotationPayload('9', {
|
||||
id: 'm1',
|
||||
frameId: '5',
|
||||
pathData: 'M 10 10 L 90 10 L 90 40 Z',
|
||||
label: 'Tracked',
|
||||
color: '#22c55e',
|
||||
segmentation: [[10, 10, 90, 10, 90, 40]],
|
||||
metadata: {
|
||||
source: 'sam2_propagation',
|
||||
propagated_from_frame_id: 1,
|
||||
source_annotation_id: 7,
|
||||
source_mask_id: 'annotation-7',
|
||||
propagation_seed_key: 'annotation:7',
|
||||
geometry_smoothing_preview: { strength: 35, method: 'chaikin' },
|
||||
},
|
||||
}, frame)).toEqual(expect.objectContaining({
|
||||
mask_data: expect.objectContaining({
|
||||
source: 'sam2_propagation',
|
||||
propagated_from_frame_id: 1,
|
||||
source_annotation_id: 7,
|
||||
source_mask_id: 'annotation-7',
|
||||
propagation_seed_key: 'annotation:7',
|
||||
}),
|
||||
}));
|
||||
expect(buildAnnotationPayload('9', {
|
||||
id: 'm1',
|
||||
frameId: '5',
|
||||
pathData: 'M 10 10 L 90 10 L 90 40 Z',
|
||||
label: 'Tracked',
|
||||
color: '#22c55e',
|
||||
segmentation: [[10, 10, 90, 10, 90, 40]],
|
||||
metadata: { geometry_smoothing_preview: { strength: 35, method: 'chaikin' } },
|
||||
}, frame)?.mask_data).not.toHaveProperty('geometry_smoothing_preview');
|
||||
});
|
||||
|
||||
it('normalizes positive and negative point prompts for AI prediction', async () => {
|
||||
const { predictMask } = await import('./api');
|
||||
axiosMock.client.post.mockResolvedValueOnce({
|
||||
|
||||
134
src/lib/api.ts
134
src/lib/api.ts
@@ -1,6 +1,7 @@
|
||||
import axios, { AxiosError } from 'axios';
|
||||
import { DEFAULT_AI_MODEL_ID, type AiModelId, type Frame, type Mask, type Project, type Template } from '../store/useStore';
|
||||
import { DEFAULT_AI_MODEL_ID, type AiModelId, type Frame, type Mask, type Project, type Template, type UserProfile } from '../store/useStore';
|
||||
import { API_BASE_URL } from './config';
|
||||
import { normalizeClassMaskIds } from './maskIds';
|
||||
|
||||
const apiClient = axios.create({
|
||||
baseURL: API_BASE_URL,
|
||||
@@ -28,18 +29,88 @@ apiClient.interceptors.response.use(
|
||||
(error: AxiosError) => {
|
||||
if (error.response?.status === 401) {
|
||||
localStorage.removeItem('token');
|
||||
window.location.reload();
|
||||
if (!error.config?.url?.includes('/api/auth/login')) {
|
||||
window.location.reload();
|
||||
}
|
||||
}
|
||||
return Promise.reject(error);
|
||||
}
|
||||
);
|
||||
|
||||
// Auth
|
||||
export async function login(username: string, password: string): Promise<{ token: string }> {
|
||||
export async function login(username: string, password: string): Promise<{ token: string; username: string; user: UserProfile }> {
|
||||
const response = await apiClient.post('/api/auth/login', { username, password });
|
||||
return response.data;
|
||||
}
|
||||
|
||||
export async function getCurrentUser(): Promise<UserProfile> {
|
||||
const response = await apiClient.get('/api/auth/me');
|
||||
return response.data;
|
||||
}
|
||||
|
||||
export interface AdminUser extends UserProfile {
|
||||
is_active: number;
|
||||
}
|
||||
|
||||
export interface AuditLog {
|
||||
id: number;
|
||||
actor_user_id?: number | null;
|
||||
action: string;
|
||||
target_type?: string | null;
|
||||
target_id?: string | null;
|
||||
detail?: Record<string, any> | null;
|
||||
created_at: string;
|
||||
}
|
||||
|
||||
export interface DemoFactoryResetResult {
|
||||
admin_user: AdminUser;
|
||||
project: Project;
|
||||
deleted_counts: Record<string, number>;
|
||||
message: string;
|
||||
}
|
||||
|
||||
export async function getAdminUsers(): Promise<AdminUser[]> {
|
||||
const response = await apiClient.get('/api/admin/users');
|
||||
return response.data;
|
||||
}
|
||||
|
||||
export async function createAdminUser(payload: {
|
||||
username: string;
|
||||
password: string;
|
||||
role: string;
|
||||
is_active: boolean;
|
||||
}): Promise<AdminUser> {
|
||||
const response = await apiClient.post('/api/admin/users', payload);
|
||||
return response.data;
|
||||
}
|
||||
|
||||
export async function updateAdminUser(id: number, payload: {
|
||||
username?: string;
|
||||
password?: string;
|
||||
role?: string;
|
||||
is_active?: boolean;
|
||||
}): Promise<AdminUser> {
|
||||
const response = await apiClient.patch(`/api/admin/users/${id}`, payload);
|
||||
return response.data;
|
||||
}
|
||||
|
||||
export async function deleteAdminUser(id: number): Promise<void> {
|
||||
await apiClient.delete(`/api/admin/users/${id}`);
|
||||
}
|
||||
|
||||
export async function getAuditLogs(limit = 100): Promise<AuditLog[]> {
|
||||
const response = await apiClient.get('/api/admin/audit-logs', { params: { limit } });
|
||||
return response.data;
|
||||
}
|
||||
|
||||
export async function resetDemoFactory(confirmation: string): Promise<DemoFactoryResetResult> {
|
||||
const response = await apiClient.post('/api/admin/demo-factory-reset', { confirmation });
|
||||
return {
|
||||
...response.data,
|
||||
project: mapProject(response.data.project),
|
||||
};
|
||||
}
|
||||
|
||||
// Projects
|
||||
function normalizeProjectStatus(status?: string): Project['status'] {
|
||||
const value = (status || 'pending').toLowerCase();
|
||||
@@ -103,7 +174,7 @@ function _mapTemplate(t: any): Template {
|
||||
id: String(t.id),
|
||||
name: t.name,
|
||||
description: t.description,
|
||||
classes: mapping.classes || [],
|
||||
classes: normalizeClassMaskIds(mapping.classes || []),
|
||||
rules: mapping.rules || [],
|
||||
createdAt: t.created_at,
|
||||
updatedAt: t.updated_at,
|
||||
@@ -120,7 +191,7 @@ export async function createTemplate(payload: {
|
||||
description?: string;
|
||||
color: string;
|
||||
z_index: number;
|
||||
classes?: { name: string; color: string; zIndex: number; category?: string }[];
|
||||
classes?: { name: string; color: string; zIndex: number; maskId?: number; category?: string }[];
|
||||
rules?: any[];
|
||||
}): Promise<Template> {
|
||||
const response = await apiClient.post('/api/templates', payload);
|
||||
@@ -298,6 +369,7 @@ export interface SavedAnnotation {
|
||||
name?: string;
|
||||
color?: string;
|
||||
zIndex?: number;
|
||||
maskId?: number;
|
||||
category?: string;
|
||||
};
|
||||
source?: string;
|
||||
@@ -326,6 +398,7 @@ export interface SaveAnnotationPayload {
|
||||
name?: string;
|
||||
color?: string;
|
||||
zIndex?: number;
|
||||
maskId?: number;
|
||||
category?: string;
|
||||
};
|
||||
geometry_smoothing?: GeometrySmoothingOptions;
|
||||
@@ -505,6 +578,16 @@ function normalizeGeometrySmoothing(value: unknown): GeometrySmoothingOptions |
|
||||
};
|
||||
}
|
||||
|
||||
function persistableMaskMetadata(metadata?: Record<string, unknown>): Record<string, unknown> {
|
||||
if (!metadata) return {};
|
||||
const {
|
||||
geometry_smoothing: _geometrySmoothing,
|
||||
geometry_smoothing_preview: _geometrySmoothingPreview,
|
||||
...rest
|
||||
} = metadata;
|
||||
return rest;
|
||||
}
|
||||
|
||||
function pixelSegmentationToNormalizedPolygons(
|
||||
segmentation: number[][] | undefined,
|
||||
width: number,
|
||||
@@ -534,21 +617,24 @@ export function buildAnnotationPayload(
|
||||
const polygons = pixelSegmentationToNormalizedPolygons(mask.segmentation, frame.width, frame.height);
|
||||
if (polygons.length === 0) return null;
|
||||
const effectiveTemplateId = mask.templateId || templateId || undefined;
|
||||
const classMetadata = mask.classId || mask.className || mask.classZIndex !== undefined
|
||||
const classMetadata = mask.classId || mask.className || mask.classZIndex !== undefined || mask.classMaskId !== undefined
|
||||
? {
|
||||
id: mask.classId,
|
||||
name: mask.className || mask.label,
|
||||
color: mask.color,
|
||||
zIndex: mask.classZIndex,
|
||||
maskId: mask.classMaskId,
|
||||
}
|
||||
: undefined;
|
||||
const geometrySmoothing = normalizeGeometrySmoothing(mask.metadata?.geometry_smoothing);
|
||||
const metadata = persistableMaskMetadata(mask.metadata);
|
||||
|
||||
const payload: SaveAnnotationPayload = {
|
||||
project_id: Number(projectId),
|
||||
frame_id: Number(frame.id),
|
||||
template_id: effectiveTemplateId ? Number(effectiveTemplateId) : undefined,
|
||||
mask_data: {
|
||||
...metadata,
|
||||
polygons,
|
||||
label: mask.label,
|
||||
color: mask.color,
|
||||
@@ -591,6 +677,7 @@ export function annotationToMask(annotation: SavedAnnotation, frame: Frame): Mas
|
||||
classId: classMetadata?.id,
|
||||
className: classMetadata?.name,
|
||||
classZIndex: classMetadata?.zIndex,
|
||||
classMaskId: classMetadata?.maskId,
|
||||
saveStatus: 'saved',
|
||||
saved: true,
|
||||
pathData: polygonToPath(firstPolygon, frame.width, frame.height),
|
||||
@@ -785,12 +872,14 @@ export async function importGtMask(
|
||||
projectId: string,
|
||||
frameId: string,
|
||||
templateId?: string | null,
|
||||
options: { unknownColorPolicy?: 'discard' | 'undefined' } = {},
|
||||
): Promise<SavedAnnotation[]> {
|
||||
const formData = new FormData();
|
||||
formData.append('file', file);
|
||||
formData.append('project_id', projectId);
|
||||
formData.append('frame_id', frameId);
|
||||
if (templateId) formData.append('template_id', templateId);
|
||||
if (options.unknownColorPolicy) formData.append('unknown_color_policy', options.unknownColorPolicy);
|
||||
const response = await apiClient.post('/api/ai/import-gt-mask', formData, {
|
||||
headers: { 'Content-Type': 'multipart/form-data' },
|
||||
});
|
||||
@@ -817,4 +906,37 @@ export async function exportMasks(projectId: string): Promise<Blob> {
|
||||
return response.data;
|
||||
}
|
||||
|
||||
export type SegmentationExportScope = 'all' | 'range' | 'current';
|
||||
export type SegmentationMaskType = 'separate' | 'gt_label' | 'both';
|
||||
export type SegmentationExportOutput = 'separate' | 'gt_label' | 'pro_label' | 'mix_label';
|
||||
|
||||
export interface ExportSegmentationResultsOptions {
|
||||
scope: SegmentationExportScope;
|
||||
maskType?: SegmentationMaskType;
|
||||
outputs?: SegmentationExportOutput[];
|
||||
mixOpacity?: number;
|
||||
startFrame?: number;
|
||||
endFrame?: number;
|
||||
frameId?: string;
|
||||
}
|
||||
|
||||
export async function exportSegmentationResults(
|
||||
projectId: string,
|
||||
options: ExportSegmentationResultsOptions,
|
||||
): Promise<Blob> {
|
||||
const response = await apiClient.get(`/api/export/${projectId}/results`, {
|
||||
params: {
|
||||
scope: options.scope,
|
||||
mask_type: options.maskType,
|
||||
outputs: options.outputs?.join(','),
|
||||
mix_opacity: options.mixOpacity,
|
||||
start_frame: options.startFrame,
|
||||
end_frame: options.endFrame,
|
||||
frame_id: options.frameId ? Number(options.frameId) : undefined,
|
||||
},
|
||||
responseType: 'blob',
|
||||
});
|
||||
return response.data;
|
||||
}
|
||||
|
||||
export default apiClient;
|
||||
|
||||
34
src/lib/maskIds.ts
Normal file
34
src/lib/maskIds.ts
Normal file
@@ -0,0 +1,34 @@
|
||||
import type { TemplateClass } from '../store/useStore';
|
||||
|
||||
export function normalizeClassMaskIds(classes: TemplateClass[] = []): TemplateClass[] {
|
||||
const used = new Set<number>();
|
||||
let nextMaskId = 1;
|
||||
|
||||
const nextAvailableMaskId = () => {
|
||||
while (used.has(nextMaskId)) nextMaskId += 1;
|
||||
const value = nextMaskId;
|
||||
used.add(value);
|
||||
nextMaskId += 1;
|
||||
return value;
|
||||
};
|
||||
|
||||
return classes.map((templateClass) => {
|
||||
const parsed = Number(templateClass.maskId);
|
||||
if (Number.isInteger(parsed) && parsed > 0 && !used.has(parsed)) {
|
||||
used.add(parsed);
|
||||
return { ...templateClass, maskId: parsed };
|
||||
}
|
||||
return { ...templateClass, maskId: nextAvailableMaskId() };
|
||||
});
|
||||
}
|
||||
|
||||
export function nextClassMaskId(classes: TemplateClass[] = []): number {
|
||||
const used = new Set(
|
||||
classes
|
||||
.map((templateClass) => Number(templateClass.maskId))
|
||||
.filter((value) => Number.isInteger(value) && value > 0),
|
||||
);
|
||||
let value = 1;
|
||||
while (used.has(value)) value += 1;
|
||||
return value;
|
||||
}
|
||||
@@ -94,4 +94,34 @@ describe('progress websocket client', () => {
|
||||
unsubscribeStatus();
|
||||
progressWS.disconnect();
|
||||
});
|
||||
|
||||
it('does not reconnect after an intentional disconnect', async () => {
|
||||
vi.useFakeTimers();
|
||||
const instances: any[] = [];
|
||||
class FakeWebSocket {
|
||||
static CONNECTING = 0;
|
||||
static OPEN = 1;
|
||||
readyState = FakeWebSocket.OPEN;
|
||||
onopen?: () => void;
|
||||
onmessage?: (event: MessageEvent) => void;
|
||||
onclose?: () => void;
|
||||
onerror?: () => void;
|
||||
constructor(public url: string) {
|
||||
instances.push(this);
|
||||
}
|
||||
close = vi.fn();
|
||||
send = vi.fn();
|
||||
}
|
||||
vi.stubGlobal('WebSocket', FakeWebSocket);
|
||||
|
||||
const { progressWS } = await import('./websocket');
|
||||
progressWS.connect();
|
||||
instances[0].onopen?.();
|
||||
progressWS.disconnect();
|
||||
instances[0].onclose?.();
|
||||
|
||||
vi.advanceTimersByTime(30000);
|
||||
|
||||
expect(instances).toHaveLength(1);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -30,6 +30,7 @@ class ProgressWebSocket {
|
||||
private heartbeatInterval = 15000;
|
||||
private shouldReconnect = false;
|
||||
private shouldCloseAfterOpen = false;
|
||||
private manualDisconnect = false;
|
||||
private currentInterval = 3000;
|
||||
|
||||
constructor(url = WS_PROGRESS_URL) {
|
||||
@@ -43,6 +44,7 @@ class ProgressWebSocket {
|
||||
|
||||
this.shouldReconnect = true;
|
||||
this.shouldCloseAfterOpen = false;
|
||||
this.manualDisconnect = false;
|
||||
this.notifyStatus('connecting');
|
||||
|
||||
try {
|
||||
@@ -71,7 +73,9 @@ class ProgressWebSocket {
|
||||
};
|
||||
|
||||
this.ws.onclose = () => {
|
||||
console.log('[WebSocket] Connection closed');
|
||||
if (!this.manualDisconnect) {
|
||||
console.log('[WebSocket] Connection closed');
|
||||
}
|
||||
this.stopHeartbeat();
|
||||
this.ws = null;
|
||||
this.notifyStatus('disconnected');
|
||||
@@ -97,6 +101,7 @@ class ProgressWebSocket {
|
||||
|
||||
disconnect() {
|
||||
this.shouldReconnect = false;
|
||||
this.manualDisconnect = true;
|
||||
this.stopHeartbeat();
|
||||
if (this.reconnectTimer) {
|
||||
clearTimeout(this.reconnectTimer);
|
||||
|
||||
@@ -8,15 +8,17 @@ describe('useStore', () => {
|
||||
});
|
||||
|
||||
it('stores and clears auth state with localStorage', () => {
|
||||
useStore.getState().login('token-1');
|
||||
useStore.getState().login('token-1', { id: 1, username: 'admin', role: 'admin' });
|
||||
|
||||
expect(useStore.getState().isAuthenticated).toBe(true);
|
||||
expect(useStore.getState().token).toBe('token-1');
|
||||
expect(useStore.getState().currentUser?.username).toBe('admin');
|
||||
expect(localStorage.getItem('token')).toBe('token-1');
|
||||
|
||||
useStore.getState().logout();
|
||||
|
||||
expect(useStore.getState().isAuthenticated).toBe(false);
|
||||
expect(useStore.getState().currentUser).toBeNull();
|
||||
expect(useStore.getState().projects).toEqual([]);
|
||||
expect(useStore.getState().frames).toEqual([]);
|
||||
expect(localStorage.getItem('token')).toBeNull();
|
||||
@@ -32,6 +34,8 @@ describe('useStore', () => {
|
||||
useStore.getState().addMask({ id: 'm1', frameId: 'f1', pathData: 'M 0 0 Z', label: 'mask', color: '#fff' });
|
||||
useStore.getState().setSelectedMaskIds(['m1']);
|
||||
useStore.getState().setMaskPreviewOpacity(35);
|
||||
useStore.getState().setBrushSize(36);
|
||||
useStore.getState().setEraserSize(44);
|
||||
useStore.getState().updateMask('m1', { label: 'updated mask', saveStatus: 'dirty' });
|
||||
useStore.getState().addAnnotation({ id: 'a1', frameId: 'f1', type: 'mask', points: [], label: 'ann', color: '#fff' });
|
||||
useStore.getState().addTemplate({ id: 't1', name: 'Template', classes: [], rules: [] });
|
||||
@@ -44,6 +48,8 @@ describe('useStore', () => {
|
||||
expect(useStore.getState().currentFrameIndex).toBe(0);
|
||||
expect(useStore.getState().selectedMaskIds).toEqual(['m1']);
|
||||
expect(useStore.getState().maskPreviewOpacity).toBe(35);
|
||||
expect(useStore.getState().brushSize).toBe(36);
|
||||
expect(useStore.getState().eraserSize).toBe(44);
|
||||
expect(useStore.getState().masks[0]).toEqual(expect.objectContaining({ label: 'updated mask', saveStatus: 'dirty' }));
|
||||
expect(useStore.getState().annotations).toHaveLength(1);
|
||||
expect(useStore.getState().templates[0].name).toBe('Template 2');
|
||||
|
||||
@@ -24,6 +24,8 @@ export type AiModelId =
|
||||
| 'sam2.1_hiera_large';
|
||||
|
||||
export const DEFAULT_AI_MODEL_ID: AiModelId = 'sam2.1_hiera_tiny';
|
||||
export const DEFAULT_BRUSH_SIZE = 24;
|
||||
export const DEFAULT_ERASER_SIZE = 28;
|
||||
|
||||
export const SAM2_MODEL_OPTIONS: Array<{ id: AiModelId; label: string; shortLabel: string }> = [
|
||||
{ id: 'sam2.1_hiera_tiny', label: 'SAM 2.1 Tiny', shortLabel: 'tiny' },
|
||||
@@ -64,6 +66,7 @@ export interface Mask {
|
||||
classId?: string;
|
||||
className?: string;
|
||||
classZIndex?: number;
|
||||
classMaskId?: number;
|
||||
saveStatus?: 'draft' | 'saved' | 'dirty' | 'saving' | 'error';
|
||||
saved?: boolean;
|
||||
pathData: string;
|
||||
@@ -92,6 +95,7 @@ export interface TemplateClass {
|
||||
name: string;
|
||||
color: string;
|
||||
zIndex: number;
|
||||
maskId?: number;
|
||||
category?: string;
|
||||
description?: string;
|
||||
}
|
||||
@@ -104,11 +108,20 @@ export interface TemplateRule {
|
||||
operation: string;
|
||||
}
|
||||
|
||||
export interface UserProfile {
|
||||
id: number;
|
||||
username: string;
|
||||
role: string;
|
||||
is_active?: number;
|
||||
}
|
||||
|
||||
export interface AppState {
|
||||
// Auth
|
||||
isAuthenticated: boolean;
|
||||
token: string | null;
|
||||
login: (token: string) => void;
|
||||
currentUser: UserProfile | null;
|
||||
login: (token: string, user?: UserProfile | null) => void;
|
||||
setCurrentUser: (user: UserProfile | null) => void;
|
||||
logout: () => void;
|
||||
|
||||
// Projects
|
||||
@@ -129,6 +142,8 @@ export interface AppState {
|
||||
masks: Mask[];
|
||||
selectedMaskIds: string[];
|
||||
maskPreviewOpacity: number;
|
||||
brushSize: number;
|
||||
eraserSize: number;
|
||||
maskHistory: Mask[][];
|
||||
maskFuture: Mask[][];
|
||||
setActiveModule: (module: string) => void;
|
||||
@@ -142,6 +157,8 @@ export interface AppState {
|
||||
setMasks: (masks: Mask[]) => void;
|
||||
setSelectedMaskIds: (ids: string[]) => void;
|
||||
setMaskPreviewOpacity: (opacity: number) => void;
|
||||
setBrushSize: (size: number) => void;
|
||||
setEraserSize: (size: number) => void;
|
||||
clearMasks: () => void;
|
||||
undoMasks: () => void;
|
||||
redoMasks: () => void;
|
||||
@@ -169,17 +186,20 @@ export interface AppState {
|
||||
|
||||
export const useStore = create<AppState>((set) => ({
|
||||
// Auth
|
||||
isAuthenticated: false,
|
||||
token: null,
|
||||
login: (token: string) => {
|
||||
isAuthenticated: Boolean(localStorage.getItem('token')),
|
||||
token: localStorage.getItem('token'),
|
||||
currentUser: null,
|
||||
login: (token: string, user: UserProfile | null = null) => {
|
||||
localStorage.setItem('token', token);
|
||||
set({ isAuthenticated: true, token });
|
||||
set({ isAuthenticated: true, token, currentUser: user });
|
||||
},
|
||||
setCurrentUser: (currentUser: UserProfile | null) => set({ currentUser }),
|
||||
logout: () => {
|
||||
localStorage.removeItem('token');
|
||||
set({
|
||||
isAuthenticated: false,
|
||||
token: null,
|
||||
currentUser: null,
|
||||
currentProject: null,
|
||||
projects: [],
|
||||
templates: [],
|
||||
@@ -188,6 +208,8 @@ export const useStore = create<AppState>((set) => ({
|
||||
masks: [],
|
||||
selectedMaskIds: [],
|
||||
maskPreviewOpacity: 50,
|
||||
brushSize: DEFAULT_BRUSH_SIZE,
|
||||
eraserSize: DEFAULT_ERASER_SIZE,
|
||||
maskHistory: [],
|
||||
maskFuture: [],
|
||||
activeTemplateId: null,
|
||||
@@ -218,6 +240,8 @@ export const useStore = create<AppState>((set) => ({
|
||||
masks: [],
|
||||
selectedMaskIds: [],
|
||||
maskPreviewOpacity: 50,
|
||||
brushSize: DEFAULT_BRUSH_SIZE,
|
||||
eraserSize: DEFAULT_ERASER_SIZE,
|
||||
maskHistory: [],
|
||||
maskFuture: [],
|
||||
setActiveModule: (activeModule: string) => set({ activeModule }),
|
||||
@@ -254,6 +278,12 @@ export const useStore = create<AppState>((set) => ({
|
||||
setMaskPreviewOpacity: (maskPreviewOpacity: number) => set({
|
||||
maskPreviewOpacity: Math.min(Math.max(maskPreviewOpacity, 10), 100),
|
||||
}),
|
||||
setBrushSize: (brushSize: number) => set({
|
||||
brushSize: Math.round(Math.min(Math.max(brushSize, 4), 96)),
|
||||
}),
|
||||
setEraserSize: (eraserSize: number) => set({
|
||||
eraserSize: Math.round(Math.min(Math.max(eraserSize, 4), 128)),
|
||||
}),
|
||||
clearMasks: () =>
|
||||
set((state) => ({
|
||||
masks: [],
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
import { DEFAULT_AI_MODEL_ID, useStore } from '../store/useStore';
|
||||
import { DEFAULT_AI_MODEL_ID, DEFAULT_BRUSH_SIZE, DEFAULT_ERASER_SIZE, useStore } from '../store/useStore';
|
||||
|
||||
export function resetStore() {
|
||||
useStore.setState({
|
||||
isAuthenticated: false,
|
||||
token: null,
|
||||
currentUser: null,
|
||||
projects: [],
|
||||
currentProject: null,
|
||||
activeModule: 'workspace',
|
||||
@@ -15,6 +16,8 @@ export function resetStore() {
|
||||
masks: [],
|
||||
selectedMaskIds: [],
|
||||
maskPreviewOpacity: 50,
|
||||
brushSize: DEFAULT_BRUSH_SIZE,
|
||||
eraserSize: DEFAULT_ERASER_SIZE,
|
||||
maskHistory: [],
|
||||
maskFuture: [],
|
||||
templates: [],
|
||||
|
||||
Reference in New Issue
Block a user