feat: 打通全栈标注闭环、异步拆帧与模型状态

后端能力:

- 新增 Celery app、worker task、ProcessingTask 模型、/api/tasks 查询接口和 media_task_runner,将 /api/media/parse 改为创建后台任务并由 worker 执行 FFmpeg/OpenCV/pydicom 拆帧。

- 新增 Redis 进度事件模块和 FastAPI Redis pub/sub 订阅,将 worker 任务进度广播到 /ws/progress;Dashboard 后端概览接口改为聚合 projects/frames/annotations/templates/processing_tasks。

- 统一项目状态为 pending/parsing/ready/error,新增共享 status 常量,并让前端兼容归一化旧状态值。

- 扩展 AI 后端:新增 SAM registry、SAM2 真实运行状态、SAM3 状态检测与文本语义推理适配入口,以及 /api/ai/models/status GPU/模型状态接口。

- 补齐标注保存/更新/删除、COCO/PNG mask 导出相关后端契约和模板 mapping_rules 打包/解包行为。

前端能力:

- 新增运行时 API/WS 地址推导配置,前端 API 封装对齐 FastAPI 路由、字段映射、任务轮询、标注归档、导出下载和 AI 预测响应转换。

- Dashboard 改为读取 /api/dashboard/overview,并订阅 WebSocket progress/complete/error/status 更新解析队列和实时流转记录。

- 项目库导入视频/DICOM 后创建项目、上传媒体、触发异步解析并刷新真实项目列表。

- 工作区加载真实帧、无帧时触发解析任务、回显已保存标注、保存未归档 mask、更新 dirty mask、清空当前帧后端标注、导出 COCO JSON。

- Canvas 支持当前帧点/框提示调用后端 AI、渲染推理/已保存 mask、应用模板分类并维护保存状态计数;时间轴按项目 fps 播放。

- AI 页面新增 SAM2/SAM3 模型选择,预测请求携带 model;侧边栏和工作区新增真实 GPU/SAM 状态徽标。

- 模板库和本体面板接入真实模板 CRUD、分类编辑、拖拽排序、JSON 导入、默认腹腔镜分类和本地自定义分类选择。

测试与文档:

- 新增 Vitest 配置、前端测试 setup、API/config/websocket/store/组件测试,覆盖登录、项目库、Dashboard、Canvas、工作区、模型状态、时间轴、本体和模板库。

- 新增 pytest 后端测试夹具和 auth/projects/templates/media/AI/export/dashboard/tasks/progress 测试,使用 SQLite、fake MinIO、fake SAM registry 和 Redis monkeypatch 隔离外部服务。

- 新增 doc/ 文档结构,冻结当前需求、设计、接口契约、测试计划、前端逐元素审计、实现地图和后续实施计划,并同步更新 README 与 AGENTS。

验证:

- conda run -n seg_server pytest backend/tests:27 passed。

- npm run test:run:54 passed。

- npm run lint、npm run build、compileall、git diff --check 均通过;Vite 仅提示大 chunk 警告。
This commit is contained in:
2026-05-01 13:29:14 +08:00
parent 4d65c37c73
commit f020ff3b4f
78 changed files with 7089 additions and 456 deletions

View File

@@ -7,3 +7,16 @@ GEMINI_API_KEY="MY_GEMINI_API_KEY"
# AI Studio automatically injects this at runtime with the Cloud Run service URL. # AI Studio automatically injects this at runtime with the Cloud Run service URL.
# Used for self-referential links, OAuth callbacks, and API endpoints. # Used for self-referential links, OAuth callbacks, and API endpoints.
APP_URL="MY_APP_URL" APP_URL="MY_APP_URL"
# Frontend API address. If unset, the app infers http://<current-browser-host>:8000.
VITE_API_BASE_URL="http://192.168.3.11:8000"
# Optional WebSocket override. If unset, it is derived from VITE_API_BASE_URL.
VITE_WS_PROGRESS_URL="ws://192.168.3.11:8000/ws/progress"
# Backend SAM runtime defaults. SAM 3 additionally requires the official sam3
# package, Python 3.12+, PyTorch 2.7+, and a CUDA-capable GPU per Meta's repo.
sam_default_model="sam2"
sam_model_path="/home/wkmgc/Desktop/Seg_Server/models/sam2_hiera_tiny.pt"
sam_model_config="configs/sam2/sam2_hiera_t.yaml"
sam3_model_version="sam3.1"

322
AGENTS.md
View File

@@ -1,17 +1,21 @@
# AGENTS.md — AI 编码助手项目指南 # AGENTS.md — AI 编码助手项目指南
> 本文件面向 AI 编码助手。阅读者应假设对该项目一无所知。以下所有信息基于项目实际内容,不做假设性推断 > 本文件面向 AI 编码助手。阅读者应假设对该项目一无所知。以下信息基于当前仓库实际文件、脚本和源码不要把早期设计目标当作已实现事实。任何代码和功能修改都要落实到文档和测试上如果生成git commit信息要逐个列点把所有修改都列上重要的、大的修改放前面不重要的、小的修改列在后面
--- ---
## 项目概述 ## 项目概述
本项目是一个**语义分割系统**Semantic Segmentation System Web 前端应用,用于 AI 驱动的图像/视频分割与标注。它提供了一个深色主题Dark Mode的单页应用SPA包含项目管理、分割工作区、AI 智能分割引擎和模板库四大核心模块 本项目是一个**语义分割系统**Semantic Segmentation System,当前形态是 React 前端 + FastAPI 后端的全栈 Web 应用,用于视频/DICOM 医学影像上传、服务器端拆帧、交互式 Canvas 标注、SAM 2/SAM 3 可选辅助分割、模板分类管理和标注导出
- **项目名称**: `react-example`package.json 中的 name - **项目名称**: `react-example``package.json` 中的 `name`
- **部署目标**: Google AI StudioCloud Run - **前端入口**: `src/main.tsx``src/App.tsx`
- **AI Studio 应用链接**: https://ai.studio/apps/2707f0e1-d453-4594-a618-fba53cb937c4 - **前端服务入口**: `server.ts`Express + Vite 中间件 / 生产静态服务,并保留少量旧版 mock API
- **业务文档**: `语义分割系统构建方案.docx`(项目根目录,未解析内容 - **后端入口**: `backend/main.py`FastAPI
- **默认前端地址**: `http://localhost:3000`
- **默认后端地址**: `http://localhost:8000`
- **前端 API 配置**: `src/lib/config.ts`,优先读取 `VITE_API_BASE_URL`,未配置时按当前浏览器 hostname 推导 `http://<host>:8000`
- **业务文档**: `语义分割系统构建方案.docx`(项目根目录)
--- ---
@@ -21,15 +25,22 @@
|------|------| |------|------|
| 前端框架 | React 19 + TypeScript 5.8 | | 前端框架 | React 19 + TypeScript 5.8 |
| 构建工具 | Vite 6 | | 构建工具 | Vite 6 |
| 样式方案 | TailwindCSS 4 + 自定义深色主题 | | 前端样式 | TailwindCSS 4 + 自定义深色主题 |
| 状态管理 | React `useState`(无外部状态库 | | 前端状态 | Zustand`src/store/useStore.ts` |
| 路由 | 无路由库,基于 React 状态切换模块 | | 前端请求 | Axios`src/lib/api.ts` |
| Canvas 渲染 | Konva + react-konva | | 实时通信 | WebSocket 客户端(`src/lib/websocket.ts` |
| Canvas 渲染 | Konva + react-konva + use-image |
| 图标库 | lucide-react | | 图标库 | lucide-react |
| 动画 | motion | | 动画依赖 | motion(在 `package.json` 中声明) |
| AI SDK | @google/genaiGemini API | | AI SDK 依赖 | `@google/genai``package.json` 中声明;当前业务源码未直接调用 |
| 后端/服务器 | Express 4单文件 `server.ts` | | 后端框架 | FastAPI + Uvicorn |
| 运行时 | Node.jsES Modules`"type": "module"` | | ORM / 数据库 | SQLAlchemy + PostgreSQL |
| 缓存 / 队列 Broker | Redis |
| 后台任务 | Celery worker |
| 对象存储 | MinIO |
| AI 推理 | SAM 2 / SAM 3 可选模型 + PyTorch`GET /api/ai/models/status` 返回真实 GPU/模型状态 |
| 视频 / 影像处理 | FFmpeg / OpenCV / pydicom |
| 运行时 | Node.js ES ModulesPython 3.11 后端环境 |
--- ---
@@ -37,159 +48,244 @@
``` ```
Seg_Server/ Seg_Server/
├── server.ts # Express 服务端入口(开发服务器 + 生产静态文件服务) ├── server.ts # Express + Vite 前端入口;保留 /api/login、/api/projects、/api/templates mock
├── index.html # SPA HTML 入口 ├── index.html # SPA HTML 入口
├── vite.config.ts # Vite 构建配置 ├── vite.config.ts # Vite 配置;含 @/* 路径别名与 DISABLE_HMR 逻辑
├── tsconfig.json # TypeScript 配置 ├── tsconfig.json # TypeScript 配置@/* 映射到项目根目录
├── package.json # 依赖与脚本 ├── package.json # npm 依赖与脚本
├── .env.example # 环境变量模板 ├── .env.example # AI Studio/Gemini 前端环境变量模板
├── metadata.json # AI Studio 元数据(目前为空) ├── metadata.json # AI Studio 元数据
├── src/ ├── public/
── main.tsx # React 应用挂载点StrictMode ── logo.png # Sidebar 使用的 /logo.png
│ ├── App.tsx # 根组件:模块路由 + 登录鉴权 ├── doc/ # 当前实现审计、接口契约和后续实施文档
│ ├── index.css # TailwindCSS 导入 + 自定义工具类 ├── start_services.sh # 本地一键启动 PostgreSQL/Redis/MinIO/FastAPI/Celery/前端
│ ├── lib/ ├── backend/ # FastAPI 后端
│ └── utils.ts # `cn()` 工具函数clsx + tailwind-merge ├── main.py # 应用入口、lifespan、CORS、路由注册、WebSocket
── components/ ── config.py # Pydantic Settings读取 backend/.env
├── auth/ ├── database.py # SQLAlchemy Engine / Session
│ └── Login.tsx # 登录页 ├── models.py # Project/Frame/Template/Annotation/Mask/ProcessingTask ORM
├── layout/ ├── schemas.py # Pydantic 请求/响应模型
│ └── Sidebar.tsx # 左侧导航栏w-16 ├── minio_client.py # MinIO 上传、下载、预签名 URL
├── dashboard/ ├── redis_client.py # Redis 连接封装
└── Dashboard.tsx # 总体概况仪表盘 ├── celery_app.py # Celery app 配置
├── projects/ ├── worker_tasks.py # Celery 任务入口
└── ProjectLibrary.tsx # 项目库列表 ├── download_sam2.py # SAM 2 权重下载脚本
├── workspace/ ├── requirements.txt # Python 依赖
│ ├── VideoWorkspace.tsx # 核心分割工作区布局 ├── routers/
│ ├── CanvasArea.tsx # Konva 画布(缩放/平移/选点) │ │ ├── auth.py # /api/auth/login
│ ├── ToolsPalette.tsx # 左侧工具栏 │ │ ├── projects.py # /api/projects 与 /api/projects/{id}/frames
│ ├── OntologyInspector.tsx # 右侧本体/属性检查面板 │ │ ├── templates.py # /api/templates
── FrameTimeline.tsx # 底部时间轴 │ │ ── media.py # /api/media/upload、/upload/dicom、/parse
├── ai/ ├── ai.py # /api/ai/predict、/models/status、/auto、/annotate
│ └── AISegmentation.tsx # AI 智能分割引擎界面 │ │ └── export.py # /api/export/{project_id}/coco、/masks
└── templates/ └── services/
└── TemplateRegistry.tsx # 模板库管理 ├── frame_parser.py # FFmpeg/OpenCV 拆帧、pydicom 读片、帧上传
│ ├── sam2_engine.py # SAM 2 懒加载推理封装和 fallback
│ ├── sam3_engine.py # SAM 3 状态检测与文本语义推理适配器
│ └── sam_registry.py # SAM 模型选择、GPU 状态与推理分发
└── src/ # React 前端
├── main.tsx # React StrictMode 挂载
├── App.tsx # 登录拦截 + 模块切换
├── index.css # TailwindCSS 导入 + 全局样式
├── store/useStore.ts # Zustand 全局状态
├── lib/api.ts # Axios API 封装
├── lib/websocket.ts # 解析进度 WebSocket 客户端
├── lib/utils.ts # cn() 工具函数
└── components/ # 扁平化组件目录
├── Login.tsx
├── Sidebar.tsx
├── Dashboard.tsx
├── ProjectLibrary.tsx
├── VideoWorkspace.tsx
├── CanvasArea.tsx
├── ToolsPalette.tsx
├── OntologyInspector.tsx
├── FrameTimeline.tsx
├── AISegmentation.tsx
└── TemplateRegistry.tsx
``` ```
以下目录/文件通常是运行产物或本地数据,已在 `.gitignore` 中忽略:`node_modules/``dist/``models/``uploads/``frames/``Data_*/``*.mp4``*.dcm``*.7z``backend/.env`、日志文件等。
`doc/` 目录是当前项目的事实文档入口。修改功能前优先查看:
- `doc/03-frontend-element-audit.md`:哪些前端元素是真功能,哪些是 Mock/UI-only。
- `doc/04-api-contracts.md`:前后端接口契约,以及当前不一致点。
- `doc/05-implementation-plan.md`:建议的后续实施顺序。
--- ---
## 构建与运行命令 ## 构建与运行命令
### 前端 / Node 入口
```bash ```bash
# 安装依赖
npm install npm install
# 开发模式(启动 Express + Vite 中间件,端口 3000 # 开发模式:运行 tsx server.tsExpress 集成 Vite middleware,端口 3000
npm run dev npm run dev
# 生产构建输出 dist/ # 生产构建输出 dist/
npm run build npm run build
# 预览生产构建 # Vite 预览
npm run preview npm run preview
# 生产模式启动Node 直接运行 server.ts需先 build # 生产模式运行 server.ts服务 dist/;仍保留 server.ts 中的旧版 mock API
npm start npm start
# 类型检查(不输出文件) # TypeScript 类型检查
npm run lint npm run lint
# 清理构建产物 # 删除 dist/
npm run clean npm run clean
``` ```
**开发服务器地址**: `http://localhost:3000` ### FastAPI 后端
**环境变量**(复制 `.env.example``.env.local` ```bash
- `GEMINI_API_KEY` — Gemini AI API 密钥AI Studio 会自动注入) cd backend
- `APP_URL` — 应用托管 URLAI Studio 自动注入 Cloud Run 地址) uvicorn main:app --host 0.0.0.0 --port 8000 --reload
```
### 一键启动
```bash
./start_services.sh
```
该脚本会依次检查/启动 PostgreSQL、Redis、MinIO、FastAPI 后端、Celery worker 和前端。
--- ---
## 运行时架构 ## 运行时架构
### 前端 ### 前端
- 单页应用React 19 `StrictMode` 挂载。
- 模块切换通过 `App.tsx` 中的 `activeModule` 状态控制,可选值:
`'dashboard' | 'projects' | 'ai' | 'workspace' | 'templates'`
- 默认进入 `workspace`(分割工作区)。
- 未登录时全局拦截,显示 `Login` 组件。
### 后端 (`server.ts`) - 单页应用,无路由库;模块切换由 `useStore().activeModule` 控制。
- Express 服务器,端口 `3000` - 模块值包括:`dashboard``projects``ai``workspace``templates`
- **开发模式**: 集成 Vite 中间件(`middlewareMode: true` - 默认模块是 `workspace`
- **生产模式**: 静态文件服务 `dist/`,所有路由回退到 `index.html` - 未登录时渲染 `Login`
- **API 端点**(内存数据存储,无数据库): - 登录成功后 token 写入 `localStorage`Axios request interceptor 会附加 `Authorization: Bearer <token>`
- `POST /api/login` — 认证(固定用户名 `admin`,密码 `123456` - `App.tsx` 在登录后调用 `getProjects()` 初始化项目列表。
- `GET /api/projects` — 返回项目列表
- `GET /api/templates` — 返回模板列表
### 部署 ### 后端
- 面向 **Google AI Studio** / **Cloud Run** 部署。
- `metadata.json` 用于 AI Studio 元数据配置(当前为空) - 主后端是 `backend/main.py` 的 FastAPI 服务
- `vite.config.ts` 中 HMR 可通过环境变量 `DISABLE_HMR=true` 关闭AI Studio 环境下文件监听被禁用以防止 agent 编辑时闪烁)。 - `lifespan` 启动时会:
- 创建数据库表;
- 检查/创建 MinIO bucket `seg-media`
- 测试 Redis 连接;
- 后台 seed 默认模板;
- 如果本地存在 `Data_MyVideo_1.mp4`,后台 seed 默认演示项目并拆前 100 帧。
- API 路由包括:
- `POST /api/auth/login`
- `GET/POST/PATCH/DELETE /api/projects`
- `GET/POST /api/projects/{project_id}/frames`
- `GET/POST/PATCH/DELETE /api/templates`
- `POST /api/media/upload`
- `POST /api/media/upload/dicom`
- `POST /api/media/parse`
- `GET /api/tasks`
- `GET /api/tasks/{task_id}`
- `POST /api/ai/predict`
- `GET /api/ai/models/status`
- `POST /api/ai/auto`
- `POST /api/ai/annotate`
- `GET /api/ai/annotations`
- `PATCH/DELETE /api/ai/annotations/{annotation_id}`
- `GET /api/dashboard/overview`
- `GET /api/export/{project_id}/coco`
- `GET /api/export/{project_id}/masks`
- `GET /health`
- `WS /ws/progress`
### 存储
- PostgreSQL 存储项目、帧、模板、标注、mask 和后台任务元数据。
- MinIO 存储上传视频、DICOM、拆出的帧、缩略图等对象前端展示使用预签名 URL。
- Redis 当前作为 Celery broker/result backend并用于连接检查。
---
## 主要业务流程
1. 登录:`Login.tsx` 调用 `POST /api/auth/login`,默认开发凭证为 `admin / 123456`
2. 项目管理:`ProjectLibrary.tsx` 调用项目 API 创建项目、拉取列表。
3. 上传资源:视频走 `/api/media/upload`DICOM 批量走 `/api/media/upload/dicom`
4. 拆帧入队:前端调用 `/api/media/parse`;后端创建 `ProcessingTask` 并投递 Celery。
5. worker 执行Celery worker 用 FFmpeg 优先拆视频帧,失败后用 OpenCV fallbackDICOM 使用 pydicom并持续更新任务进度。
6. 帧展示:`VideoWorkspace.tsx` 调用 `/api/projects/{id}/frames``CanvasArea.tsx``FrameTimeline.tsx` 显示当前帧与时间轴缩略图。
7. AI 分割:前端工具包括正向点、反向点和框选;后端 `ai.py` 期望按 `image_id``prompt_type``prompt_data``model` 调用 SAM registry。SAM 2 支持点/框/自动分割SAM 3 入口支持文本语义推理,运行时不满足官方要求时会在状态接口中标为不可用。
8. 模板管理:`TemplateRegistry.tsx` 管理分类、颜色和 z-index`OntologyInspector.tsx` 在工作区显示当前模板分类树。
9. 导出:后端支持 COCO JSON 和 PNG mask ZIP 导出。
---
## 当前实现注意事项
- `src/lib/config.ts` 会优先读取 `VITE_API_BASE_URL``VITE_WS_PROGRESS_URL`;未配置时按当前浏览器 hostname 推导后端 `:8000` 地址。
- 前端 `predictMask()` 已按后端 `PredictRequest` 发送 `image_id``prompt_type``prompt_data``model`,并将后端 `polygons` 转成 Konva 可渲染的 `pathData`
- 前端 `exportCoco()` 已对齐后端 `/api/export/{project_id}/coco`;工作区“导出 JSON 标注集”按钮已绑定下载流程,导出前会先保存当前待归档 mask。
- 工作区“结构化归档保存”按钮已接入 `POST /api/ai/annotate``PATCH /api/ai/annotations/{id}`;加载工作区时会通过 `GET /api/ai/annotations` 回显已保存标注。
- 工作区“清空遮罩”会调用 `DELETE /api/ai/annotations/{id}` 删除当前帧已保存标注,并清空当前帧本地 mask。
- 项目状态已统一为 `pending``parsing``ready``error`;前端 `src/lib/api.ts` 会兼容归一化旧库中可能存在的 `Ready``Parsing``Error`
- `server.ts` 仍有旧版 `/api/login``/api/projects``/api/templates` mock当前前端真实 API 调用主要走 FastAPI 的 `/api/auth/*``/api/projects``/api/templates` 等接口。
- `Dashboard.tsx` 初始统计、队列和活动日志来自 `GET /api/dashboard/overview`;解析队列来自 `processing_tasks`Celery worker 通过 Redis pub/sub 的 `seg:progress` 频道推送细粒度进度,再由 FastAPI 广播到 `/ws/progress`
--- ---
## 代码风格与约定 ## 代码风格与约定
### 样式规范 ### 样式规范
- **深色主题**: 全局背景色以 `#0a0a0a``#111``#0d0d0d``#151515``#1e1e1e` 为主。
- **强调色**: 青色(`cyan-400`/`cyan-500`)用于激活状态、按钮和关键指示器 - 深色主题为主,常见背景色包括 `#0a0a0a``#111``#0d0d0d``#151515``#1e1e1e`
- **工具类优先**: 全面使用 TailwindCSS 工具类,通过 `cn()` 合并条件类名 - 青色(如 `cyan-400` / `cyan-500`)用于激活状态、主按钮和关键指标
- **自定义工具类**: `index.css` 中定义 `.no-scrollbar` 用于隐藏滚动条 - 前端样式主要使用 TailwindCSS 工具类,通过 `cn()` 合并条件类名
- `src/index.css` 使用 TailwindCSS 4 的 `@import "tailwindcss";`
### 组件规范 ### 组件规范
- 所有组件使用 **函数组件 + Hooks**,无类组件。
- 组件按功能模块分目录存放在 `src/components/{module}/` - 组件使用函数组件 + Hooks
- Props 类型使用 TypeScript `interface` 定义 - 当前组件目录是扁平结构:`src/components/*.tsx`,不是按模块子目录分层
- 导入排序React → 第三方库 → 内部模块 → 类型 - Props 类型优先使用 TypeScript `interface`
- UI 文本保持中文。
- 代码与注释优先使用英文。
### 命名规范 ### 命名规范
- 组件文件使用 **PascalCase**(如 `AISegmentation.tsx`)。
- 工具文件使用 **camelCase**(如 `utils.ts`)。
- 类型/接口使用 **PascalCase**
### 语言约定 - 组件文件使用 PascalCase例如 `AISegmentation.tsx`
- **界面文本**: 全部使用 **中文**(如 "核心分割工作区"、"AI智能分割引擎"、"导出 JSON 标注集" - 工具文件使用 camelCase例如 `utils.ts`
- **代码与注释**: 使用英文 - 类型和接口使用 PascalCase
- 添加新 UI 文本时,**必须保持中文**。
--- ---
## 测试策略 ## 测试策略
**当前状态:无测试文件。** 当前仓库已配置前端 Vitest 测试和后端 pytest 测试。测试依据 `doc/07-current-requirements-freeze.md``doc/08-current-design-freeze.md``doc/09-test-plan.md`
- 项目中不存在 `.test.``.spec.` 文件 - 前端测试配置:`vitest.config.ts`,共享 setup 在 `src/test/setup.tsx`
- 无测试框架配置(如 Jest、Vitest、Playwright - 前端测试命令:`npm run test:run`
- 若需添加测试,建议在前端引入 Vitest与 Vite 同生态)进行单元测试,或使用 Playwright 进行 E2E 测试 - 后端测试依赖:`backend/requirements-dev.txt`
- 后端测试命令:`pytest backend/tests`,或在 `backend/` 目录执行 `pytest tests`
- 基础静态校验:`npm run lint``npm run build``python -m py_compile backend/routers/ai.py backend/routers/templates.py backend/schemas.py`
- 后端测试使用内存 SQLite、fake MinIO 和 fake SAM registry不依赖真实 PostgreSQL、MinIO、Redis 或模型权重。
--- ---
## 安全注意事项 ## 安全注意事项
- **硬编码凭证**: `server.ts` 中登录验证使用硬编码凭据(`admin` / `123456`),生产环境必须替换为真实的身份验证机制 - FastAPI 登录是开发用硬编码凭证:`admin / 123456`
- **Mock JWT**: 登录成功返回固定 `fake-jwt-token-for-admin`无实际的 JWT 签名验 - 登录成功返回固定 token`fake-jwt-token-for-admin`没有真实 JWT 签名验。
- **内存数据存储**: 所有项目/模板数据存储在内存中,服务重启后数据丢失。无持久化层 - Axios 会附加 Bearer token但后端大多数业务路由当前没有鉴权依赖
- **环境变量**: `GEMINI_API_KEY` 通过 `.env.local` 管理,已加入 `.gitignore`,不会误提交 - `backend/.env``.gitignore` 忽略不要提交真实数据库、MinIO、Redis、模型路径等敏感配置
- **CORS / 安全头**: Express 服务器目前未配置 CORS 策略或安全响应头(如 Helmet - `start_services.sh` 中包含本机路径和 sudo 启动逻辑,迁移机器时要审查
- Express `server.ts` 的旧版 mock API 只适合开发/兼容场景,不能当生产鉴权或持久化方案。
--- ---
## 关键依赖与注意事项 ## AI Studio / Vite 特定配置
- **React 19**: 使用 `createRoot` API注意与 React 18 的部分差异 - `.env.example` 包含 `GEMINI_API_KEY``APP_URL`,说明这些值由 AI Studio 注入
- **TailwindCSS 4**: 使用 `@import "tailwindcss"` 语法(非 v3 的 `@tailwind` 指令) - `vite.config.ts` 通过 `loadEnv``GEMINI_API_KEY` 注入到 `process.env.GEMINI_API_KEY`
- **react-konva**: Canvas 交互核心,所有画布相关操作(缩放、选点、遮罩)均依赖此库。 - `vite.config.ts` 中的 `DISABLE_HMR` 逻辑用于关闭 HMR避免 AI Studio agent 编辑时闪烁。**不要随意修改该逻辑。**
- **use-image**: 用于异步加载图片到 Konva 画布。
- **路径别名**: `@/*` 映射到项目根目录(由 `vite.config.ts``tsconfig.json` 共同配置)。
- **缺失资源**: `Sidebar.tsx` 引用了 `/Logo.png`,但项目根目录无此文件,运行时会 404。
---
## AI Studio 特定配置
- `vite.config.ts` 中通过 `loadEnv` 加载环境变量,并将 `GEMINI_API_KEY` 注入到 `process.env.GEMINI_API_KEY`
- AI Studio 会在部署时自动注入 `GEMINI_API_KEY``APP_URL`
- `DISABLE_HMR` 环境变量用于在 AI Studio agent 编辑模式下关闭 HMR避免界面闪烁。**请勿修改此逻辑。**

128
README.md
View File

@@ -4,16 +4,16 @@
# 语义分割系统SegServer # 语义分割系统SegServer
> 基于 React + FastAPI + SAM 2 的全栈交互式图像/视频语义分割与标注平台。 > 基于 React + FastAPI + 可选 SAM 2 / SAM 3 的全栈交互式图像/视频语义分割与标注平台。
> >
> 支持本地多媒体资产上传、服务器端按帧解析、AI 视觉大模型实时推理(正反向选点、框选生成分割 Mask、动态图层状态管理及最终标注数据结构化导出 > 支持本地多媒体资产上传、服务器端按帧解析、交互式 Canvas 标注、模板分类管理和标注数据结构化导出;工作区点/框 AI 推理默认走 SAM 2语义文本可选择 SAM 3前端会显示真实 GPU/模型状态
--- ---
## 核心功能 ## 核心功能
- **多媒体资产管理** — 支持视频MP4/AVI/MOV和 DICOM 医学影像的上传、存储与解析 - **多媒体资产管理** — 支持视频MP4/AVI/MOV和 DICOM 医学影像的上传、存储与解析
- **AI 智能分割引擎** — 集成 SAM 2 模型,支持点分割point、框分割box、语义分割semantic和自动分割auto - **AI 智能分割引擎** — 后端提供 SAM 2 / SAM 3 模型选择SAM 2 支持点分割point、框分割box和自动分割autoSAM 3 入口支持文本语义提示并按真实运行环境显示可用性
- **交互式画布标注** — 基于 Konva 的高性能 Canvas支持缩放/平移/选点/框选,实时渲染 Mask 遮罩 - **交互式画布标注** — 基于 Konva 的高性能 Canvas支持缩放/平移/选点/框选,实时渲染 Mask 遮罩
- **本体字典管理** — 可配置的分类体系、颜色映射、图层优先级z-index - **本体字典管理** — 可配置的分类体系、颜色映射、图层优先级z-index
- **项目工作区** — 项目创建、帧浏览、多图层标注、进度追踪 - **项目工作区** — 项目创建、帧浏览、多图层标注、进度追踪
@@ -37,15 +37,16 @@
│ ├── /api/auth 登录认证 │ │ ├── /api/auth 登录认证 │
│ ├── /api/projects 项目 & 视频帧 CRUD │ │ ├── /api/projects 项目 & 视频帧 CRUD │
│ ├── /api/templates 本体字典(分类/颜色/z-index │ ├── /api/templates 本体字典(分类/颜色/z-index
│ ├── /api/media 文件上传 & FFmpeg/pydicom 帧解析 │ ├── /api/media 文件上传 & 异步拆帧任务创建
│ ├── /api/ai SAM 2 推理(点/框/语义/自动分割) │ ├── /api/tasks Celery 后台任务状态
│ ├── /api/ai SAM 2 / SAM 3 推理与模型状态 │
│ └── /api/export COCO JSON / PNG Masks 导出 │ │ └── /api/export COCO JSON / PNG Masks 导出 │
└──────────────────────────┬──────────────────────────────────┘ └──────────────────────────┬──────────────────────────────────┘
│ SQLAlchemy 2.0 │ SQLAlchemy 2.0
┌──────────────────────────▼──────────────────────────────────┐ ┌──────────────────────────▼──────────────────────────────────┐
│ 数据持久化层 │ │ 数据持久化层 │
│ PostgreSQL 14 — 项目/帧/标注/Mask 元数据 │ PostgreSQL 14 — 项目/帧/标注/Mask/Task 元数据 │
│ Redis 6 — 缓存 & 任务队列状态 │ Redis 6 — Celery broker/result backend + 进度 pub/sub
│ MinIO — 对象存储(原始视频/解析帧/Mask图像 │ MinIO — 对象存储(原始视频/解析帧/Mask图像
└─────────────────────────────────────────────────────────────┘ └─────────────────────────────────────────────────────────────┘
``` ```
@@ -63,11 +64,12 @@
| Canvas 渲染 | Konva + react-konva | - | | Canvas 渲染 | Konva + react-konva | - |
| HTTP 客户端 | Axios | - | | HTTP 客户端 | Axios | - |
| 后端框架 | FastAPI | v0.136+ | | 后端框架 | FastAPI | v0.136+ |
| 数据库 ORM | SQLAlchemy + Alembic | 2.0+ | | 数据库 ORM | SQLAlchemy(依赖中包含 Alembic | 2.0+ |
| 数据库 | PostgreSQL | 14 | | 数据库 | PostgreSQL | 14 |
| 缓存 | Redis | 6 | | 队列 Broker | Redis | 6 |
| 后台任务 | Celery worker | 5.6+ |
| 对象存储 | MinIO | 2025+ | | 对象存储 | MinIO | 2025+ |
| AI 推理 | SAM 2 (Meta) + PyTorch | - | | AI 推理 | SAM 2 / SAM 3 (Meta) + PyTorch | - |
| 视频处理 | FFmpeg + OpenCV | 4.4+ | | 视频处理 | FFmpeg + OpenCV | 4.4+ |
| DICOM 处理 | pydicom | 3.0+ | | DICOM 处理 | pydicom | 3.0+ |
@@ -78,13 +80,17 @@
``` ```
Seg_Server/ Seg_Server/
├── backend/ # FastAPI 后端 ├── backend/ # FastAPI 后端
│ ├── main.py # 应用入口CORS/生命周期/路由注册) │ ├── main.py # 应用入口CORS/生命周期/路由注册/WebSocket
│ ├── config.py # 环境变量配置Pydantic Settings │ ├── config.py # 环境变量配置Pydantic Settings
│ ├── database.py # SQLAlchemy 引擎 + Session │ ├── database.py # SQLAlchemy 引擎 + Session
│ ├── models.py # ORM 模型Project/Frame/Template/Annotation/Mask │ ├── models.py # ORM 模型Project/Frame/Template/Annotation/Mask/ProcessingTask
│ ├── schemas.py # Pydantic 请求/响应校验模型 │ ├── schemas.py # Pydantic 请求/响应校验模型
│ ├── minio_client.py # MinIO 上传/下载/预签名URL封装 │ ├── minio_client.py # MinIO 上传/下载/预签名URL封装
│ ├── redis_client.py # Redis 连接封装 │ ├── redis_client.py # Redis 连接封装
│ ├── progress_events.py # 任务进度事件 payload 与 Redis 发布
│ ├── statuses.py # 项目/任务状态常量
│ ├── celery_app.py # Celery app 配置
│ ├── worker_tasks.py # Celery 任务入口
│ ├── download_sam2.py # SAM 2 模型权重自动下载脚本 │ ├── download_sam2.py # SAM 2 模型权重自动下载脚本
│ ├── requirements.txt # Python 依赖 │ ├── requirements.txt # Python 依赖
│ ├── routers/ # API 路由 │ ├── routers/ # API 路由
@@ -92,10 +98,12 @@ Seg_Server/
│ │ ├── projects.py # 项目 & 帧 CRUD │ │ ├── projects.py # 项目 & 帧 CRUD
│ │ ├── templates.py # 本体字典管理 │ │ ├── templates.py # 本体字典管理
│ │ ├── media.py # 上传 & 解析 │ │ ├── media.py # 上传 & 解析
│ │ ├── ai.py # SAM 2 推理接口 │ │ ├── ai.py # SAM 推理与模型状态接口
│ │ └── export.py # 数据导出 │ │ └── export.py # 数据导出
│ └── services/ # 业务服务 │ └── services/ # 业务服务
│ ├── sam2_engine.py # SAM 2 推理引擎(懒加载 + stub降级 │ ├── sam2_engine.py # SAM 2 推理引擎(懒加载 + stub降级
│ ├── sam3_engine.py # SAM 3 状态检测与文本语义推理适配器
│ ├── sam_registry.py # SAM 模型选择、GPU 状态与推理分发
│ └── frame_parser.py # FFmpeg 拆帧 / pydicom 读片 │ └── frame_parser.py # FFmpeg 拆帧 / pydicom 读片
├── src/ # React 前端 ├── src/ # React 前端
│ ├── main.tsx # 应用挂载点 │ ├── main.tsx # 应用挂载点
@@ -121,8 +129,11 @@ Seg_Server/
├── models/ # SAM 2 模型权重(.pt 文件) ├── models/ # SAM 2 模型权重(.pt 文件)
├── uploads/ # 临时上传目录 ├── uploads/ # 临时上传目录
├── frames/ # 临时帧目录 ├── frames/ # 临时帧目录
├── doc/ # 当前实现审计、接口契约与后续实施文档
├── public/
│ └── logo.png # 侧边栏 Logo 静态资源
├── start_services.sh # 一键启动所有服务脚本 ├── start_services.sh # 一键启动所有服务脚本
├── server.ts # 旧版 Express 入口(已弃用 ├── server.ts # Express + Vite 前端入口(也保留少量旧版 mock API
├── index.html # SPA HTML 入口 ├── index.html # SPA HTML 入口
├── vite.config.ts # Vite 构建配置 ├── vite.config.ts # Vite 构建配置
├── package.json # npm 依赖与脚本 ├── package.json # npm 依赖与脚本
@@ -131,12 +142,23 @@ Seg_Server/
--- ---
## 项目文档
当前实现审计与接口契约文档在 `doc/` 目录:
- `doc/01-purpose-and-word-summary.md` — 项目目的、Word 方案摘要与当前落地程度
- `doc/03-frontend-element-audit.md` — 前端逐元素功能审计标注真实可用、部分可用、Mock/UI-only、接口不通
- `doc/04-api-contracts.md` — 前后端接口契约和已知不一致
- `doc/06-fastapi-docs-explained.md``http://192.168.3.11:8000/docs` 的作用说明
---
## 环境准备 ## 环境准备
### 系统要求 ### 系统要求
- **OS**: Ubuntu 22.04 LTS - **OS**: Ubuntu 22.04 LTS
- **GPU**: NVIDIA GPU推荐 RTX 4090 或同等算力),用于 SAM 2 推理 - **GPU**: NVIDIA GPU推荐 RTX 4090 或同等算力),用于 SAM 推理SAM 3 官方要求 Python 3.12+、PyTorch 2.7+ 和 CUDA 12.6+ 环境
- **CUDA**: 12.x / 13.x - **CUDA**: 12.x / 13.x
- **Node.js**: 22.x+ - **Node.js**: 22.x+
- **Python**: 3.11(通过 Miniconda/Anaconda 管理) - **Python**: 3.11(通过 Miniconda/Anaconda 管理)
@@ -223,19 +245,32 @@ python download_sam2.py
### 步骤 5: 配置环境变量 ### 步骤 5: 配置环境变量
项目根目录已提供默认配置,如需修改请编辑以下文件: 后端通过 `backend/config.py` 中的 Pydantic Settings 读取 `backend/.env`。如需覆盖默认值,请编辑以下文件:
**backend/.env**(数据库/Redis/MinIO/SAM 路径): **backend/.env**(数据库/Redis/MinIO/SAM 路径):
```ini ```ini
DATABASE_URL=postgresql://seguser:segpass123@localhost:5432/segserver db_url=postgresql://seguser:segpass123@localhost:5432/segserver
REDIS_URL=redis://localhost:6379/0 redis_url=redis://localhost:6379/0
MINIO_ENDPOINT=localhost:9000 minio_endpoint=192.168.3.11:9000
MINIO_ACCESS_KEY=minioadmin minio_access_key=minioadmin
MINIO_SECRET_KEY=minioadmin minio_secret_key=minioadmin
MINIO_BUCKET_NAME=seg-media minio_secure=false
SAM2_MODEL_PATH=/home/wkmgc/Desktop/Seg_Server/models/sam2_hiera_tiny.pt sam_model_path=/home/wkmgc/Desktop/Seg_Server/models/sam2_hiera_tiny.pt
sam_model_config=configs/sam2/sam2_hiera_t.yaml
sam_default_model=sam2
sam3_model_version=sam3.1
cors_origins=["http://localhost:3000","http://192.168.3.11:3000"]
``` ```
前端根目录的 `.env.example` 包含 AI Studio 注入变量和前端 API 配置:
```ini
VITE_API_BASE_URL=http://192.168.3.11:8000
VITE_WS_PROGRESS_URL=ws://192.168.3.11:8000/ws/progress
```
如果未配置 `VITE_API_BASE_URL`,前端会按当前浏览器 hostname 推导 `http://<host>:8000`
### 步骤 6: 启动后端服务 ### 步骤 6: 启动后端服务
```bash ```bash
@@ -252,7 +287,21 @@ nohup uvicorn main:app --host 0.0.0.0 --port 8000 > /tmp/fastapi.log 2>&1 &
- 创建数据库表(如果不存在) - 创建数据库表(如果不存在)
- 检查 MinIO bucket 是否存在 - 检查 MinIO bucket 是否存在
- 测试 Redis 连接 - 测试 Redis 连接
- 懒加载 SAM 2 模型(权重存在且 sam2 包已安装时) - 懒加载 SAM 模型;`GET /api/ai/models/status` 会返回 SAM 2、SAM 3 与 GPU 的真实可用状态
### 步骤 6.1: 启动 Celery Worker
```bash
cd ~/Desktop/Seg_Server/backend
source ~/miniconda3/etc/profile.d/conda.sh
conda activate seg_server
celery -A celery_app:celery_app worker --loglevel=info --concurrency=1
# 或使用后台模式
nohup celery -A celery_app:celery_app worker --loglevel=info --concurrency=1 > /tmp/celery.log 2>&1 &
```
`POST /api/media/parse` 只创建 `processing_tasks` 记录并把任务投递给 Celery真正的 FFmpeg/OpenCV/pydicom 拆帧由 worker 执行。worker 每次更新任务状态后会发布到 Redis `seg:progress` 频道FastAPI 订阅后转发到 `/ws/progress`,前端 Dashboard 可实时更新。
### 步骤 7: 安装前端依赖并构建 ### 步骤 7: 安装前端依赖并构建
@@ -286,7 +335,7 @@ cd ~/Desktop/Seg_Server
./start_services.sh ./start_services.sh
``` ```
脚本将依次检查并启动PostgreSQL → Redis → MinIO → FastAPI 后端 → 前端。 脚本将依次检查并启动PostgreSQL → Redis → MinIO → FastAPI 后端 → Celery Worker → 前端。
--- ---
@@ -307,10 +356,12 @@ cd ~/Desktop/Seg_Server
```bash ```bash
npm install # 安装依赖 npm install # 安装依赖
npm run dev # Vite 开发模式(端口 5173 npm run dev # 运行 tsx server.tsExpress + Vite 中间件(端口 3000
npm run build # 生产构建(输出到 dist/ npm run build # 生产构建(输出到 dist/
npm run lint # TypeScript 类型检查 npm run lint # TypeScript 类型检查
npm start # Node.js 运行 server.ts旧版 npm run test # Vitest watch 模式
npm run test:run # Vitest 单次运行
npm start # Node.js 运行 server.ts生产静态服务 / 旧版 mock API
``` ```
### 后端 ### 后端
@@ -318,8 +369,11 @@ npm start # Node.js 运行 server.ts旧版
```bash ```bash
# 在 conda seg_server 环境中 # 在 conda seg_server 环境中
cd backend cd backend
pip install -r requirements-dev.txt # 安装后端测试依赖
pytest tests # 后端接口测试
uvicorn main:app --host 0.0.0.0 --port 8000 --reload # 开发模式(热重载) uvicorn main:app --host 0.0.0.0 --port 8000 --reload # 开发模式(热重载)
uvicorn main:app --host 0.0.0.0 --port 8000 # 生产模式 uvicorn main:app --host 0.0.0.0 --port 8000 # 生产模式
celery -A celery_app:celery_app worker --loglevel=info --concurrency=1 # 后台任务 worker
``` ```
--- ---
@@ -375,7 +429,25 @@ pip install -e . --no-build-isolation
**检查清单**: **检查清单**:
1. 后端是否已启动(`curl http://localhost:8000/health` 1. 后端是否已启动(`curl http://localhost:8000/health`
2. `backend/.env` 中的 `cors_origins` 是否包含 `http://localhost:3000` 2. `backend/.env` 中的 `cors_origins` 是否包含 `http://localhost:3000`
3. 前端 `src/lib/api.ts` 中的 `baseURL` 是否为 `http://localhost:8000` 3. 前端是否配置了正确的 `VITE_API_BASE_URL`;未配置时会按当前浏览器 hostname 推导 `http://<host>:8000`
### Q5: 如何验证 AI 推理或 COCO 导出接口
**当前状态**:
- 前端 `predictMask()` 已发送后端需要的 `image_id``prompt_type``prompt_data`,并把后端 `polygons` 转成 Konva `pathData`
- 工作区点选/框选会使用当前帧的数据库 `frame.id` 调用 `/api/ai/predict`
- 前端 `exportCoco()` 已对齐到 `/api/export/{projectId}/coco`
- 工作区“导出 JSON 标注集”按钮已绑定下载流程;导出前会先保存当前待归档的前端 mask。
- 工作区“结构化归档保存”按钮会把当前项目未保存 mask 写入 `POST /api/ai/annotate`,并把 dirty mask 写入 `PATCH /api/ai/annotations/{id}`
- 工作区“清空遮罩”会通过 `DELETE /api/ai/annotations/{id}` 删除当前帧已保存标注,并清空当前帧本地 mask。
**验证**:
```bash
curl http://localhost:8000/health
curl http://localhost:8000/api/export/1/coco
```
--- ---

21
backend/celery_app.py Normal file
View File

@@ -0,0 +1,21 @@
"""Celery application for background processing."""
from celery import Celery
from config import settings
celery_app = Celery(
"seg_server",
broker=settings.redis_url,
backend=settings.redis_url,
include=["worker_tasks"],
)
celery_app.conf.update(
task_serializer="json",
result_serializer="json",
accept_content=["json"],
timezone="Asia/Shanghai",
enable_utc=True,
task_track_started=True,
)

View File

@@ -18,9 +18,11 @@ class Settings(BaseSettings):
minio_secret_key: str = "minioadmin" minio_secret_key: str = "minioadmin"
minio_secure: bool = False minio_secure: bool = False
# SAM2 # SAM
sam_default_model: str = "sam2"
sam_model_path: str = "/home/wkmgc/Desktop/Seg_Server/models/sam2_hiera_tiny.pt" sam_model_path: str = "/home/wkmgc/Desktop/Seg_Server/models/sam2_hiera_tiny.pt"
sam_model_config: str = "configs/sam2/sam2_hiera_t.yaml" sam_model_config: str = "configs/sam2/sam2_hiera_t.yaml"
sam3_model_version: str = "sam3.1"
# App # App
app_env: str = "development" app_env: str = "development"

View File

@@ -1,11 +1,13 @@
"""FastAPI application entrypoint.""" """FastAPI application entrypoint."""
import asyncio import asyncio
import json
import logging import logging
import os import os
import shutil import shutil
import tempfile import tempfile
from contextlib import asynccontextmanager from contextlib import asynccontextmanager, suppress
from datetime import datetime, timezone
from fastapi import FastAPI, WebSocket, WebSocketDisconnect from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
@@ -13,9 +15,11 @@ from fastapi.middleware.cors import CORSMiddleware
from config import settings from config import settings
from database import Base, engine, SessionLocal from database import Base, engine, SessionLocal
from minio_client import ensure_bucket_exists, upload_file from minio_client import ensure_bucket_exists, upload_file
from redis_client import ping as redis_ping from progress_events import PROGRESS_CHANNEL
from redis_client import get_redis_client, ping as redis_ping
from statuses import PROJECT_STATUS_PENDING, PROJECT_STATUS_READY
from routers import projects, templates, media, ai, export, auth from routers import projects, templates, media, ai, export, auth, dashboard, tasks
logging.basicConfig( logging.basicConfig(
level=logging.INFO, level=logging.INFO,
@@ -45,7 +49,7 @@ def _seed_default_project_sync() -> None:
project = Project( project = Project(
name="Data_MyVideo_1", name="Data_MyVideo_1",
description="默认演示视频", description="默认演示视频",
status="pending", status=PROJECT_STATUS_PENDING,
source_type="video", source_type="video",
parse_fps=30.0, parse_fps=30.0,
) )
@@ -98,7 +102,7 @@ def _seed_default_project_sync() -> None:
) )
db.add(frame) db.add(frame)
project.status = "ready" project.status = PROJECT_STATUS_READY
db.commit() db.commit()
logger.info("Seeded default project id=%s with %d frames", project.id, len(object_names)) logger.info("Seeded default project id=%s with %d frames", project.id, len(object_names))
finally: finally:
@@ -165,6 +169,7 @@ def _seed_default_templates_sync() -> None:
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
"""Application lifespan: startup and shutdown hooks.""" """Application lifespan: startup and shutdown hooks."""
progress_listener: asyncio.Task | None = None
# Startup # Startup
logger.info("Starting up SegServer backend...") logger.info("Starting up SegServer backend...")
@@ -187,6 +192,11 @@ async def lifespan(app: FastAPI):
else: else:
logger.warning("Redis connection failed.") logger.warning("Redis connection failed.")
try:
progress_listener = asyncio.create_task(_progress_pubsub_loop())
except Exception as exc: # noqa: BLE001
logger.error("Failed to start Redis progress subscription: %s", exc)
# Seed default templates # Seed default templates
try: try:
asyncio.create_task(asyncio.to_thread(_seed_default_templates_sync)) asyncio.create_task(asyncio.to_thread(_seed_default_templates_sync))
@@ -203,6 +213,10 @@ async def lifespan(app: FastAPI):
# Shutdown # Shutdown
logger.info("Shutting down SegServer backend...") logger.info("Shutting down SegServer backend...")
if progress_listener is not None:
progress_listener.cancel()
with suppress(asyncio.CancelledError):
await progress_listener
engine.dispose() engine.dispose()
@@ -229,6 +243,8 @@ app.include_router(templates.router)
app.include_router(media.router) app.include_router(media.router)
app.include_router(ai.router) app.include_router(ai.router)
app.include_router(export.router) app.include_router(export.router)
app.include_router(dashboard.router)
app.include_router(tasks.router)
@app.get("/health", tags=["Health"]) @app.get("/health", tags=["Health"])
@@ -269,6 +285,34 @@ class ConnectionManager:
manager = ConnectionManager() manager = ConnectionManager()
async def _progress_pubsub_loop() -> None:
"""Forward Redis task-progress events to connected WebSocket clients."""
while True:
pubsub = None
try:
pubsub = get_redis_client().pubsub()
await asyncio.to_thread(pubsub.subscribe, PROGRESS_CHANNEL)
logger.info("Subscribed to Redis progress channel: %s", PROGRESS_CHANNEL)
while True:
message = await asyncio.to_thread(pubsub.get_message, True, 1.0)
if message is None:
await asyncio.sleep(0)
continue
raw_data = message.get("data")
payload = json.loads(raw_data) if isinstance(raw_data, str) else raw_data
if isinstance(payload, dict):
await manager.broadcast(payload)
except asyncio.CancelledError:
raise
except Exception as exc: # noqa: BLE001
logger.error("Redis progress subscription failed: %s", exc)
await asyncio.sleep(5)
finally:
if pubsub is not None:
with suppress(Exception):
await asyncio.to_thread(pubsub.close)
@app.websocket("/ws/progress") @app.websocket("/ws/progress")
async def websocket_progress(websocket: WebSocket): async def websocket_progress(websocket: WebSocket):
"""WebSocket endpoint for real-time parsing/AI progress updates.""" """WebSocket endpoint for real-time parsing/AI progress updates."""
@@ -284,7 +328,7 @@ async def websocket_progress(websocket: WebSocket):
"type": "status", "type": "status",
"status": "connected", "status": "connected",
"message": "Progress stream active", "message": "Progress stream active",
"timestamp": str(logging.time.time() if hasattr(logging, 'time') else __import__('time').time()), "timestamp": datetime.now(timezone.utc).isoformat(),
}) })
except WebSocketDisconnect: except WebSocketDisconnect:
manager.disconnect(websocket) manager.disconnect(websocket)

View File

@@ -14,6 +14,7 @@ from sqlalchemy.orm import relationship
from sqlalchemy.sql import func from sqlalchemy.sql import func
from database import Base from database import Base
from statuses import PROJECT_STATUS_PENDING
class Project(Base): class Project(Base):
@@ -26,7 +27,7 @@ class Project(Base):
description = Column(Text, nullable=True) description = Column(Text, nullable=True)
video_path = Column(String(512), nullable=True) video_path = Column(String(512), nullable=True)
thumbnail_url = Column(String(512), nullable=True) thumbnail_url = Column(String(512), nullable=True)
status = Column(String(50), default="Ready", nullable=False) status = Column(String(50), default=PROJECT_STATUS_PENDING, nullable=False)
source_type = Column(String(20), default="video", nullable=False) # video | dicom source_type = Column(String(20), default="video", nullable=False) # video | dicom
original_fps = Column(Float, nullable=True) original_fps = Column(Float, nullable=True)
parse_fps = Column(Float, default=30.0, nullable=False) parse_fps = Column(Float, default=30.0, nullable=False)
@@ -39,6 +40,9 @@ class Project(Base):
annotations = relationship( annotations = relationship(
"Annotation", back_populates="project", cascade="all, delete-orphan" "Annotation", back_populates="project", cascade="all, delete-orphan"
) )
tasks = relationship(
"ProcessingTask", back_populates="project", cascade="all, delete-orphan"
)
class Frame(Base): class Frame(Base):
@@ -121,3 +125,30 @@ class Mask(Base):
created_at = Column(DateTime(timezone=True), server_default=func.now()) created_at = Column(DateTime(timezone=True), server_default=func.now())
annotation = relationship("Annotation", back_populates="masks") annotation = relationship("Annotation", back_populates="masks")
class ProcessingTask(Base):
"""Background task state persisted for dashboard and polling."""
__tablename__ = "processing_tasks"
id = Column(Integer, primary_key=True, index=True)
task_type = Column(String(80), nullable=False)
status = Column(String(40), default="queued", nullable=False)
progress = Column(Integer, default=0, nullable=False)
message = Column(Text, nullable=True)
project_id = Column(
Integer, ForeignKey("projects.id", ondelete="CASCADE"), nullable=True
)
celery_task_id = Column(String(255), nullable=True)
payload = Column(JSON, nullable=True)
result = Column(JSON, nullable=True)
error = Column(Text, nullable=True)
created_at = Column(DateTime(timezone=True), server_default=func.now())
started_at = Column(DateTime(timezone=True), nullable=True)
finished_at = Column(DateTime(timezone=True), nullable=True)
updated_at = Column(
DateTime(timezone=True), server_default=func.now(), onupdate=func.now()
)
project = relationship("Project", back_populates="tasks")

View File

@@ -0,0 +1,64 @@
"""Progress event payloads and Redis publication helpers."""
from __future__ import annotations
import json
import logging
from datetime import datetime, timezone
from typing import Any
from redis_client import get_redis_client
from statuses import TASK_STATUS_FAILED, TASK_STATUS_SUCCESS
logger = logging.getLogger(__name__)
PROGRESS_CHANNEL = "seg:progress"
def _iso_now() -> str:
return datetime.now(timezone.utc).isoformat()
def _event_type(task_status: str) -> str:
if task_status == TASK_STATUS_SUCCESS:
return "complete"
if task_status == TASK_STATUS_FAILED:
return "error"
return "progress"
def task_progress_payload(task: Any) -> dict[str, Any]:
"""Build the WebSocket payload from a persisted processing task."""
project = getattr(task, "project", None)
project_name = getattr(project, "name", None)
status = getattr(task, "status", "")
updated_at = getattr(task, "updated_at", None)
timestamp = updated_at.isoformat() if updated_at is not None else _iso_now()
message = getattr(task, "message", None)
return {
"type": _event_type(status),
"taskId": f"task-{task.id}",
"task_id": task.id,
"project_id": getattr(task, "project_id", None),
"projectName": project_name,
"filename": project_name,
"progress": getattr(task, "progress", 0),
"status": message or status,
"message": message,
"error": getattr(task, "error", None),
"timestamp": timestamp,
}
def publish_progress_event(payload: dict[str, Any]) -> None:
"""Publish a JSON progress event without failing the worker on Redis errors."""
try:
get_redis_client().publish(PROGRESS_CHANNEL, json.dumps(payload, ensure_ascii=False))
except Exception as exc: # noqa: BLE001
logger.warning("Failed to publish progress event: %s", exc)
def publish_task_progress_event(task: Any) -> None:
"""Publish a progress event for a ProcessingTask ORM object."""
publish_progress_event(task_progress_payload(task))

View File

@@ -0,0 +1,2 @@
pytest
httpx

View File

@@ -1,18 +1,25 @@
"""AI inference endpoints using SAM 2.""" """AI inference endpoints using selectable SAM runtimes."""
import logging import logging
from typing import Any, List from typing import Any, List
import cv2 import cv2
import numpy as np import numpy as np
from fastapi import APIRouter, Depends, HTTPException, status from fastapi import APIRouter, Depends, HTTPException, Response, status
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from database import get_db from database import get_db
from minio_client import download_file from minio_client import download_file
from models import Frame, Annotation from models import Project, Frame, Template, Annotation
from schemas import PredictRequest, PredictResponse, AnnotationOut, AnnotationCreate from schemas import (
from services.sam2_engine import sam_engine AiRuntimeStatus,
PredictRequest,
PredictResponse,
AnnotationOut,
AnnotationCreate,
AnnotationUpdate,
)
from services.sam_registry import ModelUnavailableError, sam_registry
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/ai", tags=["AI"]) router = APIRouter(prefix="/api/ai", tags=["AI"])
@@ -35,14 +42,15 @@ def _load_frame_image(frame: Frame) -> np.ndarray:
@router.post( @router.post(
"/predict", "/predict",
response_model=PredictResponse, response_model=PredictResponse,
summary="Run SAM 2 inference with a prompt", summary="Run SAM inference with a prompt",
) )
def predict(payload: PredictRequest, db: Session = Depends(get_db)) -> dict: def predict(payload: PredictRequest, db: Session = Depends(get_db)) -> dict:
"""Execute SAM 2 segmentation given an image and a prompt. """Execute selected SAM segmentation given an image and a prompt.
- **point**: `prompt_data` is a list of `[[x, y], ...]` normalized coordinates. - **point**: `prompt_data` is either a list of `[[x, y], ...]` normalized
coordinates or `{ "points": [[x, y], ...], "labels": [1, 0, ...] }`.
- **box**: `prompt_data` is `[x1, y1, x2, y2]` normalized coordinates. - **box**: `prompt_data` is `[x1, y1, x2, y2]` normalized coordinates.
- **semantic**: Not yet implemented; falls back to auto segmentation. - **semantic**: SAM 3 text prompt when model=`sam3`; SAM 2 falls back to auto.
""" """
frame = db.query(Frame).filter(Frame.id == payload.image_id).first() frame = db.query(Frame).filter(Frame.id == payload.image_id).first()
if not frame: if not frame:
@@ -54,30 +62,57 @@ def predict(payload: PredictRequest, db: Session = Depends(get_db)) -> dict:
polygons: List[List[List[float]]] = [] polygons: List[List[List[float]]] = []
scores: List[float] = [] scores: List[float] = []
if prompt_type == "point": try:
points = payload.prompt_data if prompt_type == "point":
if not isinstance(points, list) or len(points) == 0: point_payload = payload.prompt_data
raise HTTPException(status_code=400, detail="Invalid point prompt data") if isinstance(point_payload, dict):
labels = [1] * len(points) points = point_payload.get("points")
polygons, scores = sam_engine.predict_points(image, points, labels) labels = point_payload.get("labels")
else:
points = point_payload
labels = None
elif prompt_type == "box": if not isinstance(points, list) or len(points) == 0:
box = payload.prompt_data raise HTTPException(status_code=400, detail="Invalid point prompt data")
if not isinstance(box, list) or len(box) != 4: if not isinstance(labels, list) or len(labels) != len(points):
raise HTTPException(status_code=400, detail="Invalid box prompt data") labels = [1] * len(points)
polygons, scores = sam_engine.predict_box(image, box) polygons, scores = sam_registry.predict_points(payload.model, image, points, labels)
elif prompt_type == "semantic": elif prompt_type == "box":
# Placeholder: use auto segmentation for now box = payload.prompt_data
logger.info("Semantic prompt not implemented; using auto segmentation") if not isinstance(box, list) or len(box) != 4:
polygons, scores = sam_engine.predict_auto(image) raise HTTPException(status_code=400, detail="Invalid box prompt data")
polygons, scores = sam_registry.predict_box(payload.model, image, box)
else: elif prompt_type == "semantic":
raise HTTPException(status_code=400, detail=f"Unsupported prompt_type: {prompt_type}") text = payload.prompt_data if isinstance(payload.prompt_data, str) else ""
polygons, scores = sam_registry.predict_semantic(payload.model, image, text)
else:
raise HTTPException(status_code=400, detail=f"Unsupported prompt_type: {prompt_type}")
except ModelUnavailableError as exc:
raise HTTPException(status_code=503, detail=str(exc)) from exc
except NotImplementedError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
except ValueError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
return {"polygons": polygons, "scores": scores} return {"polygons": polygons, "scores": scores}
@router.get(
"/models/status",
response_model=AiRuntimeStatus,
summary="Get SAM model and GPU runtime status",
)
def model_status(selected_model: str | None = None) -> dict:
"""Return real runtime availability for GPU, SAM 2, and SAM 3."""
try:
return sam_registry.runtime_status(selected_model)
except ValueError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
@router.post( @router.post(
"/auto", "/auto",
response_model=PredictResponse, response_model=PredictResponse,
@@ -90,7 +125,10 @@ def auto_segment(image_id: int, db: Session = Depends(get_db)) -> dict:
raise HTTPException(status_code=404, detail="Frame not found") raise HTTPException(status_code=404, detail="Frame not found")
image = _load_frame_image(frame) image = _load_frame_image(frame)
polygons, scores = sam_engine.predict_auto(image) try:
polygons, scores = sam_registry.predict_auto(None, image)
except ModelUnavailableError as exc:
raise HTTPException(status_code=503, detail=str(exc)) from exc
return {"polygons": polygons, "scores": scores} return {"polygons": polygons, "scores": scores}
@@ -106,7 +144,7 @@ def save_annotation(
db: Session = Depends(get_db), db: Session = Depends(get_db),
) -> Annotation: ) -> Annotation:
"""Persist an annotation (mask, points, bbox) into the database.""" """Persist an annotation (mask, points, bbox) into the database."""
project = db.query(Frame).filter(Frame.id == payload.project_id).first() project = db.query(Project).filter(Project.id == payload.project_id).first()
if not project: if not project:
raise HTTPException(status_code=404, detail="Project not found") raise HTTPException(status_code=404, detail="Project not found")
@@ -121,3 +159,74 @@ def save_annotation(
db.refresh(annotation) db.refresh(annotation)
logger.info("Saved annotation id=%s project_id=%s", annotation.id, annotation.project_id) logger.info("Saved annotation id=%s project_id=%s", annotation.id, annotation.project_id)
return annotation return annotation
@router.get(
"/annotations",
response_model=List[AnnotationOut],
summary="List saved annotations for a project",
)
def list_annotations(
project_id: int,
frame_id: int | None = None,
db: Session = Depends(get_db),
) -> List[Annotation]:
"""Return persisted annotations for a project, optionally scoped to one frame."""
project = db.query(Project).filter(Project.id == project_id).first()
if not project:
raise HTTPException(status_code=404, detail="Project not found")
query = db.query(Annotation).filter(Annotation.project_id == project_id)
if frame_id is not None:
query = query.filter(Annotation.frame_id == frame_id)
return query.order_by(Annotation.id).all()
@router.patch(
"/annotations/{annotation_id}",
response_model=AnnotationOut,
summary="Update a saved annotation",
)
def update_annotation(
annotation_id: int,
payload: AnnotationUpdate,
db: Session = Depends(get_db),
) -> Annotation:
"""Update mutable annotation fields persisted in the database."""
annotation = db.query(Annotation).filter(Annotation.id == annotation_id).first()
if not annotation:
raise HTTPException(status_code=404, detail="Annotation not found")
updates = payload.model_dump(exclude_unset=True)
if "template_id" in updates and updates["template_id"] is not None:
template = db.query(Template).filter(Template.id == updates["template_id"]).first()
if not template:
raise HTTPException(status_code=404, detail="Template not found")
for field, value in updates.items():
setattr(annotation, field, value)
db.commit()
db.refresh(annotation)
logger.info("Updated annotation id=%s", annotation.id)
return annotation
@router.delete(
"/annotations/{annotation_id}",
status_code=status.HTTP_204_NO_CONTENT,
summary="Delete a saved annotation",
)
def delete_annotation(
annotation_id: int,
db: Session = Depends(get_db),
) -> Response:
"""Delete an annotation and its derived mask rows through ORM cascade."""
annotation = db.query(Annotation).filter(Annotation.id == annotation_id).first()
if not annotation:
raise HTTPException(status_code=404, detail="Annotation not found")
db.delete(annotation)
db.commit()
logger.info("Deleted annotation id=%s", annotation_id)
return Response(status_code=status.HTTP_204_NO_CONTENT)

View File

@@ -0,0 +1,137 @@
"""Dashboard overview endpoints."""
import os
from datetime import datetime, timezone
from typing import Any
from fastapi import APIRouter, Depends
from sqlalchemy import func
from sqlalchemy.orm import Session
from database import get_db
from models import Annotation, Frame, ProcessingTask, Project, Template
router = APIRouter(prefix="/api/dashboard", tags=["Dashboard"])
ACTIVE_TASK_STATUSES = {"queued", "running"}
def _system_load_percent() -> int:
"""Return a real host load estimate without adding a psutil dependency."""
try:
load_1m = os.getloadavg()[0]
cpu_count = os.cpu_count() or 1
return min(100, max(0, round((load_1m / cpu_count) * 100)))
except (AttributeError, OSError):
return 0
def _iso_or_none(value: datetime | None) -> str | None:
if value is None:
return None
if value.tzinfo is None:
value = value.replace(tzinfo=timezone.utc)
return value.isoformat()
def _task_payload(task: ProcessingTask) -> dict[str, Any]:
return {
"id": f"task-{task.id}",
"task_id": task.id,
"project_id": task.project_id or 0,
"name": task.project.name if task.project else f"任务 {task.id}",
"progress": task.progress,
"status": task.message or task.status,
"frame_count": (task.result or {}).get("frames_extracted", 0),
"updated_at": _iso_or_none(task.updated_at),
}
@router.get("/overview", summary="Get dashboard overview")
def get_dashboard_overview(db: Session = Depends(get_db)) -> dict[str, Any]:
"""Return live dashboard data derived from persisted backend records."""
project_count = db.query(func.count(Project.id)).scalar() or 0
frame_count = db.query(func.count(Frame.id)).scalar() or 0
annotation_count = db.query(func.count(Annotation.id)).scalar() or 0
template_count = db.query(func.count(Template.id)).scalar() or 0
active_task_count = (
db.query(func.count(ProcessingTask.id))
.filter(ProcessingTask.status.in_(ACTIVE_TASK_STATUSES))
.scalar()
or 0
)
projects = db.query(Project).order_by(Project.updated_at.desc()).all()
recent_tasks = (
db.query(ProcessingTask)
.order_by(ProcessingTask.created_at.desc())
.limit(50)
.all()
)
tasks = [_task_payload(task) for task in recent_tasks if task.status in ACTIVE_TASK_STATUSES]
activities: list[dict[str, Any]] = []
for task in recent_tasks[:10]:
project_name = task.project.name if task.project else f"项目 {task.project_id}"
activities.append({
"id": f"task-{task.id}",
"kind": "task",
"time": _iso_or_none(task.updated_at),
"message": task.message or f"任务状态: {task.status}",
"project": project_name,
})
for project in projects[:10]:
activities.append({
"id": f"project-{project.id}",
"kind": "project",
"time": _iso_or_none(project.updated_at),
"message": f"项目状态: {project.status}",
"project": project.name,
})
recent_annotations = (
db.query(Annotation)
.order_by(Annotation.updated_at.desc())
.limit(10)
.all()
)
for annotation in recent_annotations:
project_name = annotation.project.name if annotation.project else f"项目 {annotation.project_id}"
activities.append({
"id": f"annotation-{annotation.id}",
"kind": "annotation",
"time": _iso_or_none(annotation.updated_at),
"message": f"标注已更新 #{annotation.id}",
"project": project_name,
})
recent_templates = (
db.query(Template)
.order_by(Template.created_at.desc())
.limit(10)
.all()
)
for template in recent_templates:
activities.append({
"id": f"template-{template.id}",
"kind": "template",
"time": _iso_or_none(template.created_at),
"message": f"模板可用: {template.name}",
"project": "系统",
})
activities.sort(key=lambda item: item["time"] or "", reverse=True)
return {
"summary": {
"project_count": project_count,
"parsing_task_count": active_task_count,
"annotation_count": annotation_count,
"frame_count": frame_count,
"template_count": template_count,
"system_load_percent": _system_load_percent(),
},
"tasks": tasks,
"activity": activities[:10],
}

View File

@@ -1,10 +1,6 @@
"""Media upload and parsing endpoints.""" """Media upload and parsing endpoints."""
import logging import logging
import os
import shutil
import subprocess
import tempfile
from pathlib import Path from pathlib import Path
from typing import List, Optional from typing import List, Optional
@@ -12,13 +8,12 @@ from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile, s
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from database import get_db from database import get_db
from minio_client import upload_file, get_presigned_url, download_file from minio_client import upload_file, get_presigned_url
from models import Project, Frame from models import ProcessingTask, Project
from schemas import FrameOut from progress_events import publish_task_progress_event
from services.frame_parser import ( from schemas import ProcessingTaskOut
parse_video, parse_dicom, upload_frames_to_minio, from statuses import PROJECT_STATUS_PARSING, PROJECT_STATUS_PENDING, TASK_STATUS_QUEUED
extract_thumbnail, get_video_fps, from worker_tasks import parse_project_media
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/media", tags=["Media"]) router = APIRouter(prefix="/api/media", tags=["Media"])
@@ -79,7 +74,7 @@ async def upload_media(
project = Project( project = Project(
name=file.filename, name=file.filename,
description="Auto-created from upload", description="Auto-created from upload",
status="pending", status=PROJECT_STATUS_PENDING,
video_path=object_name, video_path=object_name,
source_type="video", source_type="video",
) )
@@ -135,7 +130,7 @@ async def upload_dicom_batch(
project = Project( project = Project(
name=first_name, name=first_name,
description=f"DICOM series with {len(files)} files", description=f"DICOM series with {len(files)} files",
status="pending", status=PROJECT_STATUS_PENDING,
source_type="dicom", source_type="dicom",
) )
db.add(project) db.add(project)
@@ -168,19 +163,18 @@ async def upload_dicom_batch(
@router.post( @router.post(
"/parse", "/parse",
status_code=status.HTTP_202_ACCEPTED, status_code=status.HTTP_202_ACCEPTED,
response_model=ProcessingTaskOut,
summary="Trigger frame extraction", summary="Trigger frame extraction",
) )
def parse_media( def parse_media(
project_id: int, project_id: int,
source_type: Optional[str] = None, source_type: Optional[str] = None,
db: Session = Depends(get_db), db: Session = Depends(get_db),
) -> dict: ) -> ProcessingTask:
"""Trigger frame extraction for a project's uploaded media. """Create a background task for media frame extraction.
* video: uses FFmpeg or OpenCV fallback, extracts thumbnail. The Celery worker performs the heavy FFmpeg/OpenCV/pydicom work and
* dicom: uses pydicom to read DCM frames. updates the persisted task record as it progresses.
Extracted frames are uploaded to MinIO and registered in the database.
""" """
project = db.query(Project).filter(Project.id == project_id).first() project = db.query(Project).filter(Project.id == project_id).first()
if not project: if not project:
@@ -190,100 +184,24 @@ def parse_media(
raise HTTPException(status_code=400, detail="Project has no media uploaded") raise HTTPException(status_code=400, detail="Project has no media uploaded")
effective_source = source_type or project.source_type or "video" effective_source = source_type or project.source_type or "video"
parse_fps = project.parse_fps or 30.0 task = ProcessingTask(
task_type=f"parse_{effective_source}",
tmp_dir = tempfile.mkdtemp(prefix=f"seg_parse_{project_id}_") status=TASK_STATUS_QUEUED,
output_dir = os.path.join(tmp_dir, "frames") progress=0,
os.makedirs(output_dir, exist_ok=True) message="解析任务已入队",
project_id=project_id,
try: payload={"source_type": effective_source},
if effective_source == "dicom": )
# Download all dicom files from MinIO project.status = PROJECT_STATUS_PARSING
dcm_dir = os.path.join(tmp_dir, "dcm") db.add(task)
os.makedirs(dcm_dir, exist_ok=True)
from minio_client import get_minio_client, BUCKET_NAME
client = get_minio_client()
prefix = project.video_path
objects = list(client.list_objects(BUCKET_NAME, prefix=prefix, recursive=True))
for obj in objects:
if obj.object_name.lower().endswith(".dcm"):
data = download_file(obj.object_name)
local_dcm = os.path.join(dcm_dir, os.path.basename(obj.object_name))
with open(local_dcm, "wb") as f:
f.write(data)
frame_files = parse_dicom(dcm_dir, output_dir)
else:
# Video: download and parse
media_bytes = download_file(project.video_path)
local_path = os.path.join(tmp_dir, Path(project.video_path).name)
with open(local_path, "wb") as f:
f.write(media_bytes)
frame_files, original_fps = parse_video(local_path, output_dir, fps=int(parse_fps))
project.original_fps = original_fps
# Extract thumbnail from first frame
thumbnail_path = os.path.join(tmp_dir, "thumbnail.jpg")
try:
extract_thumbnail(local_path, thumbnail_path)
with open(thumbnail_path, "rb") as f:
thumb_data = f.read()
thumb_object = f"projects/{project_id}/thumbnail.jpg"
upload_file(thumb_object, thumb_data, content_type="image/jpeg", length=len(thumb_data))
project.thumbnail_url = thumb_object
logger.info("Uploaded thumbnail for project_id=%s", project_id)
except Exception as exc: # noqa: BLE001
logger.warning("Thumbnail extraction failed: %s", exc)
except Exception as exc: # noqa: BLE001
logger.error("Frame extraction failed: %s", exc)
shutil.rmtree(tmp_dir, ignore_errors=True)
raise HTTPException(status_code=500, detail="Frame extraction failed") from exc
# Upload frames to MinIO
try:
object_names = upload_frames_to_minio(frame_files, project_id)
except Exception as exc: # noqa: BLE001
logger.error("Frame upload failed: %s", exc)
shutil.rmtree(tmp_dir, ignore_errors=True)
raise HTTPException(status_code=500, detail="Frame upload to storage failed") from exc
# Register frames in DB
frames_out = []
for idx, obj_name in enumerate(object_names):
local_frame = frame_files[idx]
try:
import cv2
img = cv2.imread(local_frame)
h, w = img.shape[:2] if img is not None else (None, None)
except Exception: # noqa: BLE001
h, w = None, None
frame = Frame(
project_id=project_id,
frame_index=idx,
image_url=obj_name,
width=w,
height=h,
)
db.add(frame)
frames_out.append(frame)
db.commit() db.commit()
for f in frames_out: db.refresh(task)
db.refresh(f) publish_task_progress_event(task)
# Cleanup temp files async_result = parse_project_media.delay(task.id)
shutil.rmtree(tmp_dir, ignore_errors=True) task.celery_task_id = async_result.id
project.status = "ready"
db.commit() db.commit()
db.refresh(task)
logger.info("Parsed %d frames for project_id=%s", len(frames_out), project_id) logger.info("Queued parse task id=%s project_id=%s celery_id=%s", task.id, project_id, async_result.id)
return { return task
"project_id": project_id,
"frames_extracted": len(frames_out),
"status": "ready",
"message": "Frame extraction completed successfully.",
}

37
backend/routers/tasks.py Normal file
View File

@@ -0,0 +1,37 @@
"""Processing task query endpoints."""
from typing import List
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.orm import Session
from database import get_db
from models import ProcessingTask
from schemas import ProcessingTaskOut
router = APIRouter(prefix="/api/tasks", tags=["Tasks"])
@router.get("", response_model=List[ProcessingTaskOut], summary="List processing tasks")
def list_tasks(
project_id: int | None = None,
status: str | None = None,
limit: int = 50,
db: Session = Depends(get_db),
) -> List[ProcessingTask]:
"""Return recent background processing tasks."""
query = db.query(ProcessingTask)
if project_id is not None:
query = query.filter(ProcessingTask.project_id == project_id)
if status is not None:
query = query.filter(ProcessingTask.status == status)
return query.order_by(ProcessingTask.created_at.desc()).limit(limit).all()
@router.get("/{task_id}", response_model=ProcessingTaskOut, summary="Get processing task")
def get_task(task_id: int, db: Session = Depends(get_db)) -> ProcessingTask:
"""Return one background task by id."""
task = db.query(ProcessingTask).filter(ProcessingTask.id == task_id).first()
if not task:
raise HTTPException(status_code=404, detail="Task not found")
return task

View File

@@ -18,9 +18,9 @@ def _pack_mapping_rules(data: dict) -> dict:
"""Pack classes/rules into mapping_rules for DB storage.""" """Pack classes/rules into mapping_rules for DB storage."""
mapping = data.get("mapping_rules") or {} mapping = data.get("mapping_rules") or {}
if "classes" in data and data["classes"] is not None: if "classes" in data and data["classes"] is not None:
mapping["classes"] = data["classes"] mapping["classes"] = data.pop("classes")
if "rules" in data and data["rules"] is not None: if "rules" in data and data["rules"] is not None:
mapping["rules"] = data["rules"] mapping["rules"] = data.pop("rules")
data["mapping_rules"] = mapping data["mapping_rules"] = mapping
return data return data

View File

@@ -70,6 +70,7 @@ class FrameOut(FrameBase):
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
class TemplateBase(BaseModel): class TemplateBase(BaseModel):
name: str name: str
description: Optional[str] = None
color: str color: str
z_index: int = 0 z_index: int = 0
mapping_rules: Optional[dict[str, Any]] = None mapping_rules: Optional[dict[str, Any]] = None
@@ -83,6 +84,7 @@ class TemplateCreate(TemplateBase):
class TemplateUpdate(BaseModel): class TemplateUpdate(BaseModel):
name: Optional[str] = None name: Optional[str] = None
description: Optional[str] = None
color: Optional[str] = None color: Optional[str] = None
z_index: Optional[int] = None z_index: Optional[int] = None
mapping_rules: Optional[dict[str, Any]] = None mapping_rules: Optional[dict[str, Any]] = None
@@ -115,7 +117,7 @@ class AnnotationCreate(AnnotationBase):
class AnnotationUpdate(BaseModel): class AnnotationUpdate(BaseModel):
mask_data: Optional[dict[str, Any]] = None mask_data: Optional[dict[str, Any]] = None
points: Optional[list[float]] = None points: Optional[list[list[float]]] = None
bbox: Optional[list[float]] = None bbox: Optional[list[float]] = None
template_id: Optional[int] = None template_id: Optional[int] = None
@@ -148,6 +150,28 @@ class MaskOut(MaskBase):
created_at: datetime created_at: datetime
# ---------------------------------------------------------------------------
# Processing task schemas
# ---------------------------------------------------------------------------
class ProcessingTaskOut(BaseModel):
model_config = ConfigDict(from_attributes=True)
id: int
task_type: str
status: str
progress: int
message: Optional[str] = None
project_id: Optional[int] = None
celery_task_id: Optional[str] = None
payload: Optional[dict[str, Any]] = None
result: Optional[dict[str, Any]] = None
error: Optional[str] = None
created_at: datetime
started_at: Optional[datetime] = None
finished_at: Optional[datetime] = None
updated_at: datetime
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# AI schemas # AI schemas
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -155,6 +179,7 @@ class PredictRequest(BaseModel):
image_id: int image_id: int
prompt_type: str # point / box / semantic prompt_type: str # point / box / semantic
prompt_data: Any prompt_data: Any
model: Optional[str] = None
class PredictResponse(BaseModel): class PredictResponse(BaseModel):
@@ -162,6 +187,37 @@ class PredictResponse(BaseModel):
scores: Optional[list[float]] = None scores: Optional[list[float]] = None
class AiModelStatus(BaseModel):
id: str
label: str
available: bool
loaded: bool = False
device: str
supports: list[str]
message: str
package_available: bool = False
checkpoint_exists: bool = False
checkpoint_path: Optional[str] = None
python_ok: bool = True
torch_ok: bool = True
cuda_required: bool = False
class GpuStatus(BaseModel):
available: bool
device: str
name: Optional[str] = None
torch_available: bool
torch_version: Optional[str] = None
cuda_version: Optional[str] = None
class AiRuntimeStatus(BaseModel):
selected_model: str
gpu: GpuStatus
models: list[AiModelStatus]
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Export schemas # Export schemas
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------

View File

@@ -0,0 +1,220 @@
"""Background media parsing runner used by Celery workers."""
import logging
import os
import shutil
import tempfile
from datetime import datetime, timezone
from pathlib import Path
from typing import Any
from sqlalchemy.orm import Session
from minio_client import BUCKET_NAME, download_file, get_minio_client, upload_file
from models import Frame, ProcessingTask, Project
from progress_events import publish_task_progress_event
from services.frame_parser import (
extract_thumbnail,
parse_dicom,
parse_video,
upload_frames_to_minio,
)
from statuses import (
PROJECT_STATUS_ERROR,
PROJECT_STATUS_PARSING,
PROJECT_STATUS_READY,
TASK_STATUS_FAILED,
TASK_STATUS_RUNNING,
TASK_STATUS_SUCCESS,
)
logger = logging.getLogger(__name__)
def _now() -> datetime:
return datetime.now(timezone.utc)
def _set_task_state(
db: Session,
task: ProcessingTask,
*,
status: str | None = None,
progress: int | None = None,
message: str | None = None,
result: dict[str, Any] | None = None,
error: str | None = None,
started: bool = False,
finished: bool = False,
) -> None:
if status is not None:
task.status = status
if progress is not None:
task.progress = max(0, min(100, progress))
if message is not None:
task.message = message
if result is not None:
task.result = result
if error is not None:
task.error = error
if started:
task.started_at = _now()
if finished:
task.finished_at = _now()
db.commit()
db.refresh(task)
publish_task_progress_event(task)
def run_parse_media_task(db: Session, task_id: int) -> dict[str, Any]:
"""Parse one project's media and update task progress in the database."""
task = db.query(ProcessingTask).filter(ProcessingTask.id == task_id).first()
if not task:
raise ValueError(f"Task not found: {task_id}")
if task.project_id is None:
_set_task_state(
db,
task,
status=TASK_STATUS_FAILED,
progress=100,
message="任务缺少 project_id",
error="Task has no project_id",
finished=True,
)
raise ValueError("Task has no project_id")
project = db.query(Project).filter(Project.id == task.project_id).first()
if not project:
_set_task_state(
db,
task,
status=TASK_STATUS_FAILED,
progress=100,
message="项目不存在",
error="Project not found",
finished=True,
)
raise ValueError(f"Project not found: {task.project_id}")
if not project.video_path:
_set_task_state(
db,
task,
status=TASK_STATUS_FAILED,
progress=100,
message="项目没有可解析媒体",
error="Project has no media uploaded",
finished=True,
)
project.status = PROJECT_STATUS_ERROR
db.commit()
raise ValueError("Project has no media uploaded")
project.status = PROJECT_STATUS_PARSING
_set_task_state(db, task, status=TASK_STATUS_RUNNING, progress=5, message="后台解析已启动", started=True)
effective_source = (task.payload or {}).get("source_type") or project.source_type or "video"
parse_fps = project.parse_fps or 30.0
tmp_dir = tempfile.mkdtemp(prefix=f"seg_parse_{project.id}_")
output_dir = os.path.join(tmp_dir, "frames")
os.makedirs(output_dir, exist_ok=True)
try:
_set_task_state(db, task, progress=15, message="正在下载媒体文件")
if effective_source == "dicom":
dcm_dir = os.path.join(tmp_dir, "dcm")
os.makedirs(dcm_dir, exist_ok=True)
client = get_minio_client()
objects = list(client.list_objects(BUCKET_NAME, prefix=project.video_path, recursive=True))
for obj in objects:
if obj.object_name.lower().endswith(".dcm"):
data = download_file(obj.object_name)
local_dcm = os.path.join(dcm_dir, os.path.basename(obj.object_name))
with open(local_dcm, "wb") as f:
f.write(data)
_set_task_state(db, task, progress=35, message="正在解析 DICOM 序列")
frame_files = parse_dicom(dcm_dir, output_dir)
else:
media_bytes = download_file(project.video_path)
local_path = os.path.join(tmp_dir, Path(project.video_path).name)
with open(local_path, "wb") as f:
f.write(media_bytes)
_set_task_state(db, task, progress=35, message="正在使用 FFmpeg/OpenCV 拆帧")
frame_files, original_fps = parse_video(local_path, output_dir, fps=int(parse_fps))
project.original_fps = original_fps
thumbnail_path = os.path.join(tmp_dir, "thumbnail.jpg")
try:
extract_thumbnail(local_path, thumbnail_path)
with open(thumbnail_path, "rb") as f:
thumb_data = f.read()
thumb_object = f"projects/{project.id}/thumbnail.jpg"
upload_file(thumb_object, thumb_data, content_type="image/jpeg", length=len(thumb_data))
project.thumbnail_url = thumb_object
except Exception as exc: # noqa: BLE001
logger.warning("Thumbnail extraction failed: %s", exc)
_set_task_state(db, task, progress=70, message="正在上传帧到对象存储")
object_names = upload_frames_to_minio(frame_files, project.id)
_set_task_state(db, task, progress=85, message="正在写入帧索引")
frames_out = []
for idx, obj_name in enumerate(object_names):
local_frame = frame_files[idx]
try:
import cv2
img = cv2.imread(local_frame)
h, w = img.shape[:2] if img is not None else (None, None)
except Exception: # noqa: BLE001
h, w = None, None
frame = Frame(
project_id=project.id,
frame_index=idx,
image_url=obj_name,
width=w,
height=h,
)
db.add(frame)
frames_out.append(frame)
project.status = PROJECT_STATUS_READY
db.commit()
result = {
"project_id": project.id,
"frames_extracted": len(frames_out),
"status": PROJECT_STATUS_READY,
"message": "Frame extraction completed successfully.",
}
_set_task_state(
db,
task,
status=TASK_STATUS_SUCCESS,
progress=100,
message="解析完成",
result=result,
finished=True,
)
logger.info("Parsed %d frames for project_id=%s", len(frames_out), project.id)
return result
except Exception as exc: # noqa: BLE001
project.status = PROJECT_STATUS_ERROR
_set_task_state(
db,
task,
status=TASK_STATUS_FAILED,
progress=100,
message="解析失败",
error=str(exc),
finished=True,
)
logger.error("Frame extraction failed: %s", exc)
raise
finally:
shutil.rmtree(tmp_dir, ignore_errors=True)

View File

@@ -1,4 +1,4 @@
"""SAM 2 engine wrapper with lazy loading and fallback stubs.""" """SAM 2 engine wrapper with lazy loading and explicit runtime status."""
import logging import logging
import os import os
@@ -11,10 +11,18 @@ from config import settings
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Attempt to import SAM 2; fall back to stubs if unavailable. # Attempt to import PyTorch and SAM 2; fall back to stubs if unavailable.
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
try: try:
import torch import torch
TORCH_AVAILABLE = True
except Exception as exc: # noqa: BLE001
TORCH_AVAILABLE = False
torch = None # type: ignore[assignment]
logger.warning("PyTorch import failed (%s). SAM2 will be unavailable.", exc)
try:
from sam2.build_sam import build_sam2 from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor from sam2.sam2_image_predictor import SAM2ImagePredictor
@@ -31,6 +39,8 @@ class SAM2Engine:
def __init__(self) -> None: def __init__(self) -> None:
self._predictor: Optional[SAM2ImagePredictor] = None self._predictor: Optional[SAM2ImagePredictor] = None
self._model_loaded = False self._model_loaded = False
self._loaded_device: str | None = None
self._last_error: str | None = None
# ----------------------------------------------------------------------- # -----------------------------------------------------------------------
# Internal helpers # Internal helpers
@@ -40,34 +50,87 @@ class SAM2Engine:
if self._model_loaded: if self._model_loaded:
return return
if not TORCH_AVAILABLE:
self._last_error = "PyTorch is not installed."
logger.warning("PyTorch not available; skipping SAM2 model load.")
self._model_loaded = True
return
if not SAM2_AVAILABLE: if not SAM2_AVAILABLE:
self._last_error = "sam2 package is not installed."
logger.warning("SAM2 not available; skipping model load.") logger.warning("SAM2 not available; skipping model load.")
self._model_loaded = True self._model_loaded = True
return return
if not os.path.isfile(settings.sam_model_path): if not os.path.isfile(settings.sam_model_path):
self._last_error = f"SAM2 checkpoint not found: {settings.sam_model_path}"
logger.error("SAM checkpoint not found at %s", settings.sam_model_path) logger.error("SAM checkpoint not found at %s", settings.sam_model_path)
self._model_loaded = True self._model_loaded = True
return return
try: try:
device = self._best_device()
model = build_sam2( model = build_sam2(
settings.sam_model_config, settings.sam_model_config,
settings.sam_model_path, settings.sam_model_path,
device="cuda", device=device,
) )
self._predictor = SAM2ImagePredictor(model) self._predictor = SAM2ImagePredictor(model)
self._model_loaded = True self._model_loaded = True
logger.info("SAM 2 model loaded from %s", settings.sam_model_path) self._loaded_device = device
self._last_error = None
logger.info("SAM 2 model loaded from %s on %s", settings.sam_model_path, device)
except Exception as exc: # noqa: BLE001 except Exception as exc: # noqa: BLE001
self._last_error = str(exc)
logger.error("Failed to load SAM 2 model: %s", exc) logger.error("Failed to load SAM 2 model: %s", exc)
self._model_loaded = True # Prevent repeated load attempts self._model_loaded = True # Prevent repeated load attempts
def _best_device(self) -> str:
if TORCH_AVAILABLE and torch is not None and torch.cuda.is_available():
return "cuda"
return "cpu"
def _ensure_ready(self) -> bool: def _ensure_ready(self) -> bool:
"""Ensure the model is loaded; return whether it is usable.""" """Ensure the model is loaded; return whether it is usable."""
self._load_model() self._load_model()
return SAM2_AVAILABLE and self._predictor is not None return SAM2_AVAILABLE and self._predictor is not None
def status(self) -> dict:
"""Return lightweight, real runtime status without forcing model load."""
checkpoint_exists = os.path.isfile(settings.sam_model_path)
device = self._loaded_device or self._best_device()
available = bool(TORCH_AVAILABLE and SAM2_AVAILABLE and checkpoint_exists)
if self._predictor is not None:
message = "SAM 2 model loaded and ready."
elif available:
message = "SAM 2 dependencies and checkpoint are present; model will load on first inference."
else:
missing = []
if not TORCH_AVAILABLE:
missing.append("PyTorch")
if not SAM2_AVAILABLE:
missing.append("sam2 package")
if not checkpoint_exists:
missing.append("checkpoint")
message = f"SAM 2 unavailable: missing {', '.join(missing)}."
if self._last_error and not self._predictor:
message = self._last_error
return {
"id": "sam2",
"label": "SAM 2",
"available": available,
"loaded": self._predictor is not None,
"device": device,
"supports": ["point", "box", "auto"],
"message": message,
"package_available": SAM2_AVAILABLE,
"checkpoint_exists": checkpoint_exists,
"checkpoint_path": settings.sam_model_path,
"python_ok": True,
"torch_ok": TORCH_AVAILABLE,
"cuda_required": False,
}
# ----------------------------------------------------------------------- # -----------------------------------------------------------------------
# Public API # Public API
# ----------------------------------------------------------------------- # -----------------------------------------------------------------------

View File

@@ -0,0 +1,148 @@
"""SAM 3 engine adapter and runtime status.
The official facebookresearch/sam3 package currently targets Python 3.12+
and CUDA-capable PyTorch. This adapter reports those requirements honestly and
only performs inference when the local runtime can actually import and execute
the package.
"""
from __future__ import annotations
import importlib.util
import logging
import sys
from typing import Any
import numpy as np
from PIL import Image
from config import settings
from services.sam2_engine import SAM2Engine
logger = logging.getLogger(__name__)
try:
import torch
TORCH_AVAILABLE = True
except Exception as exc: # noqa: BLE001
TORCH_AVAILABLE = False
torch = None # type: ignore[assignment]
logger.warning("PyTorch import failed (%s). SAM3 will be unavailable.", exc)
SAM3_PACKAGE_AVAILABLE = importlib.util.find_spec("sam3") is not None
class SAM3Engine:
"""Lazy SAM 3 image inference adapter."""
def __init__(self) -> None:
self._model: Any | None = None
self._processor: Any | None = None
self._model_loaded = False
self._last_error: str | None = None
def _python_ok(self) -> bool:
return sys.version_info >= (3, 12)
def _gpu_ok(self) -> bool:
return bool(TORCH_AVAILABLE and torch is not None and torch.cuda.is_available())
def _can_load(self) -> bool:
return bool(SAM3_PACKAGE_AVAILABLE and TORCH_AVAILABLE and self._python_ok() and self._gpu_ok())
def _load_model(self) -> None:
if self._model_loaded:
return
if not self._can_load():
self._last_error = self._status_message()
self._model_loaded = True
return
try:
from sam3.model.sam3_image_processor import Sam3Processor
from sam3.model_builder import build_sam3_image_model
self._model = build_sam3_image_model()
self._processor = Sam3Processor(self._model)
self._model_loaded = True
self._last_error = None
logger.info("SAM 3 image model loaded with version setting %s", settings.sam3_model_version)
except Exception as exc: # noqa: BLE001
self._last_error = str(exc)
self._model_loaded = True
logger.error("Failed to load SAM 3 model: %s", exc)
def _ensure_ready(self) -> bool:
self._load_model()
return self._processor is not None
def _status_message(self) -> str:
missing = []
if not SAM3_PACKAGE_AVAILABLE:
missing.append("sam3 package")
if not self._python_ok():
missing.append("Python 3.12+ runtime")
if not TORCH_AVAILABLE:
missing.append("PyTorch")
if not self._gpu_ok():
missing.append("CUDA GPU")
if missing:
return f"SAM 3 unavailable: missing {', '.join(missing)}."
return "SAM 3 dependencies are present; model will load on first inference."
def status(self) -> dict:
available = self._can_load()
return {
"id": "sam3",
"label": "SAM 3",
"available": available,
"loaded": self._processor is not None,
"device": "cuda" if self._gpu_ok() else "unavailable",
"supports": ["semantic"],
"message": "SAM 3 model loaded and ready." if self._processor is not None else (self._last_error or self._status_message()),
"package_available": SAM3_PACKAGE_AVAILABLE,
"checkpoint_exists": SAM3_PACKAGE_AVAILABLE,
"checkpoint_path": f"official/HuggingFace ({settings.sam3_model_version})",
"python_ok": self._python_ok(),
"torch_ok": TORCH_AVAILABLE,
"cuda_required": True,
}
def predict_semantic(self, image: np.ndarray, text: str) -> tuple[list[list[list[float]]], list[float]]:
if not text.strip():
raise ValueError("SAM 3 semantic prompt requires non-empty text.")
if not self._ensure_ready():
raise RuntimeError(self.status()["message"])
pil_image = Image.fromarray(image)
with torch.inference_mode(): # type: ignore[union-attr]
state = self._processor.set_image(pil_image)
output = self._processor.set_text_prompt(state=state, prompt=text.strip())
masks = output.get("masks", [])
scores = output.get("scores", [])
polygons = []
for mask in masks:
if hasattr(mask, "detach"):
mask = mask.detach().cpu().numpy()
if mask.ndim == 3:
mask = mask[0]
poly = SAM2Engine._mask_to_polygon(mask)
if poly:
polygons.append(poly)
if hasattr(scores, "detach"):
scores = scores.detach().cpu().tolist()
elif hasattr(scores, "tolist"):
scores = scores.tolist()
return polygons, list(scores)
def predict_points(self, *_args: Any, **_kwargs: Any) -> tuple[list[list[list[float]]], list[float]]:
raise NotImplementedError("This backend currently exposes SAM 3 semantic text inference; use SAM 2 for point prompts.")
def predict_box(self, *_args: Any, **_kwargs: Any) -> tuple[list[list[list[float]]], list[float]]:
raise NotImplementedError("This backend currently exposes SAM 3 semantic text inference; use SAM 2 for box prompts.")
sam3_engine = SAM3Engine()

View File

@@ -0,0 +1,80 @@
"""Model registry for SAM runtimes and GPU status."""
from __future__ import annotations
from typing import Any
from config import settings
from services.sam2_engine import TORCH_AVAILABLE, sam_engine as sam2_engine
from services.sam3_engine import sam3_engine
try:
import torch
except Exception: # noqa: BLE001
torch = None # type: ignore[assignment]
class ModelUnavailableError(RuntimeError):
"""Raised when a selected model cannot run in this environment."""
class SAMRegistry:
"""Dispatch predictions to the selected SAM backend."""
def __init__(self) -> None:
self._engines = {
"sam2": sam2_engine,
"sam3": sam3_engine,
}
def normalize_model_id(self, model_id: str | None) -> str:
selected = (model_id or settings.sam_default_model or "sam2").lower()
if selected not in self._engines:
raise ValueError(f"Unsupported model: {model_id}")
return selected
def runtime_status(self, selected_model: str | None = None) -> dict[str, Any]:
return {
"selected_model": self.normalize_model_id(selected_model),
"gpu": self.gpu_status(),
"models": [engine.status() for engine in self._engines.values()],
}
def gpu_status(self) -> dict[str, Any]:
cuda_available = bool(TORCH_AVAILABLE and torch is not None and torch.cuda.is_available())
return {
"available": cuda_available,
"device": "cuda" if cuda_available else "cpu",
"name": torch.cuda.get_device_name(0) if cuda_available else None,
"torch_available": bool(TORCH_AVAILABLE),
"torch_version": getattr(torch, "__version__", None) if torch is not None else None,
"cuda_version": getattr(torch.version, "cuda", None) if torch is not None else None,
}
def _engine(self, model_id: str | None) -> Any:
return self._engines[self.normalize_model_id(model_id)]
def _ensure_available(self, model_id: str | None) -> Any:
engine = self._engine(model_id)
status = engine.status()
if not status["available"]:
raise ModelUnavailableError(status["message"])
return engine
def predict_points(self, model_id: str | None, image: Any, points: list[list[float]], labels: list[int]):
return self._ensure_available(model_id).predict_points(image, points, labels)
def predict_box(self, model_id: str | None, image: Any, box: list[float]):
return self._ensure_available(model_id).predict_box(image, box)
def predict_auto(self, model_id: str | None, image: Any):
return self._ensure_available(model_id).predict_auto(image)
def predict_semantic(self, model_id: str | None, image: Any, text: str):
model = self.normalize_model_id(model_id)
if model == "sam3":
return self._ensure_available(model).predict_semantic(image, text)
return self._ensure_available(model).predict_auto(image)
sam_registry = SAMRegistry()

11
backend/statuses.py Normal file
View File

@@ -0,0 +1,11 @@
"""Shared status constants used across backend project/task flows."""
PROJECT_STATUS_PENDING = "pending"
PROJECT_STATUS_PARSING = "parsing"
PROJECT_STATUS_READY = "ready"
PROJECT_STATUS_ERROR = "error"
TASK_STATUS_QUEUED = "queued"
TASK_STATUS_RUNNING = "running"
TASK_STATUS_SUCCESS = "success"
TASK_STATUS_FAILED = "failed"

72
backend/tests/conftest.py Normal file
View File

@@ -0,0 +1,72 @@
"""Shared pytest fixtures for backend API tests."""
from __future__ import annotations
import sys
from pathlib import Path
from typing import Iterator
import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
from sqlalchemy import create_engine
from sqlalchemy.orm import Session, sessionmaker
from sqlalchemy.pool import StaticPool
BACKEND_DIR = Path(__file__).resolve().parents[1]
if str(BACKEND_DIR) not in sys.path:
sys.path.insert(0, str(BACKEND_DIR))
from database import Base, get_db # noqa: E402
from main import websocket_progress # noqa: E402
from routers import ai, auth, dashboard, export, media, projects, tasks, templates # noqa: E402
@pytest.fixture()
def db_session() -> Iterator[Session]:
engine = create_engine(
"sqlite://",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base.metadata.create_all(bind=engine)
session = TestingSessionLocal()
try:
yield session
finally:
session.close()
Base.metadata.drop_all(bind=engine)
engine.dispose()
@pytest.fixture()
def app(db_session: Session) -> FastAPI:
test_app = FastAPI()
def override_get_db() -> Iterator[Session]:
yield db_session
test_app.dependency_overrides[get_db] = override_get_db
test_app.include_router(auth.router)
test_app.include_router(projects.router)
test_app.include_router(templates.router)
test_app.include_router(media.router)
test_app.include_router(ai.router)
test_app.include_router(export.router)
test_app.include_router(dashboard.router)
test_app.include_router(tasks.router)
@test_app.get("/health")
def health_check() -> dict[str, str]:
return {"status": "ok", "service": "SegServer"}
test_app.add_api_websocket_route("/ws/progress", websocket_progress)
return test_app
@pytest.fixture()
def client(app: FastAPI) -> Iterator[TestClient]:
with TestClient(app) as test_client:
yield test_client

248
backend/tests/test_ai.py Normal file
View File

@@ -0,0 +1,248 @@
import numpy as np
def _create_project_and_frame(client):
project = client.post("/api/projects", json={"name": "AI Project"}).json()
frame = client.post(f"/api/projects/{project['id']}/frames", json={
"project_id": project["id"],
"frame_index": 0,
"image_url": "frames/0.jpg",
"width": 640,
"height": 360,
}).json()
template = client.post("/api/templates", json={
"name": "Template",
"color": "#06b6d4",
"z_index": 0,
"classes": [],
"rules": [],
}).json()
return project, frame, template
def test_predict_accepts_point_object_with_labels(client, monkeypatch):
_, frame, _ = _create_project_and_frame(client)
calls = {}
monkeypatch.setattr("routers.ai._load_frame_image", lambda frame: np.zeros((10, 10, 3), dtype=np.uint8))
def fake_predict_points(image, points, labels):
calls["args"] = (points, labels)
return (
[[[0.1, 0.1], [0.9, 0.1], [0.9, 0.9]]],
[0.95],
)
monkeypatch.setattr("routers.ai.sam_registry.predict_points", lambda model, image, points, labels: fake_predict_points(image, points, labels))
response = client.post("/api/ai/predict", json={
"image_id": frame["id"],
"prompt_type": "point",
"prompt_data": {"points": [[0.5, 0.5], [0.1, 0.1]], "labels": [1, 0]},
})
assert response.status_code == 200
assert response.json()["scores"] == [0.95]
assert calls["args"] == ([[0.5, 0.5], [0.1, 0.1]], [1, 0])
def test_predict_box_and_semantic_fallback(client, monkeypatch):
_, frame, _ = _create_project_and_frame(client)
monkeypatch.setattr("routers.ai._load_frame_image", lambda frame: np.zeros((10, 10, 3), dtype=np.uint8))
monkeypatch.setattr("routers.ai.sam_registry.predict_box", lambda model, image, box: (
[[[0.2, 0.2], [0.8, 0.2], [0.8, 0.8]]],
[0.8],
))
monkeypatch.setattr("routers.ai.sam_registry.predict_semantic", lambda model, image, text: (
[[[0.0, 0.0], [1.0, 0.0], [1.0, 1.0]]],
[0.5],
))
box_response = client.post("/api/ai/predict", json={
"image_id": frame["id"],
"prompt_type": "box",
"prompt_data": [0.2, 0.2, 0.8, 0.8],
})
semantic_response = client.post("/api/ai/predict", json={
"image_id": frame["id"],
"prompt_type": "semantic",
"prompt_data": "胆囊",
})
assert box_response.status_code == 200
assert box_response.json()["scores"] == [0.8]
assert semantic_response.status_code == 200
assert semantic_response.json()["scores"] == [0.5]
def test_model_status_reports_runtime(client, monkeypatch):
monkeypatch.setattr("routers.ai.sam_registry.runtime_status", lambda selected_model=None: {
"selected_model": selected_model or "sam2",
"gpu": {
"available": False,
"device": "cpu",
"name": None,
"torch_available": True,
"torch_version": "2.x",
"cuda_version": None,
},
"models": [
{
"id": "sam2",
"label": "SAM 2",
"available": True,
"loaded": False,
"device": "cpu",
"supports": ["point", "box", "auto"],
"message": "ready",
"package_available": True,
"checkpoint_exists": True,
"checkpoint_path": "model.pt",
"python_ok": True,
"torch_ok": True,
"cuda_required": False,
},
{
"id": "sam3",
"label": "SAM 3",
"available": False,
"loaded": False,
"device": "unavailable",
"supports": ["semantic"],
"message": "missing Python 3.12+ runtime",
"package_available": False,
"checkpoint_exists": False,
"checkpoint_path": None,
"python_ok": False,
"torch_ok": True,
"cuda_required": True,
},
],
})
response = client.get("/api/ai/models/status?selected_model=sam3")
assert response.status_code == 200
body = response.json()
assert body["selected_model"] == "sam3"
assert body["models"][1]["id"] == "sam3"
assert body["models"][1]["available"] is False
def test_predict_validation_errors(client, monkeypatch):
project, _, _ = _create_project_and_frame(client)
assert client.post("/api/ai/predict", json={
"image_id": 999,
"prompt_type": "point",
"prompt_data": [[0.5, 0.5]],
}).status_code == 404
frame = client.post(f"/api/projects/{project['id']}/frames", json={
"project_id": project["id"],
"frame_index": 1,
"image_url": "frames/1.jpg",
}).json()
monkeypatch.setattr("routers.ai._load_frame_image", lambda frame: np.zeros((10, 10, 3), dtype=np.uint8))
assert client.post("/api/ai/predict", json={
"image_id": frame["id"],
"prompt_type": "box",
"prompt_data": [0.1, 0.2],
}).status_code == 400
def test_save_annotation_validates_project_and_frame(client):
project, frame, template = _create_project_and_frame(client)
saved = client.post("/api/ai/annotate", json={
"project_id": project["id"],
"frame_id": frame["id"],
"template_id": template["id"],
"mask_data": {"polygons": [[[0.1, 0.1], [0.9, 0.1], [0.9, 0.9]]]},
"points": [[0.5, 0.5]],
"bbox": [0.1, 0.1, 0.8, 0.8],
})
assert saved.status_code == 201
assert saved.json()["project_id"] == project["id"]
listing = client.get(f"/api/ai/annotations?project_id={project['id']}")
assert listing.status_code == 200
assert listing.json()[0]["id"] == saved.json()["id"]
frame_listing = client.get(f"/api/ai/annotations?project_id={project['id']}&frame_id={frame['id']}")
assert frame_listing.status_code == 200
assert len(frame_listing.json()) == 1
missing_project = client.post("/api/ai/annotate", json={"project_id": 999})
assert missing_project.status_code == 404
missing_frame = client.post("/api/ai/annotate", json={
"project_id": project["id"],
"frame_id": 999,
})
assert missing_frame.status_code == 404
missing_project_list = client.get("/api/ai/annotations?project_id=999")
assert missing_project_list.status_code == 404
def test_update_and_delete_annotation(client):
project, frame, template = _create_project_and_frame(client)
saved = client.post("/api/ai/annotate", json={
"project_id": project["id"],
"frame_id": frame["id"],
"template_id": template["id"],
"mask_data": {
"polygons": [[[0.1, 0.1], [0.9, 0.1], [0.9, 0.9]]],
"label": "AI Mask",
"color": "#06b6d4",
},
"points": [[0.5, 0.5]],
"bbox": [0.1, 0.1, 0.8, 0.8],
}).json()
updated = client.patch(f"/api/ai/annotations/{saved['id']}", json={
"template_id": template["id"],
"mask_data": {
"polygons": [[[0.2, 0.2], [0.8, 0.2], [0.8, 0.8]]],
"label": "胆囊",
"color": "#ff0000",
"class": {"id": "c1", "name": "胆囊", "color": "#ff0000", "zIndex": 20},
},
"points": [[0.4, 0.4]],
"bbox": [0.2, 0.2, 0.6, 0.6],
})
assert updated.status_code == 200
body = updated.json()
assert body["mask_data"]["label"] == "胆囊"
assert body["mask_data"]["class"]["id"] == "c1"
assert body["points"] == [[0.4, 0.4]]
assert body["bbox"] == [0.2, 0.2, 0.6, 0.6]
listing = client.get(f"/api/ai/annotations?project_id={project['id']}")
assert listing.status_code == 200
assert listing.json()[0]["mask_data"]["class"]["name"] == "胆囊"
deleted = client.delete(f"/api/ai/annotations/{saved['id']}")
assert deleted.status_code == 204
empty_listing = client.get(f"/api/ai/annotations?project_id={project['id']}")
assert empty_listing.status_code == 200
assert empty_listing.json() == []
def test_update_and_delete_annotation_validation(client):
project, frame, template = _create_project_and_frame(client)
saved = client.post("/api/ai/annotate", json={
"project_id": project["id"],
"frame_id": frame["id"],
"template_id": template["id"],
}).json()
assert client.patch("/api/ai/annotations/999", json={"bbox": [0, 0, 1, 1]}).status_code == 404
assert client.delete("/api/ai/annotations/999").status_code == 404
assert client.patch(
f"/api/ai/annotations/{saved['id']}",
json={"template_id": 999},
).status_code == 404

View File

@@ -0,0 +1,15 @@
def test_login_success(client):
response = client.post("/api/auth/login", json={"username": "admin", "password": "123456"})
assert response.status_code == 200
assert response.json() == {
"token": "fake-jwt-token-for-admin",
"username": "admin",
}
def test_login_rejects_invalid_credentials(client):
response = client.post("/api/auth/login", json={"username": "admin", "password": "wrong"})
assert response.status_code == 401
assert response.json()["detail"] == "Invalid credentials"

View File

@@ -0,0 +1,69 @@
def test_dashboard_overview_uses_persisted_records(client, db_session):
from models import ProcessingTask
project_pending = client.post("/api/projects", json={
"name": "Pending Project",
"status": "pending",
}).json()
project_ready = client.post("/api/projects", json={
"name": "Ready Project",
"status": "ready",
}).json()
frame = client.post(f"/api/projects/{project_pending['id']}/frames", json={
"project_id": project_pending["id"],
"frame_index": 0,
"image_url": "frames/0.jpg",
"width": 640,
"height": 360,
}).json()
template = client.post("/api/templates", json={
"name": "Dashboard Template",
"color": "#06b6d4",
"z_index": 0,
"classes": [],
"rules": [],
}).json()
annotation = client.post("/api/ai/annotate", json={
"project_id": project_pending["id"],
"frame_id": frame["id"],
"template_id": template["id"],
"mask_data": {"polygons": [[[0.1, 0.1], [0.9, 0.1], [0.9, 0.9]]]},
})
assert annotation.status_code == 201
task = ProcessingTask(
task_type="parse_video",
status="running",
progress=35,
message="正在使用 FFmpeg/OpenCV 拆帧",
project_id=project_pending["id"],
payload={"source_type": "video"},
)
db_session.add(task)
db_session.commit()
db_session.refresh(task)
response = client.get("/api/dashboard/overview")
assert response.status_code == 200
body = response.json()
assert body["summary"]["project_count"] == 2
assert body["summary"]["frame_count"] == 1
assert body["summary"]["annotation_count"] == 1
assert body["summary"]["template_count"] == 1
assert body["summary"]["parsing_task_count"] == 1
assert body["tasks"] == [
{
"id": f"task-{task.id}",
"task_id": task.id,
"project_id": project_pending["id"],
"name": "Pending Project",
"progress": 35,
"status": "正在使用 FFmpeg/OpenCV 拆帧",
"frame_count": 0,
"updated_at": body["tasks"][0]["updated_at"],
},
]
assert any(item["kind"] == "task" for item in body["activity"])
assert any(item["kind"] == "annotation" for item in body["activity"])
assert any(item["kind"] == "template" for item in body["activity"])
assert all(item["name"] != "Ready Project" for item in body["tasks"])

View File

@@ -0,0 +1,66 @@
import zipfile
from io import BytesIO
def _seed_export_data(client):
project = client.post("/api/projects", json={"name": "Export Project"}).json()
frame = client.post(f"/api/projects/{project['id']}/frames", json={
"project_id": project["id"],
"frame_index": 0,
"image_url": "frames/0.jpg",
"width": 100,
"height": 50,
}).json()
template = client.post("/api/templates", json={
"name": "Category",
"color": "#06b6d4",
"z_index": 0,
"classes": [],
"rules": [],
}).json()
annotation = client.post("/api/ai/annotate", json={
"project_id": project["id"],
"frame_id": frame["id"],
"template_id": template["id"],
"mask_data": {"polygons": [[[0.1, 0.2], [0.9, 0.2], [0.9, 0.8], [0.1, 0.8]]]},
"points": [[0.5, 0.5]],
"bbox": [0.1, 0.2, 0.8, 0.6],
}).json()
return project, frame, template, annotation
def test_export_coco_json_structure(client):
project, frame, _, _ = _seed_export_data(client)
response = client.get(f"/api/export/{project['id']}/coco")
assert response.status_code == 200
assert response.headers["content-type"].startswith("application/json")
data = response.json()
assert data["info"]["description"] == "Annotations for Export Project"
assert data["images"][0] == {
"id": frame["id"],
"file_name": "frames/0.jpg",
"width": 100,
"height": 50,
"frame_index": 0,
}
assert data["annotations"][0]["segmentation"] == [[10.0, 10.0, 90.0, 10.0, 90.0, 40.0, 10.0, 40.0]]
assert data["annotations"][0]["bbox"] == [10.0, 10.0, 80.0, 30.000000000000004]
assert data["categories"][0]["name"] == "Category"
def test_export_masks_zip(client):
project, _, _, annotation = _seed_export_data(client)
response = client.get(f"/api/export/{project['id']}/masks")
assert response.status_code == 200
assert response.headers["content-type"].startswith("application/zip")
with zipfile.ZipFile(BytesIO(response.content)) as archive:
assert archive.namelist() == [f"mask_{annotation['id']:06d}.png"]
def test_export_missing_project_returns_404(client):
assert client.get("/api/export/999/coco").status_code == 404
assert client.get("/api/export/999/masks").status_code == 404

View File

@@ -0,0 +1,15 @@
def test_health_endpoint(client):
response = client.get("/health")
assert response.status_code == 200
assert response.json() == {"status": "ok", "service": "SegServer"}
def test_websocket_progress_heartbeat(client):
with client.websocket_connect("/ws/progress") as websocket:
websocket.send_text("ping")
data = websocket.receive_json()
assert data["type"] == "status"
assert data["status"] == "connected"
assert data["message"] == "Progress stream active"

142
backend/tests/test_media.py Normal file
View File

@@ -0,0 +1,142 @@
def test_upload_rejects_unsupported_file_type(client):
response = client.post(
"/api/media/upload",
files={"file": ("notes.txt", b"text", "text/plain")},
)
assert response.status_code == 400
assert "Unsupported file type" in response.json()["detail"]
def test_upload_auto_creates_project(client, monkeypatch):
uploaded = []
monkeypatch.setattr("routers.media.upload_file", lambda object_name, data, content_type, length: uploaded.append(object_name))
monkeypatch.setattr("routers.media.get_presigned_url", lambda object_name, expires=3600: f"http://storage/{object_name}")
response = client.post(
"/api/media/upload",
files={"file": ("clip.mp4", b"video", "video/mp4")},
)
assert response.status_code == 201
data = response.json()
assert data["project_id"] is not None
assert data["object_name"] == f"uploads/{data['project_id']}/clip.mp4"
assert uploaded == ["uploads/general/clip.mp4", f"uploads/{data['project_id']}/clip.mp4"]
def test_upload_links_existing_project(client, monkeypatch):
project = client.post("/api/projects", json={"name": "Existing"}).json()
monkeypatch.setattr("routers.media.upload_file", lambda *args, **kwargs: None)
monkeypatch.setattr("routers.media.get_presigned_url", lambda object_name, expires=3600: f"http://storage/{object_name}")
response = client.post(
"/api/media/upload",
data={"project_id": str(project["id"])},
files={"file": ("clip.mp4", b"video", "video/mp4")},
)
assert response.status_code == 201
detail = client.get(f"/api/projects/{project['id']}").json()
assert detail["video_path"] == f"uploads/{project['id']}/clip.mp4"
def test_upload_dicom_batch_filters_files_and_creates_project(client, monkeypatch):
uploaded = []
monkeypatch.setattr("routers.media.upload_file", lambda object_name, data, content_type, length: uploaded.append(object_name))
response = client.post(
"/api/media/upload/dicom",
files=[
("files", ("a.dcm", b"dcm", "application/dicom")),
("files", ("skip.txt", b"text", "text/plain")),
],
)
assert response.status_code == 201
data = response.json()
assert data["uploaded_count"] == 1
assert uploaded == [f"uploads/{data['project_id']}/dicom/a.dcm"]
def test_parse_media_queues_background_task(client, monkeypatch):
project = client.post("/api/projects", json={
"name": "Parse Me",
"video_path": "uploads/1/clip.mp4",
"source_type": "video",
"parse_fps": 5,
}).json()
class FakeAsyncResult:
id = "celery-1"
queued = []
monkeypatch.setattr("routers.media.parse_project_media.delay", lambda task_id: queued.append(task_id) or FakeAsyncResult())
published = []
monkeypatch.setattr("routers.media.publish_task_progress_event", lambda task: published.append(task.id))
response = client.post(f"/api/media/parse?project_id={project['id']}")
assert response.status_code == 202
data = response.json()
assert data["task_type"] == "parse_video"
assert data["status"] == "queued"
assert data["progress"] == 0
assert data["project_id"] == project["id"]
assert data["celery_task_id"] == "celery-1"
assert queued == [data["id"]]
assert published == [data["id"]]
detail = client.get(f"/api/tasks/{data['id']}")
assert detail.status_code == 200
assert detail.json()["status"] == "queued"
project_detail = client.get(f"/api/projects/{project['id']}").json()
assert project_detail["status"] == "parsing"
def test_parse_task_runner_registers_frames(client, db_session, monkeypatch, tmp_path):
from models import ProcessingTask
from services.media_task_runner import run_parse_media_task
project = client.post("/api/projects", json={
"name": "Parse Me",
"video_path": "uploads/1/clip.mp4",
"source_type": "video",
"parse_fps": 5,
}).json()
task = ProcessingTask(
task_type="parse_video",
status="queued",
progress=0,
project_id=project["id"],
payload={"source_type": "video"},
)
db_session.add(task)
db_session.commit()
db_session.refresh(task)
frame_file = tmp_path / "frame_000001.jpg"
frame_file.write_bytes(b"fake image")
monkeypatch.setattr("services.media_task_runner.download_file", lambda object_name: b"video")
monkeypatch.setattr("services.media_task_runner.parse_video", lambda local_path, output_dir, fps: ([str(frame_file)], 25.0))
monkeypatch.setattr("services.media_task_runner.extract_thumbnail", lambda local_path, thumbnail_path: open(thumbnail_path, "wb").write(b"thumb"))
monkeypatch.setattr("services.media_task_runner.upload_file", lambda *args, **kwargs: None)
monkeypatch.setattr("services.media_task_runner.upload_frames_to_minio", lambda frame_files, project_id: [f"projects/{project_id}/frames/frame_000001.jpg"])
published = []
monkeypatch.setattr(
"services.media_task_runner.publish_task_progress_event",
lambda event_task: published.append((event_task.status, event_task.progress, event_task.message)),
)
result = run_parse_media_task(db_session, task.id)
assert result["frames_extracted"] == 1
db_session.refresh(task)
assert task.status == "success"
assert task.progress == 100
assert ("running", 5, "后台解析已启动") in published
assert ("success", 100, "解析完成") in published
project_detail = client.get(f"/api/projects/{project['id']}").json()
assert project_detail["status"] == "ready"
frames = client.get(f"/api/projects/{project['id']}/frames").json()
assert "frame_000001.jpg" in frames[0]["image_url"]

View File

@@ -0,0 +1,42 @@
from types import SimpleNamespace
from progress_events import PROGRESS_CHANNEL, publish_progress_event, task_progress_payload
def test_task_progress_payload_uses_dashboard_task_id_and_project_name():
task = SimpleNamespace(
id=12,
project_id=7,
project=SimpleNamespace(name="demo.mp4"),
status="success",
progress=100,
message="解析完成",
error=None,
updated_at=None,
)
payload = task_progress_payload(task)
assert payload["type"] == "complete"
assert payload["taskId"] == "task-12"
assert payload["task_id"] == 12
assert payload["project_id"] == 7
assert payload["filename"] == "demo.mp4"
assert payload["projectName"] == "demo.mp4"
assert payload["status"] == "解析完成"
def test_publish_progress_event_writes_json_to_redis(monkeypatch):
calls = []
class FakeRedis:
def publish(self, channel, payload):
calls.append((channel, payload))
monkeypatch.setattr("progress_events.get_redis_client", lambda: FakeRedis())
publish_progress_event({"type": "progress", "message": "正在下载媒体文件"})
assert calls
assert calls[0][0] == PROGRESS_CHANNEL
assert "正在下载媒体文件" in calls[0][1]

View File

@@ -0,0 +1,56 @@
def test_project_crud_and_frames(client, monkeypatch):
monkeypatch.setattr("routers.projects.get_presigned_url", lambda key, expires=3600: f"http://storage/{key}")
created = client.post("/api/projects", json={
"name": "Demo",
"description": "desc",
"thumbnail_url": "thumb.jpg",
"parse_fps": 12,
})
assert created.status_code == 201
project_id = created.json()["id"]
frame = client.post(f"/api/projects/{project_id}/frames", json={
"project_id": project_id,
"frame_index": 0,
"image_url": "frames/0.jpg",
"width": 640,
"height": 360,
})
assert frame.status_code == 201
frame_id = frame.json()["id"]
listing = client.get("/api/projects")
assert listing.status_code == 200
assert listing.json()[0]["frame_count"] == 1
assert listing.json()[0]["thumbnail_url"] == "http://storage/thumb.jpg"
frames = client.get(f"/api/projects/{project_id}/frames")
assert frames.status_code == 200
assert frames.json()[0]["image_url"] == "http://storage/frames/0.jpg"
single_frame = client.get(f"/api/projects/{project_id}/frames/{frame_id}")
assert single_frame.status_code == 200
assert single_frame.json()["frame_index"] == 0
updated = client.patch(f"/api/projects/{project_id}", json={"name": "Renamed", "status": "ready"})
assert updated.status_code == 200
assert updated.json()["name"] == "Renamed"
assert updated.json()["status"] == "ready"
deleted = client.delete(f"/api/projects/{project_id}")
assert deleted.status_code == 204
assert client.get(f"/api/projects/{project_id}").status_code == 404
def test_project_and_frame_404s(client):
assert client.get("/api/projects/999").status_code == 404
assert client.patch("/api/projects/999", json={"name": "x"}).status_code == 404
assert client.delete("/api/projects/999").status_code == 404
assert client.post("/api/projects/999/frames", json={
"project_id": 999,
"frame_index": 0,
"image_url": "missing.jpg",
}).status_code == 404
assert client.get("/api/projects/999/frames").status_code == 404
assert client.get("/api/projects/999/frames/1").status_code == 404

View File

@@ -0,0 +1,39 @@
def test_template_crud_packs_and_unpacks_mapping_rules(client):
payload = {
"name": "Template",
"color": "#06b6d4",
"z_index": 0,
"classes": [{"id": "c1", "name": "胆囊", "color": "#ff0000", "zIndex": 10}],
"rules": [{"id": "r1", "name": "rule"}],
}
created = client.post("/api/templates", json=payload)
assert created.status_code == 201
template_id = created.json()["id"]
assert created.json()["classes"][0]["name"] == "胆囊"
assert created.json()["rules"][0]["id"] == "r1"
listing = client.get("/api/templates")
assert listing.status_code == 200
assert listing.json()[0]["classes"][0]["name"] == "胆囊"
detail = client.get(f"/api/templates/{template_id}")
assert detail.status_code == 200
assert detail.json()["name"] == "Template"
updated = client.patch(f"/api/templates/{template_id}", json={
"classes": [{"id": "c2", "name": "肝脏", "color": "#00ff00", "zIndex": 20}],
"rules": [],
})
assert updated.status_code == 200
assert updated.json()["classes"][0]["name"] == "肝脏"
deleted = client.delete(f"/api/templates/{template_id}")
assert deleted.status_code == 204
assert client.get(f"/api/templates/{template_id}").status_code == 404
def test_template_404s(client):
assert client.get("/api/templates/999").status_code == 404
assert client.patch("/api/templates/999", json={"name": "x"}).status_code == 404
assert client.delete("/api/templates/999").status_code == 404

22
backend/worker_tasks.py Normal file
View File

@@ -0,0 +1,22 @@
"""Celery task definitions."""
import logging
from celery_app import celery_app
from database import SessionLocal
from services.media_task_runner import run_parse_media_task
logger = logging.getLogger(__name__)
@celery_app.task(name="media.parse_project")
def parse_project_media(task_id: int) -> dict:
"""Run media parsing for one queued task."""
db = SessionLocal()
try:
return run_parse_media_task(db, task_id)
except Exception as exc: # noqa: BLE001
logger.exception("Parse media task failed: task_id=%s", task_id)
raise exc
finally:
db.close()

View File

@@ -0,0 +1,58 @@
# 目的与 Word 方案摘要
## 为什么要做这个系统
Word 文档《语义分割系统构建方案.docx》的核心目标是建设一个面向视频和连续帧的智能语义分割标注系统解决传统标注工具在以下场景中的痛点
- 视频或连续帧数量大,逐帧人工画 mask 成本高。
- 高分辨率图像上同时存在底图、点、框、多边形和遮罩DOM 渲染难以支撑重交互。
- AI 分割需要低延迟点选/框选反馈,普通 REST 往返在密集交互场景下体验较差。
- 语义分割要求一个像素只能归属一个类别因此需要模板、颜色、z-index 和类别优先级来解决遮罩重叠。
- 历史 GT mask 如果只是作为静态像素图层叠加后续修改不灵活Word 方案希望把 mask 降维成可编辑的点区域。
所以这个系统的业务目的不是单纯播放视频,而是把“视频/DICOM 数据接入、拆帧、AI 辅助分割、语义分类、标注导出”串成一个工作台。
## Word 中的目标架构
Word 方案描述的理想系统包含:
- React/Vue + Konva 的高性能 Canvas 工作台。
- FastAPI 后端,使用 WebSocket 处理实时交互与任务进度。
- Celery + Redis 处理视频拆帧等长任务。
- FFmpeg/OpenCV 解析视频pydicom 解析医学影像。
- 本地 CUDA 上的 SAM 3 推理。
- GT mask 导入后通过距离变换、骨架提取、聚类等算法降维为点区域。
- 模板库管理分类、颜色和 z-index用于语义分割遮罩重叠裁决。
- PostgreSQL 存储项目、帧、模板和点区域数据。
## 当前代码已落地的部分
| 目标 | 当前代码状态 | 依据 |
|------|--------------|------|
| React 前端工作台 | 已落地 | `src/App.tsx``src/components/*.tsx` |
| Konva Canvas | 已落地 | `CanvasArea.tsx``AISegmentation.tsx` 使用 `react-konva` |
| FastAPI 后端 | 已落地 | `backend/main.py` |
| PostgreSQL ORM | 已落地 | `backend/database.py``backend/models.py` |
| MinIO 对象存储 | 已落地 | `backend/minio_client.py` |
| Redis 连接 | 已落地 | 用于 Celery broker/result backend并通过 `seg:progress` pub/sub 转发任务进度 |
| 视频拆帧 | 已落地 | `backend/services/frame_parser.py``backend/routers/media.py` |
| DICOM 批量导入 | 部分落地 | 上传和解析存在,项目级体验还需完善 |
| WebSocket 进度 | 已落地 | 拆帧进度写入任务表后发布到 Redis `seg:progress`FastAPI 广播到 `/ws/progress` |
| SAM 推理 | 部分落地 | 后端已有 SAM 2 / SAM 3 选择和真实模型状态接口SAM 3 依赖官方运行环境,当前环境不满足时会标为不可用 |
| 模板库 | 部分落地 | 分类、颜色、z-index 能存储和编辑;重叠裁决算法未落地 |
| 标注持久化 | 部分落地 | 后端有 `Annotation` 表,前端已接入新增、回显、分类更新和当前帧删除;逐点几何编辑未落地 |
| COCO / Mask 导出 | 部分落地 | `backend/routers/export.py`COCO JSON 前端按钮已接入PNG mask ZIP 尚未提供前端按钮 |
## 当前代码尚未落地的目标
- SAM 3当前已提供 `sam3_engine.py` 适配入口和状态检测;要实际运行仍需安装官方 `facebookresearch/sam3` 依赖并满足 Python 3.12+、PyTorch 2.7+、CUDA 12.6+。
- Celery 异步任务队列:已注册 Celery app 和拆帧 worker task`/api/media/parse` 会创建任务表记录并入队。
- GT mask 导入:当前前端没有 GT Label 导入入口,后端也没有对应路由。
- Mask 到点区域的拓扑降维当前没有距离变换、骨架提取、HDBSCAN 等实现。
- 类别优先级融合:模板有 z-index但没有后端融合算法。
- 撤销/重做:工具栏有按钮,但没有历史栈。
- 结构化归档保存:工作区按钮已调用 `POST /api/ai/annotate` 保存当前未归档 mask并通过 `PATCH /api/ai/annotations/{id}` 更新 dirty mask。
## 结论
当前项目已经从 UI 原型推进到“可上传、可异步拆帧、可实时查看任务进度、可浏览项目帧、可维护模板、可点/框 AI 推理、可保存标注、可导出 COCO、可查看 Dashboard 后端概览”的全栈雏形,但离 Word 中描述的完整智能标注系统还有明显差距。下一阶段最重要的是继续补齐手工绘制、撤销重做和真实语义文本分割。

View File

@@ -0,0 +1,104 @@
# 当前实现地图
## 运行入口
### 前端入口
- React 挂载:`src/main.tsx`
- 根组件:`src/App.tsx`
- 前端服务:`server.ts`
- 默认访问:`http://localhost:3000`
`server.ts` 的角色比较特殊:它既负责在开发模式下创建 Vite middleware也在生产模式下服务 `dist/`。同时它还保留了旧版 mock API`/api/login``/api/projects``/api/templates`。当前前端业务 API 主要不走这些 mock而是走 `src/lib/api.ts` 指向的 FastAPI。
### 后端入口
- FastAPI 应用:`backend/main.py`
- 默认访问:`http://localhost:8000`
- API 文档:`http://localhost:8000/docs`
- 健康检查:`GET /health`
后端启动时会通过 lifespan 执行:
- 创建数据库表。
- 检查 MinIO bucket。
- 测试 Redis。
- Seed 默认模板。
- 如果存在 `Data_MyVideo_1.mp4`,创建默认项目并拆前 100 帧。
## 前端模块切换
`App.tsx` 使用 Zustand 中的 `activeModule` 做模块切换,没有使用路由库。
| activeModule | 组件 | 页面 |
|--------------|------|------|
| `dashboard` | `Dashboard` | 系统概况 |
| `projects` | `ProjectLibrary` | 项目库 |
| `workspace` | `VideoWorkspace` | 分割工作区 |
| `ai` | `AISegmentation` | AI 智能分割页 |
| `templates` | `TemplateRegistry` | 模板库 |
未登录时,`App.tsx` 直接渲染 `Login`
## 全局状态
全局状态在 `src/store/useStore.ts` 中,主要包括:
- 登录状态:`isAuthenticated``token`
- 项目:`projects``currentProject`
- 工作区:`activeModule``activeTool``frames``currentFrameIndex`
- 标注与 mask`annotations``masks`
- 模板:`templates``activeTemplateId`
- UI`isLoading``error`
当前状态管理是前端内存状态,没有持久化到 localStorage除了登录 token。
## 数据流
### 登录
1. `Login.tsx` 调用 `login()`
2. `src/lib/api.ts` 请求 `POST /api/auth/login`
3. FastAPI `backend/routers/auth.py` 校验 `admin / 123456`
4. 前端把返回 token 写入 localStorage。
### 项目与拆帧
1. `ProjectLibrary.tsx` 调用 `getProjects()` 获取项目。
2. 上传视频时先 `createProject()`,再 `uploadMedia()`,再 `parseMedia()`
3. 后端 `media.py` 把原始文件上传到 MinIO。
4. `parseMedia()` 创建 `processing_tasks` 记录并投递 Celery worker。
5. Celery worker 下载 MinIO 文件,调用 `frame_parser.py` 拆帧。
6. worker 把拆出的帧重新上传 MinIO写入 `frames` 表,并更新任务状态。
7. 工作区通过 `GET /api/tasks/{id}` 等待任务完成,再通过 `GET /api/projects/{id}/frames` 获取预签名图片 URL。
### 工作区浏览
1. `VideoWorkspace.tsx` 根据 `currentProject.id` 加载帧。
2. `CanvasArea.tsx` 用当前帧 URL 加载底图。
3. `FrameTimeline.tsx` 显示缩略图和当前帧索引。
4. 播放按钮会推进 `currentFrameIndex`,从而更新画布底图。
### 模板管理
1. `TemplateRegistry.tsx` 调用模板 API。
2. 后端 `templates.py``classes``rules` 打包进 `mapping_rules` JSON 字段。
3. `OntologyInspector.tsx` 读取全局 `templates``activeTemplateId` 展示分类树。
## 后端数据模型
| 模型 | 表 | 用途 |
|------|----|------|
| `Project` | `projects` | 项目元数据包含视频路径、缩略图、状态、fps |
| `Frame` | `frames` | 拆帧后的图片记录 |
| `Template` | `templates` | 模板、本体类别、颜色、z-index、mapping_rules |
| `Annotation` | `annotations` | 标注数据、点、bbox、mask_data |
| `Mask` | `masks` | mask 文件元数据 |
## 当前主要风险点
- 前端 API/WS 地址虽然已支持环境变量和 hostname 推导,但部署时仍需要确认浏览器可访问 `:8000` 后端。
- AI 语义文本提示在选择 SAM 3 且运行环境满足官方依赖时走 SAM 3当前环境若不满足会在模型状态中标明不可用。
- 工作区顶部“导出 JSON 标注集”和“结构化归档保存”已接入导出、标注新增和 dirty 标注更新;清空当前帧遮罩会删除对应后端标注。撤销重做和手工绘制仍未持久化。
- Dashboard 初始统计、队列和活动日志来自后端聚合接口;解析队列来自 `processing_tasks`worker 进度通过 Redis `seg:progress` 转发到 WebSocket。
- 后端路由大多未做真实鉴权。

View File

@@ -0,0 +1,146 @@
# 前端逐元素审计
状态说明:
- 真实可用:接真实状态或后端接口,可以完成主要动作。
- 部分可用:能展示或完成一部分,但存在关键缺口。
- Mock / UI-only只有展示或本地状态变化没有真实业务效果。
- 接口不通:前端调用与后端接口不一致,按当前代码大概率失败。
## App 与导航
| 元素 | 位置 | 状态 | 说明 |
|------|------|------|------|
| 登录拦截 | `App.tsx` | 真实可用 | 未登录显示 `Login`,登录后显示主界面 |
| 模块切换 | `Sidebar.tsx` + `App.tsx` | 真实可用 | 切换 `dashboard/projects/workspace/ai/templates` |
| Logo | `Sidebar.tsx` | 真实可用 | 使用 `/logo.png`,文件存在于 `public/logo.png` |
| GPU 状态圆标 | `Sidebar.tsx` | 真实可用 | 通过 `GET /api/ai/models/status` 显示 GPU/CPU 和当前模型可用性 |
## 登录页
| 元素 | 状态 | 说明 |
|------|------|------|
| 用户名/密码输入 | 真实可用 | 默认填入 `admin / 123456` |
| 安全登录按钮 | 真实可用 | 调用 `POST /api/auth/login` |
| 错误提示 | 真实可用 | 捕获后端错误并显示 |
| 安全审计说明文字 | Mock / UI-only | UI 文案,没有真实审计功能 |
## Dashboard 系统概况
| 元素 | 状态 | 说明 |
|------|------|------|
| WebSocket 连接状态 | 真实可用 | 前端通过 `src/lib/config.ts` 推导或读取 `VITE_WS_PROGRESS_URL`,后端有 `/ws/progress` |
| 解析队列任务 | 真实可用 | 初始数据来自 `GET /api/dashboard/overview`,按 `processing_tasks` queued/running 任务生成 |
| WebSocket 更新任务 | 真实可用 | Celery worker 更新 `processing_tasks` 后发布 Redis `seg:progress`FastAPI 广播 progress/complete/error |
| 项目、任务、标注、系统负载统计 | 真实可用 | 初始数据来自 `GET /api/dashboard/overview`,系统负载按主机 load average 估算 |
| 近期实时流转记录 | 真实可用 | 初始数据来自任务、项目、标注和模板记录WebSocket status/complete/error 会继续追加 |
## 项目库 ProjectLibrary
| 元素 | 状态 | 说明 |
|------|------|------|
| 项目列表 | 真实可用 | 调用 `GET /api/projects` |
| 项目卡片缩略图 | 真实可用 | 后端返回 MinIO 预签名 `thumbnail_url` 时显示 |
| 点击项目进入工作区 | 真实可用 | 设置 `currentProject` 后切到 `workspace` |
| 新建项目 | 真实可用 | 调用 `POST /api/projects` |
| 导入视频文件 | 真实可用 | 创建项目、上传文件、触发拆帧、刷新项目列表 |
| 解析 FPS 滑块 | 真实可用 | 值传入 `createProject({ parse_fps })` |
| 导入 DICOM 序列 | 部分可用 | 可上传 `.dcm` 并触发解析;体验和错误反馈较粗 |
| 项目状态徽标 | 真实可用 | 项目状态统一为 `pending/parsing/ready/error`,前端兼容归一化旧状态值 |
| 更多按钮 | Mock / UI-only | 有图标,没有菜单或事件 |
| alert 成功/失败提示 | 真实可用但粗糙 | 使用浏览器 `alert` |
## 工作区 VideoWorkspace
| 元素 | 状态 | 说明 |
|------|------|------|
| 当前项目名 | 真实可用 | 读取 `currentProject.name` |
| 自动加载项目帧 | 真实可用 | 调用 `GET /api/projects/{id}/frames` |
| 无帧时触发解析 | 真实可用 | 如果 `video_path` 存在会调用 `parseMedia()` 创建异步任务,并轮询 `GET /api/tasks/{id}` 等待完成 |
| SAM 模型状态徽标 | 真实可用 | 调用 `GET /api/ai/models/status`,显示当前选择的 SAM 2/SAM 3 是否可用 |
| 已保存标注回显 | 真实可用 | 加载工作区帧后调用 `GET /api/ai/annotations` 并渲染已保存 mask |
| “导出 JSON 标注集”按钮 | 真实可用 | 导出前会保存未归档 mask然后调用 `exportCoco()` 下载 JSON |
| “结构化归档保存”按钮 | 真实可用 | 未保存 mask 写入 `POST /api/ai/annotate`dirty mask 写入 `PATCH /api/ai/annotations/{id}` |
## CanvasArea 画布
| 元素 | 状态 | 说明 |
|------|------|------|
| 当前帧底图显示 | 真实可用 | `useImage(frameUrl)` 加载当前帧 URL |
| 滚轮缩放 | 真实可用 | 改变 Konva Stage scale |
| 拖拽平移 | 真实可用 | activeTool 为 `move` 时 Stage draggable |
| 光标坐标显示 | 真实可用 | 根据 pointer position 计算 |
| 正向/反向选点 | 部分可用 | UI 能加点,并按当前帧 `frame.id` 调用 `/api/ai/predict`;需点击归档保存才持久化 |
| 框选 | 部分可用 | UI 能画框,并把框坐标归一化后调用后端推理;需点击归档保存才持久化 |
| AI 推理中提示 | 真实可用 | 请求期间会显示 |
| Mask 渲染 | 部分可用 | 前端会把推理/已保存标注转成 Konva `pathData` 渲染 |
| 应用分类 | 真实可用 | 将当前选择的模板分类应用到本帧 mask已保存 mask 会标为 dirty归档保存时更新后端 |
| 清空遮罩 | 真实可用 | 工作区中会删除当前帧已保存标注并清空当前帧本地 mask |
| 保存状态计数 | 真实可用 | 底部显示已保存、未保存、待更新数量 |
| 当前图层树文字 | Mock / UI-only | 固定显示 `OBJECT_VEHICLE_01` |
## ToolsPalette 工具栏
| 元素 | 状态 | 说明 |
|------|------|------|
| 拖拽/选择 | 真实可用 | 控制 Canvas 是否可拖拽 |
| 多边形/矩形/圆/点/线 | Mock / UI-only | 只切换 activeTool没有对应绘制逻辑 |
| 区域合并/去除 | Mock / UI-only | 只切换 activeTool没有后端或前端算法 |
| 正向选点/反向选点/框选 | 部分可用 | 会影响 Canvas 交互,并能触发已对齐的 AI 推理接口 |
| 魔法棒 SAM 触发 | 部分可用 | 切到 AI 页面;不是直接执行推理 |
| 撤销/重做 | Mock / UI-only | 按钮无事件 |
## FrameTimeline 时间轴
| 元素 | 状态 | 说明 |
|------|------|------|
| 帧缩略图 | 真实可用 | 使用 `frames[].url` |
| 点击缩略图跳帧 | 真实可用 | 调用 `setCurrentFrame(idx)` |
| 顶部 range 拖动 | 真实可用 | 改变当前帧 |
| 播放/暂停 | 真实可用 | 当前代码按 `parse_fps/original_fps` 推进帧,最多 30fps |
| 方向键切帧 | Mock / UI-only | Word 提到,但当前没有键盘监听 |
## OntologyInspector 本体面板
| 元素 | 状态 | 说明 |
|------|------|------|
| 模板选择 | 部分可用 | 读取全局 templates可切换 activeTemplateId |
| 分类树展示 | 真实可用 | 显示模板 classes 和本地 customClasses |
| 添加自定义分类 | 部分可用 | 只存在组件本地状态,不保存到后端 |
| 置信度条 | Mock / UI-only | 固定 `0.9412` |
| 拓扑锚点数量 | Mock / UI-only | 固定 `12 节点` |
| 重新提取骨架按钮 | Mock / UI-only | 无事件 |
## AISegmentation 独立 AI 页
| 元素 | 状态 | 说明 |
|------|------|------|
| 模型选择 SAM2/SAM3 | 真实可用 | 选择写入 Zustand`predictMask()` 会把 `model` 传给后端 SAM registry |
| 正向/反向点 | 部分可用 | 可在当前项目帧上加点,并可调用 AI 推理接口 |
| 语义文本输入 | 部分可用 | 纯文本会以 `semantic` prompt 调用后端;选择 SAM 3 且运行环境满足官方依赖时走 SAM 3 文本语义推理,否则状态接口会标明不可用 |
| 参数开关 | Mock / UI-only | `cropMode``autoDeleteBg` 只改本地状态 |
| 执行高精度语义分割 | 部分可用 | 使用当前项目帧调用 `/api/ai/predict`;没有当前帧时按钮禁用 |
| 上传替换底图 | Mock / UI-only | 按钮无事件 |
| 清空全体锚点 | 部分可用 | 清空前端 points 和 masks |
| 退档推送至工作区重组 | 部分可用 | 只切回工作区,共用 masks store但没有保存/确认流程 |
| 背景图 | 部分可用 | 优先显示当前项目帧;没有项目帧时仍回退到 Unsplash 演示图 |
## TemplateRegistry 模板库
| 元素 | 状态 | 说明 |
|------|------|------|
| 模板列表 | 真实可用 | 调用 `GET /api/templates` |
| 新建方案 | 真实可用 | 调用 `POST /api/templates` |
| 编辑模板 | 真实可用 | 调用 `PATCH /api/templates/{id}` |
| 删除模板 | 真实可用 | 调用 `DELETE /api/templates/{id}` |
| 添加/删除分类 | 真实可用 | 保存在模板 `mapping_rules.classes` |
| 拖拽排序 | 真实可用 | 重算 zIndex保存时写后端 |
| JSON 批量导入 | 部分可用 | 前端解析 JSON 并加入编辑态,保存后才落库 |
| 载入腹腔镜 35 分类 | 真实可用 | 前端内置数据;后端也 seed 默认模板 |
| mapping rules | 部分可用 | 可存 `rules`,但无实际映射执行引擎 |
## 总体结论
当前前端真实可用的主链路是登录、Dashboard 后端概览、项目列表、新建项目、上传视频/DICOM、拆帧、浏览帧、播放帧、工作区点/框 AI 推理、标注保存/回显、COCO 导出、模板 CRUD。
当前最主要的 Mock 或未打通链路是撤销重做、手工几何绘制、GT 导入、mask 降维点区域、真正的文本语义分割和语义优先级融合。

193
doc/04-api-contracts.md Normal file
View File

@@ -0,0 +1,193 @@
# 接口契约清单
## 前端 API 基础配置
位置:`src/lib/config.ts``src/lib/api.ts`
```ts
API_BASE_URL = import.meta.env.VITE_API_BASE_URL || 'http://<current-browser-host>:8000'
timeout: 30000
```
前端 request interceptor 会从 localStorage 读取 `token`,附加:
```http
Authorization: Bearer <token>
```
当前后端多数接口没有鉴权依赖,所以这个 header 主要是前端侧行为。
## 前端封装的 API
| 函数 | 方法与路径 | 状态 | 说明 |
|------|------------|------|------|
| `login(username, password)` | `POST /api/auth/login` | 对齐 | 后端返回 `{ token, username }`,前端只使用 token |
| `getProjects()` | `GET /api/projects` | 对齐 | 前端映射 `frame_count``thumbnail_url` 等字段 |
| `createProject(payload)` | `POST /api/projects` | 对齐 | 支持 `name``description``parse_fps` |
| `updateProject(id, payload)` | `PATCH /api/projects/{id}` | 对齐 | 后端是 `PATCH /api/projects/{id}` |
| `deleteProject(id)` | `DELETE /api/projects/{id}` | 对齐 | 当前 UI 未明显接入 |
| `getTemplates()` | `GET /api/templates` | 对齐 | 前端从 `mapping_rules` 取 classes/rules |
| `createTemplate(payload)` | `POST /api/templates` | 对齐 | 后端会打包 classes/rules 到 mapping_rules |
| `updateTemplate(id, payload)` | `PATCH /api/templates/{id}` | 对齐 | 模板编辑页使用 |
| `deleteTemplate(id)` | `DELETE /api/templates/{id}` | 对齐 | 模板编辑页使用 |
| `uploadMedia(file, projectId)` | `POST /api/media/upload` | 对齐 | multipart form-data |
| `uploadDicomBatch(files, projectId)` | `POST /api/media/upload/dicom` | 对齐 | multipart form-data |
| `parseMedia(projectId)` | `POST /api/media/parse?project_id=...` | 对齐 | 创建异步拆帧任务并返回 task |
| `getTask(taskId)` | `GET /api/tasks/{task_id}` | 对齐 | 查询异步任务状态 |
| `getProjectFrames(projectId)` | `GET /api/projects/{id}/frames` | 对齐 | 后端返回预签名 image_url |
| `predictMask(payload)` | `POST /api/ai/predict` | 对齐 | 前端发送 `image_id/prompt_type/prompt_data/model`,并把后端 `polygons` 转为 `masks[].pathData` |
| `getAiModelStatus(selectedModel?)` | `GET /api/ai/models/status` | 对齐 | 返回 GPU、SAM 2、SAM 3 的真实运行状态 |
| `getProjectAnnotations(projectId, frameId?)` | `GET /api/ai/annotations` | 对齐 | 前端加载工作区时用于回显已保存标注 |
| `saveAnnotation(payload)` | `POST /api/ai/annotate` | 对齐 | 工作区归档保存当前项目未保存 mask |
| `updateAnnotation(annotationId, payload)` | `PATCH /api/ai/annotations/{annotation_id}` | 对齐 | 工作区归档保存 dirty mask |
| `deleteAnnotation(annotationId)` | `DELETE /api/ai/annotations/{annotation_id}` | 对齐 | 工作区清空当前帧已保存标注 |
| `getDashboardOverview()` | `GET /api/dashboard/overview` | 对齐 | Dashboard 初始统计、队列和活动日志 |
| `exportCoco(projectId)` | `GET /api/export/{projectId}/coco` | 对齐 | 后端实际是 `GET /api/export/{project_id}/coco` |
## 后端 FastAPI 接口
以下列表来自当前运行的 OpenAPI
| 方法 | 路径 | 用途 |
|------|------|------|
| POST | `/api/auth/login` | 登录 |
| POST | `/api/projects` | 创建项目 |
| GET | `/api/projects` | 项目列表 |
| GET | `/api/projects/{project_id}` | 项目详情 |
| PATCH | `/api/projects/{project_id}` | 更新项目 |
| DELETE | `/api/projects/{project_id}` | 删除项目 |
| POST | `/api/projects/{project_id}/frames` | 添加帧记录 |
| GET | `/api/projects/{project_id}/frames` | 项目帧列表 |
| GET | `/api/projects/{project_id}/frames/{frame_id}` | 单帧详情 |
| POST | `/api/templates` | 创建模板 |
| GET | `/api/templates` | 模板列表 |
| GET | `/api/templates/{template_id}` | 模板详情 |
| PATCH | `/api/templates/{template_id}` | 更新模板 |
| DELETE | `/api/templates/{template_id}` | 删除模板 |
| POST | `/api/media/upload` | 上传视频/图片/DICOM 单文件 |
| POST | `/api/media/upload/dicom` | 批量上传 DICOM |
| POST | `/api/media/parse` | 创建 Celery 拆帧任务 |
| GET | `/api/tasks` | 查询后台任务列表 |
| GET | `/api/tasks/{task_id}` | 查询单个后台任务 |
| POST | `/api/ai/predict` | SAM 2 / SAM 3 可选推理 |
| GET | `/api/ai/models/status` | GPU 和 SAM 模型状态 |
| POST | `/api/ai/auto` | 自动分割 |
| POST | `/api/ai/annotate` | 保存 AI 标注 |
| GET | `/api/ai/annotations` | 查询项目标注,可选按帧过滤 |
| PATCH | `/api/ai/annotations/{annotation_id}` | 更新已保存标注 |
| DELETE | `/api/ai/annotations/{annotation_id}` | 删除已保存标注 |
| GET | `/api/dashboard/overview` | Dashboard 聚合快照 |
| GET | `/api/export/{project_id}/coco` | 导出 COCO JSON |
| GET | `/api/export/{project_id}/masks` | 导出 PNG mask ZIP |
| GET | `/health` | 健康检查 |
| WS | `/ws/progress` | WebSocket 进度通道,未出现在 OpenAPI paths 中 |
## 关键请求体
### 登录
```json
{
"username": "admin",
"password": "123456"
}
```
### 创建项目
```json
{
"name": "example.mp4",
"description": "导入说明",
"parse_fps": 30
}
```
### 创建/更新模板
```json
{
"name": "腹腔镜胆囊切除术",
"color": "#06b6d4",
"z_index": 0,
"classes": [
{
"id": "cls-1",
"name": "胆囊",
"color": "#ffae00",
"zIndex": 280,
"category": "腹腔镜胆囊切除术"
}
],
"rules": []
}
```
### AI 推理请求体
前端 `predictMask()` 当前已适配后端 `PredictRequest`
```json
{
"image_id": 123,
"model": "sam2",
"prompt_type": "point",
"prompt_data": {
"points": [[0.5, 0.5]],
"labels": [1]
}
}
```
`prompt_type` 支持:
- `point`
- `box`
- `semantic`,选择 `sam3` 时进入 SAM 3 文本语义推理;选择 `sam2` 时仍回退到 auto segmentation
后端响应:
```json
{
"polygons": [
[[0.25, 0.25], [0.75, 0.25], [0.75, 0.75], [0.25, 0.75]]
],
"scores": [0.5]
}
```
前端会把上面的 `polygons` 转成:
```json
{
"masks": [
{
"pathData": "M 160 90 L 480 90 L 480 270 L 160 270 Z",
"segmentation": [[160, 90, 480, 90, 480, 270, 160, 270]],
"bbox": [160, 90, 320, 180]
}
]
}
```
## 已完成的接口对齐
- `updateProject()` 已从 `PUT` 改为 `PATCH`
- `exportCoco()` 已从 `/api/export/coco/{projectId}` 改为 `/api/export/{projectId}/coco`
- Canvas 已使用真实 `frame.id` 作为 `image_id`
- 点和框坐标已转成后端需要的归一化坐标。
- 后端 `polygons` 已在前端转成 Konva 可渲染的 path。
- `saveAnnotation()` 已接入 `POST /api/ai/annotate`
- `getProjectAnnotations()` 已接入 `GET /api/ai/annotations`
- `updateAnnotation()` 已接入 `PATCH /api/ai/annotations/{annotationId}`
- `deleteAnnotation()` 已接入 `DELETE /api/ai/annotations/{annotationId}`
- `parseMedia()` 已改为创建 Celery 后台任务,并返回 `ProcessingTask`
- `getTask()` 已接入 `GET /api/tasks/{taskId}`
- `getDashboardOverview()` 已从 `processing_tasks` 聚合解析队列。
- 工作区导出按钮已调用 `exportCoco()`,并会先保存未归档 mask。
## 仍需处理的接口问题
- WebSocket 地址已从 `VITE_WS_PROGRESS_URL` 读取,未配置时从 `API_BASE_URL` 推导;部署时仍要确认浏览器能访问该地址。
- Celery worker 进度会写 PostgreSQL 任务表,同时发布到 Redis `seg:progress`FastAPI 订阅后广播到 `/ws/progress`
- 已保存标注目前支持分类级更新和整帧清空删除;逐点几何编辑器尚未实现。

View File

@@ -0,0 +1,115 @@
# 后续实施建议
目标是把当前“能看、能上传、能拆帧”的系统推进到“能真实完成标注闭环”的系统。
## 阶段 1先修接口契约已完成基础对齐
优先级最高。AI 点/框推理和 COCO 导出的基础契约已经按当前代码完成对齐。
已完成:
1. `src/lib/api.ts``updateProject()` 已改为 `PATCH`
2. `exportCoco()` 路径已改为 `/api/export/{projectId}/coco`
3. Canvas 调 AI 时已使用当前帧真实 `frame.id` 作为 `image_id`
4. Canvas 点/框坐标已转成后端需要的归一化坐标。
5. 后端 `polygons` 已转成前端可渲染的 Konva path。
剩余边界:
1. SAM 3 真实推理需要独立满足官方 Python 3.12+、PyTorch 2.7+、CUDA 12.6+ 环境。
2. 标注删除/更新接口已打通基础能力;逐点几何编辑器尚未实现。
## 阶段 2打通标注保存已完成基础闭环
当前工作区可将未保存 mask 写入后端标注表,并在加载项目帧后回显。
已完成:
1. 前端根据 `Mask.segmentation` 构造后端需要的 normalized `mask_data.polygons`
2. 用户点击“结构化归档保存”后,未保存 mask 调用 `POST /api/ai/annotate`dirty mask 调用 `PATCH /api/ai/annotations/{annotation_id}`
3. 后端保存或更新 `project_id``frame_id``template_id``mask_data``bbox`;具体分类写入 `mask_data.class`
4. 工作区加载帧后调用 `GET /api/ai/annotations` 回显已保存标注。
5. 工作区“清空遮罩”调用 `DELETE /api/ai/annotations/{annotation_id}` 删除当前帧已保存标注。
剩余建议:
1. 加入保存冲突处理和批量保存错误提示。
2. 增加逐点几何编辑器,让已保存 mask 的 polygon 本身可以被修改后 PATCH。
## 阶段 3接入导出按钮已完成 COCO JSON
当前工作区“导出 JSON 标注集”会先保存未归档 mask再调用 COCO 导出接口。
建议:
1. 增加“导出 PNG Mask ZIP”按钮调用 `/api/export/{projectId}/masks`
2. 无标注时给出更明确的空导出提示。
## 阶段 4替换 Dashboard mock
当前 Dashboard 已通过 `GET /api/dashboard/overview` 读取后端聚合快照,不再使用硬编码初始统计、队列或活动日志。
已完成:
- 聚合项目、帧、标注、模板数量和主机 load average。
-`processing_tasks` queued/running 任务生成解析队列。
- 按最近任务、项目、标注、模板记录生成活动流。
剩余建议:
1. 为任务增加取消、重试和失败详情 UI。
2. 为 Dashboard 增加任务历史筛选和失败详情入口。
## 阶段 5异步拆帧和进度
Word 方案中提到 Celery + Redis。当前已经有 Celery app、worker task 和 `processing_tasks` 表。
已完成:
1. 新建 Celery app。
2. `POST /api/media/parse` 只创建任务并立即返回 task id。
3. worker 执行 FFmpeg/OpenCV/pydicom。
4. worker 写 PostgreSQL 任务进度。
5. worker 发布 Redis `seg:progress`FastAPI 广播到 `/ws/progress`
剩余建议:
1. 为任务增加取消、重试和失败详情接口。
2. 前端 Dashboard 保留轮询兜底,并补充失败详情 UI。
Dashboard 的解析队列现在已经从“项目状态派生”升级为任务表驱动,实时推送也已通过 Redis/WebSocket 打通;剩余重点是任务控制。
## 阶段 6GT 导入与点区域
这是 Word 方案中最复杂的部分,当前完全未实现。
建议拆成小步:
1. 先支持上传二值/多类别 mask。
2. 后端按类别提取 connected components。
3. 用 OpenCV distance transform 找正向点。
4. 暂时不做骨架/HDBSCAN先生成最小可用点集。
5. 前端以可拖拽点显示并保存。
6. 后续再做骨架和聚类增强。
## 阶段 7模板优先级融合
当前模板有 z-index但没有真正用于语义冲突裁决。
建议:
1. 标注保存时记录 template class id / name / zIndex。
2. 导出 mask 时按 zIndex 从低到高覆盖。
3. 同类 mask 做 union。
4. 跨类重叠由高 zIndex 覆盖低 zIndex。
这一步完成后,系统才真正符合“语义分割一个像素一个类别”的目标。
## 阶段 8清理 UI 文案与 Mock
建议统一这些文案和真实能力:
- SAM/GPU 状态已改为 `GET /api/ai/models/status` 驱动。
- 撤销/重做按钮接历史栈,否则隐藏。
- “重新提取内侧中轴树骨架”接真实接口,否则标为未实现。
- AI 独立页不要固定 Unsplash 图,应从当前项目帧或上传文件进入。

View File

@@ -0,0 +1,103 @@
# `/docs` 是什么
地址:
- 本机:`http://localhost:8000/docs`
- 局域网:`http://192.168.3.11:8000/docs`
这个页面不是文件列表,也不是项目文档目录。它是 FastAPI 自动生成的 Swagger UI用来展示和调试后端 HTTP API。
## 为什么会自动出现
FastAPI 会根据代码里的路由和 Pydantic schema 自动生成 OpenAPI 描述,然后用 Swagger UI 展示出来。
相关代码在:
- `backend/main.py` 创建 `FastAPI(...)`
- `backend/routers/*.py` 定义 `@router.get(...)``@router.post(...)` 等接口
- `backend/schemas.py` 定义请求体和响应体
## 页面上 GET / POST / PATCH / DELETE 是什么
这些是 HTTP 方法,不是文件。
| 方法 | 含义 | 例子 |
|------|------|------|
| GET | 读取数据 | `GET /api/projects` 获取项目列表 |
| POST | 创建或触发动作 | `POST /api/media/upload` 上传文件 |
| PATCH | 局部更新 | `PATCH /api/templates/{template_id}` 更新模板 |
| DELETE | 删除 | `DELETE /api/projects/{project_id}` 删除项目 |
你看到的每一行,都是后端暴露给前端调用的一个接口。
## `/docs` 能做什么
可以:
- 查看后端目前有哪些接口。
- 展开接口查看参数、请求体和响应格式。
- 点击 `Try it out` 直接发请求测试后端。
- 检查接口返回错误,比如 400、401、404、500。
不能:
- 查看前端页面源码。
- 直接代表某个功能已经完整可用。
- 展示 WebSocket 的完整交互,因为 OpenAPI 主要描述 HTTP 接口。
## 和前端有什么关系
前端的 `src/lib/api.ts` 会调用这些接口。例如:
- 登录页调用 `/api/auth/login`
- 项目库调用 `/api/projects`
- 上传视频调用 `/api/media/upload`
- 拆帧调用 `/api/media/parse`
- 模板库调用 `/api/templates`
所以 `/docs` 是检查“后端提供了什么”的地方;前端是否真的用对了,还要对照 `src/lib/api.ts`
## 目前通过 `/docs` 能看到的接口
当前后端接口包括:
- Auth登录
- Projects项目 CRUD、项目帧 CRUD
- Templates模板 CRUD
- Media上传视频/DICOM、触发拆帧
- AISAM 2 / SAM 3 可选推理、模型状态、自动分割、保存标注
- Export导出 COCO JSON、导出 PNG masks
- Health健康检查
## 为什么看起来像“列举文件和请求”
因为 Swagger UI 默认按接口分组,把每个 endpoint 展开成一行。它列举的是“后端可被调用的功能入口”,不是项目文件。
真正的项目文件在本地目录里,例如:
- 前端:`src/components/*.tsx`
- 后端路由:`backend/routers/*.py`
- 后端模型:`backend/models.py`
## 如何用 `/docs` 验证一个接口
以项目列表为例:
1. 打开 `/docs`
2. 找到 `GET /api/projects`
3. 点开。
4. 点击 `Try it out`
5. 点击 `Execute`
6. 查看 Response body。
如果这里能返回数据,但前端项目库加载失败,那问题多半在前端 API 地址、CORS、字段映射或浏览器网络请求。
## 另一个机器可读入口
OpenAPI JSON 在:
```text
http://localhost:8000/openapi.json
```
这是给工具读取的接口描述Swagger UI 就是基于它渲染出来的。

View File

@@ -0,0 +1,120 @@
# 当前需求冻结文档
冻结日期2026-05-01
本文档描述当前仓库已经实现或明确保留为占位的需求。测试用例以本文档为准,不把早期设想或 Word 文档中的远期能力当作当前版本必须实现的功能。
## R1 登录与会话
- 系统提供登录页。
- 默认开发凭证为 `admin / 123456`
- 登录成功后前端保存 token并进入主应用。
- 登录失败时显示错误信息。
- 当前 token 是开发用固定 token不做真实 JWT 校验。
## R2 项目管理
- 前端展示项目库,并从 `GET /api/projects` 获取项目列表。
- 用户可以新建项目,前端调用 `POST /api/projects`
- 用户可以选择项目,进入工作区。
- 用户可以导入视频文件,前端创建项目、上传文件、触发拆帧、刷新项目列表。
- 用户可以导入 DICOM 序列,前端上传 DICOM、触发拆帧、刷新项目列表。
- 后端支持项目创建、列表、详情、局部更新和删除。
- 后端支持项目帧创建、列表和单帧查询。
## R3 媒体上传与拆帧
- 后端允许上传视频、图片、DICOM 文件,其他扩展名返回 400。
- 未提供项目 ID 上传时,后端自动创建项目。
- 提供项目 ID 上传时,后端把上传对象关联到该项目。
- 拆帧接口根据项目 `source_type` 处理视频或 DICOM。
- 拆帧完成后写入 `frames` 记录,并把项目状态设为 `ready`
- 拆帧接口会创建 `processing_tasks` 记录并投递 Celery worker。
- 前端可通过 `GET /api/tasks/{task_id}` 查询任务状态。
## R4 工作区与帧浏览
- 工作区根据当前项目加载帧列表。
- 若项目有媒体但无帧,工作区会尝试触发拆帧后重新加载。
- Canvas 显示当前帧图片。
- Canvas 支持滚轮缩放、移动工具拖拽、鼠标坐标显示。
- 时间轴支持缩略图点击切帧、range 拖动切帧、播放/暂停顺序推进帧。
- 播放帧率使用项目 `parse_fps``original_fps`,限制在 1 到 30 FPS。
## R5 工具栏
- 工具栏可以切换当前 active tool。
- 正向点、反向点、框选工具会影响 Canvas 交互。
- 魔法棒按钮切换到 AI 页面。
- 多边形、矩形、圆、点、线、合并、去除、撤销、重做当前只提供 UI 状态或占位按钮,不完成真实绘制/算法。
## R6 AI 推理
- 前端可以在 AI 页面选择 `sam2``sam3`,选择结果存放在全局 store。
- 前端和工作区通过 `GET /api/ai/models/status` 展示 GPU、SAM 2 和 SAM 3 的真实运行状态。
- 前端 `predictMask()` 调用 `POST /api/ai/predict`
- 前端发送后端契约:`image_id``prompt_type``prompt_data``model`
- 点提示传 `{ points, labels }`,正向点 label 为 1反向点 label 为 0。
- 框选提示传归一化 `[x1, y1, x2, y2]`
- 语义文本提示传 `semantic`;选择 `sam3` 且环境满足依赖时走 SAM 3 文本语义推理,选择 `sam2` 时回退到自动分割。
- 后端返回 `polygons``scores`
- 前端把后端 `polygons` 转成 Konva `pathData``segmentation``bbox``area`
- AI 推理结果先存放在前端 store 的 `masks` 中,点击“结构化归档保存”后持久化到后端标注表。
## R7 标注保存
- 后端提供 `POST /api/ai/annotate` 保存标注。
- 保存时必须存在项目;如果传入 `frame_id`,帧也必须存在。
- 后端提供 `GET /api/ai/annotations` 查询项目标注,可选按 `frame_id` 过滤。
- 后端提供 `PATCH /api/ai/annotations/{annotation_id}` 更新已保存标注的 `mask_data``points``bbox``template_id`
- 后端提供 `DELETE /api/ai/annotations/{annotation_id}` 删除已保存标注。
- 当前前端“结构化归档保存”会保存当前项目未保存 mask并会更新已标记为 dirty 的已保存 mask。
- 工作区“清空遮罩”会删除当前帧已保存标注,并清空当前帧未保存 mask。
- 工作区加载项目帧后会查询已保存标注并回显。
## R8 模板库
- 前端展示模板列表,调用 `GET /api/templates`
- 用户可以新建、编辑、删除模板。
- 模板分类存放在 `mapping_rules.classes`,规则存放在 `mapping_rules.rules`
- 前端支持添加/删除分类、拖拽排序后重算 `zIndex`、JSON 批量导入、加载腹腔镜默认分类。
- 后端支持模板创建、列表、详情、局部更新和删除。
## R9 本体检查面板
- 工作区右侧可以选择模板。
- 面板显示模板分类和组件本地自定义分类。
- 用户可以选择具体分类;新 AI mask 会记录 `classId``className``classZIndex`,并在保存时写入 `mask_data.class`
- 添加自定义分类只存在组件本地状态,不保存到后端。
- 置信度、拓扑锚点和重新提取骨架按钮当前为展示/占位。
## R10 Dashboard 与 WebSocket
- Dashboard 显示基础统计、解析队列和活动日志。
- Dashboard 初始数据来自 `GET /api/dashboard/overview`
- 后端聚合项目数、处理中任务数、标注数、帧数、模板数和主机 load average。
- 解析队列由 `processing_tasks` 中的 queued/running 任务生成;活动日志由最近任务、项目、标注和模板记录生成。
- Dashboard 会连接 `/ws/progress`
- 收到 progress、complete、error、status 消息时,前端会更新队列或日志。
- Celery worker 每次更新 `processing_tasks` 后会发布 Redis `seg:progress` 事件FastAPI 订阅并广播给 `/ws/progress` 客户端。
- 后端 WebSocket 接收到客户端消息后返回 status heartbeat。
## R11 导出
- 后端支持 `GET /api/export/{project_id}/coco` 导出 COCO JSON。
- 后端支持 `GET /api/export/{project_id}/masks` 导出 PNG mask ZIP。
- 当前前端 `exportCoco()` API 封装已对齐后端路径。
- 工作区“导出 JSON 标注集”按钮已绑定下载事件;导出前会先保存当前未归档 mask。
## R12 配置
- 前端 API 地址由 `src/lib/config.ts` 统一推导。
- `VITE_API_BASE_URL` 优先级高于自动推导。
- `VITE_WS_PROGRESS_URL` 优先级高于从 API 地址推导 WebSocket 地址。
- 未设置环境变量时,前端按当前浏览器 hostname 推导 `http://<host>:8000`
## R13 文档与测试
- `doc/` 目录保存当前实现审计、接口契约、需求冻结、设计冻结和测试计划。
- 测试应覆盖当前冻结需求中的真实功能、半可用行为和明确占位行为。
- 对外部服务依赖 PostgreSQL、MinIO、Redis、SAM 模型的测试应使用 mock 或测试替身,不依赖真实服务可用性。

View File

@@ -0,0 +1,155 @@
# 当前设计冻结文档
冻结日期2026-05-01
本文档描述当前代码结构、数据流、接口契约和测试边界。后续实现如果改变这些设计,应同步更新本文档和测试。
## 总体架构
当前系统由三层组成:
- React + TypeScript 前端 SPA。
- FastAPI 后端 API。
- PostgreSQL、MinIO、Redis、SAM 2 / SAM 3 等外部基础设施。
开发时前端通过 `server.ts` 启动 Express + Vite middleware后端通过 `backend/main.py` 启动 FastAPI。前端业务接口主要访问 FastAPI不依赖 `server.ts` 中保留的旧 mock API。
## 前端模块
| 模块 | 文件 | 设计职责 |
|------|------|----------|
| 应用入口 | `src/App.tsx` | 根据登录状态和 `activeModule` 切换页面 |
| 全局状态 | `src/store/useStore.ts` | Zustand store保存项目、帧、模板、mask、工具状态 |
| API 封装 | `src/lib/api.ts` | Axios 客户端、字段映射、AI 响应转换 |
| 配置 | `src/lib/config.ts` | 推导 API 和 WebSocket 地址 |
| WebSocket | `src/lib/websocket.ts` | 进度流连接、订阅和重连 |
| 模型状态 | `src/components/ModelStatusBadge.tsx` | 展示 GPU 与当前 SAM 模型真实可用状态 |
| 登录页 | `src/components/Login.tsx` | 调用登录 API写入 store |
| Dashboard | `src/components/Dashboard.tsx` | 展示统计和 WebSocket 进度消息 |
| 项目库 | `src/components/ProjectLibrary.tsx` | 项目列表、新建、导入视频/DICOM |
| 工作区 | `src/components/VideoWorkspace.tsx` | 加载帧和模板组织工具栏、Canvas、本体面板、时间轴 |
| Canvas | `src/components/CanvasArea.tsx` | 显示帧、缩放平移、点/框提示、渲染 mask |
| 工具栏 | `src/components/ToolsPalette.tsx` | 切换工具和跳转 AI 页面 |
| 时间轴 | `src/components/FrameTimeline.tsx` | 帧导航和播放 |
| 本体面板 | `src/components/OntologyInspector.tsx` | 模板选择、分类树、本地自定义分类 |
| AI 页面 | `src/components/AISegmentation.tsx` | 独立 AI 推理视图,使用当前项目帧 |
| 模板库 | `src/components/TemplateRegistry.tsx` | 模板 CRUD、分类编辑、导入、排序 |
## 后端模块
| 模块 | 文件 | 设计职责 |
|------|------|----------|
| 应用入口 | `backend/main.py` | FastAPI app、CORS、路由注册、健康检查、WebSocket |
| 配置 | `backend/config.py` | Pydantic settings |
| 数据库 | `backend/database.py` | SQLAlchemy engine、session、Base |
| 模型 | `backend/models.py` | Project、Frame、Template、Annotation、Mask、ProcessingTask |
| Schema | `backend/schemas.py` | Pydantic 请求/响应模型 |
| Auth | `backend/routers/auth.py` | 开发登录 |
| Projects | `backend/routers/projects.py` | 项目与帧 CRUD |
| Templates | `backend/routers/templates.py` | 模板 CRUD 和 mapping_rules 打包/解包 |
| Media | `backend/routers/media.py` | 上传媒体和拆帧 |
| AI | `backend/routers/ai.py` | SAM 2 / SAM 3 可选推理、模型状态和标注保存 |
| Export | `backend/routers/export.py` | COCO 和 PNG mask 导出 |
| SAM 2 | `backend/services/sam2_engine.py` | SAM 2 懒加载、状态检测和点/框/自动推理 |
| SAM 3 | `backend/services/sam3_engine.py` | SAM 3 状态检测和文本语义推理适配 |
| SAM Registry | `backend/services/sam_registry.py` | 模型选择、GPU 状态和推理分发 |
## 状态模型
前端 store 的核心对象:
- `Project`项目基本信息、状态、帧数、fps、媒体路径。
- `Frame`:帧 ID、项目 ID、索引、图片 URL、宽高。
- `Template` / `TemplateClass`:模板和分类定义。
- `Mask`:前端渲染用 mask包含 `pathData``segmentation``bbox``area`
- `activeModule`:当前页面。
- `activeTool`:当前工具。
- `aiModel`:当前选择的 AI 模型,取值为 `sam2``sam3`
## 关键数据流
### 登录
1. `Login` 收集用户名和密码。
2. `login()` 调用 `POST /api/auth/login`
3. 成功后 store 写入 tokenApp 渲染主界面。
### 项目导入
1. `ProjectLibrary` 创建项目。
2. 上传视频或 DICOM 到 `/api/media/upload``/api/media/upload/dicom`
3. 调用 `/api/media/parse` 创建异步拆帧任务。
4. Celery worker 执行 FFmpeg/OpenCV/pydicom 拆帧,持续更新 `processing_tasks`,并发布 Redis `seg:progress`
5. 刷新项目列表。
### 工作区加载
1. `VideoWorkspace` 根据 `currentProject.id` 调用 `getProjectFrames()`
2. 若无帧但项目有 `video_path`,触发 `parseMedia()`,通过 `getTask()` 轮询任务完成后重新取帧。
3. 帧数据映射为 store `Frame[]`
4. 当前帧传入 `CanvasArea`
### AI 点/框推理
1. 用户在 Canvas 选择正向点、反向点或框选。
2. `CanvasArea` 读取当前帧 ID 和宽高。
3. `predictMask()` 归一化坐标并携带当前 `model` 调用 `/api/ai/predict`
4. 后端加载帧图片并通过 SAM registry 分发到 SAM 2 或 SAM 3。
5. 前端把 `polygons` 转为 mask写入 store。
6. Canvas 按当前帧过滤并渲染 mask。
7. 新 mask 会带上当前选择的模板分类元数据,包括 `classId``className``classZIndex` 和保存状态 `draft`
8. 用户点击“结构化归档保存”后,前端将像素 `segmentation` 转成 normalized `mask_data.polygons`;未保存 mask 调用 `POST /api/ai/annotate`dirty mask 调用 `PATCH /api/ai/annotations/{annotation_id}`
9. 工作区加载项目帧后通过 `GET /api/ai/annotations` 取回已保存标注并转成前端 mask。
10. 工作区“清空遮罩”删除当前帧已保存标注,并清除当前帧本地 mask。
### 模板管理
1. `TemplateRegistry` 从后端读取模板。
2. 编辑态在组件本地维护分类列表。
3. 保存时调用 `createTemplate()``updateTemplate()`
4. 后端把 `classes``rules` 打包进 `mapping_rules`
5. 返回时再解包给前端。
6. `OntologyInspector` 可以选择具体分类;选择结果进入全局 store`CanvasArea``AISegmentation` 新建/更新 mask 时使用。
### 导出
1. 后端根据项目、帧、标注和模板生成 COCO JSON。
2. PNG mask 导出会把 normalized polygon 渲染为二值 mask 并打包 ZIP。
3. 前端“导出 JSON 标注集”按钮会在导出前保存待归档标注,然后下载 COCO JSON。
## 接口契约
接口详情见 `doc/04-api-contracts.md`。测试中重点固定以下契约:
- `updateProject()` 使用 `PATCH /api/projects/{id}`
- `exportCoco()` 使用 `GET /api/export/{projectId}/coco`
- `predictMask()` 使用 `POST /api/ai/predict`,请求体为 `image_id``prompt_type``prompt_data``model`
- `saveAnnotation()` 使用 `POST /api/ai/annotate`
- `getProjectAnnotations()` 使用 `GET /api/ai/annotations`
- `updateAnnotation()` 使用 `PATCH /api/ai/annotations/{annotationId}`
- `deleteAnnotation()` 使用 `DELETE /api/ai/annotations/{annotationId}`
- 后端 `/api/ai/predict` 支持 point、box、semantic 三种 prompt_type并通过 `model` 选择 SAM 2 或 SAM 3。
- 后端 `/api/ai/models/status` 返回 GPU、SAM 2、SAM 3 的真实运行状态。
- point prompt 支持旧数组形式和 `{ points, labels }` 对象形式。
## 外部依赖边界
测试不直接依赖以下真实服务:
- PostgreSQL后端测试使用内存 SQLite。
- MinIO上传、下载、预签名 URL 使用 monkeypatch。
- Redis单测使用 monkeypatch 验证进度事件发布,不依赖真实 Redis 服务。
- SAMAI 推理测试使用 fake registry。
- 浏览器 Canvas/Konva 图片加载:前端测试 mock `react-konva``use-image`
## 已知占位设计
以下能力属于当前冻结版本的占位或半可用功能:
- Dashboard 初始快照来自 `GET /api/dashboard/overview`;解析队列由 `processing_tasks` queued/running 任务生成。
- 多边形、矩形、圆、点、线手工绘制未实现。
- 合并、去除、撤销、重做未实现。
- 工作区导出 PNG mask ZIP 按钮尚未提供。
- 已保存标注支持通过“应用分类”进入 dirty 状态并归档更新;暂未提供逐点几何编辑器。
- SAM 3 文本语义分割取决于官方依赖和 GPU 运行环境;状态接口会暴露真实可用性。
- 自定义分类只存在本地组件状态。

49
doc/09-test-plan.md Normal file
View File

@@ -0,0 +1,49 @@
# 当前测试计划
本文档把 `doc/07-current-requirements-freeze.md` 中的冻结需求映射到测试。测试目标是覆盖当前真实行为和明确占位行为。
## 测试分层
| 层级 | 工具 | 覆盖范围 |
|------|------|----------|
| 前端单元/组件 | Vitest + Testing Library | API 封装、store、组件交互、Mock/UI-only 状态 |
| 后端路由 | pytest + FastAPI TestClient | Auth、Projects、Templates、AI、Export、Media 的接口契约 |
| 静态契约 | TypeScript / py_compile | 类型和 Python 语法 |
## 覆盖矩阵
| 需求 | 测试文件 | 覆盖点 |
|------|----------|--------|
| R1 登录与会话 | `src/components/Login.test.tsx`, `backend/tests/test_auth.py` | 成功登录、失败提示、后端 401 |
| R2 项目管理 | `src/lib/api.test.ts`, `backend/tests/test_projects.py` | 前端字段映射、PATCH 更新、后端 CRUD、帧列表 |
| R3 媒体上传与拆帧 | `backend/tests/test_media.py` | 扩展名校验、自动建项目、关联项目、创建异步任务、worker 注册帧 |
| R4 工作区与帧浏览 | `src/components/VideoWorkspace.test.tsx`, `src/components/FrameTimeline.test.tsx` | 加载帧、无帧触发解析、切帧、播放 |
| R5 工具栏 | `src/components/ToolsPalette.test.tsx` | 工具切换、AI 跳转、占位按钮存在 |
| R6 AI 推理 | `src/lib/api.test.ts`, `src/components/CanvasArea.test.tsx`, `src/components/AISegmentation.test.tsx`, `src/components/ModelStatusBadge.test.tsx`, `backend/tests/test_ai.py` | 点/框/semantic 契约、模型选择、GPU/SAM 状态、状态徽标、坐标归一化、正负点 labels、polygons 转 path、后端 fake registry |
| R7 标注保存 | `src/components/VideoWorkspace.test.tsx`, `src/components/CanvasArea.test.tsx`, `src/lib/api.test.ts`, `backend/tests/test_ai.py` | 保存标注、加载回显、更新 dirty 标注、清空删除已保存标注、项目不存在、帧不存在 |
| R8 模板库 | `src/lib/api.test.ts`, `backend/tests/test_templates.py` | mapping_rules 解包/打包、模板 CRUD |
| R9 本体检查面板 | `src/components/OntologyInspector.test.tsx` | 模板选择、分类展示、具体分类选择、自定义分类本地添加 |
| R10 Dashboard 与 WebSocket | `src/lib/api.test.ts`, `src/lib/websocket.test.ts`, `src/components/Dashboard.test.tsx`, `backend/tests/test_dashboard.py`, `backend/tests/test_main.py`, `backend/tests/test_progress_events.py` | 后端概览接口、任务表驱动队列、Redis 进度事件 payload/发布、地址推导、消息订阅、队列更新、heartbeat |
| R11 导出 | `src/components/VideoWorkspace.test.tsx`, `src/lib/api.test.ts`, `backend/tests/test_export.py` | COCO 按钮下载、导出前自动保存、COCO 路径、JSON 结构、mask ZIP |
| R12 配置 | `src/lib/config.test.ts` | env 优先、hostname 推导、WS 推导 |
| R13 文档与测试 | `doc/09-test-plan.md` | 测试覆盖矩阵 |
## 运行命令
```bash
npm run test
npm run test:run
npm run lint
npm run build
pip install -r backend/requirements-dev.txt
pytest backend/tests
python -m py_compile backend/routers/ai.py backend/routers/templates.py backend/schemas.py
```
## 当前不做的测试
- 不启动真实 PostgreSQL、MinIO、Redis 或 SAM 模型。
- 不做真实视频大文件拆帧性能测试。
- 不用浏览器 E2E 验证视觉细节。
- 不把当前明确 Mock/UI-only 的按钮当成真实业务成功路径测试。

32
doc/README.md Normal file
View File

@@ -0,0 +1,32 @@
# 项目文档索引
本目录用于记录当前代码库的真实状态、目标设计与实现差距。文档依据包括:
- 根目录 Word 文档:`语义分割系统构建方案.docx`
- 前端源码:`src/App.tsx``src/components/*.tsx``src/lib/api.ts``src/store/useStore.ts`
- 后端源码:`backend/main.py``backend/routers/*.py``backend/schemas.py``backend/models.py`
- 运行时 OpenAPI`http://localhost:8000/openapi.json`
## 文档结构
| 文档 | 内容 |
|------|------|
| [01-purpose-and-word-summary.md](./01-purpose-and-word-summary.md) | 为什么要做这个系统Word 方案中的目标,以及当前代码的落地程度 |
| [02-current-implementation-map.md](./02-current-implementation-map.md) | 当前系统怎么运行,前后端、存储、数据流具体怎么串起来 |
| [03-frontend-element-audit.md](./03-frontend-element-audit.md) | 前端逐页面/逐元素审计真实可用、半可用、Mock/UI-only、接口不通 |
| [04-api-contracts.md](./04-api-contracts.md) | 前端 API 封装、后端 FastAPI 接口、已完成对齐项和剩余接口问题 |
| [05-implementation-plan.md](./05-implementation-plan.md) | 后续要把 Mock 变成真实功能的建议实施顺序 |
| [06-fastapi-docs-explained.md](./06-fastapi-docs-explained.md) | `http://192.168.3.11:8000/docs` 是什么,怎么看和怎么用 |
| [07-current-requirements-freeze.md](./07-current-requirements-freeze.md) | 当前版本需求冻结,测试以此为准 |
| [08-current-design-freeze.md](./08-current-design-freeze.md) | 当前版本设计冻结,记录模块、数据流和接口边界 |
| [09-test-plan.md](./09-test-plan.md) | 需求到测试文件的覆盖矩阵和运行命令 |
## 状态标记
| 标记 | 含义 |
|------|------|
| 真实可用 | 已接真实前端状态或后端 API按当前代码能完成主要动作 |
| 部分可用 | 有真实数据或真实 UI但存在关键缺口例如只读、不能持久化、缺少错误处理 |
| Mock / UI-only | 只有展示或本地状态变化,没有真实业务效果 |
| 接口不通 | 前端调用和后端接口契约不一致,按当前代码大概率失败 |
| 目标设计 | Word 方案中提出,但当前代码尚未实现 |

1154
package-lock.json generated

File diff suppressed because it is too large Load Diff

View File

@@ -9,7 +9,9 @@
"preview": "vite preview", "preview": "vite preview",
"start": "node server.ts", "start": "node server.ts",
"clean": "rm -rf dist", "clean": "rm -rf dist",
"lint": "tsc --noEmit" "lint": "tsc --noEmit",
"test": "vitest",
"test:run": "vitest run"
}, },
"dependencies": { "dependencies": {
"@google/genai": "^1.29.0", "@google/genai": "^1.29.0",
@@ -31,12 +33,17 @@
"zustand": "^5.0.12" "zustand": "^5.0.12"
}, },
"devDependencies": { "devDependencies": {
"@testing-library/jest-dom": "^6.9.1",
"@testing-library/react": "^16.3.2",
"@testing-library/user-event": "^14.6.1",
"@types/express": "^4.17.21", "@types/express": "^4.17.21",
"@types/node": "^22.14.0", "@types/node": "^22.14.0",
"autoprefixer": "^10.4.21", "autoprefixer": "^10.4.21",
"jsdom": "^29.1.1",
"tailwindcss": "^4.1.14", "tailwindcss": "^4.1.14",
"tsx": "^4.21.0", "tsx": "^4.21.0",
"typescript": "~5.8.2", "typescript": "~5.8.2",
"vite": "^6.2.0" "vite": "^6.2.0",
"vitest": "^4.1.5"
} }
} }

View File

@@ -0,0 +1,43 @@
import { fireEvent, render, screen } from '@testing-library/react';
import { beforeEach, describe, expect, it, vi } from 'vitest';
import { resetStore } from '../test/storeTestUtils';
import { useStore } from '../store/useStore';
import { AISegmentation } from './AISegmentation';
const apiMock = vi.hoisted(() => ({
getAiModelStatus: vi.fn(),
predictMask: vi.fn(),
}));
vi.mock('../lib/api', () => ({
getAiModelStatus: apiMock.getAiModelStatus,
predictMask: apiMock.predictMask,
}));
describe('AISegmentation', () => {
beforeEach(() => {
resetStore();
vi.clearAllMocks();
useStore.setState({
frames: [{ id: 'frame-1', projectId: 'project-1', index: 0, url: '/frame.jpg', width: 640, height: 360 }],
});
apiMock.getAiModelStatus.mockResolvedValue({
selected_model: 'sam2',
gpu: { available: true, device: 'cuda', name: 'RTX 4090', torch_available: true },
models: [
{ id: 'sam2', label: 'SAM 2', available: true, loaded: false, device: 'cuda', supports: ['point', 'box'], message: 'SAM 2 ready', package_available: true, checkpoint_exists: true, python_ok: true, torch_ok: true, cuda_required: false },
{ id: 'sam3', label: 'SAM 3', available: false, loaded: false, device: 'cuda', supports: ['semantic'], message: 'SAM 3 missing runtime', package_available: false, checkpoint_exists: false, python_ok: false, torch_ok: true, cuda_required: true },
],
});
});
it('lets the user choose SAM3 for subsequent predictions', async () => {
render(<AISegmentation onSendToWorkspace={vi.fn()} />);
const sam3Button = (await screen.findByText('SAM3')).closest('button')!;
fireEvent.click(sam3Button);
expect(useStore.getState().aiModel).toBe('sam3');
expect(await screen.findByText('SAM 3 missing runtime')).toBeInTheDocument();
});
});

View File

@@ -1,11 +1,11 @@
import React, { useState, useCallback } from 'react'; import React, { useState, useCallback, useEffect } from 'react';
import { Target, PlusCircle, MinusCircle, SquareDashed, Sparkles, SendToBack, Image as ImageIcon, Undo, Redo, Loader2 } from 'lucide-react'; import { Target, PlusCircle, MinusCircle, SquareDashed, Sparkles, SendToBack, Image as ImageIcon, Undo, Redo, Loader2 } from 'lucide-react';
import { cn } from '../lib/utils'; import { cn } from '../lib/utils';
import { Stage, Layer, Image as KonvaImage, Circle, Path, Group } from 'react-konva'; import { Stage, Layer, Image as KonvaImage, Circle, Path, Group } from 'react-konva';
import useImage from 'use-image'; import useImage from 'use-image';
import { OntologyInspector } from './OntologyInspector'; import { OntologyInspector } from './OntologyInspector';
import { useStore } from '../store/useStore'; import { useStore } from '../store/useStore';
import { predictMask } from '../lib/api'; import { getAiModelStatus, predictMask, type AiRuntimeStatus } from '../lib/api';
interface AISegmentationProps { interface AISegmentationProps {
onSendToWorkspace: () => void; onSendToWorkspace: () => void;
@@ -17,9 +17,15 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
const masks = useStore((state) => state.masks); const masks = useStore((state) => state.masks);
const addMask = useStore((state) => state.addMask); const addMask = useStore((state) => state.addMask);
const clearMasks = useStore((state) => state.clearMasks); const clearMasks = useStore((state) => state.clearMasks);
const frames = useStore((state) => state.frames);
const currentFrameIndex = useStore((state) => state.currentFrameIndex);
const activeTemplateId = useStore((state) => state.activeTemplateId);
const activeClass = useStore((state) => state.activeClass);
const aiModel = useStore((state) => state.aiModel);
const setAiModel = useStore((state) => state.setAiModel);
const [modelSize, setModelSize] = useState('vit_l');
const [semanticText, setSemanticText] = useState(''); const [semanticText, setSemanticText] = useState('');
const [modelStatus, setModelStatus] = useState<AiRuntimeStatus | null>(null);
const [autoDeleteBg, setAutoDeleteBg] = useState(true); const [autoDeleteBg, setAutoDeleteBg] = useState(true);
const [cropMode, setCropMode] = useState(false); const [cropMode, setCropMode] = useState(false);
const [isInferencing, setIsInferencing] = useState(false); const [isInferencing, setIsInferencing] = useState(false);
@@ -29,10 +35,29 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
const [position, setPosition] = useState({ x: 0, y: 0 }); const [position, setPosition] = useState({ x: 0, y: 0 });
const [points, setPoints] = useState<{ x: number, y: number, type: 'pos'|'neg' }[]>([]); const [points, setPoints] = useState<{ x: number, y: number, type: 'pos'|'neg' }[]>([]);
const [cursorPos, setCursorPos] = useState({ x: 0, y: 0 }); const [cursorPos, setCursorPos] = useState({ x: 0, y: 0 });
const [image] = useImage('https://images.unsplash.com/photo-1549317661-bd32c8ce0be2?q=80&w=2070&auto=format&fit=crop'); const currentFrame = frames[currentFrameIndex] || null;
const previewUrl = currentFrame?.url || 'https://images.unsplash.com/photo-1549317661-bd32c8ce0be2?q=80&w=2070&auto=format&fit=crop';
const [image] = useImage(previewUrl);
const frameMasks = currentFrame ? masks.filter((mask) => mask.frameId === currentFrame.id) : masks;
const selectedModelStatus = modelStatus?.models.find((model) => model.id === aiModel);
const modelCanInfer = selectedModelStatus?.available ?? true;
const effectiveTool = storeActiveTool; const effectiveTool = storeActiveTool;
useEffect(() => {
let cancelled = false;
getAiModelStatus(aiModel)
.then((status) => {
if (!cancelled) setModelStatus(status);
})
.catch(() => {
if (!cancelled) setModelStatus(null);
});
return () => {
cancelled = true;
};
}, [aiModel]);
const handleWheel = (e: any) => { const handleWheel = (e: any) => {
e.evt.preventDefault(); e.evt.preventDefault();
const scaleBy = 1.1; const scaleBy = 1.1;
@@ -63,22 +88,44 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
const runInference = useCallback(async () => { const runInference = useCallback(async () => {
if (points.length === 0 && !semanticText.trim()) return; if (points.length === 0 && !semanticText.trim()) return;
if (!currentFrame?.id) {
console.warn('AI inference skipped: no project frame is selected');
return;
}
const imageWidth = currentFrame.width || image?.naturalWidth || image?.width || 0;
const imageHeight = currentFrame.height || image?.naturalHeight || image?.height || 0;
if (imageWidth <= 0 || imageHeight <= 0) {
console.warn('AI inference skipped: active frame dimensions are unavailable');
return;
}
setIsInferencing(true); setIsInferencing(true);
try { try {
const result = await predictMask({ const result = await predictMask({
imageUrl: 'https://images.unsplash.com/photo-1549317661-bd32c8ce0be2?q=80&w=2070&auto=format&fit=crop', imageId: currentFrame.id,
imageWidth,
imageHeight,
model: aiModel,
points: points.map((p) => ({ x: p.x, y: p.y, type: p.type })), points: points.map((p) => ({ x: p.x, y: p.y, type: p.type })),
text: semanticText.trim() || undefined, text: semanticText.trim() || undefined,
modelSize,
}); });
result.masks.forEach((m) => { result.masks.forEach((m) => {
const label = activeClass?.name || m.label;
const color = activeClass?.color || m.color;
addMask({ addMask({
id: m.id, id: m.id,
frameId: 'frame-ai-1', frameId: currentFrame.id,
templateId: activeTemplateId || undefined,
classId: activeClass?.id,
className: activeClass?.name,
classZIndex: activeClass?.zIndex,
saveStatus: 'draft',
saved: false,
pathData: m.pathData, pathData: m.pathData,
label: m.label, label,
color: m.color, color,
segmentation: m.segmentation, segmentation: m.segmentation,
bbox: m.bbox, bbox: m.bbox,
area: m.area, area: m.area,
@@ -89,7 +136,7 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
} finally { } finally {
setIsInferencing(false); setIsInferencing(false);
} }
}, [points, semanticText, modelSize, addMask]); }, [activeClass, activeTemplateId, addMask, aiModel, currentFrame?.height, currentFrame?.id, currentFrame?.width, image?.height, image?.naturalHeight, image?.naturalWidth, image?.width, points, semanticText]);
const handleStageClick = (e: any) => { const handleStageClick = (e: any) => {
if (effectiveTool === 'move') return; if (effectiveTool === 'move') return;
@@ -117,17 +164,26 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
{/* Model Select */} {/* Model Select */}
<div> <div>
<h3 className="text-[10px] font-bold text-gray-500 uppercase tracking-widest mb-3"></h3> <h3 className="text-[10px] font-bold text-gray-500 uppercase tracking-widest mb-3"></h3>
<div className="bg-[#111] border border-white/5 flex p-1 rounded-lg"> <div className="bg-[#111] border border-white/5 grid grid-cols-2 gap-1 p-1 rounded-lg">
{['vit_b', 'vit_l', 'vit_h'].map(m => ( {(modelStatus?.models || [
{ id: 'sam2' as const, label: 'SAM 2', available: true, message: '正在读取 SAM 2 状态' },
{ id: 'sam3' as const, label: 'SAM 3', available: false, message: '正在读取 SAM 3 状态' },
]).map((m) => (
<button <button
key={m} key={m.id}
className={cn("flex-1 text-xs py-2 rounded-md transition-colors text-center uppercase tracking-wider font-mono", modelSize === m ? "bg-white/10 text-white font-medium shadow-sm" : "text-gray-500 hover:text-gray-300 hover:bg-white/5")} className={cn("text-xs py-2 rounded-md transition-colors text-center uppercase tracking-wider font-mono", aiModel === m.id ? "bg-white/10 text-white font-medium shadow-sm" : "text-gray-500 hover:text-gray-300 hover:bg-white/5")}
onClick={() => setModelSize(m)} onClick={() => setAiModel(m.id)}
title={m.message}
> >
{m.split('_')[1]} {m.label.replace(' ', '')}
<span className={cn("ml-1", m.available ? "text-emerald-400" : "text-amber-400")}></span>
</button> </button>
))} ))}
</div> </div>
<div className="mt-2 text-[10px] text-gray-500 leading-relaxed">
<div>{selectedModelStatus?.message || '正在读取模型状态...'}</div>
<div>GPU: {modelStatus?.gpu.available ? `${modelStatus.gpu.name || 'CUDA'} 可用` : '不可用或未检测到 CUDA'}</div>
</div>
</div> </div>
{/* Prompt Tools */} {/* Prompt Tools */}
@@ -206,16 +262,16 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
<div className="p-6 bg-[#0a0a0a] border-t border-white/5 shrink-0 flex flex-col gap-3"> <div className="p-6 bg-[#0a0a0a] border-t border-white/5 shrink-0 flex flex-col gap-3">
<button <button
onClick={runInference} onClick={runInference}
disabled={isInferencing} disabled={isInferencing || !currentFrame || !modelCanInfer}
className={cn( className={cn(
"w-full py-3.5 rounded-lg flex items-center justify-center gap-2 transition-all shadow-lg font-medium tracking-wide text-xs uppercase", "w-full py-3.5 rounded-lg flex items-center justify-center gap-2 transition-all shadow-lg font-medium tracking-wide text-xs uppercase",
isInferencing isInferencing || !currentFrame || !modelCanInfer
? "bg-cyan-500/50 text-black/70 cursor-not-allowed" ? "bg-cyan-500/50 text-black/70 cursor-not-allowed"
: "bg-cyan-500 hover:bg-cyan-400 text-black shadow-cyan-500/20 hover:shadow-cyan-500/40" : "bg-cyan-500 hover:bg-cyan-400 text-black shadow-cyan-500/20 hover:shadow-cyan-500/40"
)} )}
> >
{isInferencing ? <Loader2 size={16} className="animate-spin" /> : <Sparkles size={16} />} {isInferencing ? <Loader2 size={16} className="animate-spin" /> : <Sparkles size={16} />}
{isInferencing ? '推理中...' : '执行高精度语义分割'} {isInferencing ? '推理中...' : modelCanInfer ? '执行高精度语义分割' : '当前模型不可用'}
</button> </button>
<button <button
onClick={onSendToWorkspace} onClick={onSendToWorkspace}
@@ -231,7 +287,7 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
<header className="h-16 border-b border-white/5 bg-[#111] flex items-center justify-between px-6 shrink-0"> <header className="h-16 border-b border-white/5 bg-[#111] flex items-center justify-between px-6 shrink-0">
<div className="flex flex-col"> <div className="flex flex-col">
<h2 className="text-sm font-semibold tracking-wide text-white"> (Visualizer)</h2> <h2 className="text-sm font-semibold tracking-wide text-white"> (Visualizer)</h2>
<span className="text-[10px] text-gray-500 uppercase tracking-widest font-mono">SAM 3 </span> <span className="text-[10px] text-gray-500 uppercase tracking-widest font-mono">{aiModel.toUpperCase()} </span>
</div> </div>
<div className="flex items-center gap-4"> <div className="flex items-center gap-4">
<button className="w-8 h-8 rounded text-gray-400 hover:bg-white/5 hover:text-white flex items-center justify-center transition-colors" title="撤销操作 (Ctrl+Z)"> <button className="w-8 h-8 rounded text-gray-400 hover:bg-white/5 hover:text-white flex items-center justify-center transition-colors" title="撤销操作 (Ctrl+Z)">
@@ -276,7 +332,7 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
)} )}
{/* AI Returned Masks */} {/* AI Returned Masks */}
{masks.map((mask) => ( {frameMasks.map((mask) => (
<Group key={mask.id} opacity={0.45}> <Group key={mask.id} opacity={0.45}>
<Path <Path
data={mask.pathData} data={mask.pathData}
@@ -309,7 +365,7 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
<div className="absolute bottom-4 left-4 flex gap-4 text-[10px] font-mono text-gray-500 pointer-events-none"> <div className="absolute bottom-4 left-4 flex gap-4 text-[10px] font-mono text-gray-500 pointer-events-none">
<span>: {cursorPos.x.toFixed(2)}, {cursorPos.y.toFixed(2)}</span> <span>: {cursorPos.x.toFixed(2)}, {cursorPos.y.toFixed(2)}</span>
<span>: {(scale * 100).toFixed(0)}%</span> <span>: {(scale * 100).toFixed(0)}%</span>
<span>: {masks.length}</span> <span>: {frameMasks.length}</span>
</div> </div>
</div> </div>
</div> </div>

View File

@@ -0,0 +1,130 @@
import { fireEvent, render, screen, waitFor } from '@testing-library/react';
import { beforeEach, describe, expect, it, vi } from 'vitest';
import { resetStore } from '../test/storeTestUtils';
import { useStore } from '../store/useStore';
import { CanvasArea } from './CanvasArea';
const apiMock = vi.hoisted(() => ({
predictMask: vi.fn(),
}));
vi.mock('../lib/api', () => ({
predictMask: apiMock.predictMask,
}));
describe('CanvasArea', () => {
const frame = { id: 'frame-1', projectId: 'project-1', index: 0, url: '/frame.jpg', width: 640, height: 360 };
beforeEach(() => {
resetStore();
vi.clearAllMocks();
});
it('calls AI prediction with the active frame when a point prompt is placed', async () => {
useStore.setState({
activeTemplateId: '2',
activeClass: { id: 'c1', name: '胆囊', color: '#ff0000', zIndex: 20 },
activeClassId: 'c1',
});
apiMock.predictMask.mockResolvedValueOnce({
masks: [
{
id: 'mask-1',
pathData: 'M 0 0 L 10 0 L 10 10 Z',
label: 'AI Mask',
color: '#06b6d4',
segmentation: [[0, 0, 10, 0, 10, 10]],
bbox: [0, 0, 10, 10],
area: 100,
},
],
});
render(<CanvasArea activeTool="point_pos" frame={frame} />);
fireEvent.click(screen.getByTestId('konva-stage'));
await waitFor(() => expect(apiMock.predictMask).toHaveBeenCalledWith({
imageId: 'frame-1',
imageWidth: 640,
imageHeight: 360,
model: 'sam2',
points: [{ x: 120, y: 80, type: 'pos' }],
box: undefined,
}));
expect(useStore.getState().masks[0]).toEqual(expect.objectContaining({
id: 'mask-1',
frameId: 'frame-1',
pathData: 'M 0 0 L 10 0 L 10 10 Z',
templateId: '2',
classId: 'c1',
className: '胆囊',
classZIndex: 20,
label: '胆囊',
color: '#ff0000',
saveStatus: 'draft',
}));
});
it('renders only masks that belong to the current frame', () => {
useStore.setState({
masks: [
{ id: 'm1', frameId: 'frame-1', pathData: 'M 0 0 Z', label: 'A', color: '#fff' },
{ id: 'm2', frameId: 'frame-2', pathData: 'M 1 1 Z', label: 'B', color: '#000' },
],
});
render(<CanvasArea activeTool="move" frame={frame} />);
expect(screen.getAllByTestId('konva-path')).toHaveLength(1);
expect(screen.getByText('遮罩数: 1')).toBeInTheDocument();
});
it('applies the selected class to current-frame masks and marks saved masks dirty', () => {
useStore.setState({
activeTemplateId: '2',
activeClass: { id: 'c1', name: '胆囊', color: '#ff0000', zIndex: 20 },
activeClassId: 'c1',
masks: [
{
id: 'm1',
frameId: 'frame-1',
annotationId: '99',
pathData: 'M 0 0 Z',
label: '旧标签',
color: '#06b6d4',
saved: true,
saveStatus: 'saved',
},
],
});
render(<CanvasArea activeTool="move" frame={frame} />);
fireEvent.click(screen.getByRole('button', { name: '应用分类' }));
expect(useStore.getState().masks[0]).toEqual(expect.objectContaining({
templateId: '2',
classId: 'c1',
className: '胆囊',
classZIndex: 20,
label: '胆囊',
color: '#ff0000',
saveStatus: 'dirty',
saved: false,
}));
});
it('delegates clear to the workspace handler so saved annotations can be deleted', () => {
const onClearMasks = vi.fn();
useStore.setState({
masks: [
{ id: 'm1', frameId: 'frame-1', pathData: 'M 0 0 Z', label: 'A', color: '#fff' },
],
});
render(<CanvasArea activeTool="move" frame={frame} onClearMasks={onClearMasks} />);
fireEvent.click(screen.getByRole('button', { name: '清空遮罩' }));
expect(onClearMasks).toHaveBeenCalled();
expect(useStore.getState().masks).toHaveLength(1);
});
});

View File

@@ -3,14 +3,15 @@ import { Stage, Layer, Image as KonvaImage, Circle, Rect, Path, Group } from 're
import useImage from 'use-image'; import useImage from 'use-image';
import { useStore } from '../store/useStore'; import { useStore } from '../store/useStore';
import { predictMask } from '../lib/api'; import { predictMask } from '../lib/api';
import { cn } from '../lib/utils'; import type { Frame } from '../store/useStore';
interface CanvasAreaProps { interface CanvasAreaProps {
activeTool: string; activeTool: string;
frameUrl: string; frame: Frame | null;
onClearMasks?: () => void;
} }
export function CanvasArea({ activeTool, frameUrl }: CanvasAreaProps) { export function CanvasArea({ activeTool, frame, onClearMasks }: CanvasAreaProps) {
const containerRef = useRef<HTMLDivElement>(null); const containerRef = useRef<HTMLDivElement>(null);
const [stageSize, setStageSize] = useState({ width: 800, height: 600 }); const [stageSize, setStageSize] = useState({ width: 800, height: 600 });
const [scale, setScale] = useState(1); const [scale, setScale] = useState(1);
@@ -24,13 +25,20 @@ export function CanvasArea({ activeTool, frameUrl }: CanvasAreaProps) {
const masks = useStore((state) => state.masks); const masks = useStore((state) => state.masks);
const addMask = useStore((state) => state.addMask); const addMask = useStore((state) => state.addMask);
const clearMasks = useStore((state) => state.clearMasks); const clearMasks = useStore((state) => state.clearMasks);
const setMasks = useStore((state) => state.setMasks);
const storeActiveTool = useStore((state) => state.activeTool); const storeActiveTool = useStore((state) => state.activeTool);
const setActiveTool = useStore((state) => state.setActiveTool); const aiModel = useStore((state) => state.aiModel);
const activeTemplateId = useStore((state) => state.activeTemplateId);
const activeClass = useStore((state) => state.activeClass);
const effectiveTool = activeTool || storeActiveTool; const effectiveTool = activeTool || storeActiveTool;
// Load the actual frame image // Load the actual frame image
const [image] = useImage(frameUrl || ''); const [image] = useImage(frame?.url || '');
const frameMasks = masks.filter((mask) => mask.frameId === frame?.id);
const savedMaskCount = frameMasks.filter((mask) => mask.saveStatus === 'saved' || mask.saved).length;
const draftMaskCount = frameMasks.filter((mask) => !mask.annotationId).length;
const dirtyMaskCount = frameMasks.filter((mask) => mask.saveStatus === 'dirty').length;
useEffect(() => { useEffect(() => {
const handleResize = () => { const handleResize = () => {
@@ -85,21 +93,44 @@ export function CanvasArea({ activeTool, frameUrl }: CanvasAreaProps) {
}; };
const runInference = useCallback(async (promptPoints?: typeof points, promptBox?: { x1: number, y1: number, x2: number, y2: number }) => { const runInference = useCallback(async (promptPoints?: typeof points, promptBox?: { x1: number, y1: number, x2: number, y2: number }) => {
if (!frame?.id) {
console.warn('Inference skipped: no active frame');
return;
}
const imageWidth = frame.width || image?.naturalWidth || image?.width || 0;
const imageHeight = frame.height || image?.naturalHeight || image?.height || 0;
if (imageWidth <= 0 || imageHeight <= 0) {
console.warn('Inference skipped: active frame dimensions are unavailable');
return;
}
setIsInferencing(true); setIsInferencing(true);
try { try {
const result = await predictMask({ const result = await predictMask({
imageUrl: frameUrl || '', imageId: frame.id,
imageWidth,
imageHeight,
model: aiModel,
points: promptPoints?.map((p) => ({ x: p.x, y: p.y, type: p.type })), points: promptPoints?.map((p) => ({ x: p.x, y: p.y, type: p.type })),
box: promptBox, box: promptBox,
}); });
result.masks.forEach((m) => { result.masks.forEach((m) => {
const label = activeClass?.name || m.label;
const color = activeClass?.color || m.color;
addMask({ addMask({
id: m.id, id: m.id,
frameId: 'frame-1', frameId: frame.id,
templateId: activeTemplateId || undefined,
classId: activeClass?.id,
className: activeClass?.name,
classZIndex: activeClass?.zIndex,
saveStatus: 'draft',
saved: false,
pathData: m.pathData, pathData: m.pathData,
label: m.label, label,
color: m.color, color,
segmentation: m.segmentation, segmentation: m.segmentation,
bbox: m.bbox, bbox: m.bbox,
area: m.area, area: m.area,
@@ -110,7 +141,33 @@ export function CanvasArea({ activeTool, frameUrl }: CanvasAreaProps) {
} finally { } finally {
setIsInferencing(false); setIsInferencing(false);
} }
}, [addMask]); }, [activeClass, activeTemplateId, addMask, aiModel, frame?.height, frame?.id, frame?.width, image?.height, image?.naturalHeight, image?.naturalWidth, image?.width]);
const handleApplyActiveClass = () => {
if (!frame?.id || !activeClass) return;
setMasks(masks.map((mask) => {
if (mask.frameId !== frame.id) return mask;
return {
...mask,
templateId: activeTemplateId || mask.templateId,
classId: activeClass.id,
className: activeClass.name,
classZIndex: activeClass.zIndex,
label: activeClass.name,
color: activeClass.color,
saveStatus: mask.annotationId ? 'dirty' : 'draft',
saved: Boolean(mask.annotationId) ? false : mask.saved,
};
}));
};
const handleClearMasks = () => {
if (onClearMasks) {
onClearMasks();
return;
}
clearMasks();
};
const handleStageMouseDown = (e: any) => { const handleStageMouseDown = (e: any) => {
if (effectiveTool === 'box_select') { if (effectiveTool === 'box_select') {
@@ -199,7 +256,7 @@ export function CanvasArea({ activeTool, frameUrl }: CanvasAreaProps) {
)} )}
{/* AI Returned Masks */} {/* AI Returned Masks */}
{masks.map((mask) => ( {frameMasks.map((mask) => (
<Group key={mask.id} opacity={0.5}> <Group key={mask.id} opacity={0.5}>
<Path <Path
data={mask.pathData} data={mask.pathData}
@@ -248,16 +305,29 @@ export function CanvasArea({ activeTool, frameUrl }: CanvasAreaProps) {
<span>: {cursorPos.x.toFixed(2)}, {cursorPos.y.toFixed(2)}</span> <span>: {cursorPos.x.toFixed(2)}, {cursorPos.y.toFixed(2)}</span>
<span>当前图层树: OBJECT_VEHICLE_01</span> <span>当前图层树: OBJECT_VEHICLE_01</span>
<span>: {(scale * 100).toFixed(0)}%</span> <span>: {(scale * 100).toFixed(0)}%</span>
<span>: {masks.length}</span> <span>: {frameMasks.length}</span>
<span>: {savedMaskCount}</span>
<span>: {draftMaskCount}</span>
<span>: {dirtyMaskCount}</span>
</div> </div>
{masks.length > 0 && ( {frameMasks.length > 0 && (
<button <div className="absolute bottom-4 right-4 flex gap-2">
onClick={clearMasks} {activeClass && (
className="absolute bottom-4 right-4 text-xs bg-red-500/10 hover:bg-red-500/20 text-red-400 border border-red-500/20 px-3 py-1.5 rounded transition-colors" <button
> onClick={handleApplyActiveClass}
className="text-xs bg-cyan-500/10 hover:bg-cyan-500/20 text-cyan-300 border border-cyan-500/20 px-3 py-1.5 rounded transition-colors"
</button> >
</button>
)}
<button
onClick={handleClearMasks}
className="text-xs bg-red-500/10 hover:bg-red-500/20 text-red-400 border border-red-500/20 px-3 py-1.5 rounded transition-colors"
>
</button>
</div>
)} )}
</div> </div>
); );

View File

@@ -0,0 +1,115 @@
import { act, render, screen, waitFor } from '@testing-library/react';
import { beforeEach, describe, expect, it, vi } from 'vitest';
import { Dashboard } from './Dashboard';
const apiMock = vi.hoisted(() => ({
getDashboardOverview: vi.fn(),
}));
const wsMock = vi.hoisted(() => {
const state = {
callback: undefined as undefined | ((data: any) => void),
connected: false,
};
return {
state,
progressWS: {
connect: vi.fn(() => { state.connected = true; }),
disconnect: vi.fn(() => { state.connected = false; }),
isConnected: vi.fn(() => state.connected),
onProgress: vi.fn((cb: (data: any) => void) => {
state.callback = cb;
return vi.fn();
}),
},
};
});
vi.mock('../lib/websocket', () => ({
progressWS: wsMock.progressWS,
}));
vi.mock('../lib/api', () => ({
getDashboardOverview: apiMock.getDashboardOverview,
}));
describe('Dashboard', () => {
beforeEach(() => {
vi.useRealTimers();
vi.clearAllMocks();
wsMock.state.connected = false;
wsMock.state.callback = undefined;
apiMock.getDashboardOverview.mockResolvedValue({
summary: {
project_count: 2,
parsing_task_count: 1,
annotation_count: 5,
frame_count: 100,
template_count: 3,
system_load_percent: 12,
},
tasks: [
{
id: 'project-1',
project_id: 1,
name: '真实项目.mp4',
progress: 60,
status: 'pending',
frame_count: 10,
updated_at: '2026-05-01T00:00:00Z',
},
],
activity: [
{
id: 'activity-1',
kind: 'project',
time: '2026-05-01T00:00:00Z',
message: '项目状态: pending',
project: '真实项目.mp4',
},
],
});
});
it('loads dashboard stats, tasks, and activity from the backend overview endpoint', async () => {
render(<Dashboard />);
await waitFor(() => expect(apiMock.getDashboardOverview).toHaveBeenCalled());
expect(screen.getByText('项目总数')).toBeInTheDocument();
expect(screen.getByText('已存标注')).toBeInTheDocument();
expect(screen.getByText('真实项目.mp4')).toBeInTheDocument();
expect(screen.getByText('项目状态: pending')).toBeInTheDocument();
expect(screen.queryByText('City_Driving_Dataset_004.mp4')).not.toBeInTheDocument();
});
it('connects to the progress stream and updates progress tasks', async () => {
render(<Dashboard />);
await waitFor(() => expect(wsMock.progressWS.connect).toHaveBeenCalled());
act(() => {
wsMock.state.callback?.({
type: 'progress',
taskId: 'task-1',
projectName: 'demo.mp4',
progress: 44,
status: '正在截取帧',
});
});
expect(await screen.findByText('demo.mp4')).toBeInTheDocument();
expect(screen.getByText('44%')).toBeInTheDocument();
});
it('adds activity logs for complete and status messages', async () => {
render(<Dashboard />);
act(() => {
wsMock.state.callback?.({ type: 'status', message: 'Progress stream active' });
wsMock.state.callback?.({ type: 'complete', taskId: '1', filename: 'done.mp4' });
});
await waitFor(() => expect(screen.getByText('Progress stream active')).toBeInTheDocument());
expect(screen.getByText('解析完成: done.mp4')).toBeInTheDocument();
});
});

View File

@@ -2,30 +2,68 @@ import React, { useState, useEffect } from 'react';
import { Activity, Clock, Folders, CheckCircle2, Loader2 } from 'lucide-react'; import { Activity, Clock, Folders, CheckCircle2, Loader2 } from 'lucide-react';
import { progressWS, type ProgressMessage } from '../lib/websocket'; import { progressWS, type ProgressMessage } from '../lib/websocket';
import { cn } from '../lib/utils'; import { cn } from '../lib/utils';
import { getDashboardOverview, type DashboardActivity, type DashboardOverview, type DashboardTask } from '../lib/api';
interface QueueTask { const emptySummary: DashboardOverview['summary'] = {
id: string; project_count: 0,
name: string; parsing_task_count: 0,
progress: number; annotation_count: 0,
status: string; frame_count: 0,
} template_count: 0,
system_load_percent: 0,
};
export function Dashboard() { export function Dashboard() {
const [tasks, setTasks] = useState<QueueTask[]>([ const [summary, setSummary] = useState<DashboardOverview['summary']>(emptySummary);
{ id: '1', name: 'City_Driving_Dataset_004.mp4', progress: 85, status: '正在截取帧 (30fps)' }, const [tasks, setTasks] = useState<DashboardTask[]>([]);
{ id: '2', name: 'Pedestrian_Night_Vision_02.mkv', progress: 32, status: '正在截取帧 (60fps)' },
{ id: '3', name: 'Drone_Mapping_Sector_7.avi', progress: 0, status: '队列排队等待中' },
]);
const [isConnected, setIsConnected] = useState(false); const [isConnected, setIsConnected] = useState(false);
const [activityLog, setActivityLog] = useState<Array<{ time: string; message: string; project?: string }>>([ const [activityLog, setActivityLog] = useState<DashboardActivity[]>([]);
{ time: '10 分钟前', message: '语义归档完成 54 帧', project: 'Highway_Data' }, const [isLoading, setIsLoading] = useState(true);
{ time: '25 分钟前', message: '项目解析开始', project: 'City_Driving_Dataset_004' }, const [loadError, setLoadError] = useState('');
{ time: '1 小时前', message: '模板库更新: Cityscapes_v2', project: '系统' },
{ time: '2 小时前', message: 'AI 推理完成 12 个实例', project: 'Nav_Cam_Left' }, useEffect(() => {
]); let cancelled = false;
const loadOverview = () => {
getDashboardOverview()
.then((overview) => {
if (cancelled) return;
setSummary(overview.summary);
setTasks((prev) => {
if (prev.length === 0) return overview.tasks;
const overviewIds = new Set(overview.tasks.map((task) => task.id));
const wsOnly = prev.filter((task) => !task.id.startsWith('task-') && !overviewIds.has(task.id) && task.progress < 100);
return [...overview.tasks, ...wsOnly];
});
setActivityLog((prev) => {
if (prev.length === 0) return overview.activity;
const byId = new Map(prev.map((item) => [item.id, item]));
overview.activity.forEach((item) => byId.set(item.id, item));
return Array.from(byId.values()).slice(0, 10);
});
setLoadError('');
})
.catch((err) => {
console.error('Failed to load dashboard overview:', err);
if (!cancelled) setLoadError('Dashboard 数据加载失败');
})
.finally(() => {
if (!cancelled) setIsLoading(false);
});
};
loadOverview();
const overviewInterval = setInterval(loadOverview, 5000);
return () => {
cancelled = true;
clearInterval(overviewInterval);
};
}, []);
useEffect(() => { useEffect(() => {
let mounted = true; let mounted = true;
const taskTitle = (data: ProgressMessage) => data.filename || data.projectName || data.taskId || '后台任务';
const timer = setTimeout(() => { const timer = setTimeout(() => {
if (mounted) progressWS.connect(); if (mounted) progressWS.connect();
}, 500); }, 500);
@@ -34,7 +72,7 @@ export function Dashboard() {
if (!mounted) return; if (!mounted) return;
setIsConnected(progressWS.isConnected()); setIsConnected(progressWS.isConnected());
if (data.type === 'progress' && data.taskId && data.filename) { if (data.type === 'progress' && data.taskId) {
setTasks((prev) => { setTasks((prev) => {
const exists = prev.find((t) => t.id === data.taskId); const exists = prev.find((t) => t.id === data.taskId);
if (exists) { if (exists) {
@@ -48,9 +86,12 @@ export function Dashboard() {
...prev, ...prev,
{ {
id: data.taskId!, id: data.taskId!,
name: data.filename!, project_id: data.project_id ?? Number(data.task_id || 0),
name: taskTitle(data),
progress: data.progress ?? 0, progress: data.progress ?? 0,
status: data.status ?? '处理中', status: data.status ?? '处理中',
frame_count: 0,
updated_at: new Date().toISOString(),
}, },
]; ];
}); });
@@ -63,7 +104,7 @@ export function Dashboard() {
) )
); );
setActivityLog((prev) => [ setActivityLog((prev) => [
{ time: '刚刚', message: `解析完成: ${data.filename || data.taskId}`, project: '系统' }, { id: `ws-complete-${Date.now()}`, kind: 'websocket', time: new Date().toISOString(), message: data.message || `解析完成: ${taskTitle(data)}`, project: data.projectName || '系统' },
...prev.slice(0, 9), ...prev.slice(0, 9),
]); ]);
} }
@@ -71,14 +112,18 @@ export function Dashboard() {
if (data.type === 'error' && data.taskId) { if (data.type === 'error' && data.taskId) {
setTasks((prev) => setTasks((prev) =>
prev.map((t) => prev.map((t) =>
t.id === data.taskId ? { ...t, status: `错误: ${data.message || '未知错误'}` } : t t.id === data.taskId ? { ...t, progress: data.progress ?? t.progress, status: `错误: ${data.error || data.message || '未知错误'}` } : t
) )
); );
setActivityLog((prev) => [
{ id: `ws-error-${Date.now()}`, kind: 'websocket', time: new Date().toISOString(), message: data.message || `解析失败: ${taskTitle(data)}`, project: data.projectName || '系统' },
...prev.slice(0, 9),
]);
} }
if (data.type === 'status') { if (data.type === 'status') {
setActivityLog((prev) => [ setActivityLog((prev) => [
{ time: '刚刚', message: data.message || '状态更新', project: '系统' }, { id: `ws-status-${Date.now()}`, kind: 'websocket', time: new Date().toISOString(), message: data.message || '状态更新', project: '系统' },
...prev.slice(0, 9), ...prev.slice(0, 9),
]); ]);
} }
@@ -97,12 +142,24 @@ export function Dashboard() {
}, []); }, []);
const stats = [ const stats = [
{ label: '运行中项目', value: '14', icon: Folders, color: 'text-blue-400', bg: 'bg-blue-400/10' }, { label: '项目总数', value: summary.project_count.toString(), icon: Folders, color: 'text-blue-400', bg: 'bg-blue-400/10' },
{ label: '排队处理任务', value: tasks.length.toString(), icon: Clock, color: 'text-orange-400', bg: 'bg-orange-400/10' }, { label: '处理任务', value: summary.parsing_task_count.toString(), icon: Clock, color: 'text-orange-400', bg: 'bg-orange-400/10' },
{ label: '已归档批次', value: '128', icon: CheckCircle2, color: 'text-emerald-400', bg: 'bg-emerald-400/10' }, { label: '已存标注', value: summary.annotation_count.toString(), icon: CheckCircle2, color: 'text-emerald-400', bg: 'bg-emerald-400/10' },
{ label: '系统负载', value: '78%', icon: Activity, color: 'text-cyan-400', bg: 'bg-cyan-400/10' }, { label: '系统负载', value: `${summary.system_load_percent}%`, icon: Activity, color: 'text-cyan-400', bg: 'bg-cyan-400/10' },
]; ];
function formatActivityTime(value: string | null): string {
if (!value) return '未知时间';
const date = new Date(value);
if (Number.isNaN(date.getTime())) return value;
return date.toLocaleString('zh-CN', {
month: '2-digit',
day: '2-digit',
hour: '2-digit',
minute: '2-digit',
});
}
return ( return (
<div className="p-8 w-full h-full overflow-y-auto bg-[#0a0a0a]"> <div className="p-8 w-full h-full overflow-y-auto bg-[#0a0a0a]">
<header className="mb-8"> <header className="mb-8">
@@ -119,6 +176,7 @@ export function Dashboard() {
</div> </div>
</div> </div>
<p className="text-gray-400 text-sm mt-1"></p> <p className="text-gray-400 text-sm mt-1"></p>
{loadError && <p className="text-red-400 text-xs mt-2">{loadError}</p>}
</header> </header>
<div className="grid grid-cols-1 md:grid-cols-2 lg:grid-cols-4 gap-4 mb-8"> <div className="grid grid-cols-1 md:grid-cols-2 lg:grid-cols-4 gap-4 mb-8">
@@ -140,8 +198,11 @@ export function Dashboard() {
<div className="grid grid-cols-1 lg:grid-cols-3 gap-6"> <div className="grid grid-cols-1 lg:grid-cols-3 gap-6">
<div className="lg:col-span-2 bg-[#111] border border-white/5 rounded-xl p-6 min-h-[400px]"> <div className="lg:col-span-2 bg-[#111] border border-white/5 rounded-xl p-6 min-h-[400px]">
<h2 className="text-sm font-medium text-gray-400 uppercase tracking-widest mb-6"> (FFmpeg )</h2> <h2 className="text-sm font-medium text-gray-400 uppercase tracking-widest mb-6"> ()</h2>
<div className="space-y-4"> <div className="space-y-4">
{isLoading && (
<div className="text-sm text-gray-500 text-center py-12"> Dashboard ...</div>
)}
{tasks.map((task) => ( {tasks.map((task) => (
<div key={task.id} className="bg-[#0d0d0d] border border-white/5 p-4 rounded-lg"> <div key={task.id} className="bg-[#0d0d0d] border border-white/5 p-4 rounded-lg">
<div className="flex justify-between items-center mb-2"> <div className="flex justify-between items-center mb-2">
@@ -152,7 +213,7 @@ export function Dashboard() {
<div className="h-full bg-gradient-to-r from-cyan-600 to-cyan-400 rounded-full transition-all duration-500" style={{ width: `${task.progress}%` }} /> <div className="h-full bg-gradient-to-r from-cyan-600 to-cyan-400 rounded-full transition-all duration-500" style={{ width: `${task.progress}%` }} />
</div> </div>
<div className="text-xs text-gray-500 flex items-center gap-2"> <div className="text-xs text-gray-500 flex items-center gap-2">
{task.status === '已完成' ? ( {task.status === '已完成' || task.progress >= 100 ? (
<CheckCircle2 size={12} className="text-emerald-400" /> <CheckCircle2 size={12} className="text-emerald-400" />
) : task.status.includes('错误') ? ( ) : task.status.includes('错误') ? (
<span className="text-red-400"></span> <span className="text-red-400"></span>
@@ -160,10 +221,11 @@ export function Dashboard() {
<Loader2 size={12} className="text-cyan-400 animate-spin" /> <Loader2 size={12} className="text-cyan-400 animate-spin" />
)} )}
{task.status} {task.status}
<span className="text-gray-600">: {task.frame_count}</span>
</div> </div>
</div> </div>
))} ))}
{tasks.length === 0 && ( {!isLoading && tasks.length === 0 && (
<div className="text-sm text-gray-500 text-center py-12"></div> <div className="text-sm text-gray-500 text-center py-12"></div>
)} )}
</div> </div>
@@ -172,16 +234,22 @@ export function Dashboard() {
<div className="bg-[#111] border border-white/5 rounded-xl p-6 min-h-[400px]"> <div className="bg-[#111] border border-white/5 rounded-xl p-6 min-h-[400px]">
<h2 className="text-sm font-medium text-gray-400 uppercase tracking-widest mb-6"></h2> <h2 className="text-sm font-medium text-gray-400 uppercase tracking-widest mb-6"></h2>
<div className="space-y-6 relative before:absolute before:inset-0 before:ml-[11px] before:-translate-x-px md:before:mx-auto md:before:translate-x-0 before:h-full before:w-0.5 before:bg-gradient-to-b before:from-transparent before:via-white/10 before:to-transparent"> <div className="space-y-6 relative before:absolute before:inset-0 before:ml-[11px] before:-translate-x-px md:before:mx-auto md:before:translate-x-0 before:h-full before:w-0.5 before:bg-gradient-to-b before:from-transparent before:via-white/10 before:to-transparent">
{activityLog.map((log, i) => ( {isLoading && (
<div key={i} className="relative flex items-center justify-between md:justify-normal md:odd:flex-row-reverse group is-active"> <div className="text-sm text-gray-500 text-center py-12">...</div>
)}
{activityLog.map((log) => (
<div key={log.id} className="relative flex items-center justify-between md:justify-normal md:odd:flex-row-reverse group is-active">
<div className="flex items-center justify-center w-6 h-6 rounded-full border border-white/10 bg-[#111] group-[.is-active]:bg-cyan-500 group-[.is-active]:border-cyan-400 text-slate-500 group-[.is-active]:text-black shadow shrink-0 md:order-1 md:group-odd:-translate-x-1/2 md:group-even:translate-x-1/2 z-10" /> <div className="flex items-center justify-center w-6 h-6 rounded-full border border-white/10 bg-[#111] group-[.is-active]:bg-cyan-500 group-[.is-active]:border-cyan-400 text-slate-500 group-[.is-active]:text-black shadow shrink-0 md:order-1 md:group-odd:-translate-x-1/2 md:group-even:translate-x-1/2 z-10" />
<div className="w-[calc(100%-4rem)] md:w-[calc(50%-2.5rem)] bg-[#0d0d0d] p-3 rounded border border-white/5"> <div className="w-[calc(100%-4rem)] md:w-[calc(50%-2.5rem)] bg-[#0d0d0d] p-3 rounded border border-white/5">
<div className="text-xs text-gray-400 mb-1">{log.time}</div> <div className="text-xs text-gray-400 mb-1">{formatActivityTime(log.time)}</div>
<div className="text-sm font-medium text-gray-200">{log.message}</div> <div className="text-sm font-medium text-gray-200">{log.message}</div>
<div className="text-xs text-gray-500">: {log.project}</div> <div className="text-xs text-gray-500">: {log.project}</div>
</div> </div>
</div> </div>
))} ))}
{!isLoading && activityLog.length === 0 && (
<div className="text-sm text-gray-500 text-center py-12"></div>
)}
</div> </div>
</div> </div>
</div> </div>

View File

@@ -0,0 +1,62 @@
import { act, fireEvent, render, screen } from '@testing-library/react';
import { beforeEach, describe, expect, it, vi } from 'vitest';
import { resetStore } from '../test/storeTestUtils';
import { useStore } from '../store/useStore';
import { FrameTimeline } from './FrameTimeline';
describe('FrameTimeline', () => {
beforeEach(() => {
resetStore();
vi.useRealTimers();
});
it('renders empty state when no frames are loaded', () => {
render(<FrameTimeline />);
expect(screen.getByText('暂无帧数据')).toBeInTheDocument();
expect(screen.getByText('0')).toBeInTheDocument();
});
it('changes the current frame through thumbnails and range input', () => {
useStore.setState({
frames: [
{ id: 'f1', projectId: 'p1', index: 0, url: '/1.jpg', width: 640, height: 360 },
{ id: 'f2', projectId: 'p1', index: 1, url: '/2.jpg', width: 640, height: 360 },
{ id: 'f3', projectId: 'p1', index: 2, url: '/3.jpg', width: 640, height: 360 },
],
});
render(<FrameTimeline />);
fireEvent.click(screen.getByAltText('frame-1'));
expect(useStore.getState().currentFrameIndex).toBe(1);
fireEvent.change(screen.getByRole('slider'), { target: { value: '3' } });
expect(useStore.getState().currentFrameIndex).toBe(2);
});
it('plays forward using the project parse fps and stops at the end', () => {
vi.useFakeTimers();
useStore.setState({
currentProject: { id: 'p1', name: 'P', status: 'ready', parse_fps: 10 },
frames: [
{ id: 'f1', projectId: 'p1', index: 0, url: '/1.jpg', width: 640, height: 360 },
{ id: 'f2', projectId: 'p1', index: 1, url: '/2.jpg', width: 640, height: 360 },
],
});
const { container } = render(<FrameTimeline />);
fireEvent.click(container.querySelector('button')!);
act(() => {
vi.advanceTimersByTime(100);
});
expect(useStore.getState().currentFrameIndex).toBe(1);
act(() => {
vi.advanceTimersByTime(100);
});
expect(screen.getByText('播放序列 (F5)')).toBeInTheDocument();
});
});

View File

@@ -1,16 +1,42 @@
import React, { useState } from 'react'; import React, { useEffect, useMemo, useState } from 'react';
import { Play, Pause } from 'lucide-react'; import { Play, Pause } from 'lucide-react';
import { cn } from '../lib/utils'; import { cn } from '../lib/utils';
import { useStore } from '../store/useStore'; import { useStore } from '../store/useStore';
export function FrameTimeline() { export function FrameTimeline() {
const frames = useStore((state) => state.frames); const frames = useStore((state) => state.frames);
const currentProject = useStore((state) => state.currentProject);
const currentFrameIndex = useStore((state) => state.currentFrameIndex); const currentFrameIndex = useStore((state) => state.currentFrameIndex);
const setCurrentFrame = useStore((state) => state.setCurrentFrame); const setCurrentFrame = useStore((state) => state.setCurrentFrame);
const [isPlaying, setIsPlaying] = useState(false); const [isPlaying, setIsPlaying] = useState(false);
const totalFrames = frames.length; const totalFrames = frames.length;
const currentFrame = totalFrames > 0 ? currentFrameIndex + 1 : 0; const currentFrame = totalFrames > 0 ? currentFrameIndex + 1 : 0;
const playbackFps = useMemo(() => {
const fps = currentProject?.parse_fps || currentProject?.original_fps || 12;
return Math.min(Math.max(fps, 1), 30);
}, [currentProject?.original_fps, currentProject?.parse_fps]);
useEffect(() => {
if (!isPlaying || totalFrames <= 1) return;
const timer = window.setTimeout(() => {
if (currentFrameIndex >= totalFrames - 1) {
setIsPlaying(false);
return;
}
setCurrentFrame(currentFrameIndex + 1);
}, 1000 / playbackFps);
return () => window.clearTimeout(timer);
}, [currentFrameIndex, isPlaying, playbackFps, setCurrentFrame, totalFrames]);
useEffect(() => {
if (totalFrames === 0) {
setIsPlaying(false);
}
}, [totalFrames]);
// show frames around current frame // show frames around current frame
const frameWindow = 20; const frameWindow = 20;
@@ -45,8 +71,14 @@ export function FrameTimeline() {
<div className="flex-1 flex items-center px-4 gap-6"> <div className="flex-1 flex items-center px-4 gap-6">
<div className="flex flex-col items-center gap-2 px-4 border-r border-white/10 shrink-0"> <div className="flex flex-col items-center gap-2 px-4 border-r border-white/10 shrink-0">
<button <button
className="p-2 rounded-full bg-white/5 text-white hover:bg-white/10" className="p-2 rounded-full bg-white/5 text-white hover:bg-white/10 disabled:opacity-40 disabled:cursor-not-allowed"
onClick={() => setIsPlaying(!isPlaying)} disabled={totalFrames <= 1}
onClick={() => {
if (currentFrameIndex >= totalFrames - 1) {
setCurrentFrame(0);
}
setIsPlaying(!isPlaying);
}}
> >
{isPlaying ? <Pause size={20} fill="currentColor" /> : <Play size={20} fill="currentColor" />} {isPlaying ? <Pause size={20} fill="currentColor" /> : <Play size={20} fill="currentColor" />}
</button> </button>

View File

@@ -0,0 +1,42 @@
import { fireEvent, render, screen, waitFor } from '@testing-library/react';
import { beforeEach, describe, expect, it, vi } from 'vitest';
import { resetStore } from '../test/storeTestUtils';
import { useStore } from '../store/useStore';
import { Login } from './Login';
const apiMock = vi.hoisted(() => ({
login: vi.fn(),
}));
vi.mock('../lib/api', () => ({
login: apiMock.login,
}));
describe('Login', () => {
beforeEach(() => {
resetStore();
vi.clearAllMocks();
});
it('logs in with the development credentials and stores the token', async () => {
apiMock.login.mockResolvedValueOnce({ token: 'fake-jwt-token-for-admin' });
render(<Login />);
fireEvent.click(screen.getByRole('button', { name: '安全登录' }));
await waitFor(() => expect(apiMock.login).toHaveBeenCalledWith('admin', '123456'));
expect(useStore.getState().isAuthenticated).toBe(true);
expect(localStorage.getItem('token')).toBe('fake-jwt-token-for-admin');
});
it('shows backend login errors', async () => {
apiMock.login.mockRejectedValueOnce({ response: { data: { detail: 'Invalid credentials' } } });
render(<Login />);
fireEvent.change(screen.getByDisplayValue('admin'), { target: { value: 'bad' } });
fireEvent.click(screen.getByRole('button', { name: '安全登录' }));
expect(await screen.findByText('Invalid credentials')).toBeInTheDocument();
expect(useStore.getState().isAuthenticated).toBe(false);
});
});

View File

@@ -0,0 +1,45 @@
import { render, screen, waitFor } from '@testing-library/react';
import { beforeEach, describe, expect, it, vi } from 'vitest';
import { resetStore } from '../test/storeTestUtils';
import { useStore } from '../store/useStore';
import { ModelStatusBadge } from './ModelStatusBadge';
const apiMock = vi.hoisted(() => ({
getAiModelStatus: vi.fn(),
}));
vi.mock('../lib/api', () => ({
getAiModelStatus: apiMock.getAiModelStatus,
}));
describe('ModelStatusBadge', () => {
beforeEach(() => {
resetStore();
vi.clearAllMocks();
apiMock.getAiModelStatus.mockResolvedValue({
selected_model: 'sam2',
gpu: { available: true, device: 'cuda', name: 'RTX 4090', torch_available: true },
models: [
{ id: 'sam2', label: 'SAM 2', available: true, loaded: false, device: 'cuda', supports: ['point', 'box'], message: 'SAM 2 ready', package_available: true, checkpoint_exists: true, python_ok: true, torch_ok: true, cuda_required: false },
{ id: 'sam3', label: 'SAM 3', available: false, loaded: false, device: 'cuda', supports: ['semantic'], message: 'SAM 3 missing runtime', package_available: false, checkpoint_exists: false, python_ok: false, torch_ok: true, cuda_required: true },
],
});
});
it('loads real model status for the selected model', async () => {
render(<ModelStatusBadge />);
expect(await screen.findByText('SAM 2 可用')).toBeInTheDocument();
expect(apiMock.getAiModelStatus).toHaveBeenCalledWith('sam2');
});
it('shows unavailable state when SAM3 is selected but not runnable', async () => {
useStore.getState().setAiModel('sam3');
render(<ModelStatusBadge />);
await waitFor(() => expect(apiMock.getAiModelStatus).toHaveBeenCalledWith('sam3'));
expect(await screen.findByText('SAM 3 不可用')).toBeInTheDocument();
expect(screen.getByTitle('SAM 3 missing runtime')).toBeInTheDocument();
});
});

View File

@@ -0,0 +1,56 @@
import React, { useEffect, useState } from 'react';
import { Cpu, Loader2 } from 'lucide-react';
import { getAiModelStatus, type AiRuntimeStatus } from '../lib/api';
import { cn } from '../lib/utils';
import { useStore } from '../store/useStore';
interface ModelStatusBadgeProps {
compact?: boolean;
}
export function ModelStatusBadge({ compact = false }: ModelStatusBadgeProps) {
const aiModel = useStore((state) => state.aiModel);
const [status, setStatus] = useState<AiRuntimeStatus | null>(null);
const [isLoading, setIsLoading] = useState(true);
useEffect(() => {
let cancelled = false;
setIsLoading(true);
getAiModelStatus(aiModel)
.then((data) => {
if (!cancelled) setStatus(data);
})
.catch(() => {
if (!cancelled) setStatus(null);
})
.finally(() => {
if (!cancelled) setIsLoading(false);
});
return () => {
cancelled = true;
};
}, [aiModel]);
const model = status?.models.find((item) => item.id === aiModel);
const ready = Boolean(model?.available);
const gpuReady = Boolean(status?.gpu.available);
const label = compact
? (gpuReady ? 'GPU' : 'CPU')
: `${model?.label || aiModel.toUpperCase()} ${ready ? '可用' : '不可用'}`;
return (
<div
className={cn(
"inline-flex items-center gap-1.5 rounded border font-mono uppercase",
compact ? "w-8 h-8 justify-center text-[9px]" : "px-2 py-0.5 text-[10px]",
ready
? "bg-emerald-500/10 text-emerald-400 border-emerald-500/20"
: "bg-amber-500/10 text-amber-400 border-amber-500/20"
)}
title={model?.message || 'AI 模型状态读取中'}
>
{isLoading ? <Loader2 size={compact ? 12 : 10} className="animate-spin" /> : <Cpu size={compact ? 12 : 10} />}
<span>{label}</span>
</div>
);
}

View File

@@ -0,0 +1,60 @@
import { fireEvent, render, screen, within } from '@testing-library/react';
import { beforeEach, describe, expect, it } from 'vitest';
import { resetStore } from '../test/storeTestUtils';
import { useStore } from '../store/useStore';
import { OntologyInspector } from './OntologyInspector';
describe('OntologyInspector', () => {
beforeEach(() => {
resetStore();
useStore.setState({
templates: [
{
id: 't1',
name: '腹腔镜模板',
classes: [
{ id: 'c1', name: '胆囊', color: '#ff0000', zIndex: 20, category: '器官' },
{ id: 'c2', name: '肝脏', color: '#00ff00', zIndex: 10, category: '器官' },
],
rules: [],
},
],
});
});
it('shows template classes and changes the active template', () => {
render(<OntologyInspector />);
fireEvent.change(screen.getByRole('combobox'), { target: { value: 't1' } });
expect(useStore.getState().activeTemplateId).toBe('t1');
expect(screen.getByText('胆囊')).toBeInTheDocument();
expect(screen.getByText('肝脏')).toBeInTheDocument();
});
it('selects a concrete class for subsequent masks', () => {
render(<OntologyInspector />);
fireEvent.click(screen.getByText('胆囊'));
expect(useStore.getState().activeClassId).toBe('c1');
expect(useStore.getState().activeClass).toEqual(expect.objectContaining({
id: 'c1',
name: '胆囊',
zIndex: 20,
}));
});
it('adds custom classes locally without backend persistence', () => {
const { container } = render(<OntologyInspector />);
const customSection = screen.getByText('自定义分类').parentElement!;
fireEvent.click(within(customSection).getByRole('button'));
fireEvent.change(screen.getByPlaceholderText('分类名称'), { target: { value: '新局部分类' } });
fireEvent.keyDown(screen.getByPlaceholderText('分类名称'), { key: 'Enter' });
expect(screen.getAllByText('新局部分类')).toHaveLength(2);
expect(useStore.getState().activeClass).toEqual(expect.objectContaining({ name: '新局部分类' }));
expect(useStore.getState().templates[0].classes).toHaveLength(2);
expect(container).toHaveTextContent('2 个分类来自模板 + 1 个自定义');
});
});

View File

@@ -2,11 +2,16 @@ import React, { useState } from 'react';
import { Layers, ChevronDown, Tag, Eye, Plus, X } from 'lucide-react'; import { Layers, ChevronDown, Tag, Eye, Plus, X } from 'lucide-react';
import { useStore } from '../store/useStore'; import { useStore } from '../store/useStore';
import type { TemplateClass } from '../store/useStore'; import type { TemplateClass } from '../store/useStore';
import { cn } from '../lib/utils';
import { getActiveTemplate } from '../lib/templateSelection';
export function OntologyInspector() { export function OntologyInspector() {
const templates = useStore((state) => state.templates); const templates = useStore((state) => state.templates);
const activeTemplateId = useStore((state) => state.activeTemplateId); const activeTemplateId = useStore((state) => state.activeTemplateId);
const activeClassId = useStore((state) => state.activeClassId);
const activeClass = useStore((state) => state.activeClass);
const setActiveTemplateId = useStore((state) => state.setActiveTemplateId); const setActiveTemplateId = useStore((state) => state.setActiveTemplateId);
const setActiveClass = useStore((state) => state.setActiveClass);
// Project-level custom classes (in addition to template classes) // Project-level custom classes (in addition to template classes)
const [customClasses, setCustomClasses] = useState<TemplateClass[]>([]); const [customClasses, setCustomClasses] = useState<TemplateClass[]>([]);
@@ -14,10 +19,17 @@ export function OntologyInspector() {
const [newClassName, setNewClassName] = useState(''); const [newClassName, setNewClassName] = useState('');
const [newClassColor, setNewClassColor] = useState('#06b6d4'); const [newClassColor, setNewClassColor] = useState('#06b6d4');
const activeTemplate = templates.find((t) => t.id === activeTemplateId) || templates[0] || null; const activeTemplate = getActiveTemplate(templates, activeTemplateId);
const templateClasses = activeTemplate?.classes || []; const templateClasses = activeTemplate?.classes || [];
const allClasses = [...templateClasses, ...customClasses].sort((a, b) => b.zIndex - a.zIndex); const allClasses = [...templateClasses, ...customClasses].sort((a, b) => b.zIndex - a.zIndex);
const handleSelectClass = (templateClass: TemplateClass) => {
if (activeTemplate && !activeTemplateId) {
setActiveTemplateId(activeTemplate.id);
}
setActiveClass(templateClass);
};
const handleAddCustom = () => { const handleAddCustom = () => {
if (!newClassName.trim()) return; if (!newClassName.trim()) return;
const maxZ = allClasses.length > 0 ? Math.max(...allClasses.map((c) => c.zIndex)) : 0; const maxZ = allClasses.length > 0 ? Math.max(...allClasses.map((c) => c.zIndex)) : 0;
@@ -29,6 +41,7 @@ export function OntologyInspector() {
category: '自定义', category: '自定义',
}; };
setCustomClasses([...customClasses, newClass]); setCustomClasses([...customClasses, newClass]);
handleSelectClass(newClass);
setNewClassName(''); setNewClassName('');
setShowAddForm(false); setShowAddForm(false);
}; };
@@ -47,7 +60,10 @@ export function OntologyInspector() {
<div className="relative"> <div className="relative">
<select <select
value={activeTemplate?.id || ''} value={activeTemplate?.id || ''}
onChange={(e) => setActiveTemplateId(e.target.value || null)} onChange={(e) => {
setActiveTemplateId(e.target.value || null);
setActiveClass(null);
}}
className="w-full bg-[#1a1a1a] border border-white/10 rounded-lg px-3 py-2 text-xs text-gray-300 appearance-none cursor-pointer focus:outline-none focus:border-cyan-500/50" className="w-full bg-[#1a1a1a] border border-white/10 rounded-lg px-3 py-2 text-xs text-gray-300 appearance-none cursor-pointer focus:outline-none focus:border-cyan-500/50"
> >
<option value="">-- --</option> <option value="">-- --</option>
@@ -73,7 +89,14 @@ export function OntologyInspector() {
<div className="space-y-2"> <div className="space-y-2">
{allClasses.map(cls => ( {allClasses.map(cls => (
<div key={cls.id} className="flex flex-col gap-1"> <div key={cls.id} className="flex flex-col gap-1">
<div className="flex items-center justify-between p-2 rounded bg-white/5 hover:bg-white/10 cursor-pointer group transition-colors"> <button
type="button"
onClick={() => handleSelectClass(cls)}
className={cn(
'flex items-center justify-between p-2 rounded bg-white/5 hover:bg-white/10 cursor-pointer group transition-colors text-left border',
activeClassId === cls.id ? 'border-cyan-500/50 bg-cyan-500/10' : 'border-transparent',
)}
>
<div className="flex items-center gap-2"> <div className="flex items-center gap-2">
<span className="w-2.5 h-2.5 rounded-sm" style={{ backgroundColor: cls.color }} /> <span className="w-2.5 h-2.5 rounded-sm" style={{ backgroundColor: cls.color }} />
<span className="text-xs font-medium text-gray-200">{cls.name}</span> <span className="text-xs font-medium text-gray-200">{cls.name}</span>
@@ -82,7 +105,7 @@ export function OntologyInspector() {
<span className="text-[10px] text-gray-500 font-mono">z:{cls.zIndex}</span> <span className="text-[10px] text-gray-500 font-mono">z:{cls.zIndex}</span>
<Eye size={14} className="text-gray-500 group-hover:text-gray-300" /> <Eye size={14} className="text-gray-500 group-hover:text-gray-300" />
</div> </div>
</div> </button>
</div> </div>
))} ))}
{allClasses.length === 0 && ( {allClasses.length === 0 && (
@@ -136,7 +159,9 @@ export function OntologyInspector() {
<div className="bg-white/5 rounded-lg p-3"> <div className="bg-white/5 rounded-lg p-3">
<div className="flex items-center gap-2 mb-3"> <div className="flex items-center gap-2 mb-3">
<Tag size={12} className="text-cyan-400" /> <Tag size={12} className="text-cyan-400" />
<span className="text-xs font-semibold text-gray-200">{activeTemplate?.name || '未选择'}</span> <span className="text-xs font-semibold text-gray-200">
{activeClass?.name || activeTemplate?.name || '未选择'}
</span>
</div> </div>
<div className="space-y-3"> <div className="space-y-3">
<div className="space-y-1"> <div className="space-y-1">

View File

@@ -0,0 +1,92 @@
import { fireEvent, render, screen, waitFor } from '@testing-library/react';
import { beforeEach, describe, expect, it, vi } from 'vitest';
import { resetStore } from '../test/storeTestUtils';
import { useStore } from '../store/useStore';
import { ProjectLibrary } from './ProjectLibrary';
const apiMock = vi.hoisted(() => ({
getProjects: vi.fn(),
createProject: vi.fn(),
uploadMedia: vi.fn(),
parseMedia: vi.fn(),
uploadDicomBatch: vi.fn(),
}));
vi.mock('../lib/api', () => ({
getProjects: apiMock.getProjects,
createProject: apiMock.createProject,
uploadMedia: apiMock.uploadMedia,
parseMedia: apiMock.parseMedia,
uploadDicomBatch: apiMock.uploadDicomBatch,
}));
describe('ProjectLibrary', () => {
beforeEach(() => {
resetStore();
vi.clearAllMocks();
apiMock.getProjects.mockResolvedValue([]);
});
it('loads projects and selects one into the workspace', async () => {
const onProjectSelect = vi.fn();
apiMock.getProjects.mockResolvedValueOnce([
{ id: 'p1', name: 'Demo Project', status: 'ready', frames: 3, fps: '30FPS' },
]);
render(<ProjectLibrary onProjectSelect={onProjectSelect} />);
fireEvent.click(await screen.findByText('Demo Project'));
expect(useStore.getState().currentProject?.id).toBe('p1');
expect(onProjectSelect).toHaveBeenCalled();
});
it('creates a new project from the modal', async () => {
apiMock.createProject.mockResolvedValueOnce({ id: 'p2', name: 'New Project', status: 'pending' });
render(<ProjectLibrary onProjectSelect={vi.fn()} />);
fireEvent.click(screen.getByText('新建项目'));
fireEvent.change(screen.getByPlaceholderText('输入项目名称'), { target: { value: 'New Project' } });
fireEvent.change(screen.getByPlaceholderText('输入项目描述'), { target: { value: 'desc' } });
fireEvent.click(screen.getByRole('button', { name: '创建' }));
await waitFor(() => expect(apiMock.createProject).toHaveBeenCalledWith({
name: 'New Project',
description: 'desc',
}));
expect(useStore.getState().projects[0]).toEqual(expect.objectContaining({ id: 'p2' }));
});
it('imports video by creating a project, uploading media, parsing frames and refreshing projects', async () => {
apiMock.createProject.mockResolvedValueOnce({ id: 'p3', name: 'clip.mp4', status: 'pending' });
apiMock.uploadMedia.mockResolvedValueOnce({ url: 'http://file', id: 'object' });
apiMock.parseMedia.mockResolvedValueOnce({ frames_extracted: 1 });
apiMock.getProjects.mockResolvedValue([]);
const { container } = render(<ProjectLibrary onProjectSelect={vi.fn()} />);
const input = container.querySelector('input[accept="video/*"]') as HTMLInputElement;
const file = new File(['video'], 'clip.mp4', { type: 'video/mp4' });
fireEvent.change(input, { target: { files: [file] } });
fireEvent.click(await screen.findByRole('button', { name: '开始导入' }));
await waitFor(() => expect(apiMock.createProject).toHaveBeenCalledWith(expect.objectContaining({
name: 'clip.mp4',
parse_fps: 30,
})));
expect(apiMock.uploadMedia).toHaveBeenCalledWith(file, 'p3');
expect(apiMock.parseMedia).toHaveBeenCalledWith('p3');
});
it('imports only valid DICOM files and parses the returned project', async () => {
apiMock.uploadDicomBatch.mockResolvedValueOnce({ project_id: 77, uploaded_count: 1, message: 'ok' });
apiMock.parseMedia.mockResolvedValueOnce({ frames_extracted: 1 });
const { container } = render(<ProjectLibrary onProjectSelect={vi.fn()} />);
const input = container.querySelector('input[accept=".dcm"]') as HTMLInputElement;
const dcm = new File(['dcm'], 'scan.dcm', { type: 'application/dicom' });
const ignored = new File(['txt'], 'notes.txt', { type: 'text/plain' });
fireEvent.change(input, { target: { files: [dcm, ignored] } });
await waitFor(() => expect(apiMock.uploadDicomBatch).toHaveBeenCalledWith([dcm]));
expect(apiMock.parseMedia).toHaveBeenCalledWith('77');
});
});

View File

@@ -212,11 +212,11 @@ export function ProjectLibrary({ onProjectSelect }: ProjectLibraryProps) {
{proj.source_type === 'dicom' ? 'DICOM' : (proj.fps || '30FPS')} {proj.source_type === 'dicom' ? 'DICOM' : (proj.fps || '30FPS')}
</span> </span>
<span className="backdrop-blur-md bg-black/40 text-gray-200 text-[10px] px-2 py-1 rounded border border-white/10 flex items-center gap-1 uppercase tracking-widest"> <span className="backdrop-blur-md bg-black/40 text-gray-200 text-[10px] px-2 py-1 rounded border border-white/10 flex items-center gap-1 uppercase tracking-widest">
{proj.status === 'Ready' ? ( {proj.status === 'ready' ? (
<><div className="w-1.5 h-1.5 bg-emerald-500 rounded-full" /> </> <><div className="w-1.5 h-1.5 bg-emerald-500 rounded-full" /> </>
) : proj.status === 'Parsing' ? ( ) : proj.status === 'parsing' ? (
<><div className="w-1.5 h-1.5 bg-amber-500 rounded-full animate-pulse" /> </> <><div className="w-1.5 h-1.5 bg-amber-500 rounded-full animate-pulse" /> </>
) : proj.status === 'Error' ? ( ) : proj.status === 'error' ? (
<><div className="w-1.5 h-1.5 bg-red-500 rounded-full" /> </> <><div className="w-1.5 h-1.5 bg-red-500 rounded-full" /> </>
) : ( ) : (
<><div className="w-1.5 h-1.5 bg-blue-500 rounded-full" /> </> <><div className="w-1.5 h-1.5 bg-blue-500 rounded-full" /> </>

View File

@@ -2,6 +2,7 @@ import React from 'react';
import { Home, FolderOpen, Edit3, LayoutTemplate, BrainCircuit } from 'lucide-react'; import { Home, FolderOpen, Edit3, LayoutTemplate, BrainCircuit } from 'lucide-react';
import { cn } from '../lib/utils'; import { cn } from '../lib/utils';
import type { ActiveModule } from '../App'; import type { ActiveModule } from '../App';
import { ModelStatusBadge } from './ModelStatusBadge';
interface SidebarProps { interface SidebarProps {
activeModule: ActiveModule; activeModule: ActiveModule;
@@ -47,9 +48,7 @@ export function Sidebar({ activeModule, setActiveModule }: SidebarProps) {
})} })}
</nav> </nav>
<div className="mt-auto mb-4 flex flex-col gap-4"> <div className="mt-auto mb-4 flex flex-col gap-4">
<div className="w-8 h-8 rounded-full border border-cyan-500/50 flex items-center justify-center text-[10px] text-cyan-400 font-bold cursor-pointer transition-all hover:bg-cyan-500/10"> <ModelStatusBadge compact />
GPU
</div>
</div> </div>
</aside> </aside>
); );

View File

@@ -0,0 +1,85 @@
import { fireEvent, render, screen, waitFor } from '@testing-library/react';
import { beforeEach, describe, expect, it, vi } from 'vitest';
import { resetStore } from '../test/storeTestUtils';
import { useStore } from '../store/useStore';
import { TemplateRegistry } from './TemplateRegistry';
const apiMock = vi.hoisted(() => ({
getTemplates: vi.fn(),
createTemplate: vi.fn(),
updateTemplate: vi.fn(),
deleteTemplate: vi.fn(),
}));
vi.mock('../lib/api', () => ({
getTemplates: apiMock.getTemplates,
createTemplate: apiMock.createTemplate,
updateTemplate: apiMock.updateTemplate,
deleteTemplate: apiMock.deleteTemplate,
}));
describe('TemplateRegistry', () => {
beforeEach(() => {
resetStore();
vi.clearAllMocks();
});
it('loads and displays templates with unpacked classes', async () => {
apiMock.getTemplates.mockResolvedValueOnce([
{
id: 't1',
name: '腹腔镜胆囊切除术',
description: 'desc',
classes: [{ id: 'c1', name: '胆囊', color: '#ff0000', zIndex: 10, category: '器官' }],
rules: [],
},
]);
render(<TemplateRegistry />);
expect(await screen.findAllByText('腹腔镜胆囊切除术')).toHaveLength(2);
expect(screen.getByText('胆囊')).toBeInTheDocument();
});
it('creates a template and stores it globally', async () => {
apiMock.getTemplates.mockResolvedValueOnce([]);
apiMock.createTemplate.mockResolvedValueOnce({
id: 't2',
name: 'New Template',
description: 'desc',
classes: [],
rules: [],
});
render(<TemplateRegistry />);
fireEvent.click(screen.getByText('新建方案'));
fireEvent.change(screen.getAllByRole('textbox')[0], { target: { value: 'New Template' } });
fireEvent.change(screen.getAllByRole('textbox')[1], { target: { value: 'desc' } });
fireEvent.click(screen.getByRole('button', { name: '保存' }));
await waitFor(() => expect(apiMock.createTemplate).toHaveBeenCalledWith(expect.objectContaining({
name: 'New Template',
description: 'desc',
classes: [],
rules: [],
color: '#06b6d4',
z_index: 0,
})));
expect(useStore.getState().templates[0]).toEqual(expect.objectContaining({ id: 't2' }));
});
it('imports JSON classes into the edit modal before saving', async () => {
apiMock.getTemplates.mockResolvedValueOnce([]);
render(<TemplateRegistry />);
fireEvent.click(screen.getByText('新建方案'));
fireEvent.change(screen.getAllByRole('textbox')[0], { target: { value: 'With Classes' } });
fireEvent.click(screen.getByText('批量导入'));
fireEvent.change(screen.getByPlaceholderText('[[[255,0,0], [0,255,0]], ["分类A", "分类B"]]'), {
target: { value: '{"colors":[[255,0,0]],"names":["分类A"]}' },
});
fireEvent.click(screen.getByRole('button', { name: '导入' }));
expect(screen.getByText('分类A')).toBeInTheDocument();
});
});

View File

@@ -0,0 +1,30 @@
import { fireEvent, render, screen } from '@testing-library/react';
import { describe, expect, it, vi } from 'vitest';
import { ToolsPalette } from './ToolsPalette';
describe('ToolsPalette', () => {
it('switches tools and exposes UI-only placeholder buttons', () => {
const setActiveTool = vi.fn();
render(<ToolsPalette activeTool="move" setActiveTool={setActiveTool} />);
fireEvent.click(screen.getByTitle('创建多边形 (P)'));
fireEvent.click(screen.getByTitle('正向选点 (SAM)'));
expect(setActiveTool).toHaveBeenNthCalledWith(1, 'create_polygon');
expect(setActiveTool).toHaveBeenNthCalledWith(2, 'point_pos');
expect(screen.getByTitle('撤销操作 (Ctrl+Z)')).toBeInTheDocument();
expect(screen.getByTitle('重做操作 (Ctrl+Shift+Z)')).toBeInTheDocument();
});
it('switches to SAM trigger and calls the AI navigation hook', () => {
const setActiveTool = vi.fn();
const onTriggerAI = vi.fn();
render(<ToolsPalette activeTool="move" setActiveTool={setActiveTool} onTriggerAI={onTriggerAI} />);
fireEvent.click(screen.getByTitle('触发 SAM 推理 (Enter)'));
expect(setActiveTool).toHaveBeenCalledWith('sam_trigger');
expect(onTriggerAI).toHaveBeenCalled();
});
});

View File

@@ -78,7 +78,7 @@ export function ToolsPalette({ activeTool, setActiveTool, onTriggerAI }: ToolsPa
setActiveTool('sam_trigger'); setActiveTool('sam_trigger');
if (onTriggerAI) onTriggerAI(); if (onTriggerAI) onTriggerAI();
}} }}
title="触发 SAM 3 推理 (Enter)" title="触发 SAM 推理 (Enter)"
className={cn( className={cn(
"w-10 h-10 rounded-lg flex items-center justify-center transition-all", "w-10 h-10 rounded-lg flex items-center justify-center transition-all",
activeTool === 'sam_trigger' activeTool === 'sam_trigger'

View File

@@ -0,0 +1,259 @@
import { act, fireEvent, render, screen, waitFor } from '@testing-library/react';
import { beforeEach, describe, expect, it, vi } from 'vitest';
import { resetStore } from '../test/storeTestUtils';
import { useStore } from '../store/useStore';
import { VideoWorkspace } from './VideoWorkspace';
const apiMock = vi.hoisted(() => ({
getProjectFrames: vi.fn(),
parseMedia: vi.fn(),
getTask: vi.fn(),
getTemplates: vi.fn(),
getProjectAnnotations: vi.fn(),
saveAnnotation: vi.fn(),
updateAnnotation: vi.fn(),
deleteAnnotation: vi.fn(),
exportCoco: vi.fn(),
annotationToMask: vi.fn(),
buildAnnotationPayload: vi.fn(),
getAiModelStatus: vi.fn(),
}));
vi.mock('../lib/api', () => ({
getProjectFrames: apiMock.getProjectFrames,
parseMedia: apiMock.parseMedia,
getTask: apiMock.getTask,
getTemplates: apiMock.getTemplates,
getProjectAnnotations: apiMock.getProjectAnnotations,
saveAnnotation: apiMock.saveAnnotation,
updateAnnotation: apiMock.updateAnnotation,
deleteAnnotation: apiMock.deleteAnnotation,
exportCoco: apiMock.exportCoco,
annotationToMask: apiMock.annotationToMask,
buildAnnotationPayload: apiMock.buildAnnotationPayload,
getAiModelStatus: apiMock.getAiModelStatus,
}));
describe('VideoWorkspace', () => {
beforeEach(() => {
resetStore();
vi.clearAllMocks();
useStore.setState({ currentProject: { id: '1', name: 'Demo', status: 'ready', video_path: 'uploads/demo.mp4' } });
apiMock.getTemplates.mockResolvedValue([]);
apiMock.getProjectAnnotations.mockResolvedValue([]);
apiMock.annotationToMask.mockReturnValue(null);
apiMock.getTask.mockResolvedValue({ id: 1, status: 'success', progress: 100, message: '解析完成' });
apiMock.getAiModelStatus.mockResolvedValue({
selected_model: 'sam2',
gpu: { available: false, device: 'cpu', name: null, torch_available: true },
models: [
{ id: 'sam2', label: 'SAM 2', available: true, loaded: false, device: 'cpu', supports: [], message: 'ready', package_available: true, checkpoint_exists: true, python_ok: true, torch_ok: true, cuda_required: false },
{ id: 'sam3', label: 'SAM 3', available: false, loaded: false, device: 'unavailable', supports: [], message: 'missing', package_available: false, checkpoint_exists: false, python_ok: false, torch_ok: true, cuda_required: true },
],
});
});
it('loads project frames into the workspace store', async () => {
apiMock.getProjectFrames.mockResolvedValueOnce([
{ id: 10, project_id: 1, frame_index: 0, image_url: '/frame.jpg', width: 640, height: 360 },
]);
render(<VideoWorkspace />);
await waitFor(() => expect(useStore.getState().frames).toEqual([
{ id: '10', projectId: '1', index: 0, url: '/frame.jpg', width: 640, height: 360 },
]));
expect(screen.getByText('Demo')).toBeInTheDocument();
expect(apiMock.parseMedia).not.toHaveBeenCalled();
expect(apiMock.getProjectAnnotations).toHaveBeenCalledWith('1');
});
it('triggers parsing when a media project has no frames yet', async () => {
apiMock.getProjectFrames
.mockResolvedValueOnce([])
.mockResolvedValueOnce([
{ id: 11, project_id: 1, frame_index: 0, image_url: '/parsed.jpg', width: 320, height: 240 },
]);
apiMock.parseMedia.mockResolvedValueOnce({ id: 7, status: 'queued', progress: 0 });
apiMock.getTask.mockResolvedValueOnce({ id: 7, status: 'success', progress: 100, message: '解析完成' });
render(<VideoWorkspace />);
await waitFor(() => expect(apiMock.parseMedia).toHaveBeenCalledWith('1'));
expect(apiMock.getTask).toHaveBeenCalledWith(7);
await waitFor(() => expect(useStore.getState().frames[0]).toEqual(expect.objectContaining({
id: '11',
url: '/parsed.jpg',
})));
});
it('hydrates saved annotations after loading frames', async () => {
apiMock.getProjectFrames.mockResolvedValueOnce([
{ id: 10, project_id: 1, frame_index: 0, image_url: '/frame.jpg', width: 640, height: 360 },
]);
apiMock.getProjectAnnotations.mockResolvedValueOnce([{ id: 99, frame_id: 10 }]);
apiMock.annotationToMask.mockReturnValueOnce({
id: 'annotation-99',
annotationId: '99',
frameId: '10',
saved: true,
pathData: 'M 0 0 Z',
label: 'Saved',
color: '#06b6d4',
});
render(<VideoWorkspace />);
await waitFor(() => expect(useStore.getState().masks).toEqual([
expect.objectContaining({ id: 'annotation-99', saved: true }),
]));
});
it('saves pending masks through the archive button', async () => {
apiMock.getProjectFrames.mockResolvedValueOnce([
{ id: 10, project_id: 1, frame_index: 0, image_url: '/frame.jpg', width: 640, height: 360 },
]);
apiMock.buildAnnotationPayload.mockReturnValueOnce({ project_id: 1, frame_id: 10, mask_data: { polygons: [] } });
apiMock.saveAnnotation.mockResolvedValueOnce({ id: 5 });
render(<VideoWorkspace />);
await waitFor(() => expect(useStore.getState().frames).toHaveLength(1));
act(() => {
useStore.setState({
activeTemplateId: '2',
masks: [{
id: 'mask-1',
frameId: '10',
pathData: 'M 0 0 Z',
label: 'AI Mask',
color: '#06b6d4',
segmentation: [[0, 0, 10, 0, 10, 10]],
bbox: [0, 0, 10, 10],
}],
});
});
fireEvent.click(screen.getByRole('button', { name: '结构化归档保存' }));
await waitFor(() => expect(apiMock.saveAnnotation).toHaveBeenCalledWith({
project_id: 1,
frame_id: 10,
mask_data: { polygons: [] },
}));
expect(apiMock.buildAnnotationPayload).toHaveBeenCalledWith(
'1',
expect.objectContaining({ id: 'mask-1' }),
expect.objectContaining({ id: '10' }),
'2',
);
});
it('updates dirty saved masks through the archive button', async () => {
apiMock.getProjectFrames.mockResolvedValueOnce([
{ id: 10, project_id: 1, frame_index: 0, image_url: '/frame.jpg', width: 640, height: 360 },
]);
apiMock.buildAnnotationPayload.mockReturnValueOnce({
project_id: 1,
frame_id: 10,
template_id: 2,
mask_data: { polygons: [], label: '胆囊' },
});
apiMock.updateAnnotation.mockResolvedValueOnce({ id: 99 });
render(<VideoWorkspace />);
await waitFor(() => expect(useStore.getState().frames).toHaveLength(1));
act(() => {
useStore.setState({
activeTemplateId: '2',
masks: [{
id: 'annotation-99',
annotationId: '99',
frameId: '10',
pathData: 'M 0 0 Z',
label: '胆囊',
color: '#ff0000',
saveStatus: 'dirty',
segmentation: [[0, 0, 10, 0, 10, 10]],
bbox: [0, 0, 10, 10],
}],
});
});
fireEvent.click(screen.getByRole('button', { name: '结构化归档保存' }));
await waitFor(() => expect(apiMock.updateAnnotation).toHaveBeenCalledWith('99', {
template_id: 2,
mask_data: { polygons: [], label: '胆囊' },
points: undefined,
bbox: undefined,
}));
expect(apiMock.saveAnnotation).not.toHaveBeenCalled();
});
it('deletes saved annotations when clearing current-frame masks', async () => {
apiMock.getProjectFrames.mockResolvedValueOnce([
{ id: 10, project_id: 1, frame_index: 0, image_url: '/frame.jpg', width: 640, height: 360 },
]);
apiMock.deleteAnnotation.mockResolvedValueOnce(undefined);
render(<VideoWorkspace />);
await waitFor(() => expect(useStore.getState().frames).toHaveLength(1));
act(() => {
useStore.setState({
masks: [
{
id: 'annotation-99',
annotationId: '99',
frameId: '10',
pathData: 'M 0 0 Z',
label: 'Saved',
color: '#06b6d4',
saved: true,
saveStatus: 'saved',
},
{
id: 'draft-1',
frameId: '10',
pathData: 'M 1 1 Z',
label: 'Draft',
color: '#ff0000',
},
],
});
});
fireEvent.click(screen.getByRole('button', { name: '清空遮罩' }));
await waitFor(() => expect(apiMock.deleteAnnotation).toHaveBeenCalledWith('99'));
expect(useStore.getState().masks).toEqual([]);
});
it('auto-saves pending masks before exporting COCO', async () => {
apiMock.getProjectFrames.mockResolvedValueOnce([
{ id: 10, project_id: 1, frame_index: 0, image_url: '/frame.jpg', width: 640, height: 360 },
]);
apiMock.buildAnnotationPayload.mockReturnValueOnce({ project_id: 1, frame_id: 10, mask_data: { polygons: [] } });
apiMock.saveAnnotation.mockResolvedValueOnce({ id: 5 });
apiMock.exportCoco.mockResolvedValueOnce(new Blob(['{}'], { type: 'application/json' }));
render(<VideoWorkspace />);
await waitFor(() => expect(useStore.getState().frames).toHaveLength(1));
act(() => {
useStore.setState({
masks: [{
id: 'mask-1',
frameId: '10',
pathData: 'M 0 0 Z',
label: 'AI Mask',
color: '#06b6d4',
segmentation: [[0, 0, 10, 0, 10, 10]],
}],
});
});
fireEvent.click(screen.getByRole('button', { name: '导出 JSON 标注集' }));
await waitFor(() => expect(apiMock.saveAnnotation).toHaveBeenCalled());
expect(apiMock.exportCoco).toHaveBeenCalledWith('1');
});
});

View File

@@ -1,10 +1,28 @@
import React, { useEffect } from 'react'; import React, { useCallback, useEffect, useMemo, useState } from 'react';
import { useStore } from '../store/useStore'; import { useStore } from '../store/useStore';
import { getProjectFrames, parseMedia, getTemplates } from '../lib/api'; import {
annotationToMask,
buildAnnotationPayload,
deleteAnnotation,
exportCoco,
getProjectAnnotations,
getProjectFrames,
getTask,
getTemplates,
parseMedia,
saveAnnotation,
updateAnnotation,
} from '../lib/api';
import { CanvasArea } from './CanvasArea'; import { CanvasArea } from './CanvasArea';
import { ToolsPalette } from './ToolsPalette'; import { ToolsPalette } from './ToolsPalette';
import { OntologyInspector } from './OntologyInspector'; import { OntologyInspector } from './OntologyInspector';
import { FrameTimeline } from './FrameTimeline'; import { FrameTimeline } from './FrameTimeline';
import { ModelStatusBadge } from './ModelStatusBadge';
import type { Frame } from '../store/useStore';
function sleep(ms: number) {
return new Promise((resolve) => setTimeout(resolve, ms));
}
export function VideoWorkspace({ onNavigateToAI }: { onNavigateToAI?: () => void }) { export function VideoWorkspace({ onNavigateToAI }: { onNavigateToAI?: () => void }) {
const activeTool = useStore((state) => state.activeTool); const activeTool = useStore((state) => state.activeTool);
@@ -12,8 +30,26 @@ export function VideoWorkspace({ onNavigateToAI }: { onNavigateToAI?: () => void
const currentProject = useStore((state) => state.currentProject); const currentProject = useStore((state) => state.currentProject);
const frames = useStore((state) => state.frames); const frames = useStore((state) => state.frames);
const currentFrameIndex = useStore((state) => state.currentFrameIndex); const currentFrameIndex = useStore((state) => state.currentFrameIndex);
const masks = useStore((state) => state.masks);
const activeTemplateId = useStore((state) => state.activeTemplateId);
const setFrames = useStore((state) => state.setFrames); const setFrames = useStore((state) => state.setFrames);
const setCurrentFrame = useStore((state) => state.setCurrentFrame); const setCurrentFrame = useStore((state) => state.setCurrentFrame);
const setMasks = useStore((state) => state.setMasks);
const [isSaving, setIsSaving] = useState(false);
const [isExporting, setIsExporting] = useState(false);
const [statusMessage, setStatusMessage] = useState('');
const hydrateSavedAnnotations = useCallback(async (projectId: string, projectFrames: Frame[]) => {
const frameById = new Map(projectFrames.map((frame) => [frame.id, frame]));
const annotations = await getProjectAnnotations(projectId);
const savedMasks = annotations
.map((annotation) => {
const frame = annotation.frame_id ? frameById.get(String(annotation.frame_id)) : null;
return frame ? annotationToMask(annotation, frame) : null;
})
.filter((mask): mask is NonNullable<typeof mask> => Boolean(mask));
setMasks(savedMasks);
}, [setMasks]);
useEffect(() => { useEffect(() => {
if (!currentProject?.id) return; if (!currentProject?.id) return;
@@ -25,34 +61,58 @@ export function VideoWorkspace({ onNavigateToAI }: { onNavigateToAI?: () => void
if (cancelled) return; if (cancelled) return;
if (data.length === 0 && currentProject.video_path) { if (data.length === 0 && currentProject.video_path) {
// No frames yet but video exists → trigger parsing // No frames yet but video exists -> queue parsing and poll the task.
try { try {
await parseMedia(String(currentProject.id)); const task = await parseMedia(String(currentProject.id));
if (cancelled) return; if (cancelled) return;
setStatusMessage(`解析任务已入队 #${task.id}`);
let completed = false;
for (let attempt = 0; attempt < 60; attempt += 1) {
const freshTask = await getTask(task.id);
if (cancelled) return;
setStatusMessage(freshTask.message || `解析进度 ${freshTask.progress}%`);
if (freshTask.status === 'success') {
completed = true;
break;
}
if (freshTask.status === 'failed') {
setStatusMessage(freshTask.error || '解析任务失败');
return;
}
await sleep(2000);
}
if (!completed) {
setStatusMessage('解析仍在后台运行,可稍后刷新工作区');
return;
}
const fresh = await getProjectFrames(String(currentProject.id)); const fresh = await getProjectFrames(String(currentProject.id));
if (cancelled) return; if (cancelled) return;
setFrames(fresh.map((f) => ({ const mappedFrames = fresh.map((f) => ({
id: String(f.id), id: String(f.id),
projectId: String(f.project_id), projectId: String(f.project_id),
index: f.frame_index, index: f.frame_index,
url: f.image_url, url: f.image_url,
width: f.width ?? 0, width: f.width ?? 0,
height: f.height ?? 0, height: f.height ?? 0,
}))); }));
setFrames(mappedFrames);
setCurrentFrame(0); setCurrentFrame(0);
await hydrateSavedAnnotations(String(currentProject.id), mappedFrames);
} catch (err) { } catch (err) {
console.error('Parse failed:', err); console.error('Parse failed:', err);
} }
} else { } else {
setFrames(data.map((f) => ({ const mappedFrames = data.map((f) => ({
id: String(f.id), id: String(f.id),
projectId: String(f.project_id), projectId: String(f.project_id),
index: f.frame_index, index: f.frame_index,
url: f.image_url, url: f.image_url,
width: f.width ?? 0, width: f.width ?? 0,
height: f.height ?? 0, height: f.height ?? 0,
}))); }));
setFrames(mappedFrames);
setCurrentFrame(0); setCurrentFrame(0);
await hydrateSavedAnnotations(String(currentProject.id), mappedFrames);
} }
} catch (err) { } catch (err) {
console.error('Failed to load frames:', err); console.error('Failed to load frames:', err);
@@ -61,7 +121,7 @@ export function VideoWorkspace({ onNavigateToAI }: { onNavigateToAI?: () => void
loadFrames(); loadFrames();
return () => { cancelled = true; }; return () => { cancelled = true; };
}, [currentProject?.id, setFrames, setCurrentFrame]); }, [currentProject?.id, currentProject?.video_path, hydrateSavedAnnotations, setFrames, setCurrentFrame]);
const templates = useStore((state) => state.templates); const templates = useStore((state) => state.templates);
const setTemplates = useStore((state) => state.setTemplates); const setTemplates = useStore((state) => state.setTemplates);
@@ -72,7 +132,121 @@ export function VideoWorkspace({ onNavigateToAI }: { onNavigateToAI?: () => void
} }
}, [templates.length, setTemplates]); }, [templates.length, setTemplates]);
const currentFrameUrl = frames[currentFrameIndex]?.url || ''; const currentFrame = frames[currentFrameIndex] || null;
const frameById = useMemo(() => new Map(frames.map((frame) => [frame.id, frame])), [frames]);
const projectFrameIds = useMemo(() => new Set(frames.map((frame) => frame.id)), [frames]);
const savePendingAnnotations = useCallback(async ({ silent = false } = {}) => {
if (!currentProject?.id) return 0;
const projectMasks = masks.filter((mask) => projectFrameIds.has(mask.frameId));
const pendingMasks = projectMasks.filter((mask) => !mask.annotationId);
const dirtyMasks = projectMasks.filter((mask) => mask.annotationId && mask.saveStatus === 'dirty');
if (pendingMasks.length === 0 && dirtyMasks.length === 0) {
if (!silent) setStatusMessage('没有待保存标注');
return 0;
}
setIsSaving(true);
setStatusMessage('正在保存标注...');
try {
const createPayloads = pendingMasks
.map((mask) => {
const frame = frameById.get(mask.frameId);
return frame ? buildAnnotationPayload(currentProject.id, mask, frame, activeTemplateId) : null;
})
.filter((payload): payload is NonNullable<typeof payload> => Boolean(payload));
const updatePayloads = dirtyMasks
.map((mask) => {
const frame = frameById.get(mask.frameId);
const payload = frame ? buildAnnotationPayload(currentProject.id, mask, frame, activeTemplateId) : null;
if (!payload || !mask.annotationId) return null;
const updatePayload = {
template_id: payload.template_id,
mask_data: payload.mask_data,
points: payload.points,
bbox: payload.bbox,
};
return { annotationId: mask.annotationId, payload: updatePayload };
})
.filter((item): item is NonNullable<typeof item> => Boolean(item));
if (createPayloads.length === 0 && updatePayloads.length === 0) {
setStatusMessage('没有可保存的标注数据');
return 0;
}
await Promise.all([
...createPayloads.map((payload) => saveAnnotation(payload)),
...updatePayloads.map(({ annotationId, payload }) => updateAnnotation(annotationId, payload)),
]);
await hydrateSavedAnnotations(currentProject.id, frames);
const savedCount = createPayloads.length + updatePayloads.length;
setStatusMessage(`已保存 ${savedCount} 个标注`);
return savedCount;
} catch (err) {
console.error('Save annotations failed:', err);
setStatusMessage('保存失败,请检查后端服务');
throw err;
} finally {
setIsSaving(false);
}
}, [activeTemplateId, currentProject?.id, frameById, frames, hydrateSavedAnnotations, masks, projectFrameIds]);
const handleClearCurrentFrameMasks = useCallback(async () => {
if (!currentFrame) return;
const frameMasks = masks.filter((mask) => mask.frameId === currentFrame.id);
const annotationIds = frameMasks
.map((mask) => mask.annotationId)
.filter((annotationId): annotationId is string => Boolean(annotationId));
setIsSaving(true);
setStatusMessage(annotationIds.length > 0 ? '正在删除已保存标注...' : '正在清空本帧遮罩...');
try {
await Promise.all(annotationIds.map((annotationId) => deleteAnnotation(annotationId)));
setMasks(masks.filter((mask) => mask.frameId !== currentFrame.id));
setStatusMessage(annotationIds.length > 0
? `已删除 ${annotationIds.length} 个后端标注`
: '已清空本帧未保存遮罩');
} catch (err) {
console.error('Delete annotations failed:', err);
setStatusMessage('删除失败,请检查后端服务');
} finally {
setIsSaving(false);
}
}, [currentFrame, masks, setMasks]);
const handleSave = async () => {
try {
await savePendingAnnotations();
} catch {
// status message is set in savePendingAnnotations
}
};
const handleExport = async () => {
if (!currentProject?.id) return;
setIsExporting(true);
setStatusMessage('正在准备导出...');
try {
await savePendingAnnotations({ silent: true });
const blob = await exportCoco(currentProject.id);
const url = URL.createObjectURL(blob);
const link = document.createElement('a');
link.href = url;
link.download = `project_${currentProject.id}_coco.json`;
document.body.appendChild(link);
link.click();
link.remove();
URL.revokeObjectURL(url);
setStatusMessage('COCO JSON 已导出');
} catch (err) {
console.error('Export failed:', err);
setStatusMessage('导出失败,请检查后端服务');
} finally {
setIsExporting(false);
}
};
return ( return (
<div className="w-full h-full flex flex-col bg-[#0a0a0a]"> <div className="w-full h-full flex flex-col bg-[#0a0a0a]">
@@ -84,14 +258,25 @@ export function VideoWorkspace({ onNavigateToAI }: { onNavigateToAI?: () => void
<span className="text-sm text-white font-mono">{currentProject?.name || '未选择项目'}</span> <span className="text-sm text-white font-mono">{currentProject?.name || '未选择项目'}</span>
</div> </div>
<div className="flex items-center gap-3"> <div className="flex items-center gap-3">
<div className="flex items-center gap-1.5 text-[10px] uppercase font-medium"> {statusMessage && (
<span className="px-2 py-0.5 rounded bg-green-500/10 text-green-400 border border-green-500/20">SAM 3 </span> <span className="text-[10px] text-gray-500 font-mono max-w-48 truncate" title={statusMessage}>
</div> {statusMessage}
<button className="px-4 py-1.5 bg-white/5 hover:bg-white/10 border border-white/10 rounded-md text-xs transition-colors text-white"> </span>
JSON )}
<ModelStatusBadge />
<button
onClick={handleExport}
disabled={!currentProject?.id || isExporting || isSaving}
className="px-4 py-1.5 bg-white/5 hover:bg-white/10 border border-white/10 rounded-md text-xs transition-colors text-white disabled:opacity-40 disabled:cursor-not-allowed"
>
{isExporting ? '导出中...' : '导出 JSON 标注集'}
</button> </button>
<button className="px-4 py-1.5 bg-cyan-600 hover:bg-cyan-500 text-white text-xs font-medium rounded-md transition-shadow shadow-lg shadow-cyan-900/20"> <button
onClick={handleSave}
disabled={!currentProject?.id || isSaving || isExporting}
className="px-4 py-1.5 bg-cyan-600 hover:bg-cyan-500 text-white text-xs font-medium rounded-md transition-shadow shadow-lg shadow-cyan-900/20 disabled:opacity-40 disabled:cursor-not-allowed"
>
{isSaving ? '保存中...' : '结构化归档保存'}
</button> </button>
</div> </div>
</div> </div>
@@ -102,7 +287,7 @@ export function VideoWorkspace({ onNavigateToAI }: { onNavigateToAI?: () => void
<div className="flex-1 relative flex items-center justify-center p-8 bg-[#151515] overflow-hidden"> <div className="flex-1 relative flex items-center justify-center p-8 bg-[#151515] overflow-hidden">
<div className="relative w-full h-full bg-[#1e1e1e] border border-white/5 shadow-2xl rounded-sm"> <div className="relative w-full h-full bg-[#1e1e1e] border border-white/5 shadow-2xl rounded-sm">
<CanvasArea activeTool={activeTool} frameUrl={currentFrameUrl} /> <CanvasArea activeTool={activeTool} frame={currentFrame} onClearMasks={handleClearCurrentFrameMasks} />
</div> </div>
</div> </div>

361
src/lib/api.test.ts Normal file
View File

@@ -0,0 +1,361 @@
import { beforeEach, describe, expect, it, vi } from 'vitest';
const axiosMock = vi.hoisted(() => {
const client = {
get: vi.fn(),
post: vi.fn(),
patch: vi.fn(),
delete: vi.fn(),
interceptors: {
request: { use: vi.fn() },
response: { use: vi.fn() },
},
};
return { client, create: vi.fn(() => client) };
});
vi.mock('axios', () => ({
default: {
create: axiosMock.create,
},
}));
describe('api client contracts', () => {
beforeEach(() => {
vi.clearAllMocks();
vi.setSystemTime(new Date('2026-05-01T00:00:00Z'));
});
it('maps backend project fields into frontend project fields', async () => {
const { getProjects } = await import('./api');
axiosMock.client.get.mockResolvedValueOnce({
data: [
{
id: 7,
name: 'Demo',
description: 'desc',
status: 'ready',
frame_count: 12,
original_fps: 29.97,
parse_fps: 10,
thumbnail_url: 'thumb',
video_path: 'uploads/demo.mp4',
source_type: 'video',
created_at: 'created',
updated_at: 'updated',
},
],
});
await expect(getProjects()).resolves.toEqual([
expect.objectContaining({
id: '7',
name: 'Demo',
status: 'ready',
frames: 12,
fps: '30FPS',
thumbnail_url: 'thumb',
video_path: 'uploads/demo.mp4',
source_type: 'video',
createdAt: 'created',
updatedAt: 'updated',
}),
]);
expect(axiosMock.client.get).toHaveBeenCalledWith('/api/projects');
});
it('updates projects with PATCH instead of the old PUT contract', async () => {
const { updateProject } = await import('./api');
axiosMock.client.patch.mockResolvedValueOnce({ data: { id: 3, name: 'Renamed', status: 'ready' } });
await updateProject('3', { name: 'Renamed' } as any);
expect(axiosMock.client.patch).toHaveBeenCalledWith('/api/projects/3', { name: 'Renamed' });
});
it('normalizes legacy project status values returned by existing databases', async () => {
const { getProjects } = await import('./api');
axiosMock.client.get.mockResolvedValueOnce({
data: [
{ id: 1, name: 'Old Ready', status: 'Ready' },
{ id: 2, name: 'Old Parsing', status: 'Parsing' },
{ id: 3, name: 'Old Error', status: 'Error' },
],
});
await expect(getProjects()).resolves.toEqual([
expect.objectContaining({ status: 'ready' }),
expect.objectContaining({ status: 'parsing' }),
expect.objectContaining({ status: 'error' }),
]);
});
it('exports COCO from the backend route shape', async () => {
const { exportCoco } = await import('./api');
const blob = new Blob(['{}'], { type: 'application/json' });
axiosMock.client.get.mockResolvedValueOnce({ data: blob });
await expect(exportCoco('9')).resolves.toBe(blob);
expect(axiosMock.client.get).toHaveBeenCalledWith('/api/export/9/coco', {
responseType: 'blob',
});
});
it('loads dashboard overview from the backend summary endpoint', async () => {
const { getDashboardOverview } = await import('./api');
const overview = {
summary: {
project_count: 2,
parsing_task_count: 1,
annotation_count: 5,
frame_count: 100,
template_count: 3,
system_load_percent: 12,
},
tasks: [
{ id: 'project-1', project_id: 1, name: 'Demo', progress: 60, status: 'pending', frame_count: 10, updated_at: 'now' },
],
activity: [
{ id: 'project-1', kind: 'project', time: 'now', message: '项目状态: pending', project: 'Demo' },
],
};
axiosMock.client.get.mockResolvedValueOnce({ data: overview });
await expect(getDashboardOverview()).resolves.toEqual(overview);
expect(axiosMock.client.get).toHaveBeenCalledWith('/api/dashboard/overview');
});
it('queues media parsing and reads processing task status', async () => {
const { getTask, parseMedia } = await import('./api');
const task = {
id: 12,
task_type: 'parse_video',
status: 'queued',
progress: 0,
message: '解析任务已入队',
project_id: 9,
celery_task_id: 'celery-12',
payload: { source_type: 'video' },
result: null,
error: null,
created_at: 'created',
started_at: null,
finished_at: null,
updated_at: 'updated',
};
axiosMock.client.post.mockResolvedValueOnce({ data: task });
axiosMock.client.get.mockResolvedValueOnce({ data: { ...task, status: 'success', progress: 100 } });
await expect(parseMedia('9')).resolves.toEqual(task);
expect(axiosMock.client.post).toHaveBeenCalledWith('/api/media/parse', null, {
params: { project_id: '9' },
});
await expect(getTask(12)).resolves.toEqual(expect.objectContaining({ status: 'success', progress: 100 }));
expect(axiosMock.client.get).toHaveBeenCalledWith('/api/tasks/12');
});
it('lists, saves, updates, and deletes annotations with the backend annotation contract', async () => {
const { deleteAnnotation, getProjectAnnotations, saveAnnotation, updateAnnotation } = await import('./api');
const saved = {
id: 1,
project_id: 9,
frame_id: 5,
template_id: 2,
mask_data: { polygons: [[[0, 0], [1, 0], [1, 1]]] },
points: null,
bbox: null,
created_at: 'created',
updated_at: 'updated',
};
axiosMock.client.get.mockResolvedValueOnce({ data: [saved] });
axiosMock.client.post.mockResolvedValueOnce({ data: saved });
await expect(getProjectAnnotations('9', '5')).resolves.toEqual([saved]);
expect(axiosMock.client.get).toHaveBeenCalledWith('/api/ai/annotations', {
params: { project_id: 9, frame_id: 5 },
});
await expect(saveAnnotation({
project_id: 9,
frame_id: 5,
template_id: 2,
mask_data: { polygons: [[[0, 0], [1, 0], [1, 1]]], label: 'mask' },
})).resolves.toEqual(saved);
expect(axiosMock.client.post).toHaveBeenCalledWith('/api/ai/annotate', {
project_id: 9,
frame_id: 5,
template_id: 2,
mask_data: { polygons: [[[0, 0], [1, 0], [1, 1]]], label: 'mask' },
});
axiosMock.client.patch.mockResolvedValueOnce({ data: { ...saved, mask_data: { ...saved.mask_data, label: 'updated' } } });
await expect(updateAnnotation('1', {
template_id: 2,
mask_data: { polygons: [[[0, 0], [1, 0], [1, 1]]], label: 'updated' },
})).resolves.toEqual(expect.objectContaining({ mask_data: expect.objectContaining({ label: 'updated' }) }));
expect(axiosMock.client.patch).toHaveBeenCalledWith('/api/ai/annotations/1', {
template_id: 2,
mask_data: { polygons: [[[0, 0], [1, 0], [1, 1]]], label: 'updated' },
});
axiosMock.client.delete.mockResolvedValueOnce({ data: null });
await expect(deleteAnnotation('1')).resolves.toBeUndefined();
expect(axiosMock.client.delete).toHaveBeenCalledWith('/api/ai/annotations/1');
});
it('builds annotation payloads from frontend masks and restores saved annotations to masks', async () => {
const { annotationToMask, buildAnnotationPayload } = await import('./api');
const frame = { id: '5', projectId: '9', index: 0, url: '/frame.jpg', width: 100, height: 50 };
const payload = buildAnnotationPayload('9', {
id: 'm1',
frameId: '5',
pathData: 'M 10 10 L 90 10 L 90 40 Z',
label: '胆囊',
color: '#ff0000',
classId: 'c1',
className: '胆囊',
classZIndex: 20,
segmentation: [[10, 10, 90, 10, 90, 40]],
bbox: [10, 10, 80, 30],
}, frame, '2');
expect(payload).toEqual({
project_id: 9,
frame_id: 5,
template_id: 2,
mask_data: {
polygons: [[[0.1, 0.2], [0.9, 0.2], [0.9, 0.8]]],
label: '胆囊',
color: '#ff0000',
class: { id: 'c1', name: '胆囊', color: '#ff0000', zIndex: 20 },
},
bbox: [0.1, 0.2, 0.8, 0.6],
});
expect(annotationToMask({
id: 3,
project_id: 9,
frame_id: 5,
template_id: 2,
mask_data: {
polygons: [[[0.1, 0.2], [0.9, 0.2], [0.9, 0.8]]],
label: '旧标签',
color: '#06b6d4',
class: { id: 'c1', name: '胆囊', color: '#ff0000', zIndex: 20 },
},
points: null,
bbox: null,
created_at: 'created',
updated_at: 'updated',
}, frame)).toEqual(expect.objectContaining({
id: 'annotation-3',
annotationId: '3',
frameId: '5',
templateId: '2',
classId: 'c1',
className: '胆囊',
classZIndex: 20,
label: '胆囊',
color: '#ff0000',
saveStatus: 'saved',
saved: true,
pathData: 'M 10 10 L 90 10 L 90 40 Z',
bbox: [10, 10, 80, 30],
}));
});
it('normalizes positive and negative point prompts for AI prediction', async () => {
const { predictMask } = await import('./api');
axiosMock.client.post.mockResolvedValueOnce({
data: {
polygons: [[[0.25, 0.25], [0.75, 0.25], [0.75, 0.75], [0.25, 0.75]]],
scores: [0.9],
},
});
const result = await predictMask({
imageId: '42',
imageWidth: 400,
imageHeight: 200,
points: [
{ x: 200, y: 100, type: 'pos' },
{ x: 40, y: 20, type: 'neg' },
],
});
expect(axiosMock.client.post).toHaveBeenCalledWith('/api/ai/predict', {
image_id: 42,
prompt_type: 'point',
prompt_data: {
points: [[0.5, 0.5], [0.1, 0.1]],
labels: [1, 0],
},
model: 'sam2',
});
expect(result.masks[0]).toEqual(expect.objectContaining({
pathData: 'M 100 50 L 300 50 L 300 150 L 100 150 Z',
segmentation: [[100, 50, 300, 50, 300, 150, 100, 150]],
bbox: [100, 50, 200, 100],
area: 20000,
confidence: 0.9,
}));
});
it('normalizes box prompts for AI prediction', async () => {
const { predictMask } = await import('./api');
axiosMock.client.post.mockResolvedValueOnce({ data: { polygons: [], scores: [] } });
await predictMask({
imageId: '5',
imageWidth: 640,
imageHeight: 320,
box: { x1: 64, y1: 32, x2: 320, y2: 160 },
});
expect(axiosMock.client.post).toHaveBeenCalledWith('/api/ai/predict', {
image_id: 5,
prompt_type: 'box',
prompt_data: [0.1, 0.1, 0.5, 0.5],
model: 'sam2',
});
});
it('uses semantic prompt type for text-only AI prediction', async () => {
const { predictMask } = await import('./api');
axiosMock.client.post.mockResolvedValueOnce({ data: { polygons: [], scores: [] } });
await predictMask({
imageId: '6',
imageWidth: 640,
imageHeight: 360,
model: 'sam3',
text: '分割胆囊',
});
expect(axiosMock.client.post).toHaveBeenCalledWith('/api/ai/predict', {
image_id: 6,
prompt_type: 'semantic',
prompt_data: '分割胆囊',
model: 'sam3',
});
});
it('loads AI model and GPU runtime status', async () => {
const { getAiModelStatus } = await import('./api');
const status = {
selected_model: 'sam2',
gpu: { available: false, device: 'cpu', name: null, torch_available: true, torch_version: '2.x', cuda_version: null },
models: [
{ id: 'sam2', label: 'SAM 2', available: true, loaded: false, device: 'cpu', supports: ['point'], message: 'ready', package_available: true, checkpoint_exists: true, checkpoint_path: 'model.pt', python_ok: true, torch_ok: true, cuda_required: false },
{ id: 'sam3', label: 'SAM 3', available: false, loaded: false, device: 'unavailable', supports: ['semantic'], message: 'missing runtime', package_available: false, checkpoint_exists: false, checkpoint_path: null, python_ok: false, torch_ok: true, cuda_required: true },
],
};
axiosMock.client.get.mockResolvedValueOnce({ data: status });
await expect(getAiModelStatus('sam3')).resolves.toEqual(status);
expect(axiosMock.client.get).toHaveBeenCalledWith('/api/ai/models/status', {
params: { selected_model: 'sam3' },
});
});
});

View File

@@ -1,8 +1,9 @@
import axios, { AxiosError } from 'axios'; import axios, { AxiosError } from 'axios';
import type { Project, Template } from '../store/useStore'; import type { AiModelId, Frame, Mask, Project, Template } from '../store/useStore';
import { API_BASE_URL } from './config';
const apiClient = axios.create({ const apiClient = axios.create({
baseURL: 'http://192.168.3.11:8000', baseURL: API_BASE_URL,
headers: { headers: {
'Content-Type': 'application/json', 'Content-Type': 'application/json',
}, },
@@ -40,37 +41,20 @@ export async function login(username: string, password: string): Promise<{ token
} }
// Projects // Projects
export async function getProjects(): Promise<Project[]> { function normalizeProjectStatus(status?: string): Project['status'] {
const response = await apiClient.get('/api/projects'); const value = (status || 'pending').toLowerCase();
return response.data.map((p: any) => ({ if (value === 'ready') return 'ready';
id: String(p.id), if (value === 'parsing' || value === 'queued' || value === 'running') return 'parsing';
name: p.name, if (value === 'error' || value === 'failed') return 'error';
description: p.description, return 'pending';
status: p.status,
frames: p.frame_count ?? 0,
fps: p.original_fps ? `${Math.round(p.original_fps)}FPS` : '30FPS',
thumbnail_url: p.thumbnail_url,
video_path: p.video_path,
source_type: p.source_type,
original_fps: p.original_fps,
parse_fps: p.parse_fps,
createdAt: p.created_at,
updatedAt: p.updated_at,
}));
} }
export async function createProject(payload: { function mapProject(p: any): Project {
name: string;
description?: string;
parse_fps?: number;
}): Promise<Project> {
const response = await apiClient.post('/api/projects', payload);
const p = response.data;
return { return {
id: String(p.id), id: String(p.id),
name: p.name, name: p.name,
description: p.description, description: p.description,
status: p.status, status: normalizeProjectStatus(p.status),
frames: p.frame_count ?? 0, frames: p.frame_count ?? 0,
fps: p.original_fps ? `${Math.round(p.original_fps)}FPS` : '30FPS', fps: p.original_fps ? `${Math.round(p.original_fps)}FPS` : '30FPS',
thumbnail_url: p.thumbnail_url, thumbnail_url: p.thumbnail_url,
@@ -83,9 +67,23 @@ export async function createProject(payload: {
}; };
} }
export async function getProjects(): Promise<Project[]> {
const response = await apiClient.get('/api/projects');
return response.data.map(mapProject);
}
export async function createProject(payload: {
name: string;
description?: string;
parse_fps?: number;
}): Promise<Project> {
const response = await apiClient.post('/api/projects', payload);
return mapProject(response.data);
}
export async function updateProject(id: string, payload: Partial<Project>): Promise<Project> { export async function updateProject(id: string, payload: Partial<Project>): Promise<Project> {
const response = await apiClient.put(`/api/projects/${id}`, payload); const response = await apiClient.patch(`/api/projects/${id}`, payload);
return response.data; return mapProject(response.data);
} }
export async function deleteProject(id: string): Promise<void> { export async function deleteProject(id: string): Promise<void> {
@@ -170,26 +168,46 @@ export async function uploadDicomBatch(files: File[], projectId?: string): Promi
return response.data; return response.data;
} }
export async function parseMedia(projectId: string): Promise<{ export interface ProcessingTask {
project_id: number; id: number;
frames_extracted: number; task_type: string;
status: string; status: 'queued' | 'running' | 'success' | 'failed' | string;
message: string; progress: number;
}> { message?: string | null;
project_id?: number | null;
celery_task_id?: string | null;
payload?: Record<string, unknown> | null;
result?: Record<string, unknown> | null;
error?: string | null;
created_at: string;
started_at?: string | null;
finished_at?: string | null;
updated_at: string;
}
export async function parseMedia(projectId: string): Promise<ProcessingTask> {
const response = await apiClient.post('/api/media/parse', null, { const response = await apiClient.post('/api/media/parse', null, {
params: { project_id: projectId }, params: { project_id: projectId },
}); });
return response.data; return response.data;
} }
// AI Prediction export async function getTask(taskId: string | number): Promise<ProcessingTask> {
export async function predictMask(payload: { const response = await apiClient.get(`/api/tasks/${taskId}`);
imageUrl: string; return response.data;
}
interface PredictMaskPayload {
imageId: string;
imageWidth: number;
imageHeight: number;
model?: AiModelId;
points?: { x: number; y: number; type: 'pos' | 'neg' }[]; points?: { x: number; y: number; type: 'pos' | 'neg' }[];
box?: { x1: number; y1: number; x2: number; y2: number }; box?: { x1: number; y1: number; x2: number; y2: number };
text?: string; text?: string;
modelSize?: string; }
}): Promise<{
interface PredictMaskResult {
masks: Array<{ masks: Array<{
id: string; id: string;
pathData: string; pathData: string;
@@ -200,14 +218,319 @@ export async function predictMask(payload: {
area: number; area: number;
confidence: number; confidence: number;
}>; }>;
}> { }
const response = await apiClient.post('/api/ai/predict', payload);
export interface AiModelStatus {
id: AiModelId;
label: string;
available: boolean;
loaded: boolean;
device: string;
supports: string[];
message: string;
package_available: boolean;
checkpoint_exists: boolean;
checkpoint_path?: string | null;
python_ok: boolean;
torch_ok: boolean;
cuda_required: boolean;
}
export interface AiRuntimeStatus {
selected_model: AiModelId;
gpu: {
available: boolean;
device: string;
name?: string | null;
torch_available: boolean;
torch_version?: string | null;
cuda_version?: string | null;
};
models: AiModelStatus[];
}
export interface SavedAnnotation {
id: number;
project_id: number;
frame_id: number | null;
template_id: number | null;
mask_data: {
polygons?: number[][][];
label?: string;
color?: string;
class?: {
id?: string;
name?: string;
color?: string;
zIndex?: number;
category?: string;
};
} | null;
points: number[][] | null;
bbox: number[] | null;
created_at: string;
updated_at: string;
}
export interface SaveAnnotationPayload {
project_id: number;
frame_id?: number;
template_id?: number;
mask_data?: {
polygons: number[][][];
label?: string;
color?: string;
class?: {
id?: string;
name?: string;
color?: string;
zIndex?: number;
category?: string;
};
};
points?: number[][];
bbox?: number[];
}
export type UpdateAnnotationPayload = Omit<SaveAnnotationPayload, 'project_id' | 'frame_id'>;
export interface DashboardTask {
id: string;
task_id?: number;
project_id: number;
name: string;
progress: number;
status: string;
frame_count: number;
updated_at: string | null;
}
export interface DashboardActivity {
id: string;
kind: 'project' | 'annotation' | 'template' | string;
time: string | null;
message: string;
project: string;
}
export interface DashboardOverview {
summary: {
project_count: number;
parsing_task_count: number;
annotation_count: number;
frame_count: number;
template_count: number;
system_load_percent: number;
};
tasks: DashboardTask[];
activity: DashboardActivity[];
}
function clamp01(value: number): number {
return Math.min(Math.max(value, 0), 1);
}
function normalizePoint(point: { x: number; y: number }, width: number, height: number): [number, number] {
return [
clamp01(point.x / Math.max(width, 1)),
clamp01(point.y / Math.max(height, 1)),
];
}
function polygonToPath(points: number[][], width: number, height: number): string {
if (points.length === 0) return '';
return points
.map(([x, y], index) => {
const px = x * width;
const py = y * height;
return `${index === 0 ? 'M' : 'L'} ${px} ${py}`;
})
.join(' ')
.concat(' Z');
}
function polygonToBbox(points: number[][], width: number, height: number): [number, number, number, number] {
const xs = points.map(([x]) => x * width);
const ys = points.map(([, y]) => y * height);
const minX = Math.min(...xs);
const minY = Math.min(...ys);
const maxX = Math.max(...xs);
const maxY = Math.max(...ys);
return [minX, minY, maxX - minX, maxY - minY];
}
function pixelSegmentationToNormalizedPolygons(
segmentation: number[][] | undefined,
width: number,
height: number,
): number[][][] {
if (!segmentation) return [];
return segmentation
.map((poly) => {
const points: number[][] = [];
for (let i = 0; i < poly.length - 1; i += 2) {
points.push([
clamp01(poly[i] / Math.max(width, 1)),
clamp01(poly[i + 1] / Math.max(height, 1)),
]);
}
return points;
})
.filter((points) => points.length > 0);
}
export function buildAnnotationPayload(
projectId: string,
mask: Mask,
frame: Frame,
templateId?: string | null,
): SaveAnnotationPayload | null {
const polygons = pixelSegmentationToNormalizedPolygons(mask.segmentation, frame.width, frame.height);
if (polygons.length === 0) return null;
const effectiveTemplateId = mask.templateId || templateId || undefined;
const classMetadata = mask.classId || mask.className || mask.classZIndex !== undefined
? {
id: mask.classId,
name: mask.className || mask.label,
color: mask.color,
zIndex: mask.classZIndex,
}
: undefined;
return {
project_id: Number(projectId),
frame_id: Number(frame.id),
template_id: effectiveTemplateId ? Number(effectiveTemplateId) : undefined,
mask_data: {
polygons,
label: mask.label,
color: mask.color,
...(classMetadata ? { class: classMetadata } : {}),
},
bbox: mask.bbox
? [
clamp01(mask.bbox[0] / Math.max(frame.width, 1)),
clamp01(mask.bbox[1] / Math.max(frame.height, 1)),
clamp01(mask.bbox[2] / Math.max(frame.width, 1)),
clamp01(mask.bbox[3] / Math.max(frame.height, 1)),
]
: undefined,
};
}
export function annotationToMask(annotation: SavedAnnotation, frame: Frame): Mask | null {
const polygons = annotation.mask_data?.polygons || [];
const firstPolygon = polygons[0];
if (!firstPolygon || firstPolygon.length === 0) return null;
const bbox = polygonToBbox(firstPolygon, frame.width, frame.height);
const classMetadata = annotation.mask_data?.class;
return {
id: `annotation-${annotation.id}`,
annotationId: String(annotation.id),
frameId: String(annotation.frame_id),
templateId: annotation.template_id ? String(annotation.template_id) : undefined,
classId: classMetadata?.id,
className: classMetadata?.name,
classZIndex: classMetadata?.zIndex,
saveStatus: 'saved',
saved: true,
pathData: polygonToPath(firstPolygon, frame.width, frame.height),
label: classMetadata?.name || annotation.mask_data?.label || `Annotation ${annotation.id}`,
color: classMetadata?.color || annotation.mask_data?.color || '#06b6d4',
segmentation: polygons.map((polygon) => polygon.flatMap(([x, y]) => [x * frame.width, y * frame.height])),
bbox,
area: bbox[2] * bbox[3],
};
}
export async function predictMask(payload: PredictMaskPayload): Promise<PredictMaskResult> {
let prompt_type: 'point' | 'box' | 'semantic';
let prompt_data: unknown;
if (payload.box) {
prompt_type = 'box';
prompt_data = [
clamp01(payload.box.x1 / Math.max(payload.imageWidth, 1)),
clamp01(payload.box.y1 / Math.max(payload.imageHeight, 1)),
clamp01(payload.box.x2 / Math.max(payload.imageWidth, 1)),
clamp01(payload.box.y2 / Math.max(payload.imageHeight, 1)),
];
} else if (payload.points && payload.points.length > 0) {
prompt_type = 'point';
prompt_data = {
points: payload.points.map((point) => normalizePoint(point, payload.imageWidth, payload.imageHeight)),
labels: payload.points.map((point) => (point.type === 'neg' ? 0 : 1)),
};
} else {
prompt_type = 'semantic';
prompt_data = payload.text?.trim() || '';
}
const response = await apiClient.post('/api/ai/predict', {
image_id: Number(payload.imageId),
prompt_type,
prompt_data,
model: payload.model || 'sam2',
});
const polygons: number[][][] = response.data.polygons || [];
const scores: number[] = response.data.scores || [];
return {
masks: polygons.map((polygon, index) => {
const bbox = polygonToBbox(polygon, payload.imageWidth, payload.imageHeight);
return {
id: `mask-${payload.imageId}-${Date.now()}-${index}`,
pathData: polygonToPath(polygon, payload.imageWidth, payload.imageHeight),
label: prompt_type === 'semantic' ? (payload.text?.trim() || 'AI Mask') : 'AI Mask',
color: '#06b6d4',
segmentation: [polygon.flatMap(([x, y]) => [x * payload.imageWidth, y * payload.imageHeight])],
bbox,
area: bbox[2] * bbox[3],
confidence: scores[index] ?? 0,
};
}),
};
}
export async function getAiModelStatus(selectedModel?: AiModelId): Promise<AiRuntimeStatus> {
const response = await apiClient.get('/api/ai/models/status', {
params: selectedModel ? { selected_model: selectedModel } : undefined,
});
return response.data;
}
export async function getProjectAnnotations(projectId: string, frameId?: string): Promise<SavedAnnotation[]> {
const response = await apiClient.get('/api/ai/annotations', {
params: {
project_id: Number(projectId),
...(frameId ? { frame_id: Number(frameId) } : {}),
},
});
return response.data;
}
export async function saveAnnotation(payload: SaveAnnotationPayload): Promise<SavedAnnotation> {
const response = await apiClient.post('/api/ai/annotate', payload);
return response.data;
}
export async function updateAnnotation(annotationId: string, payload: UpdateAnnotationPayload): Promise<SavedAnnotation> {
const response = await apiClient.patch(`/api/ai/annotations/${annotationId}`, payload);
return response.data;
}
export async function deleteAnnotation(annotationId: string): Promise<void> {
await apiClient.delete(`/api/ai/annotations/${annotationId}`);
}
export async function getDashboardOverview(): Promise<DashboardOverview> {
const response = await apiClient.get('/api/dashboard/overview');
return response.data; return response.data;
} }
// Export // Export
export async function exportCoco(projectId: string): Promise<Blob> { export async function exportCoco(projectId: string): Promise<Blob> {
const response = await apiClient.get(`/api/export/coco/${projectId}`, { const response = await apiClient.get(`/api/export/${projectId}/coco`, {
responseType: 'blob', responseType: 'blob',
}); });
return response.data; return response.data;

38
src/lib/config.test.ts Normal file
View File

@@ -0,0 +1,38 @@
import { afterEach, describe, expect, it, vi } from 'vitest';
describe('frontend runtime config', () => {
afterEach(() => {
vi.unstubAllEnvs();
vi.resetModules();
});
it('prefers explicit VITE_API_BASE_URL and trims trailing slashes', async () => {
vi.stubEnv('VITE_API_BASE_URL', 'http://api.example.test:8000///');
const config = await import('./config');
expect(config.API_BASE_URL).toBe('http://api.example.test:8000');
});
it('infers the API host from the current browser hostname', async () => {
const config = await import('./config');
expect(config.API_BASE_URL).toBe('http://seg.local:8000');
});
it('derives websocket URL from API URL unless explicitly configured', async () => {
vi.stubEnv('VITE_API_BASE_URL', 'https://seg.example.test');
const config = await import('./config');
expect(config.WS_PROGRESS_URL).toBe('wss://seg.example.test/ws/progress');
});
it('prefers explicit VITE_WS_PROGRESS_URL', async () => {
vi.stubEnv('VITE_WS_PROGRESS_URL', 'ws://custom/ws/progress');
const config = await import('./config');
expect(config.WS_PROGRESS_URL).toBe('ws://custom/ws/progress');
});
});

29
src/lib/config.ts Normal file
View File

@@ -0,0 +1,29 @@
const DEFAULT_API_BASE_URL = 'http://192.168.3.11:8000';
function trimTrailingSlash(value: string): string {
return value.replace(/\/+$/, '');
}
function inferApiBaseUrl(): string {
const envUrl = import.meta.env.VITE_API_BASE_URL;
if (envUrl) return trimTrailingSlash(envUrl);
if (typeof window !== 'undefined' && window.location.hostname) {
return `${window.location.protocol}//${window.location.hostname}:8000`;
}
return DEFAULT_API_BASE_URL;
}
export const API_BASE_URL = inferApiBaseUrl();
function inferWsProgressUrl(): string {
const envUrl = import.meta.env.VITE_WS_PROGRESS_URL;
if (envUrl) return envUrl;
const url = new URL('/ws/progress', API_BASE_URL);
url.protocol = url.protocol === 'https:' ? 'wss:' : 'ws:';
return url.toString();
}
export const WS_PROGRESS_URL = inferWsProgressUrl();

View File

@@ -0,0 +1,15 @@
import type { Template, TemplateClass } from '../store/useStore';
export function getActiveTemplate(templates: Template[], activeTemplateId: string | null): Template | null {
return templates.find((template) => template.id === activeTemplateId) || templates[0] || null;
}
export function getActiveClass(
templates: Template[],
activeTemplateId: string | null,
activeClassId: string | null,
): TemplateClass | null {
const template = getActiveTemplate(templates, activeTemplateId);
if (!template) return null;
return template.classes.find((templateClass) => templateClass.id === activeClassId) || null;
}

46
src/lib/websocket.test.ts Normal file
View File

@@ -0,0 +1,46 @@
import { afterEach, describe, expect, it, vi } from 'vitest';
describe('progress websocket client', () => {
afterEach(() => {
vi.restoreAllMocks();
vi.resetModules();
vi.unstubAllGlobals();
});
it('connects using the configured URL and reports open state', async () => {
const instances: any[] = [];
class FakeWebSocket {
static CONNECTING = 0;
static OPEN = 1;
readyState = FakeWebSocket.OPEN;
onopen?: () => void;
onmessage?: (event: MessageEvent) => void;
onclose?: () => void;
onerror?: () => void;
constructor(public url: string) {
instances.push(this);
}
close = vi.fn();
}
vi.stubGlobal('WebSocket', FakeWebSocket);
const { progressWS } = await import('./websocket');
progressWS.connect();
expect(instances[0].url).toContain('/ws/progress');
expect(progressWS.isConnected()).toBe(true);
});
it('subscribes and unsubscribes progress callbacks', async () => {
const { progressWS } = await import('./websocket');
const callback = vi.fn();
const unsubscribe = progressWS.onProgress(callback);
(progressWS as any).callbacks.forEach((cb: any) => cb({ type: 'status', message: 'ok' }));
unsubscribe();
(progressWS as any).callbacks.forEach((cb: any) => cb({ type: 'status', message: 'again' }));
expect(callback).toHaveBeenCalledTimes(1);
expect(callback).toHaveBeenCalledWith({ type: 'status', message: 'ok' });
});
});

View File

@@ -1,12 +1,18 @@
import { WS_PROGRESS_URL } from './config';
type ProgressCallback = (data: ProgressMessage) => void; type ProgressCallback = (data: ProgressMessage) => void;
interface ProgressMessage { interface ProgressMessage {
type: 'progress' | 'status' | 'error' | 'complete'; type: 'progress' | 'status' | 'error' | 'complete';
taskId?: string; taskId?: string;
task_id?: number;
project_id?: number;
projectName?: string;
filename?: string; filename?: string;
progress?: number; progress?: number;
status?: string; status?: string;
message?: string; message?: string;
error?: string;
timestamp?: string; timestamp?: string;
} }
@@ -21,7 +27,7 @@ class ProgressWebSocket {
private shouldCloseAfterOpen = false; private shouldCloseAfterOpen = false;
private currentInterval = 3000; private currentInterval = 3000;
constructor(url = 'ws://192.168.3.11:8000/ws/progress') { constructor(url = WS_PROGRESS_URL) {
this.url = url; this.url = url;
} }

View File

@@ -0,0 +1,56 @@
import { beforeEach, describe, expect, it } from 'vitest';
import { resetStore } from '../test/storeTestUtils';
import { useStore } from './useStore';
describe('useStore', () => {
beforeEach(() => {
resetStore();
});
it('stores and clears auth state with localStorage', () => {
useStore.getState().login('token-1');
expect(useStore.getState().isAuthenticated).toBe(true);
expect(useStore.getState().token).toBe('token-1');
expect(localStorage.getItem('token')).toBe('token-1');
useStore.getState().logout();
expect(useStore.getState().isAuthenticated).toBe(false);
expect(useStore.getState().projects).toEqual([]);
expect(useStore.getState().frames).toEqual([]);
expect(localStorage.getItem('token')).toBeNull();
});
it('manages projects, frames, masks, annotations and templates', () => {
const project = { id: '1', name: 'Project', status: 'ready' as const };
useStore.getState().addProject(project);
useStore.getState().updateProject({ ...project, name: 'Updated' });
useStore.getState().setCurrentProject(project);
useStore.getState().setFrames([{ id: 'f1', projectId: '1', index: 0, url: '/f1.jpg', width: 640, height: 360 }]);
useStore.getState().setCurrentFrame(0);
useStore.getState().addMask({ id: 'm1', frameId: 'f1', pathData: 'M 0 0 Z', label: 'mask', color: '#fff' });
useStore.getState().updateMask('m1', { label: 'updated mask', saveStatus: 'dirty' });
useStore.getState().addAnnotation({ id: 'a1', frameId: 'f1', type: 'mask', points: [], label: 'ann', color: '#fff' });
useStore.getState().addTemplate({ id: 't1', name: 'Template', classes: [], rules: [] });
useStore.getState().updateTemplate({ id: 't1', name: 'Template 2', classes: [], rules: [] });
useStore.getState().setActiveClass({ id: 'c1', name: '胆囊', color: '#ff0000', zIndex: 10 });
expect(useStore.getState().projects[0].name).toBe('Updated');
expect(useStore.getState().currentProject?.id).toBe('1');
expect(useStore.getState().frames).toHaveLength(1);
expect(useStore.getState().currentFrameIndex).toBe(0);
expect(useStore.getState().masks[0]).toEqual(expect.objectContaining({ label: 'updated mask', saveStatus: 'dirty' }));
expect(useStore.getState().annotations).toHaveLength(1);
expect(useStore.getState().templates[0].name).toBe('Template 2');
expect(useStore.getState().activeClassId).toBe('c1');
useStore.getState().removeAnnotation('a1');
useStore.getState().clearMasks();
useStore.getState().removeTemplate('t1');
expect(useStore.getState().annotations).toEqual([]);
expect(useStore.getState().masks).toEqual([]);
expect(useStore.getState().templates).toEqual([]);
});
});

View File

@@ -4,7 +4,7 @@ export interface Project {
id: string; id: string;
name: string; name: string;
description?: string; description?: string;
status: 'Ready' | 'Parsing' | 'Error'; status: 'pending' | 'parsing' | 'ready' | 'error';
fps?: string; fps?: string;
frames?: number; frames?: number;
thumbnail?: string; thumbnail?: string;
@@ -17,6 +17,8 @@ export interface Project {
updatedAt?: string; updatedAt?: string;
} }
export type AiModelId = 'sam2' | 'sam3';
export interface Frame { export interface Frame {
id: string; id: string;
projectId: string; projectId: string;
@@ -42,6 +44,13 @@ export interface Annotation {
export interface Mask { export interface Mask {
id: string; id: string;
frameId: string; frameId: string;
annotationId?: string;
templateId?: string;
classId?: string;
className?: string;
classZIndex?: number;
saveStatus?: 'draft' | 'saved' | 'dirty' | 'saving' | 'error';
saved?: boolean;
pathData: string; pathData: string;
label: string; label: string;
color: string; color: string;
@@ -96,24 +105,32 @@ export interface AppState {
// Workspace // Workspace
activeModule: string; activeModule: string;
activeTool: string; activeTool: string;
aiModel: AiModelId;
frames: Frame[]; frames: Frame[];
currentFrameIndex: number; currentFrameIndex: number;
annotations: Annotation[]; annotations: Annotation[];
masks: Mask[]; masks: Mask[];
setActiveModule: (module: string) => void; setActiveModule: (module: string) => void;
setActiveTool: (tool: string) => void; setActiveTool: (tool: string) => void;
setAiModel: (model: AiModelId) => void;
setFrames: (frames: Frame[]) => void; setFrames: (frames: Frame[]) => void;
setCurrentFrame: (index: number) => void; setCurrentFrame: (index: number) => void;
addAnnotation: (annotation: Annotation) => void; addAnnotation: (annotation: Annotation) => void;
addMask: (mask: Mask) => void; addMask: (mask: Mask) => void;
updateMask: (id: string, updates: Partial<Mask>) => void;
setMasks: (masks: Mask[]) => void;
clearMasks: () => void; clearMasks: () => void;
removeAnnotation: (id: string) => void; removeAnnotation: (id: string) => void;
// Templates // Templates
templates: Template[]; templates: Template[];
activeTemplateId: string | null; activeTemplateId: string | null;
activeClassId: string | null;
activeClass: TemplateClass | null;
setTemplates: (templates: Template[]) => void; setTemplates: (templates: Template[]) => void;
setActiveTemplateId: (id: string | null) => void; setActiveTemplateId: (id: string | null) => void;
setActiveClassId: (id: string | null) => void;
setActiveClass: (templateClass: TemplateClass | null) => void;
addTemplate: (template: Template) => void; addTemplate: (template: Template) => void;
updateTemplate: (template: Template) => void; updateTemplate: (template: Template) => void;
removeTemplate: (id: string) => void; removeTemplate: (id: string) => void;
@@ -144,6 +161,9 @@ export const useStore = create<AppState>((set) => ({
frames: [], frames: [],
annotations: [], annotations: [],
masks: [], masks: [],
activeTemplateId: null,
activeClassId: null,
activeClass: null,
}); });
}, },
@@ -162,18 +182,25 @@ export const useStore = create<AppState>((set) => ({
// Workspace // Workspace
activeModule: 'workspace', activeModule: 'workspace',
activeTool: 'move', activeTool: 'move',
aiModel: 'sam2',
frames: [], frames: [],
currentFrameIndex: 0, currentFrameIndex: 0,
annotations: [], annotations: [],
masks: [], masks: [],
setActiveModule: (activeModule: string) => set({ activeModule }), setActiveModule: (activeModule: string) => set({ activeModule }),
setActiveTool: (activeTool: string) => set({ activeTool }), setActiveTool: (activeTool: string) => set({ activeTool }),
setAiModel: (aiModel: AiModelId) => set({ aiModel }),
setFrames: (frames: Frame[]) => set({ frames }), setFrames: (frames: Frame[]) => set({ frames }),
setCurrentFrame: (currentFrameIndex: number) => set({ currentFrameIndex }), setCurrentFrame: (currentFrameIndex: number) => set({ currentFrameIndex }),
addAnnotation: (annotation: Annotation) => addAnnotation: (annotation: Annotation) =>
set((state) => ({ annotations: [...state.annotations, annotation] })), set((state) => ({ annotations: [...state.annotations, annotation] })),
addMask: (mask: Mask) => addMask: (mask: Mask) =>
set((state) => ({ masks: [...state.masks, mask] })), set((state) => ({ masks: [...state.masks, mask] })),
updateMask: (id: string, updates: Partial<Mask>) =>
set((state) => ({
masks: state.masks.map((mask) => (mask.id === id ? { ...mask, ...updates } : mask)),
})),
setMasks: (masks: Mask[]) => set({ masks }),
clearMasks: () => set({ masks: [] }), clearMasks: () => set({ masks: [] }),
removeAnnotation: (id: string) => removeAnnotation: (id: string) =>
set((state) => ({ set((state) => ({
@@ -183,8 +210,15 @@ export const useStore = create<AppState>((set) => ({
// Templates // Templates
templates: [], templates: [],
activeTemplateId: null, activeTemplateId: null,
activeClassId: null,
activeClass: null,
setTemplates: (templates: Template[]) => set({ templates }), setTemplates: (templates: Template[]) => set({ templates }),
setActiveTemplateId: (activeTemplateId: string | null) => set({ activeTemplateId }), setActiveTemplateId: (activeTemplateId: string | null) => set({ activeTemplateId }),
setActiveClassId: (activeClassId: string | null) => set({ activeClassId }),
setActiveClass: (activeClass: TemplateClass | null) => set({
activeClass,
activeClassId: activeClass?.id || null,
}),
addTemplate: (template: Template) => addTemplate: (template: Template) =>
set((state) => ({ templates: [...state.templates, template] })), set((state) => ({ templates: [...state.templates, template] })),
updateTemplate: (template: Template) => updateTemplate: (template: Template) =>

66
src/test/setup.tsx Normal file
View File

@@ -0,0 +1,66 @@
import React from 'react';
import { afterEach, vi } from 'vitest';
import { cleanup } from '@testing-library/react';
import '@testing-library/jest-dom/vitest';
afterEach(() => {
cleanup();
localStorage.clear();
});
vi.stubGlobal('alert', vi.fn());
vi.stubGlobal('confirm', vi.fn(() => true));
URL.createObjectURL = vi.fn(() => 'blob:mock-url');
URL.revokeObjectURL = vi.fn();
HTMLAnchorElement.prototype.click = vi.fn();
function makeStageEvent(x = 120, y = 80) {
const stage = {
getPointerPosition: () => ({ x, y }),
getRelativePointerPosition: () => ({ x, y }),
scaleX: () => 1,
x: () => 0,
y: () => 0,
};
return {
evt: { preventDefault: vi.fn(), deltaY: -1 },
target: {
getStage: () => stage,
},
};
}
vi.mock('react-konva', () => ({
Stage: ({ children, onClick, onMouseDown, onMouseUp, onMouseMove, onWheel }: any) => (
<div
data-testid="konva-stage"
onClick={() => onClick?.(makeStageEvent())}
onMouseDown={() => onMouseDown?.(makeStageEvent())}
onMouseUp={() => onMouseUp?.(makeStageEvent(260, 200))}
onMouseMove={() => onMouseMove?.(makeStageEvent(180, 120))}
onWheel={() => onWheel?.(makeStageEvent())}
>
{children}
</div>
),
Layer: ({ children }: any) => <div data-testid="konva-layer">{children}</div>,
Group: ({ children }: any) => <div data-testid="konva-group">{children}</div>,
Image: ({ image }: any) => <img data-testid="konva-image" alt="" src={image?.src || ''} />,
Circle: (props: any) => <span data-testid="konva-circle" data-fill={props.fill} />,
Rect: (props: any) => <span data-testid="konva-rect" data-width={props.width} />,
Path: (props: any) => <span data-testid="konva-path" data-path={props.data} data-fill={props.fill} />,
}));
vi.mock('use-image', () => ({
default: (src: string) => [
{
src,
width: 640,
height: 360,
naturalWidth: 640,
naturalHeight: 360,
},
'loaded',
],
}));

View File

@@ -0,0 +1,23 @@
import { useStore } from '../store/useStore';
export function resetStore() {
useStore.setState({
isAuthenticated: false,
token: null,
projects: [],
currentProject: null,
activeModule: 'workspace',
activeTool: 'move',
aiModel: 'sam2',
frames: [],
currentFrameIndex: 0,
annotations: [],
masks: [],
templates: [],
activeTemplateId: null,
activeClassId: null,
activeClass: null,
isLoading: false,
error: null,
});
}

6
src/vite-env.d.ts vendored Normal file
View File

@@ -0,0 +1,6 @@
/// <reference types="vite/client" />
interface ImportMetaEnv {
readonly VITE_API_BASE_URL?: string;
readonly VITE_WS_PROGRESS_URL?: string;
}

View File

@@ -12,7 +12,7 @@ echo " 语义分割系统全栈启动"
echo "========================================" echo "========================================"
# 1. 检查 PostgreSQL # 1. 检查 PostgreSQL
echo "[1/5] 检查 PostgreSQL..." echo "[1/6] 检查 PostgreSQL..."
if ! pg_isready -q; then if ! pg_isready -q; then
echo "Wkmgc" | sudo -S systemctl start postgresql echo "Wkmgc" | sudo -S systemctl start postgresql
sleep 1 sleep 1
@@ -20,7 +20,7 @@ fi
pg_isready && echo " ✓ PostgreSQL 就绪" pg_isready && echo " ✓ PostgreSQL 就绪"
# 2. 检查 Redis # 2. 检查 Redis
echo "[2/5] 检查 Redis..." echo "[2/6] 检查 Redis..."
if ! redis-cli ping > /dev/null 2>&1; then if ! redis-cli ping > /dev/null 2>&1; then
echo "Wkmgc" | sudo -S systemctl start redis-server echo "Wkmgc" | sudo -S systemctl start redis-server
sleep 1 sleep 1
@@ -28,7 +28,7 @@ fi
redis-cli ping && echo " ✓ Redis 就绪" redis-cli ping && echo " ✓ Redis 就绪"
# 3. 检查 MinIO # 3. 检查 MinIO
echo "[3/5] 检查 MinIO..." echo "[3/6] 检查 MinIO..."
if ! curl -s http://localhost:9000/minio/health/live > /dev/null; then if ! curl -s http://localhost:9000/minio/health/live > /dev/null; then
nohup minio server /home/wkmgc/minio_data --console-address :9001 > /tmp/minio.log 2>&1 & nohup minio server /home/wkmgc/minio_data --console-address :9001 > /tmp/minio.log 2>&1 &
sleep 3 sleep 3
@@ -36,7 +36,7 @@ fi
curl -s http://localhost:9000/minio/health/live > /dev/null && echo " ✓ MinIO 就绪 (http://localhost:9001)" curl -s http://localhost:9000/minio/health/live > /dev/null && echo " ✓ MinIO 就绪 (http://localhost:9001)"
# 4. 启动 FastAPI 后端 # 4. 启动 FastAPI 后端
echo "[4/5] 启动 FastAPI 后端..." echo "[4/6] 启动 FastAPI 后端..."
source /home/wkmgc/miniconda3/etc/profile.d/conda.sh source /home/wkmgc/miniconda3/etc/profile.d/conda.sh
conda activate "$CONDA_ENV" conda activate "$CONDA_ENV"
cd "$PROJECT_DIR/backend" cd "$PROJECT_DIR/backend"
@@ -44,8 +44,15 @@ nohup uvicorn main:app --host 0.0.0.0 --port 8000 --reload > /tmp/fastapi.log 2>
sleep 2 sleep 2
echo " ✓ FastAPI 已启动 (http://localhost:8000/docs)" echo " ✓ FastAPI 已启动 (http://localhost:8000/docs)"
# 5. 启动前端 # 5. 启动 Celery Worker
echo "[5/5] 启动前端..." echo "[5/6] 启动 Celery Worker..."
cd "$PROJECT_DIR/backend"
nohup celery -A celery_app:celery_app worker --loglevel=info --concurrency=1 > /tmp/celery.log 2>&1 &
sleep 2
echo " ✓ Celery Worker 已启动"
# 6. 启动前端
echo "[6/6] 启动前端..."
cd "$PROJECT_DIR" cd "$PROJECT_DIR"
nohup npm start > /tmp/frontend.log 2>&1 & nohup npm start > /tmp/frontend.log 2>&1 &
sleep 2 sleep 2
@@ -61,6 +68,7 @@ echo "MinIO: http://localhost:9001"
echo "" echo ""
echo "日志文件:" echo "日志文件:"
echo " FastAPI: /tmp/fastapi.log" echo " FastAPI: /tmp/fastapi.log"
echo " Celery: /tmp/celery.log"
echo " 前端: /tmp/frontend.log" echo " 前端: /tmp/frontend.log"
echo " MinIO: /tmp/minio.log" echo " MinIO: /tmp/minio.log"
echo "========================================" echo "========================================"

24
vitest.config.ts Normal file
View File

@@ -0,0 +1,24 @@
import react from '@vitejs/plugin-react';
import path from 'path';
import { defineConfig } from 'vitest/config';
export default defineConfig({
plugins: [react()],
resolve: {
alias: {
'@': path.resolve(__dirname, '.'),
},
},
test: {
environment: 'jsdom',
environmentOptions: {
jsdom: {
url: 'http://seg.local:3000',
},
},
globals: true,
setupFiles: './src/test/setup.tsx',
include: ['src/**/*.{test,spec}.{ts,tsx}'],
css: false,
},
});