feat: 打通全栈标注闭环、异步拆帧与模型状态
后端能力: - 新增 Celery app、worker task、ProcessingTask 模型、/api/tasks 查询接口和 media_task_runner,将 /api/media/parse 改为创建后台任务并由 worker 执行 FFmpeg/OpenCV/pydicom 拆帧。 - 新增 Redis 进度事件模块和 FastAPI Redis pub/sub 订阅,将 worker 任务进度广播到 /ws/progress;Dashboard 后端概览接口改为聚合 projects/frames/annotations/templates/processing_tasks。 - 统一项目状态为 pending/parsing/ready/error,新增共享 status 常量,并让前端兼容归一化旧状态值。 - 扩展 AI 后端:新增 SAM registry、SAM2 真实运行状态、SAM3 状态检测与文本语义推理适配入口,以及 /api/ai/models/status GPU/模型状态接口。 - 补齐标注保存/更新/删除、COCO/PNG mask 导出相关后端契约和模板 mapping_rules 打包/解包行为。 前端能力: - 新增运行时 API/WS 地址推导配置,前端 API 封装对齐 FastAPI 路由、字段映射、任务轮询、标注归档、导出下载和 AI 预测响应转换。 - Dashboard 改为读取 /api/dashboard/overview,并订阅 WebSocket progress/complete/error/status 更新解析队列和实时流转记录。 - 项目库导入视频/DICOM 后创建项目、上传媒体、触发异步解析并刷新真实项目列表。 - 工作区加载真实帧、无帧时触发解析任务、回显已保存标注、保存未归档 mask、更新 dirty mask、清空当前帧后端标注、导出 COCO JSON。 - Canvas 支持当前帧点/框提示调用后端 AI、渲染推理/已保存 mask、应用模板分类并维护保存状态计数;时间轴按项目 fps 播放。 - AI 页面新增 SAM2/SAM3 模型选择,预测请求携带 model;侧边栏和工作区新增真实 GPU/SAM 状态徽标。 - 模板库和本体面板接入真实模板 CRUD、分类编辑、拖拽排序、JSON 导入、默认腹腔镜分类和本地自定义分类选择。 测试与文档: - 新增 Vitest 配置、前端测试 setup、API/config/websocket/store/组件测试,覆盖登录、项目库、Dashboard、Canvas、工作区、模型状态、时间轴、本体和模板库。 - 新增 pytest 后端测试夹具和 auth/projects/templates/media/AI/export/dashboard/tasks/progress 测试,使用 SQLite、fake MinIO、fake SAM registry 和 Redis monkeypatch 隔离外部服务。 - 新增 doc/ 文档结构,冻结当前需求、设计、接口契约、测试计划、前端逐元素审计、实现地图和后续实施计划,并同步更新 README 与 AGENTS。 验证: - conda run -n seg_server pytest backend/tests:27 passed。 - npm run test:run:54 passed。 - npm run lint、npm run build、compileall、git diff --check 均通过;Vite 仅提示大 chunk 警告。
This commit is contained in:
13
.env.example
13
.env.example
@@ -7,3 +7,16 @@ GEMINI_API_KEY="MY_GEMINI_API_KEY"
|
||||
# AI Studio automatically injects this at runtime with the Cloud Run service URL.
|
||||
# Used for self-referential links, OAuth callbacks, and API endpoints.
|
||||
APP_URL="MY_APP_URL"
|
||||
|
||||
# Frontend API address. If unset, the app infers http://<current-browser-host>:8000.
|
||||
VITE_API_BASE_URL="http://192.168.3.11:8000"
|
||||
|
||||
# Optional WebSocket override. If unset, it is derived from VITE_API_BASE_URL.
|
||||
VITE_WS_PROGRESS_URL="ws://192.168.3.11:8000/ws/progress"
|
||||
|
||||
# Backend SAM runtime defaults. SAM 3 additionally requires the official sam3
|
||||
# package, Python 3.12+, PyTorch 2.7+, and a CUDA-capable GPU per Meta's repo.
|
||||
sam_default_model="sam2"
|
||||
sam_model_path="/home/wkmgc/Desktop/Seg_Server/models/sam2_hiera_tiny.pt"
|
||||
sam_model_config="configs/sam2/sam2_hiera_t.yaml"
|
||||
sam3_model_version="sam3.1"
|
||||
|
||||
322
AGENTS.md
322
AGENTS.md
@@ -1,17 +1,21 @@
|
||||
# AGENTS.md — AI 编码助手项目指南
|
||||
|
||||
> 本文件面向 AI 编码助手。阅读者应假设对该项目一无所知。以下所有信息均基于项目实际内容,不做假设性推断。
|
||||
> 本文件面向 AI 编码助手。阅读者应假设对该项目一无所知。以下信息基于当前仓库实际文件、脚本和源码;不要把早期设计目标当作已实现事实。任何代码和功能修改都要落实到文档和测试上,如果生成git commit信息,要逐个列点把所有修改都列上,重要的、大的修改放前面,不重要的、小的修改列在后面。
|
||||
|
||||
---
|
||||
|
||||
## 项目概述
|
||||
|
||||
本项目是一个**语义分割系统**(Semantic Segmentation System)的 Web 前端应用,用于 AI 驱动的图像/视频分割与标注。它提供了一个深色主题(Dark Mode)的单页应用(SPA),包含项目管理、分割工作区、AI 智能分割引擎和模板库四大核心模块。
|
||||
本项目是一个**语义分割系统**(Semantic Segmentation System),当前形态是 React 前端 + FastAPI 后端的全栈 Web 应用,用于视频/DICOM 医学影像上传、服务器端拆帧、交互式 Canvas 标注、SAM 2/SAM 3 可选辅助分割、模板分类管理和标注导出。
|
||||
|
||||
- **项目名称**: `react-example`(package.json 中的 name)
|
||||
- **部署目标**: Google AI Studio(Cloud Run)
|
||||
- **AI Studio 应用链接**: https://ai.studio/apps/2707f0e1-d453-4594-a618-fba53cb937c4
|
||||
- **业务文档**: `语义分割系统构建方案.docx`(项目根目录,未解析内容)
|
||||
- **项目名称**: `react-example`(`package.json` 中的 `name`)
|
||||
- **前端入口**: `src/main.tsx` → `src/App.tsx`
|
||||
- **前端服务入口**: `server.ts`(Express + Vite 中间件 / 生产静态服务,并保留少量旧版 mock API)
|
||||
- **后端入口**: `backend/main.py`(FastAPI)
|
||||
- **默认前端地址**: `http://localhost:3000`
|
||||
- **默认后端地址**: `http://localhost:8000`
|
||||
- **前端 API 配置**: `src/lib/config.ts`,优先读取 `VITE_API_BASE_URL`,未配置时按当前浏览器 hostname 推导 `http://<host>:8000`
|
||||
- **业务文档**: `语义分割系统构建方案.docx`(项目根目录)
|
||||
|
||||
---
|
||||
|
||||
@@ -21,15 +25,22 @@
|
||||
|------|------|
|
||||
| 前端框架 | React 19 + TypeScript 5.8 |
|
||||
| 构建工具 | Vite 6 |
|
||||
| 样式方案 | TailwindCSS 4 + 自定义深色主题 |
|
||||
| 状态管理 | React `useState`(无外部状态库) |
|
||||
| 路由 | 无路由库,基于 React 状态切换模块 |
|
||||
| Canvas 渲染 | Konva + react-konva |
|
||||
| 前端样式 | TailwindCSS 4 + 自定义深色主题 |
|
||||
| 前端状态 | Zustand(`src/store/useStore.ts`) |
|
||||
| 前端请求 | Axios(`src/lib/api.ts`) |
|
||||
| 实时通信 | WebSocket 客户端(`src/lib/websocket.ts`) |
|
||||
| Canvas 渲染 | Konva + react-konva + use-image |
|
||||
| 图标库 | lucide-react |
|
||||
| 动画 | motion |
|
||||
| AI SDK | @google/genai(Gemini API) |
|
||||
| 后端/服务器 | Express 4(单文件 `server.ts`) |
|
||||
| 运行时 | Node.js,ES Modules(`"type": "module"`) |
|
||||
| 动画依赖 | motion(在 `package.json` 中声明) |
|
||||
| AI SDK 依赖 | `@google/genai`(在 `package.json` 中声明;当前业务源码未直接调用) |
|
||||
| 后端框架 | FastAPI + Uvicorn |
|
||||
| ORM / 数据库 | SQLAlchemy + PostgreSQL |
|
||||
| 缓存 / 队列 Broker | Redis |
|
||||
| 后台任务 | Celery worker |
|
||||
| 对象存储 | MinIO |
|
||||
| AI 推理 | SAM 2 / SAM 3 可选模型 + PyTorch;`GET /api/ai/models/status` 返回真实 GPU/模型状态 |
|
||||
| 视频 / 影像处理 | FFmpeg / OpenCV / pydicom |
|
||||
| 运行时 | Node.js ES Modules;Python 3.11 后端环境 |
|
||||
|
||||
---
|
||||
|
||||
@@ -37,159 +48,244 @@
|
||||
|
||||
```
|
||||
Seg_Server/
|
||||
├── server.ts # Express 服务端入口(开发服务器 + 生产静态文件服务)
|
||||
├── index.html # SPA HTML 入口
|
||||
├── vite.config.ts # Vite 构建配置
|
||||
├── tsconfig.json # TypeScript 配置
|
||||
├── package.json # 依赖与脚本
|
||||
├── .env.example # 环境变量模板
|
||||
├── metadata.json # AI Studio 元数据(目前为空)
|
||||
├── src/
|
||||
│ ├── main.tsx # React 应用挂载点(StrictMode)
|
||||
│ ├── App.tsx # 根组件:模块路由 + 登录鉴权
|
||||
│ ├── index.css # TailwindCSS 导入 + 自定义工具类
|
||||
│ ├── lib/
|
||||
│ │ └── utils.ts # `cn()` 工具函数(clsx + tailwind-merge)
|
||||
│ └── components/
|
||||
│ ├── auth/
|
||||
│ │ └── Login.tsx # 登录页
|
||||
│ ├── layout/
|
||||
│ │ └── Sidebar.tsx # 左侧导航栏(w-16)
|
||||
│ ├── dashboard/
|
||||
│ │ └── Dashboard.tsx # 总体概况仪表盘
|
||||
│ ├── projects/
|
||||
│ │ └── ProjectLibrary.tsx # 项目库列表
|
||||
│ ├── workspace/
|
||||
│ │ ├── VideoWorkspace.tsx # 核心分割工作区布局
|
||||
│ │ ├── CanvasArea.tsx # Konva 画布(缩放/平移/选点)
|
||||
│ │ ├── ToolsPalette.tsx # 左侧工具栏
|
||||
│ │ ├── OntologyInspector.tsx # 右侧本体/属性检查面板
|
||||
│ │ └── FrameTimeline.tsx # 底部时间轴
|
||||
│ ├── ai/
|
||||
│ │ └── AISegmentation.tsx # AI 智能分割引擎界面
|
||||
│ └── templates/
|
||||
│ └── TemplateRegistry.tsx # 模板库管理
|
||||
├── server.ts # Express + Vite 前端入口;保留 /api/login、/api/projects、/api/templates mock
|
||||
├── index.html # SPA HTML 入口
|
||||
├── vite.config.ts # Vite 配置;含 @/* 路径别名与 DISABLE_HMR 逻辑
|
||||
├── tsconfig.json # TypeScript 配置;@/* 映射到项目根目录
|
||||
├── package.json # npm 依赖与脚本
|
||||
├── .env.example # AI Studio/Gemini 前端环境变量模板
|
||||
├── metadata.json # AI Studio 元数据
|
||||
├── public/
|
||||
│ └── logo.png # Sidebar 使用的 /logo.png
|
||||
├── doc/ # 当前实现审计、接口契约和后续实施文档
|
||||
├── start_services.sh # 本地一键启动 PostgreSQL/Redis/MinIO/FastAPI/Celery/前端
|
||||
├── backend/ # FastAPI 后端
|
||||
│ ├── main.py # 应用入口、lifespan、CORS、路由注册、WebSocket
|
||||
│ ├── config.py # Pydantic Settings;读取 backend/.env
|
||||
│ ├── database.py # SQLAlchemy Engine / Session
|
||||
│ ├── models.py # Project/Frame/Template/Annotation/Mask/ProcessingTask ORM
|
||||
│ ├── schemas.py # Pydantic 请求/响应模型
|
||||
│ ├── minio_client.py # MinIO 上传、下载、预签名 URL
|
||||
│ ├── redis_client.py # Redis 连接封装
|
||||
│ ├── celery_app.py # Celery app 配置
|
||||
│ ├── worker_tasks.py # Celery 任务入口
|
||||
│ ├── download_sam2.py # SAM 2 权重下载脚本
|
||||
│ ├── requirements.txt # Python 依赖
|
||||
│ ├── routers/
|
||||
│ │ ├── auth.py # /api/auth/login
|
||||
│ │ ├── projects.py # /api/projects 与 /api/projects/{id}/frames
|
||||
│ │ ├── templates.py # /api/templates
|
||||
│ │ ├── media.py # /api/media/upload、/upload/dicom、/parse
|
||||
│ │ ├── ai.py # /api/ai/predict、/models/status、/auto、/annotate
|
||||
│ │ └── export.py # /api/export/{project_id}/coco、/masks
|
||||
│ └── services/
|
||||
│ ├── frame_parser.py # FFmpeg/OpenCV 拆帧、pydicom 读片、帧上传
|
||||
│ ├── sam2_engine.py # SAM 2 懒加载推理封装和 fallback
|
||||
│ ├── sam3_engine.py # SAM 3 状态检测与文本语义推理适配器
|
||||
│ └── sam_registry.py # SAM 模型选择、GPU 状态与推理分发
|
||||
└── src/ # React 前端
|
||||
├── main.tsx # React StrictMode 挂载
|
||||
├── App.tsx # 登录拦截 + 模块切换
|
||||
├── index.css # TailwindCSS 导入 + 全局样式
|
||||
├── store/useStore.ts # Zustand 全局状态
|
||||
├── lib/api.ts # Axios API 封装
|
||||
├── lib/websocket.ts # 解析进度 WebSocket 客户端
|
||||
├── lib/utils.ts # cn() 工具函数
|
||||
└── components/ # 扁平化组件目录
|
||||
├── Login.tsx
|
||||
├── Sidebar.tsx
|
||||
├── Dashboard.tsx
|
||||
├── ProjectLibrary.tsx
|
||||
├── VideoWorkspace.tsx
|
||||
├── CanvasArea.tsx
|
||||
├── ToolsPalette.tsx
|
||||
├── OntologyInspector.tsx
|
||||
├── FrameTimeline.tsx
|
||||
├── AISegmentation.tsx
|
||||
└── TemplateRegistry.tsx
|
||||
```
|
||||
|
||||
以下目录/文件通常是运行产物或本地数据,已在 `.gitignore` 中忽略:`node_modules/`、`dist/`、`models/`、`uploads/`、`frames/`、`Data_*/`、`*.mp4`、`*.dcm`、`*.7z`、`backend/.env`、日志文件等。
|
||||
|
||||
`doc/` 目录是当前项目的事实文档入口。修改功能前优先查看:
|
||||
|
||||
- `doc/03-frontend-element-audit.md`:哪些前端元素是真功能,哪些是 Mock/UI-only。
|
||||
- `doc/04-api-contracts.md`:前后端接口契约,以及当前不一致点。
|
||||
- `doc/05-implementation-plan.md`:建议的后续实施顺序。
|
||||
|
||||
---
|
||||
|
||||
## 构建与运行命令
|
||||
|
||||
### 前端 / Node 入口
|
||||
|
||||
```bash
|
||||
# 安装依赖
|
||||
npm install
|
||||
|
||||
# 开发模式(启动 Express + Vite 中间件,端口 3000)
|
||||
# 开发模式:运行 tsx server.ts,Express 集成 Vite middleware,端口 3000
|
||||
npm run dev
|
||||
|
||||
# 生产构建(输出到 dist/)
|
||||
# 生产构建:输出 dist/
|
||||
npm run build
|
||||
|
||||
# 预览生产构建
|
||||
# Vite 预览
|
||||
npm run preview
|
||||
|
||||
# 生产模式启动(Node 直接运行 server.ts,需先 build)
|
||||
# 生产模式运行 server.ts,服务 dist/;仍保留 server.ts 中的旧版 mock API
|
||||
npm start
|
||||
|
||||
# 类型检查(不输出文件)
|
||||
# TypeScript 类型检查
|
||||
npm run lint
|
||||
|
||||
# 清理构建产物
|
||||
# 删除 dist/
|
||||
npm run clean
|
||||
```
|
||||
|
||||
**开发服务器地址**: `http://localhost:3000`
|
||||
### FastAPI 后端
|
||||
|
||||
**环境变量**(复制 `.env.example` 为 `.env.local`):
|
||||
- `GEMINI_API_KEY` — Gemini AI API 密钥(AI Studio 会自动注入)
|
||||
- `APP_URL` — 应用托管 URL(AI Studio 自动注入 Cloud Run 地址)
|
||||
```bash
|
||||
cd backend
|
||||
uvicorn main:app --host 0.0.0.0 --port 8000 --reload
|
||||
```
|
||||
|
||||
### 一键启动
|
||||
|
||||
```bash
|
||||
./start_services.sh
|
||||
```
|
||||
|
||||
该脚本会依次检查/启动 PostgreSQL、Redis、MinIO、FastAPI 后端、Celery worker 和前端。
|
||||
|
||||
---
|
||||
|
||||
## 运行时架构
|
||||
|
||||
### 前端
|
||||
- 单页应用,React 19 `StrictMode` 挂载。
|
||||
- 模块切换通过 `App.tsx` 中的 `activeModule` 状态控制,可选值:
|
||||
`'dashboard' | 'projects' | 'ai' | 'workspace' | 'templates'`
|
||||
- 默认进入 `workspace`(分割工作区)。
|
||||
- 未登录时全局拦截,显示 `Login` 组件。
|
||||
|
||||
### 后端 (`server.ts`)
|
||||
- Express 服务器,端口 `3000`。
|
||||
- **开发模式**: 集成 Vite 中间件(`middlewareMode: true`)。
|
||||
- **生产模式**: 静态文件服务 `dist/`,所有路由回退到 `index.html`。
|
||||
- **API 端点**(内存数据存储,无数据库):
|
||||
- `POST /api/login` — 认证(固定用户名 `admin`,密码 `123456`)
|
||||
- `GET /api/projects` — 返回项目列表
|
||||
- `GET /api/templates` — 返回模板列表
|
||||
- 单页应用,无路由库;模块切换由 `useStore().activeModule` 控制。
|
||||
- 模块值包括:`dashboard`、`projects`、`ai`、`workspace`、`templates`。
|
||||
- 默认模块是 `workspace`。
|
||||
- 未登录时渲染 `Login`。
|
||||
- 登录成功后 token 写入 `localStorage`,Axios request interceptor 会附加 `Authorization: Bearer <token>`。
|
||||
- `App.tsx` 在登录后调用 `getProjects()` 初始化项目列表。
|
||||
|
||||
### 部署
|
||||
- 面向 **Google AI Studio** / **Cloud Run** 部署。
|
||||
- `metadata.json` 用于 AI Studio 元数据配置(当前为空)。
|
||||
- `vite.config.ts` 中 HMR 可通过环境变量 `DISABLE_HMR=true` 关闭(AI Studio 环境下文件监听被禁用以防止 agent 编辑时闪烁)。
|
||||
### 后端
|
||||
|
||||
- 主后端是 `backend/main.py` 的 FastAPI 服务。
|
||||
- `lifespan` 启动时会:
|
||||
- 创建数据库表;
|
||||
- 检查/创建 MinIO bucket `seg-media`;
|
||||
- 测试 Redis 连接;
|
||||
- 后台 seed 默认模板;
|
||||
- 如果本地存在 `Data_MyVideo_1.mp4`,后台 seed 默认演示项目并拆前 100 帧。
|
||||
- API 路由包括:
|
||||
- `POST /api/auth/login`
|
||||
- `GET/POST/PATCH/DELETE /api/projects`
|
||||
- `GET/POST /api/projects/{project_id}/frames`
|
||||
- `GET/POST/PATCH/DELETE /api/templates`
|
||||
- `POST /api/media/upload`
|
||||
- `POST /api/media/upload/dicom`
|
||||
- `POST /api/media/parse`
|
||||
- `GET /api/tasks`
|
||||
- `GET /api/tasks/{task_id}`
|
||||
- `POST /api/ai/predict`
|
||||
- `GET /api/ai/models/status`
|
||||
- `POST /api/ai/auto`
|
||||
- `POST /api/ai/annotate`
|
||||
- `GET /api/ai/annotations`
|
||||
- `PATCH/DELETE /api/ai/annotations/{annotation_id}`
|
||||
- `GET /api/dashboard/overview`
|
||||
- `GET /api/export/{project_id}/coco`
|
||||
- `GET /api/export/{project_id}/masks`
|
||||
- `GET /health`
|
||||
- `WS /ws/progress`
|
||||
|
||||
### 存储
|
||||
|
||||
- PostgreSQL 存储项目、帧、模板、标注、mask 和后台任务元数据。
|
||||
- MinIO 存储上传视频、DICOM、拆出的帧、缩略图等对象;前端展示使用预签名 URL。
|
||||
- Redis 当前作为 Celery broker/result backend,并用于连接检查。
|
||||
|
||||
---
|
||||
|
||||
## 主要业务流程
|
||||
|
||||
1. 登录:`Login.tsx` 调用 `POST /api/auth/login`,默认开发凭证为 `admin / 123456`。
|
||||
2. 项目管理:`ProjectLibrary.tsx` 调用项目 API 创建项目、拉取列表。
|
||||
3. 上传资源:视频走 `/api/media/upload`;DICOM 批量走 `/api/media/upload/dicom`。
|
||||
4. 拆帧入队:前端调用 `/api/media/parse`;后端创建 `ProcessingTask` 并投递 Celery。
|
||||
5. worker 执行:Celery worker 用 FFmpeg 优先拆视频帧,失败后用 OpenCV fallback,DICOM 使用 pydicom,并持续更新任务进度。
|
||||
6. 帧展示:`VideoWorkspace.tsx` 调用 `/api/projects/{id}/frames`,`CanvasArea.tsx` 和 `FrameTimeline.tsx` 显示当前帧与时间轴缩略图。
|
||||
7. AI 分割:前端工具包括正向点、反向点和框选;后端 `ai.py` 期望按 `image_id`、`prompt_type`、`prompt_data`、`model` 调用 SAM registry。SAM 2 支持点/框/自动分割;SAM 3 入口支持文本语义推理,运行时不满足官方要求时会在状态接口中标为不可用。
|
||||
8. 模板管理:`TemplateRegistry.tsx` 管理分类、颜色和 z-index;`OntologyInspector.tsx` 在工作区显示当前模板分类树。
|
||||
9. 导出:后端支持 COCO JSON 和 PNG mask ZIP 导出。
|
||||
|
||||
---
|
||||
|
||||
## 当前实现注意事项
|
||||
|
||||
- `src/lib/config.ts` 会优先读取 `VITE_API_BASE_URL` 和 `VITE_WS_PROGRESS_URL`;未配置时按当前浏览器 hostname 推导后端 `:8000` 地址。
|
||||
- 前端 `predictMask()` 已按后端 `PredictRequest` 发送 `image_id`、`prompt_type`、`prompt_data`、`model`,并将后端 `polygons` 转成 Konva 可渲染的 `pathData`。
|
||||
- 前端 `exportCoco()` 已对齐后端 `/api/export/{project_id}/coco`;工作区“导出 JSON 标注集”按钮已绑定下载流程,导出前会先保存当前待归档 mask。
|
||||
- 工作区“结构化归档保存”按钮已接入 `POST /api/ai/annotate` 和 `PATCH /api/ai/annotations/{id}`;加载工作区时会通过 `GET /api/ai/annotations` 回显已保存标注。
|
||||
- 工作区“清空遮罩”会调用 `DELETE /api/ai/annotations/{id}` 删除当前帧已保存标注,并清空当前帧本地 mask。
|
||||
- 项目状态已统一为 `pending`、`parsing`、`ready`、`error`;前端 `src/lib/api.ts` 会兼容归一化旧库中可能存在的 `Ready`、`Parsing`、`Error`。
|
||||
- `server.ts` 仍有旧版 `/api/login`、`/api/projects`、`/api/templates` mock;当前前端真实 API 调用主要走 FastAPI 的 `/api/auth/*`、`/api/projects`、`/api/templates` 等接口。
|
||||
- `Dashboard.tsx` 初始统计、队列和活动日志来自 `GET /api/dashboard/overview`;解析队列来自 `processing_tasks`,Celery worker 通过 Redis pub/sub 的 `seg:progress` 频道推送细粒度进度,再由 FastAPI 广播到 `/ws/progress`。
|
||||
|
||||
---
|
||||
|
||||
## 代码风格与约定
|
||||
|
||||
### 样式规范
|
||||
- **深色主题**: 全局背景色以 `#0a0a0a`、`#111`、`#0d0d0d`、`#151515`、`#1e1e1e` 为主。
|
||||
- **强调色**: 青色(`cyan-400`/`cyan-500`)用于激活状态、按钮和关键指示器。
|
||||
- **工具类优先**: 全面使用 TailwindCSS 工具类,通过 `cn()` 合并条件类名。
|
||||
- **自定义工具类**: `index.css` 中定义 `.no-scrollbar` 用于隐藏滚动条。
|
||||
|
||||
- 深色主题为主,常见背景色包括 `#0a0a0a`、`#111`、`#0d0d0d`、`#151515`、`#1e1e1e`。
|
||||
- 青色(如 `cyan-400` / `cyan-500`)用于激活状态、主按钮和关键指标。
|
||||
- 前端样式主要使用 TailwindCSS 工具类,通过 `cn()` 合并条件类名。
|
||||
- `src/index.css` 使用 TailwindCSS 4 的 `@import "tailwindcss";`。
|
||||
|
||||
### 组件规范
|
||||
- 所有组件使用 **函数组件 + Hooks**,无类组件。
|
||||
- 组件按功能模块分目录存放在 `src/components/{module}/` 下。
|
||||
- Props 类型使用 TypeScript `interface` 定义。
|
||||
- 导入排序:React → 第三方库 → 内部模块 → 类型。
|
||||
|
||||
- 组件使用函数组件 + Hooks。
|
||||
- 当前组件目录是扁平结构:`src/components/*.tsx`,不是按模块子目录分层。
|
||||
- Props 类型优先使用 TypeScript `interface`。
|
||||
- UI 文本保持中文。
|
||||
- 代码与注释优先使用英文。
|
||||
|
||||
### 命名规范
|
||||
- 组件文件使用 **PascalCase**(如 `AISegmentation.tsx`)。
|
||||
- 工具文件使用 **camelCase**(如 `utils.ts`)。
|
||||
- 类型/接口使用 **PascalCase**。
|
||||
|
||||
### 语言约定
|
||||
- **界面文本**: 全部使用 **中文**(如 "核心分割工作区"、"AI智能分割引擎"、"导出 JSON 标注集")。
|
||||
- **代码与注释**: 使用英文。
|
||||
- 添加新 UI 文本时,**必须保持中文**。
|
||||
- 组件文件使用 PascalCase,例如 `AISegmentation.tsx`。
|
||||
- 工具文件使用 camelCase,例如 `utils.ts`。
|
||||
- 类型和接口使用 PascalCase。
|
||||
|
||||
---
|
||||
|
||||
## 测试策略
|
||||
|
||||
**当前状态:无测试文件。**
|
||||
当前仓库已配置前端 Vitest 测试和后端 pytest 测试。测试依据 `doc/07-current-requirements-freeze.md`、`doc/08-current-design-freeze.md` 和 `doc/09-test-plan.md`。
|
||||
|
||||
- 项目中不存在 `.test.` 或 `.spec.` 文件。
|
||||
- 无测试框架配置(如 Jest、Vitest、Playwright)。
|
||||
- 若需添加测试,建议在前端引入 Vitest(与 Vite 同生态)进行单元测试,或使用 Playwright 进行 E2E 测试。
|
||||
- 前端测试配置:`vitest.config.ts`,共享 setup 在 `src/test/setup.tsx`。
|
||||
- 前端测试命令:`npm run test:run`。
|
||||
- 后端测试依赖:`backend/requirements-dev.txt`。
|
||||
- 后端测试命令:`pytest backend/tests`,或在 `backend/` 目录执行 `pytest tests`。
|
||||
- 基础静态校验:`npm run lint`、`npm run build`、`python -m py_compile backend/routers/ai.py backend/routers/templates.py backend/schemas.py`。
|
||||
- 后端测试使用内存 SQLite、fake MinIO 和 fake SAM registry,不依赖真实 PostgreSQL、MinIO、Redis 或模型权重。
|
||||
|
||||
---
|
||||
|
||||
## 安全注意事项
|
||||
|
||||
- **硬编码凭证**: `server.ts` 中登录验证使用硬编码凭据(`admin` / `123456`),生产环境必须替换为真实的身份验证机制。
|
||||
- **Mock JWT**: 登录成功返回固定的 `fake-jwt-token-for-admin`,无实际的 JWT 签名验证。
|
||||
- **内存数据存储**: 所有项目/模板数据存储在内存中,服务重启后数据丢失。无持久化层。
|
||||
- **环境变量**: `GEMINI_API_KEY` 通过 `.env.local` 管理,已加入 `.gitignore`,不会误提交。
|
||||
- **CORS / 安全头**: Express 服务器目前未配置 CORS 策略或安全响应头(如 Helmet)。
|
||||
- FastAPI 登录是开发用硬编码凭证:`admin / 123456`。
|
||||
- 登录成功返回固定 token:`fake-jwt-token-for-admin`,没有真实 JWT 签名校验。
|
||||
- Axios 会附加 Bearer token,但后端大多数业务路由当前没有鉴权依赖。
|
||||
- `backend/.env` 被 `.gitignore` 忽略;不要提交真实数据库、MinIO、Redis、模型路径等敏感配置。
|
||||
- `start_services.sh` 中包含本机路径和 sudo 启动逻辑,迁移机器时要审查。
|
||||
- Express `server.ts` 的旧版 mock API 只适合开发/兼容场景,不能当生产鉴权或持久化方案。
|
||||
|
||||
---
|
||||
|
||||
## 关键依赖与注意事项
|
||||
## AI Studio / Vite 特定配置
|
||||
|
||||
- **React 19**: 使用 `createRoot` API,注意与 React 18 的部分差异。
|
||||
- **TailwindCSS 4**: 使用 `@import "tailwindcss"` 语法(非 v3 的 `@tailwind` 指令)。
|
||||
- **react-konva**: Canvas 交互核心,所有画布相关操作(缩放、选点、遮罩)均依赖此库。
|
||||
- **use-image**: 用于异步加载图片到 Konva 画布。
|
||||
- **路径别名**: `@/*` 映射到项目根目录(由 `vite.config.ts` 和 `tsconfig.json` 共同配置)。
|
||||
- **缺失资源**: `Sidebar.tsx` 引用了 `/Logo.png`,但项目根目录无此文件,运行时会 404。
|
||||
|
||||
---
|
||||
|
||||
## AI Studio 特定配置
|
||||
|
||||
- `vite.config.ts` 中通过 `loadEnv` 加载环境变量,并将 `GEMINI_API_KEY` 注入到 `process.env.GEMINI_API_KEY`。
|
||||
- AI Studio 会在部署时自动注入 `GEMINI_API_KEY` 和 `APP_URL`。
|
||||
- `DISABLE_HMR` 环境变量用于在 AI Studio agent 编辑模式下关闭 HMR,避免界面闪烁。**请勿修改此逻辑。**
|
||||
- `.env.example` 包含 `GEMINI_API_KEY` 和 `APP_URL`,说明这些值由 AI Studio 注入。
|
||||
- `vite.config.ts` 通过 `loadEnv` 把 `GEMINI_API_KEY` 注入到 `process.env.GEMINI_API_KEY`。
|
||||
- `vite.config.ts` 中的 `DISABLE_HMR` 逻辑用于关闭 HMR,避免 AI Studio agent 编辑时闪烁。**不要随意修改该逻辑。**
|
||||
|
||||
128
README.md
128
README.md
@@ -4,16 +4,16 @@
|
||||
|
||||
# 语义分割系统(SegServer)
|
||||
|
||||
> 基于 React + FastAPI + SAM 2 的全栈交互式图像/视频语义分割与标注平台。
|
||||
> 基于 React + FastAPI + 可选 SAM 2 / SAM 3 的全栈交互式图像/视频语义分割与标注平台。
|
||||
>
|
||||
> 支持本地多媒体资产上传、服务器端按帧解析、AI 视觉大模型实时推理(正反向选点、框选生成分割 Mask)、动态图层状态管理及最终标注数据结构化导出。
|
||||
> 支持本地多媒体资产上传、服务器端按帧解析、交互式 Canvas 标注、模板分类管理和标注数据结构化导出;工作区点/框 AI 推理默认走 SAM 2,语义文本可选择 SAM 3,前端会显示真实 GPU/模型状态。
|
||||
|
||||
---
|
||||
|
||||
## 核心功能
|
||||
|
||||
- **多媒体资产管理** — 支持视频(MP4/AVI/MOV)和 DICOM 医学影像的上传、存储与解析
|
||||
- **AI 智能分割引擎** — 集成 SAM 2 模型,支持点分割(point)、框分割(box)、语义分割(semantic)和自动分割(auto)
|
||||
- **AI 智能分割引擎** — 后端提供 SAM 2 / SAM 3 模型选择;SAM 2 支持点分割(point)、框分割(box)和自动分割(auto),SAM 3 入口支持文本语义提示并按真实运行环境显示可用性
|
||||
- **交互式画布标注** — 基于 Konva 的高性能 Canvas,支持缩放/平移/选点/框选,实时渲染 Mask 遮罩
|
||||
- **本体字典管理** — 可配置的分类体系、颜色映射、图层优先级(z-index)
|
||||
- **项目工作区** — 项目创建、帧浏览、多图层标注、进度追踪
|
||||
@@ -37,15 +37,16 @@
|
||||
│ ├── /api/auth 登录认证 │
|
||||
│ ├── /api/projects 项目 & 视频帧 CRUD │
|
||||
│ ├── /api/templates 本体字典(分类/颜色/z-index) │
|
||||
│ ├── /api/media 文件上传 & FFmpeg/pydicom 帧解析 │
|
||||
│ ├── /api/ai SAM 2 推理(点/框/语义/自动分割) │
|
||||
│ ├── /api/media 文件上传 & 异步拆帧任务创建 │
|
||||
│ ├── /api/tasks Celery 后台任务状态 │
|
||||
│ ├── /api/ai SAM 2 / SAM 3 推理与模型状态 │
|
||||
│ └── /api/export COCO JSON / PNG Masks 导出 │
|
||||
└──────────────────────────┬──────────────────────────────────┘
|
||||
│ SQLAlchemy 2.0
|
||||
┌──────────────────────────▼──────────────────────────────────┐
|
||||
│ 数据持久化层 │
|
||||
│ PostgreSQL 14 — 项目/帧/标注/Mask 元数据 │
|
||||
│ Redis 6 — 缓存 & 任务队列状态 │
|
||||
│ PostgreSQL 14 — 项目/帧/标注/Mask/Task 元数据 │
|
||||
│ Redis 6 — Celery broker/result backend + 进度 pub/sub │
|
||||
│ MinIO — 对象存储(原始视频/解析帧/Mask图像) │
|
||||
└─────────────────────────────────────────────────────────────┘
|
||||
```
|
||||
@@ -63,11 +64,12 @@
|
||||
| Canvas 渲染 | Konva + react-konva | - |
|
||||
| HTTP 客户端 | Axios | - |
|
||||
| 后端框架 | FastAPI | v0.136+ |
|
||||
| 数据库 ORM | SQLAlchemy + Alembic | 2.0+ |
|
||||
| 数据库 ORM | SQLAlchemy(依赖中包含 Alembic) | 2.0+ |
|
||||
| 数据库 | PostgreSQL | 14 |
|
||||
| 缓存 | Redis | 6 |
|
||||
| 队列 Broker | Redis | 6 |
|
||||
| 后台任务 | Celery worker | 5.6+ |
|
||||
| 对象存储 | MinIO | 2025+ |
|
||||
| AI 推理 | SAM 2 (Meta) + PyTorch | - |
|
||||
| AI 推理 | SAM 2 / SAM 3 (Meta) + PyTorch | - |
|
||||
| 视频处理 | FFmpeg + OpenCV | 4.4+ |
|
||||
| DICOM 处理 | pydicom | 3.0+ |
|
||||
|
||||
@@ -78,13 +80,17 @@
|
||||
```
|
||||
Seg_Server/
|
||||
├── backend/ # FastAPI 后端
|
||||
│ ├── main.py # 应用入口(CORS/生命周期/路由注册)
|
||||
│ ├── main.py # 应用入口(CORS/生命周期/路由注册/WebSocket)
|
||||
│ ├── config.py # 环境变量配置(Pydantic Settings)
|
||||
│ ├── database.py # SQLAlchemy 引擎 + Session
|
||||
│ ├── models.py # ORM 模型(Project/Frame/Template/Annotation/Mask)
|
||||
│ ├── models.py # ORM 模型(Project/Frame/Template/Annotation/Mask/ProcessingTask)
|
||||
│ ├── schemas.py # Pydantic 请求/响应校验模型
|
||||
│ ├── minio_client.py # MinIO 上传/下载/预签名URL封装
|
||||
│ ├── redis_client.py # Redis 连接封装
|
||||
│ ├── progress_events.py # 任务进度事件 payload 与 Redis 发布
|
||||
│ ├── statuses.py # 项目/任务状态常量
|
||||
│ ├── celery_app.py # Celery app 配置
|
||||
│ ├── worker_tasks.py # Celery 任务入口
|
||||
│ ├── download_sam2.py # SAM 2 模型权重自动下载脚本
|
||||
│ ├── requirements.txt # Python 依赖
|
||||
│ ├── routers/ # API 路由
|
||||
@@ -92,10 +98,12 @@ Seg_Server/
|
||||
│ │ ├── projects.py # 项目 & 帧 CRUD
|
||||
│ │ ├── templates.py # 本体字典管理
|
||||
│ │ ├── media.py # 上传 & 解析
|
||||
│ │ ├── ai.py # SAM 2 推理接口
|
||||
│ │ ├── ai.py # SAM 推理与模型状态接口
|
||||
│ │ └── export.py # 数据导出
|
||||
│ └── services/ # 业务服务
|
||||
│ ├── sam2_engine.py # SAM 2 推理引擎(懒加载 + stub降级)
|
||||
│ ├── sam3_engine.py # SAM 3 状态检测与文本语义推理适配器
|
||||
│ ├── sam_registry.py # SAM 模型选择、GPU 状态与推理分发
|
||||
│ └── frame_parser.py # FFmpeg 拆帧 / pydicom 读片
|
||||
├── src/ # React 前端
|
||||
│ ├── main.tsx # 应用挂载点
|
||||
@@ -121,8 +129,11 @@ Seg_Server/
|
||||
├── models/ # SAM 2 模型权重(.pt 文件)
|
||||
├── uploads/ # 临时上传目录
|
||||
├── frames/ # 临时帧目录
|
||||
├── doc/ # 当前实现审计、接口契约与后续实施文档
|
||||
├── public/
|
||||
│ └── logo.png # 侧边栏 Logo 静态资源
|
||||
├── start_services.sh # 一键启动所有服务脚本
|
||||
├── server.ts # 旧版 Express 入口(已弃用)
|
||||
├── server.ts # Express + Vite 前端入口(也保留少量旧版 mock API)
|
||||
├── index.html # SPA HTML 入口
|
||||
├── vite.config.ts # Vite 构建配置
|
||||
├── package.json # npm 依赖与脚本
|
||||
@@ -131,12 +142,23 @@ Seg_Server/
|
||||
|
||||
---
|
||||
|
||||
## 项目文档
|
||||
|
||||
当前实现审计与接口契约文档在 `doc/` 目录:
|
||||
|
||||
- `doc/01-purpose-and-word-summary.md` — 项目目的、Word 方案摘要与当前落地程度
|
||||
- `doc/03-frontend-element-audit.md` — 前端逐元素功能审计,标注真实可用、部分可用、Mock/UI-only、接口不通
|
||||
- `doc/04-api-contracts.md` — 前后端接口契约和已知不一致
|
||||
- `doc/06-fastapi-docs-explained.md` — `http://192.168.3.11:8000/docs` 的作用说明
|
||||
|
||||
---
|
||||
|
||||
## 环境准备
|
||||
|
||||
### 系统要求
|
||||
|
||||
- **OS**: Ubuntu 22.04 LTS
|
||||
- **GPU**: NVIDIA GPU(推荐 RTX 4090 或同等算力),用于 SAM 2 推理
|
||||
- **GPU**: NVIDIA GPU(推荐 RTX 4090 或同等算力),用于 SAM 推理;SAM 3 官方要求 Python 3.12+、PyTorch 2.7+ 和 CUDA 12.6+ 环境
|
||||
- **CUDA**: 12.x / 13.x
|
||||
- **Node.js**: 22.x+
|
||||
- **Python**: 3.11(通过 Miniconda/Anaconda 管理)
|
||||
@@ -223,19 +245,32 @@ python download_sam2.py
|
||||
|
||||
### 步骤 5: 配置环境变量
|
||||
|
||||
项目根目录已提供默认配置,如需修改请编辑以下文件:
|
||||
后端通过 `backend/config.py` 中的 Pydantic Settings 读取 `backend/.env`。如需覆盖默认值,请编辑以下文件:
|
||||
|
||||
**backend/.env**(数据库/Redis/MinIO/SAM 路径):
|
||||
```ini
|
||||
DATABASE_URL=postgresql://seguser:segpass123@localhost:5432/segserver
|
||||
REDIS_URL=redis://localhost:6379/0
|
||||
MINIO_ENDPOINT=localhost:9000
|
||||
MINIO_ACCESS_KEY=minioadmin
|
||||
MINIO_SECRET_KEY=minioadmin
|
||||
MINIO_BUCKET_NAME=seg-media
|
||||
SAM2_MODEL_PATH=/home/wkmgc/Desktop/Seg_Server/models/sam2_hiera_tiny.pt
|
||||
db_url=postgresql://seguser:segpass123@localhost:5432/segserver
|
||||
redis_url=redis://localhost:6379/0
|
||||
minio_endpoint=192.168.3.11:9000
|
||||
minio_access_key=minioadmin
|
||||
minio_secret_key=minioadmin
|
||||
minio_secure=false
|
||||
sam_model_path=/home/wkmgc/Desktop/Seg_Server/models/sam2_hiera_tiny.pt
|
||||
sam_model_config=configs/sam2/sam2_hiera_t.yaml
|
||||
sam_default_model=sam2
|
||||
sam3_model_version=sam3.1
|
||||
cors_origins=["http://localhost:3000","http://192.168.3.11:3000"]
|
||||
```
|
||||
|
||||
前端根目录的 `.env.example` 包含 AI Studio 注入变量和前端 API 配置:
|
||||
|
||||
```ini
|
||||
VITE_API_BASE_URL=http://192.168.3.11:8000
|
||||
VITE_WS_PROGRESS_URL=ws://192.168.3.11:8000/ws/progress
|
||||
```
|
||||
|
||||
如果未配置 `VITE_API_BASE_URL`,前端会按当前浏览器 hostname 推导 `http://<host>:8000`。
|
||||
|
||||
### 步骤 6: 启动后端服务
|
||||
|
||||
```bash
|
||||
@@ -252,7 +287,21 @@ nohup uvicorn main:app --host 0.0.0.0 --port 8000 > /tmp/fastapi.log 2>&1 &
|
||||
- 创建数据库表(如果不存在)
|
||||
- 检查 MinIO bucket 是否存在
|
||||
- 测试 Redis 连接
|
||||
- 懒加载 SAM 2 模型(权重存在且 sam2 包已安装时)
|
||||
- 懒加载 SAM 模型;`GET /api/ai/models/status` 会返回 SAM 2、SAM 3 与 GPU 的真实可用状态
|
||||
|
||||
### 步骤 6.1: 启动 Celery Worker
|
||||
|
||||
```bash
|
||||
cd ~/Desktop/Seg_Server/backend
|
||||
source ~/miniconda3/etc/profile.d/conda.sh
|
||||
conda activate seg_server
|
||||
celery -A celery_app:celery_app worker --loglevel=info --concurrency=1
|
||||
|
||||
# 或使用后台模式
|
||||
nohup celery -A celery_app:celery_app worker --loglevel=info --concurrency=1 > /tmp/celery.log 2>&1 &
|
||||
```
|
||||
|
||||
`POST /api/media/parse` 只创建 `processing_tasks` 记录并把任务投递给 Celery;真正的 FFmpeg/OpenCV/pydicom 拆帧由 worker 执行。worker 每次更新任务状态后会发布到 Redis `seg:progress` 频道,FastAPI 订阅后转发到 `/ws/progress`,前端 Dashboard 可实时更新。
|
||||
|
||||
### 步骤 7: 安装前端依赖并构建
|
||||
|
||||
@@ -286,7 +335,7 @@ cd ~/Desktop/Seg_Server
|
||||
./start_services.sh
|
||||
```
|
||||
|
||||
脚本将依次检查并启动:PostgreSQL → Redis → MinIO → FastAPI 后端 → 前端。
|
||||
脚本将依次检查并启动:PostgreSQL → Redis → MinIO → FastAPI 后端 → Celery Worker → 前端。
|
||||
|
||||
---
|
||||
|
||||
@@ -307,10 +356,12 @@ cd ~/Desktop/Seg_Server
|
||||
|
||||
```bash
|
||||
npm install # 安装依赖
|
||||
npm run dev # Vite 开发模式(端口 5173)
|
||||
npm run dev # 运行 tsx server.ts,Express + Vite 中间件(端口 3000)
|
||||
npm run build # 生产构建(输出到 dist/)
|
||||
npm run lint # TypeScript 类型检查
|
||||
npm start # Node.js 运行 server.ts(旧版)
|
||||
npm run test # Vitest watch 模式
|
||||
npm run test:run # Vitest 单次运行
|
||||
npm start # Node.js 运行 server.ts(生产静态服务 / 旧版 mock API)
|
||||
```
|
||||
|
||||
### 后端
|
||||
@@ -318,8 +369,11 @@ npm start # Node.js 运行 server.ts(旧版)
|
||||
```bash
|
||||
# 在 conda seg_server 环境中
|
||||
cd backend
|
||||
pip install -r requirements-dev.txt # 安装后端测试依赖
|
||||
pytest tests # 后端接口测试
|
||||
uvicorn main:app --host 0.0.0.0 --port 8000 --reload # 开发模式(热重载)
|
||||
uvicorn main:app --host 0.0.0.0 --port 8000 # 生产模式
|
||||
celery -A celery_app:celery_app worker --loglevel=info --concurrency=1 # 后台任务 worker
|
||||
```
|
||||
|
||||
---
|
||||
@@ -375,7 +429,25 @@ pip install -e . --no-build-isolation
|
||||
**检查清单**:
|
||||
1. 后端是否已启动(`curl http://localhost:8000/health`)
|
||||
2. `backend/.env` 中的 `cors_origins` 是否包含 `http://localhost:3000`
|
||||
3. 前端 `src/lib/api.ts` 中的 `baseURL` 是否为 `http://localhost:8000`
|
||||
3. 前端是否配置了正确的 `VITE_API_BASE_URL`;未配置时会按当前浏览器 hostname 推导 `http://<host>:8000`
|
||||
|
||||
### Q5: 如何验证 AI 推理或 COCO 导出接口
|
||||
|
||||
**当前状态**:
|
||||
|
||||
- 前端 `predictMask()` 已发送后端需要的 `image_id`、`prompt_type`、`prompt_data`,并把后端 `polygons` 转成 Konva `pathData`。
|
||||
- 工作区点选/框选会使用当前帧的数据库 `frame.id` 调用 `/api/ai/predict`。
|
||||
- 前端 `exportCoco()` 已对齐到 `/api/export/{projectId}/coco`。
|
||||
- 工作区“导出 JSON 标注集”按钮已绑定下载流程;导出前会先保存当前待归档的前端 mask。
|
||||
- 工作区“结构化归档保存”按钮会把当前项目未保存 mask 写入 `POST /api/ai/annotate`,并把 dirty mask 写入 `PATCH /api/ai/annotations/{id}`。
|
||||
- 工作区“清空遮罩”会通过 `DELETE /api/ai/annotations/{id}` 删除当前帧已保存标注,并清空当前帧本地 mask。
|
||||
|
||||
**验证**:
|
||||
|
||||
```bash
|
||||
curl http://localhost:8000/health
|
||||
curl http://localhost:8000/api/export/1/coco
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
|
||||
21
backend/celery_app.py
Normal file
21
backend/celery_app.py
Normal file
@@ -0,0 +1,21 @@
|
||||
"""Celery application for background processing."""
|
||||
|
||||
from celery import Celery
|
||||
|
||||
from config import settings
|
||||
|
||||
celery_app = Celery(
|
||||
"seg_server",
|
||||
broker=settings.redis_url,
|
||||
backend=settings.redis_url,
|
||||
include=["worker_tasks"],
|
||||
)
|
||||
|
||||
celery_app.conf.update(
|
||||
task_serializer="json",
|
||||
result_serializer="json",
|
||||
accept_content=["json"],
|
||||
timezone="Asia/Shanghai",
|
||||
enable_utc=True,
|
||||
task_track_started=True,
|
||||
)
|
||||
@@ -18,9 +18,11 @@ class Settings(BaseSettings):
|
||||
minio_secret_key: str = "minioadmin"
|
||||
minio_secure: bool = False
|
||||
|
||||
# SAM2
|
||||
# SAM
|
||||
sam_default_model: str = "sam2"
|
||||
sam_model_path: str = "/home/wkmgc/Desktop/Seg_Server/models/sam2_hiera_tiny.pt"
|
||||
sam_model_config: str = "configs/sam2/sam2_hiera_t.yaml"
|
||||
sam3_model_version: str = "sam3.1"
|
||||
|
||||
# App
|
||||
app_env: str = "development"
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
"""FastAPI application entrypoint."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
from contextlib import asynccontextmanager
|
||||
from contextlib import asynccontextmanager, suppress
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
@@ -13,9 +15,11 @@ from fastapi.middleware.cors import CORSMiddleware
|
||||
from config import settings
|
||||
from database import Base, engine, SessionLocal
|
||||
from minio_client import ensure_bucket_exists, upload_file
|
||||
from redis_client import ping as redis_ping
|
||||
from progress_events import PROGRESS_CHANNEL
|
||||
from redis_client import get_redis_client, ping as redis_ping
|
||||
from statuses import PROJECT_STATUS_PENDING, PROJECT_STATUS_READY
|
||||
|
||||
from routers import projects, templates, media, ai, export, auth
|
||||
from routers import projects, templates, media, ai, export, auth, dashboard, tasks
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
@@ -45,7 +49,7 @@ def _seed_default_project_sync() -> None:
|
||||
project = Project(
|
||||
name="Data_MyVideo_1",
|
||||
description="默认演示视频",
|
||||
status="pending",
|
||||
status=PROJECT_STATUS_PENDING,
|
||||
source_type="video",
|
||||
parse_fps=30.0,
|
||||
)
|
||||
@@ -98,7 +102,7 @@ def _seed_default_project_sync() -> None:
|
||||
)
|
||||
db.add(frame)
|
||||
|
||||
project.status = "ready"
|
||||
project.status = PROJECT_STATUS_READY
|
||||
db.commit()
|
||||
logger.info("Seeded default project id=%s with %d frames", project.id, len(object_names))
|
||||
finally:
|
||||
@@ -165,6 +169,7 @@ def _seed_default_templates_sync() -> None:
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""Application lifespan: startup and shutdown hooks."""
|
||||
progress_listener: asyncio.Task | None = None
|
||||
# Startup
|
||||
logger.info("Starting up SegServer backend...")
|
||||
|
||||
@@ -187,6 +192,11 @@ async def lifespan(app: FastAPI):
|
||||
else:
|
||||
logger.warning("Redis connection failed.")
|
||||
|
||||
try:
|
||||
progress_listener = asyncio.create_task(_progress_pubsub_loop())
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.error("Failed to start Redis progress subscription: %s", exc)
|
||||
|
||||
# Seed default templates
|
||||
try:
|
||||
asyncio.create_task(asyncio.to_thread(_seed_default_templates_sync))
|
||||
@@ -203,6 +213,10 @@ async def lifespan(app: FastAPI):
|
||||
|
||||
# Shutdown
|
||||
logger.info("Shutting down SegServer backend...")
|
||||
if progress_listener is not None:
|
||||
progress_listener.cancel()
|
||||
with suppress(asyncio.CancelledError):
|
||||
await progress_listener
|
||||
engine.dispose()
|
||||
|
||||
|
||||
@@ -229,6 +243,8 @@ app.include_router(templates.router)
|
||||
app.include_router(media.router)
|
||||
app.include_router(ai.router)
|
||||
app.include_router(export.router)
|
||||
app.include_router(dashboard.router)
|
||||
app.include_router(tasks.router)
|
||||
|
||||
|
||||
@app.get("/health", tags=["Health"])
|
||||
@@ -269,6 +285,34 @@ class ConnectionManager:
|
||||
manager = ConnectionManager()
|
||||
|
||||
|
||||
async def _progress_pubsub_loop() -> None:
|
||||
"""Forward Redis task-progress events to connected WebSocket clients."""
|
||||
while True:
|
||||
pubsub = None
|
||||
try:
|
||||
pubsub = get_redis_client().pubsub()
|
||||
await asyncio.to_thread(pubsub.subscribe, PROGRESS_CHANNEL)
|
||||
logger.info("Subscribed to Redis progress channel: %s", PROGRESS_CHANNEL)
|
||||
while True:
|
||||
message = await asyncio.to_thread(pubsub.get_message, True, 1.0)
|
||||
if message is None:
|
||||
await asyncio.sleep(0)
|
||||
continue
|
||||
raw_data = message.get("data")
|
||||
payload = json.loads(raw_data) if isinstance(raw_data, str) else raw_data
|
||||
if isinstance(payload, dict):
|
||||
await manager.broadcast(payload)
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.error("Redis progress subscription failed: %s", exc)
|
||||
await asyncio.sleep(5)
|
||||
finally:
|
||||
if pubsub is not None:
|
||||
with suppress(Exception):
|
||||
await asyncio.to_thread(pubsub.close)
|
||||
|
||||
|
||||
@app.websocket("/ws/progress")
|
||||
async def websocket_progress(websocket: WebSocket):
|
||||
"""WebSocket endpoint for real-time parsing/AI progress updates."""
|
||||
@@ -284,7 +328,7 @@ async def websocket_progress(websocket: WebSocket):
|
||||
"type": "status",
|
||||
"status": "connected",
|
||||
"message": "Progress stream active",
|
||||
"timestamp": str(logging.time.time() if hasattr(logging, 'time') else __import__('time').time()),
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
})
|
||||
except WebSocketDisconnect:
|
||||
manager.disconnect(websocket)
|
||||
|
||||
@@ -14,6 +14,7 @@ from sqlalchemy.orm import relationship
|
||||
from sqlalchemy.sql import func
|
||||
|
||||
from database import Base
|
||||
from statuses import PROJECT_STATUS_PENDING
|
||||
|
||||
|
||||
class Project(Base):
|
||||
@@ -26,7 +27,7 @@ class Project(Base):
|
||||
description = Column(Text, nullable=True)
|
||||
video_path = Column(String(512), nullable=True)
|
||||
thumbnail_url = Column(String(512), nullable=True)
|
||||
status = Column(String(50), default="Ready", nullable=False)
|
||||
status = Column(String(50), default=PROJECT_STATUS_PENDING, nullable=False)
|
||||
source_type = Column(String(20), default="video", nullable=False) # video | dicom
|
||||
original_fps = Column(Float, nullable=True)
|
||||
parse_fps = Column(Float, default=30.0, nullable=False)
|
||||
@@ -39,6 +40,9 @@ class Project(Base):
|
||||
annotations = relationship(
|
||||
"Annotation", back_populates="project", cascade="all, delete-orphan"
|
||||
)
|
||||
tasks = relationship(
|
||||
"ProcessingTask", back_populates="project", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
|
||||
class Frame(Base):
|
||||
@@ -121,3 +125,30 @@ class Mask(Base):
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
|
||||
annotation = relationship("Annotation", back_populates="masks")
|
||||
|
||||
|
||||
class ProcessingTask(Base):
|
||||
"""Background task state persisted for dashboard and polling."""
|
||||
|
||||
__tablename__ = "processing_tasks"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
task_type = Column(String(80), nullable=False)
|
||||
status = Column(String(40), default="queued", nullable=False)
|
||||
progress = Column(Integer, default=0, nullable=False)
|
||||
message = Column(Text, nullable=True)
|
||||
project_id = Column(
|
||||
Integer, ForeignKey("projects.id", ondelete="CASCADE"), nullable=True
|
||||
)
|
||||
celery_task_id = Column(String(255), nullable=True)
|
||||
payload = Column(JSON, nullable=True)
|
||||
result = Column(JSON, nullable=True)
|
||||
error = Column(Text, nullable=True)
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
started_at = Column(DateTime(timezone=True), nullable=True)
|
||||
finished_at = Column(DateTime(timezone=True), nullable=True)
|
||||
updated_at = Column(
|
||||
DateTime(timezone=True), server_default=func.now(), onupdate=func.now()
|
||||
)
|
||||
|
||||
project = relationship("Project", back_populates="tasks")
|
||||
|
||||
64
backend/progress_events.py
Normal file
64
backend/progress_events.py
Normal file
@@ -0,0 +1,64 @@
|
||||
"""Progress event payloads and Redis publication helpers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
from redis_client import get_redis_client
|
||||
from statuses import TASK_STATUS_FAILED, TASK_STATUS_SUCCESS
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
PROGRESS_CHANNEL = "seg:progress"
|
||||
|
||||
|
||||
def _iso_now() -> str:
|
||||
return datetime.now(timezone.utc).isoformat()
|
||||
|
||||
|
||||
def _event_type(task_status: str) -> str:
|
||||
if task_status == TASK_STATUS_SUCCESS:
|
||||
return "complete"
|
||||
if task_status == TASK_STATUS_FAILED:
|
||||
return "error"
|
||||
return "progress"
|
||||
|
||||
|
||||
def task_progress_payload(task: Any) -> dict[str, Any]:
|
||||
"""Build the WebSocket payload from a persisted processing task."""
|
||||
project = getattr(task, "project", None)
|
||||
project_name = getattr(project, "name", None)
|
||||
status = getattr(task, "status", "")
|
||||
updated_at = getattr(task, "updated_at", None)
|
||||
timestamp = updated_at.isoformat() if updated_at is not None else _iso_now()
|
||||
message = getattr(task, "message", None)
|
||||
|
||||
return {
|
||||
"type": _event_type(status),
|
||||
"taskId": f"task-{task.id}",
|
||||
"task_id": task.id,
|
||||
"project_id": getattr(task, "project_id", None),
|
||||
"projectName": project_name,
|
||||
"filename": project_name,
|
||||
"progress": getattr(task, "progress", 0),
|
||||
"status": message or status,
|
||||
"message": message,
|
||||
"error": getattr(task, "error", None),
|
||||
"timestamp": timestamp,
|
||||
}
|
||||
|
||||
|
||||
def publish_progress_event(payload: dict[str, Any]) -> None:
|
||||
"""Publish a JSON progress event without failing the worker on Redis errors."""
|
||||
try:
|
||||
get_redis_client().publish(PROGRESS_CHANNEL, json.dumps(payload, ensure_ascii=False))
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.warning("Failed to publish progress event: %s", exc)
|
||||
|
||||
|
||||
def publish_task_progress_event(task: Any) -> None:
|
||||
"""Publish a progress event for a ProcessingTask ORM object."""
|
||||
publish_progress_event(task_progress_payload(task))
|
||||
2
backend/requirements-dev.txt
Normal file
2
backend/requirements-dev.txt
Normal file
@@ -0,0 +1,2 @@
|
||||
pytest
|
||||
httpx
|
||||
@@ -1,18 +1,25 @@
|
||||
"""AI inference endpoints using SAM 2."""
|
||||
"""AI inference endpoints using selectable SAM runtimes."""
|
||||
|
||||
import logging
|
||||
from typing import Any, List
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from fastapi import APIRouter, Depends, HTTPException, Response, status
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from database import get_db
|
||||
from minio_client import download_file
|
||||
from models import Frame, Annotation
|
||||
from schemas import PredictRequest, PredictResponse, AnnotationOut, AnnotationCreate
|
||||
from services.sam2_engine import sam_engine
|
||||
from models import Project, Frame, Template, Annotation
|
||||
from schemas import (
|
||||
AiRuntimeStatus,
|
||||
PredictRequest,
|
||||
PredictResponse,
|
||||
AnnotationOut,
|
||||
AnnotationCreate,
|
||||
AnnotationUpdate,
|
||||
)
|
||||
from services.sam_registry import ModelUnavailableError, sam_registry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/api/ai", tags=["AI"])
|
||||
@@ -35,14 +42,15 @@ def _load_frame_image(frame: Frame) -> np.ndarray:
|
||||
@router.post(
|
||||
"/predict",
|
||||
response_model=PredictResponse,
|
||||
summary="Run SAM 2 inference with a prompt",
|
||||
summary="Run SAM inference with a prompt",
|
||||
)
|
||||
def predict(payload: PredictRequest, db: Session = Depends(get_db)) -> dict:
|
||||
"""Execute SAM 2 segmentation given an image and a prompt.
|
||||
"""Execute selected SAM segmentation given an image and a prompt.
|
||||
|
||||
- **point**: `prompt_data` is a list of `[[x, y], ...]` normalized coordinates.
|
||||
- **point**: `prompt_data` is either a list of `[[x, y], ...]` normalized
|
||||
coordinates or `{ "points": [[x, y], ...], "labels": [1, 0, ...] }`.
|
||||
- **box**: `prompt_data` is `[x1, y1, x2, y2]` normalized coordinates.
|
||||
- **semantic**: Not yet implemented; falls back to auto segmentation.
|
||||
- **semantic**: SAM 3 text prompt when model=`sam3`; SAM 2 falls back to auto.
|
||||
"""
|
||||
frame = db.query(Frame).filter(Frame.id == payload.image_id).first()
|
||||
if not frame:
|
||||
@@ -54,30 +62,57 @@ def predict(payload: PredictRequest, db: Session = Depends(get_db)) -> dict:
|
||||
polygons: List[List[List[float]]] = []
|
||||
scores: List[float] = []
|
||||
|
||||
if prompt_type == "point":
|
||||
points = payload.prompt_data
|
||||
if not isinstance(points, list) or len(points) == 0:
|
||||
raise HTTPException(status_code=400, detail="Invalid point prompt data")
|
||||
labels = [1] * len(points)
|
||||
polygons, scores = sam_engine.predict_points(image, points, labels)
|
||||
try:
|
||||
if prompt_type == "point":
|
||||
point_payload = payload.prompt_data
|
||||
if isinstance(point_payload, dict):
|
||||
points = point_payload.get("points")
|
||||
labels = point_payload.get("labels")
|
||||
else:
|
||||
points = point_payload
|
||||
labels = None
|
||||
|
||||
elif prompt_type == "box":
|
||||
box = payload.prompt_data
|
||||
if not isinstance(box, list) or len(box) != 4:
|
||||
raise HTTPException(status_code=400, detail="Invalid box prompt data")
|
||||
polygons, scores = sam_engine.predict_box(image, box)
|
||||
if not isinstance(points, list) or len(points) == 0:
|
||||
raise HTTPException(status_code=400, detail="Invalid point prompt data")
|
||||
if not isinstance(labels, list) or len(labels) != len(points):
|
||||
labels = [1] * len(points)
|
||||
polygons, scores = sam_registry.predict_points(payload.model, image, points, labels)
|
||||
|
||||
elif prompt_type == "semantic":
|
||||
# Placeholder: use auto segmentation for now
|
||||
logger.info("Semantic prompt not implemented; using auto segmentation")
|
||||
polygons, scores = sam_engine.predict_auto(image)
|
||||
elif prompt_type == "box":
|
||||
box = payload.prompt_data
|
||||
if not isinstance(box, list) or len(box) != 4:
|
||||
raise HTTPException(status_code=400, detail="Invalid box prompt data")
|
||||
polygons, scores = sam_registry.predict_box(payload.model, image, box)
|
||||
|
||||
else:
|
||||
raise HTTPException(status_code=400, detail=f"Unsupported prompt_type: {prompt_type}")
|
||||
elif prompt_type == "semantic":
|
||||
text = payload.prompt_data if isinstance(payload.prompt_data, str) else ""
|
||||
polygons, scores = sam_registry.predict_semantic(payload.model, image, text)
|
||||
|
||||
else:
|
||||
raise HTTPException(status_code=400, detail=f"Unsupported prompt_type: {prompt_type}")
|
||||
except ModelUnavailableError as exc:
|
||||
raise HTTPException(status_code=503, detail=str(exc)) from exc
|
||||
except NotImplementedError as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
||||
|
||||
return {"polygons": polygons, "scores": scores}
|
||||
|
||||
|
||||
@router.get(
|
||||
"/models/status",
|
||||
response_model=AiRuntimeStatus,
|
||||
summary="Get SAM model and GPU runtime status",
|
||||
)
|
||||
def model_status(selected_model: str | None = None) -> dict:
|
||||
"""Return real runtime availability for GPU, SAM 2, and SAM 3."""
|
||||
try:
|
||||
return sam_registry.runtime_status(selected_model)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
||||
|
||||
|
||||
@router.post(
|
||||
"/auto",
|
||||
response_model=PredictResponse,
|
||||
@@ -90,7 +125,10 @@ def auto_segment(image_id: int, db: Session = Depends(get_db)) -> dict:
|
||||
raise HTTPException(status_code=404, detail="Frame not found")
|
||||
|
||||
image = _load_frame_image(frame)
|
||||
polygons, scores = sam_engine.predict_auto(image)
|
||||
try:
|
||||
polygons, scores = sam_registry.predict_auto(None, image)
|
||||
except ModelUnavailableError as exc:
|
||||
raise HTTPException(status_code=503, detail=str(exc)) from exc
|
||||
|
||||
return {"polygons": polygons, "scores": scores}
|
||||
|
||||
@@ -106,7 +144,7 @@ def save_annotation(
|
||||
db: Session = Depends(get_db),
|
||||
) -> Annotation:
|
||||
"""Persist an annotation (mask, points, bbox) into the database."""
|
||||
project = db.query(Frame).filter(Frame.id == payload.project_id).first()
|
||||
project = db.query(Project).filter(Project.id == payload.project_id).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
|
||||
@@ -121,3 +159,74 @@ def save_annotation(
|
||||
db.refresh(annotation)
|
||||
logger.info("Saved annotation id=%s project_id=%s", annotation.id, annotation.project_id)
|
||||
return annotation
|
||||
|
||||
|
||||
@router.get(
|
||||
"/annotations",
|
||||
response_model=List[AnnotationOut],
|
||||
summary="List saved annotations for a project",
|
||||
)
|
||||
def list_annotations(
|
||||
project_id: int,
|
||||
frame_id: int | None = None,
|
||||
db: Session = Depends(get_db),
|
||||
) -> List[Annotation]:
|
||||
"""Return persisted annotations for a project, optionally scoped to one frame."""
|
||||
project = db.query(Project).filter(Project.id == project_id).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
|
||||
query = db.query(Annotation).filter(Annotation.project_id == project_id)
|
||||
if frame_id is not None:
|
||||
query = query.filter(Annotation.frame_id == frame_id)
|
||||
return query.order_by(Annotation.id).all()
|
||||
|
||||
|
||||
@router.patch(
|
||||
"/annotations/{annotation_id}",
|
||||
response_model=AnnotationOut,
|
||||
summary="Update a saved annotation",
|
||||
)
|
||||
def update_annotation(
|
||||
annotation_id: int,
|
||||
payload: AnnotationUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
) -> Annotation:
|
||||
"""Update mutable annotation fields persisted in the database."""
|
||||
annotation = db.query(Annotation).filter(Annotation.id == annotation_id).first()
|
||||
if not annotation:
|
||||
raise HTTPException(status_code=404, detail="Annotation not found")
|
||||
|
||||
updates = payload.model_dump(exclude_unset=True)
|
||||
if "template_id" in updates and updates["template_id"] is not None:
|
||||
template = db.query(Template).filter(Template.id == updates["template_id"]).first()
|
||||
if not template:
|
||||
raise HTTPException(status_code=404, detail="Template not found")
|
||||
|
||||
for field, value in updates.items():
|
||||
setattr(annotation, field, value)
|
||||
|
||||
db.commit()
|
||||
db.refresh(annotation)
|
||||
logger.info("Updated annotation id=%s", annotation.id)
|
||||
return annotation
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/annotations/{annotation_id}",
|
||||
status_code=status.HTTP_204_NO_CONTENT,
|
||||
summary="Delete a saved annotation",
|
||||
)
|
||||
def delete_annotation(
|
||||
annotation_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
) -> Response:
|
||||
"""Delete an annotation and its derived mask rows through ORM cascade."""
|
||||
annotation = db.query(Annotation).filter(Annotation.id == annotation_id).first()
|
||||
if not annotation:
|
||||
raise HTTPException(status_code=404, detail="Annotation not found")
|
||||
|
||||
db.delete(annotation)
|
||||
db.commit()
|
||||
logger.info("Deleted annotation id=%s", annotation_id)
|
||||
return Response(status_code=status.HTTP_204_NO_CONTENT)
|
||||
|
||||
137
backend/routers/dashboard.py
Normal file
137
backend/routers/dashboard.py
Normal file
@@ -0,0 +1,137 @@
|
||||
"""Dashboard overview endpoints."""
|
||||
|
||||
import os
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from database import get_db
|
||||
from models import Annotation, Frame, ProcessingTask, Project, Template
|
||||
|
||||
router = APIRouter(prefix="/api/dashboard", tags=["Dashboard"])
|
||||
|
||||
ACTIVE_TASK_STATUSES = {"queued", "running"}
|
||||
|
||||
|
||||
def _system_load_percent() -> int:
|
||||
"""Return a real host load estimate without adding a psutil dependency."""
|
||||
try:
|
||||
load_1m = os.getloadavg()[0]
|
||||
cpu_count = os.cpu_count() or 1
|
||||
return min(100, max(0, round((load_1m / cpu_count) * 100)))
|
||||
except (AttributeError, OSError):
|
||||
return 0
|
||||
|
||||
|
||||
def _iso_or_none(value: datetime | None) -> str | None:
|
||||
if value is None:
|
||||
return None
|
||||
if value.tzinfo is None:
|
||||
value = value.replace(tzinfo=timezone.utc)
|
||||
return value.isoformat()
|
||||
|
||||
|
||||
def _task_payload(task: ProcessingTask) -> dict[str, Any]:
|
||||
return {
|
||||
"id": f"task-{task.id}",
|
||||
"task_id": task.id,
|
||||
"project_id": task.project_id or 0,
|
||||
"name": task.project.name if task.project else f"任务 {task.id}",
|
||||
"progress": task.progress,
|
||||
"status": task.message or task.status,
|
||||
"frame_count": (task.result or {}).get("frames_extracted", 0),
|
||||
"updated_at": _iso_or_none(task.updated_at),
|
||||
}
|
||||
|
||||
|
||||
@router.get("/overview", summary="Get dashboard overview")
|
||||
def get_dashboard_overview(db: Session = Depends(get_db)) -> dict[str, Any]:
|
||||
"""Return live dashboard data derived from persisted backend records."""
|
||||
project_count = db.query(func.count(Project.id)).scalar() or 0
|
||||
frame_count = db.query(func.count(Frame.id)).scalar() or 0
|
||||
annotation_count = db.query(func.count(Annotation.id)).scalar() or 0
|
||||
template_count = db.query(func.count(Template.id)).scalar() or 0
|
||||
active_task_count = (
|
||||
db.query(func.count(ProcessingTask.id))
|
||||
.filter(ProcessingTask.status.in_(ACTIVE_TASK_STATUSES))
|
||||
.scalar()
|
||||
or 0
|
||||
)
|
||||
|
||||
projects = db.query(Project).order_by(Project.updated_at.desc()).all()
|
||||
recent_tasks = (
|
||||
db.query(ProcessingTask)
|
||||
.order_by(ProcessingTask.created_at.desc())
|
||||
.limit(50)
|
||||
.all()
|
||||
)
|
||||
tasks = [_task_payload(task) for task in recent_tasks if task.status in ACTIVE_TASK_STATUSES]
|
||||
|
||||
activities: list[dict[str, Any]] = []
|
||||
for task in recent_tasks[:10]:
|
||||
project_name = task.project.name if task.project else f"项目 {task.project_id}"
|
||||
activities.append({
|
||||
"id": f"task-{task.id}",
|
||||
"kind": "task",
|
||||
"time": _iso_or_none(task.updated_at),
|
||||
"message": task.message or f"任务状态: {task.status}",
|
||||
"project": project_name,
|
||||
})
|
||||
|
||||
for project in projects[:10]:
|
||||
activities.append({
|
||||
"id": f"project-{project.id}",
|
||||
"kind": "project",
|
||||
"time": _iso_or_none(project.updated_at),
|
||||
"message": f"项目状态: {project.status}",
|
||||
"project": project.name,
|
||||
})
|
||||
|
||||
recent_annotations = (
|
||||
db.query(Annotation)
|
||||
.order_by(Annotation.updated_at.desc())
|
||||
.limit(10)
|
||||
.all()
|
||||
)
|
||||
for annotation in recent_annotations:
|
||||
project_name = annotation.project.name if annotation.project else f"项目 {annotation.project_id}"
|
||||
activities.append({
|
||||
"id": f"annotation-{annotation.id}",
|
||||
"kind": "annotation",
|
||||
"time": _iso_or_none(annotation.updated_at),
|
||||
"message": f"标注已更新 #{annotation.id}",
|
||||
"project": project_name,
|
||||
})
|
||||
|
||||
recent_templates = (
|
||||
db.query(Template)
|
||||
.order_by(Template.created_at.desc())
|
||||
.limit(10)
|
||||
.all()
|
||||
)
|
||||
for template in recent_templates:
|
||||
activities.append({
|
||||
"id": f"template-{template.id}",
|
||||
"kind": "template",
|
||||
"time": _iso_or_none(template.created_at),
|
||||
"message": f"模板可用: {template.name}",
|
||||
"project": "系统",
|
||||
})
|
||||
|
||||
activities.sort(key=lambda item: item["time"] or "", reverse=True)
|
||||
|
||||
return {
|
||||
"summary": {
|
||||
"project_count": project_count,
|
||||
"parsing_task_count": active_task_count,
|
||||
"annotation_count": annotation_count,
|
||||
"frame_count": frame_count,
|
||||
"template_count": template_count,
|
||||
"system_load_percent": _system_load_percent(),
|
||||
},
|
||||
"tasks": tasks,
|
||||
"activity": activities[:10],
|
||||
}
|
||||
@@ -1,10 +1,6 @@
|
||||
"""Media upload and parsing endpoints."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
@@ -12,13 +8,12 @@ from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile, s
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from database import get_db
|
||||
from minio_client import upload_file, get_presigned_url, download_file
|
||||
from models import Project, Frame
|
||||
from schemas import FrameOut
|
||||
from services.frame_parser import (
|
||||
parse_video, parse_dicom, upload_frames_to_minio,
|
||||
extract_thumbnail, get_video_fps,
|
||||
)
|
||||
from minio_client import upload_file, get_presigned_url
|
||||
from models import ProcessingTask, Project
|
||||
from progress_events import publish_task_progress_event
|
||||
from schemas import ProcessingTaskOut
|
||||
from statuses import PROJECT_STATUS_PARSING, PROJECT_STATUS_PENDING, TASK_STATUS_QUEUED
|
||||
from worker_tasks import parse_project_media
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/api/media", tags=["Media"])
|
||||
@@ -79,7 +74,7 @@ async def upload_media(
|
||||
project = Project(
|
||||
name=file.filename,
|
||||
description="Auto-created from upload",
|
||||
status="pending",
|
||||
status=PROJECT_STATUS_PENDING,
|
||||
video_path=object_name,
|
||||
source_type="video",
|
||||
)
|
||||
@@ -135,7 +130,7 @@ async def upload_dicom_batch(
|
||||
project = Project(
|
||||
name=first_name,
|
||||
description=f"DICOM series with {len(files)} files",
|
||||
status="pending",
|
||||
status=PROJECT_STATUS_PENDING,
|
||||
source_type="dicom",
|
||||
)
|
||||
db.add(project)
|
||||
@@ -168,19 +163,18 @@ async def upload_dicom_batch(
|
||||
@router.post(
|
||||
"/parse",
|
||||
status_code=status.HTTP_202_ACCEPTED,
|
||||
response_model=ProcessingTaskOut,
|
||||
summary="Trigger frame extraction",
|
||||
)
|
||||
def parse_media(
|
||||
project_id: int,
|
||||
source_type: Optional[str] = None,
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
"""Trigger frame extraction for a project's uploaded media.
|
||||
) -> ProcessingTask:
|
||||
"""Create a background task for media frame extraction.
|
||||
|
||||
* video: uses FFmpeg or OpenCV fallback, extracts thumbnail.
|
||||
* dicom: uses pydicom to read DCM frames.
|
||||
|
||||
Extracted frames are uploaded to MinIO and registered in the database.
|
||||
The Celery worker performs the heavy FFmpeg/OpenCV/pydicom work and
|
||||
updates the persisted task record as it progresses.
|
||||
"""
|
||||
project = db.query(Project).filter(Project.id == project_id).first()
|
||||
if not project:
|
||||
@@ -190,100 +184,24 @@ def parse_media(
|
||||
raise HTTPException(status_code=400, detail="Project has no media uploaded")
|
||||
|
||||
effective_source = source_type or project.source_type or "video"
|
||||
parse_fps = project.parse_fps or 30.0
|
||||
|
||||
tmp_dir = tempfile.mkdtemp(prefix=f"seg_parse_{project_id}_")
|
||||
output_dir = os.path.join(tmp_dir, "frames")
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
try:
|
||||
if effective_source == "dicom":
|
||||
# Download all dicom files from MinIO
|
||||
dcm_dir = os.path.join(tmp_dir, "dcm")
|
||||
os.makedirs(dcm_dir, exist_ok=True)
|
||||
|
||||
from minio_client import get_minio_client, BUCKET_NAME
|
||||
client = get_minio_client()
|
||||
prefix = project.video_path
|
||||
objects = list(client.list_objects(BUCKET_NAME, prefix=prefix, recursive=True))
|
||||
for obj in objects:
|
||||
if obj.object_name.lower().endswith(".dcm"):
|
||||
data = download_file(obj.object_name)
|
||||
local_dcm = os.path.join(dcm_dir, os.path.basename(obj.object_name))
|
||||
with open(local_dcm, "wb") as f:
|
||||
f.write(data)
|
||||
|
||||
frame_files = parse_dicom(dcm_dir, output_dir)
|
||||
else:
|
||||
# Video: download and parse
|
||||
media_bytes = download_file(project.video_path)
|
||||
local_path = os.path.join(tmp_dir, Path(project.video_path).name)
|
||||
with open(local_path, "wb") as f:
|
||||
f.write(media_bytes)
|
||||
|
||||
frame_files, original_fps = parse_video(local_path, output_dir, fps=int(parse_fps))
|
||||
project.original_fps = original_fps
|
||||
|
||||
# Extract thumbnail from first frame
|
||||
thumbnail_path = os.path.join(tmp_dir, "thumbnail.jpg")
|
||||
try:
|
||||
extract_thumbnail(local_path, thumbnail_path)
|
||||
with open(thumbnail_path, "rb") as f:
|
||||
thumb_data = f.read()
|
||||
thumb_object = f"projects/{project_id}/thumbnail.jpg"
|
||||
upload_file(thumb_object, thumb_data, content_type="image/jpeg", length=len(thumb_data))
|
||||
project.thumbnail_url = thumb_object
|
||||
logger.info("Uploaded thumbnail for project_id=%s", project_id)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.warning("Thumbnail extraction failed: %s", exc)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.error("Frame extraction failed: %s", exc)
|
||||
shutil.rmtree(tmp_dir, ignore_errors=True)
|
||||
raise HTTPException(status_code=500, detail="Frame extraction failed") from exc
|
||||
|
||||
# Upload frames to MinIO
|
||||
try:
|
||||
object_names = upload_frames_to_minio(frame_files, project_id)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.error("Frame upload failed: %s", exc)
|
||||
shutil.rmtree(tmp_dir, ignore_errors=True)
|
||||
raise HTTPException(status_code=500, detail="Frame upload to storage failed") from exc
|
||||
|
||||
# Register frames in DB
|
||||
frames_out = []
|
||||
for idx, obj_name in enumerate(object_names):
|
||||
local_frame = frame_files[idx]
|
||||
try:
|
||||
import cv2
|
||||
img = cv2.imread(local_frame)
|
||||
h, w = img.shape[:2] if img is not None else (None, None)
|
||||
except Exception: # noqa: BLE001
|
||||
h, w = None, None
|
||||
|
||||
frame = Frame(
|
||||
project_id=project_id,
|
||||
frame_index=idx,
|
||||
image_url=obj_name,
|
||||
width=w,
|
||||
height=h,
|
||||
)
|
||||
db.add(frame)
|
||||
frames_out.append(frame)
|
||||
|
||||
task = ProcessingTask(
|
||||
task_type=f"parse_{effective_source}",
|
||||
status=TASK_STATUS_QUEUED,
|
||||
progress=0,
|
||||
message="解析任务已入队",
|
||||
project_id=project_id,
|
||||
payload={"source_type": effective_source},
|
||||
)
|
||||
project.status = PROJECT_STATUS_PARSING
|
||||
db.add(task)
|
||||
db.commit()
|
||||
for f in frames_out:
|
||||
db.refresh(f)
|
||||
db.refresh(task)
|
||||
publish_task_progress_event(task)
|
||||
|
||||
# Cleanup temp files
|
||||
shutil.rmtree(tmp_dir, ignore_errors=True)
|
||||
|
||||
project.status = "ready"
|
||||
async_result = parse_project_media.delay(task.id)
|
||||
task.celery_task_id = async_result.id
|
||||
db.commit()
|
||||
db.refresh(task)
|
||||
|
||||
logger.info("Parsed %d frames for project_id=%s", len(frames_out), project_id)
|
||||
return {
|
||||
"project_id": project_id,
|
||||
"frames_extracted": len(frames_out),
|
||||
"status": "ready",
|
||||
"message": "Frame extraction completed successfully.",
|
||||
}
|
||||
logger.info("Queued parse task id=%s project_id=%s celery_id=%s", task.id, project_id, async_result.id)
|
||||
return task
|
||||
|
||||
37
backend/routers/tasks.py
Normal file
37
backend/routers/tasks.py
Normal file
@@ -0,0 +1,37 @@
|
||||
"""Processing task query endpoints."""
|
||||
|
||||
from typing import List
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from database import get_db
|
||||
from models import ProcessingTask
|
||||
from schemas import ProcessingTaskOut
|
||||
|
||||
router = APIRouter(prefix="/api/tasks", tags=["Tasks"])
|
||||
|
||||
|
||||
@router.get("", response_model=List[ProcessingTaskOut], summary="List processing tasks")
|
||||
def list_tasks(
|
||||
project_id: int | None = None,
|
||||
status: str | None = None,
|
||||
limit: int = 50,
|
||||
db: Session = Depends(get_db),
|
||||
) -> List[ProcessingTask]:
|
||||
"""Return recent background processing tasks."""
|
||||
query = db.query(ProcessingTask)
|
||||
if project_id is not None:
|
||||
query = query.filter(ProcessingTask.project_id == project_id)
|
||||
if status is not None:
|
||||
query = query.filter(ProcessingTask.status == status)
|
||||
return query.order_by(ProcessingTask.created_at.desc()).limit(limit).all()
|
||||
|
||||
|
||||
@router.get("/{task_id}", response_model=ProcessingTaskOut, summary="Get processing task")
|
||||
def get_task(task_id: int, db: Session = Depends(get_db)) -> ProcessingTask:
|
||||
"""Return one background task by id."""
|
||||
task = db.query(ProcessingTask).filter(ProcessingTask.id == task_id).first()
|
||||
if not task:
|
||||
raise HTTPException(status_code=404, detail="Task not found")
|
||||
return task
|
||||
@@ -18,9 +18,9 @@ def _pack_mapping_rules(data: dict) -> dict:
|
||||
"""Pack classes/rules into mapping_rules for DB storage."""
|
||||
mapping = data.get("mapping_rules") or {}
|
||||
if "classes" in data and data["classes"] is not None:
|
||||
mapping["classes"] = data["classes"]
|
||||
mapping["classes"] = data.pop("classes")
|
||||
if "rules" in data and data["rules"] is not None:
|
||||
mapping["rules"] = data["rules"]
|
||||
mapping["rules"] = data.pop("rules")
|
||||
data["mapping_rules"] = mapping
|
||||
return data
|
||||
|
||||
|
||||
@@ -70,6 +70,7 @@ class FrameOut(FrameBase):
|
||||
# ---------------------------------------------------------------------------
|
||||
class TemplateBase(BaseModel):
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
color: str
|
||||
z_index: int = 0
|
||||
mapping_rules: Optional[dict[str, Any]] = None
|
||||
@@ -83,6 +84,7 @@ class TemplateCreate(TemplateBase):
|
||||
|
||||
class TemplateUpdate(BaseModel):
|
||||
name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
color: Optional[str] = None
|
||||
z_index: Optional[int] = None
|
||||
mapping_rules: Optional[dict[str, Any]] = None
|
||||
@@ -115,7 +117,7 @@ class AnnotationCreate(AnnotationBase):
|
||||
|
||||
class AnnotationUpdate(BaseModel):
|
||||
mask_data: Optional[dict[str, Any]] = None
|
||||
points: Optional[list[float]] = None
|
||||
points: Optional[list[list[float]]] = None
|
||||
bbox: Optional[list[float]] = None
|
||||
template_id: Optional[int] = None
|
||||
|
||||
@@ -148,6 +150,28 @@ class MaskOut(MaskBase):
|
||||
created_at: datetime
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Processing task schemas
|
||||
# ---------------------------------------------------------------------------
|
||||
class ProcessingTaskOut(BaseModel):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: int
|
||||
task_type: str
|
||||
status: str
|
||||
progress: int
|
||||
message: Optional[str] = None
|
||||
project_id: Optional[int] = None
|
||||
celery_task_id: Optional[str] = None
|
||||
payload: Optional[dict[str, Any]] = None
|
||||
result: Optional[dict[str, Any]] = None
|
||||
error: Optional[str] = None
|
||||
created_at: datetime
|
||||
started_at: Optional[datetime] = None
|
||||
finished_at: Optional[datetime] = None
|
||||
updated_at: datetime
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# AI schemas
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -155,6 +179,7 @@ class PredictRequest(BaseModel):
|
||||
image_id: int
|
||||
prompt_type: str # point / box / semantic
|
||||
prompt_data: Any
|
||||
model: Optional[str] = None
|
||||
|
||||
|
||||
class PredictResponse(BaseModel):
|
||||
@@ -162,6 +187,37 @@ class PredictResponse(BaseModel):
|
||||
scores: Optional[list[float]] = None
|
||||
|
||||
|
||||
class AiModelStatus(BaseModel):
|
||||
id: str
|
||||
label: str
|
||||
available: bool
|
||||
loaded: bool = False
|
||||
device: str
|
||||
supports: list[str]
|
||||
message: str
|
||||
package_available: bool = False
|
||||
checkpoint_exists: bool = False
|
||||
checkpoint_path: Optional[str] = None
|
||||
python_ok: bool = True
|
||||
torch_ok: bool = True
|
||||
cuda_required: bool = False
|
||||
|
||||
|
||||
class GpuStatus(BaseModel):
|
||||
available: bool
|
||||
device: str
|
||||
name: Optional[str] = None
|
||||
torch_available: bool
|
||||
torch_version: Optional[str] = None
|
||||
cuda_version: Optional[str] = None
|
||||
|
||||
|
||||
class AiRuntimeStatus(BaseModel):
|
||||
selected_model: str
|
||||
gpu: GpuStatus
|
||||
models: list[AiModelStatus]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Export schemas
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
220
backend/services/media_task_runner.py
Normal file
220
backend/services/media_task_runner.py
Normal file
@@ -0,0 +1,220 @@
|
||||
"""Background media parsing runner used by Celery workers."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from minio_client import BUCKET_NAME, download_file, get_minio_client, upload_file
|
||||
from models import Frame, ProcessingTask, Project
|
||||
from progress_events import publish_task_progress_event
|
||||
from services.frame_parser import (
|
||||
extract_thumbnail,
|
||||
parse_dicom,
|
||||
parse_video,
|
||||
upload_frames_to_minio,
|
||||
)
|
||||
from statuses import (
|
||||
PROJECT_STATUS_ERROR,
|
||||
PROJECT_STATUS_PARSING,
|
||||
PROJECT_STATUS_READY,
|
||||
TASK_STATUS_FAILED,
|
||||
TASK_STATUS_RUNNING,
|
||||
TASK_STATUS_SUCCESS,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _now() -> datetime:
|
||||
return datetime.now(timezone.utc)
|
||||
|
||||
|
||||
def _set_task_state(
|
||||
db: Session,
|
||||
task: ProcessingTask,
|
||||
*,
|
||||
status: str | None = None,
|
||||
progress: int | None = None,
|
||||
message: str | None = None,
|
||||
result: dict[str, Any] | None = None,
|
||||
error: str | None = None,
|
||||
started: bool = False,
|
||||
finished: bool = False,
|
||||
) -> None:
|
||||
if status is not None:
|
||||
task.status = status
|
||||
if progress is not None:
|
||||
task.progress = max(0, min(100, progress))
|
||||
if message is not None:
|
||||
task.message = message
|
||||
if result is not None:
|
||||
task.result = result
|
||||
if error is not None:
|
||||
task.error = error
|
||||
if started:
|
||||
task.started_at = _now()
|
||||
if finished:
|
||||
task.finished_at = _now()
|
||||
db.commit()
|
||||
db.refresh(task)
|
||||
publish_task_progress_event(task)
|
||||
|
||||
|
||||
def run_parse_media_task(db: Session, task_id: int) -> dict[str, Any]:
|
||||
"""Parse one project's media and update task progress in the database."""
|
||||
task = db.query(ProcessingTask).filter(ProcessingTask.id == task_id).first()
|
||||
if not task:
|
||||
raise ValueError(f"Task not found: {task_id}")
|
||||
|
||||
if task.project_id is None:
|
||||
_set_task_state(
|
||||
db,
|
||||
task,
|
||||
status=TASK_STATUS_FAILED,
|
||||
progress=100,
|
||||
message="任务缺少 project_id",
|
||||
error="Task has no project_id",
|
||||
finished=True,
|
||||
)
|
||||
raise ValueError("Task has no project_id")
|
||||
|
||||
project = db.query(Project).filter(Project.id == task.project_id).first()
|
||||
if not project:
|
||||
_set_task_state(
|
||||
db,
|
||||
task,
|
||||
status=TASK_STATUS_FAILED,
|
||||
progress=100,
|
||||
message="项目不存在",
|
||||
error="Project not found",
|
||||
finished=True,
|
||||
)
|
||||
raise ValueError(f"Project not found: {task.project_id}")
|
||||
|
||||
if not project.video_path:
|
||||
_set_task_state(
|
||||
db,
|
||||
task,
|
||||
status=TASK_STATUS_FAILED,
|
||||
progress=100,
|
||||
message="项目没有可解析媒体",
|
||||
error="Project has no media uploaded",
|
||||
finished=True,
|
||||
)
|
||||
project.status = PROJECT_STATUS_ERROR
|
||||
db.commit()
|
||||
raise ValueError("Project has no media uploaded")
|
||||
|
||||
project.status = PROJECT_STATUS_PARSING
|
||||
_set_task_state(db, task, status=TASK_STATUS_RUNNING, progress=5, message="后台解析已启动", started=True)
|
||||
|
||||
effective_source = (task.payload or {}).get("source_type") or project.source_type or "video"
|
||||
parse_fps = project.parse_fps or 30.0
|
||||
tmp_dir = tempfile.mkdtemp(prefix=f"seg_parse_{project.id}_")
|
||||
output_dir = os.path.join(tmp_dir, "frames")
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
try:
|
||||
_set_task_state(db, task, progress=15, message="正在下载媒体文件")
|
||||
if effective_source == "dicom":
|
||||
dcm_dir = os.path.join(tmp_dir, "dcm")
|
||||
os.makedirs(dcm_dir, exist_ok=True)
|
||||
|
||||
client = get_minio_client()
|
||||
objects = list(client.list_objects(BUCKET_NAME, prefix=project.video_path, recursive=True))
|
||||
for obj in objects:
|
||||
if obj.object_name.lower().endswith(".dcm"):
|
||||
data = download_file(obj.object_name)
|
||||
local_dcm = os.path.join(dcm_dir, os.path.basename(obj.object_name))
|
||||
with open(local_dcm, "wb") as f:
|
||||
f.write(data)
|
||||
|
||||
_set_task_state(db, task, progress=35, message="正在解析 DICOM 序列")
|
||||
frame_files = parse_dicom(dcm_dir, output_dir)
|
||||
else:
|
||||
media_bytes = download_file(project.video_path)
|
||||
local_path = os.path.join(tmp_dir, Path(project.video_path).name)
|
||||
with open(local_path, "wb") as f:
|
||||
f.write(media_bytes)
|
||||
|
||||
_set_task_state(db, task, progress=35, message="正在使用 FFmpeg/OpenCV 拆帧")
|
||||
frame_files, original_fps = parse_video(local_path, output_dir, fps=int(parse_fps))
|
||||
project.original_fps = original_fps
|
||||
|
||||
thumbnail_path = os.path.join(tmp_dir, "thumbnail.jpg")
|
||||
try:
|
||||
extract_thumbnail(local_path, thumbnail_path)
|
||||
with open(thumbnail_path, "rb") as f:
|
||||
thumb_data = f.read()
|
||||
thumb_object = f"projects/{project.id}/thumbnail.jpg"
|
||||
upload_file(thumb_object, thumb_data, content_type="image/jpeg", length=len(thumb_data))
|
||||
project.thumbnail_url = thumb_object
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.warning("Thumbnail extraction failed: %s", exc)
|
||||
|
||||
_set_task_state(db, task, progress=70, message="正在上传帧到对象存储")
|
||||
object_names = upload_frames_to_minio(frame_files, project.id)
|
||||
|
||||
_set_task_state(db, task, progress=85, message="正在写入帧索引")
|
||||
frames_out = []
|
||||
for idx, obj_name in enumerate(object_names):
|
||||
local_frame = frame_files[idx]
|
||||
try:
|
||||
import cv2
|
||||
|
||||
img = cv2.imread(local_frame)
|
||||
h, w = img.shape[:2] if img is not None else (None, None)
|
||||
except Exception: # noqa: BLE001
|
||||
h, w = None, None
|
||||
|
||||
frame = Frame(
|
||||
project_id=project.id,
|
||||
frame_index=idx,
|
||||
image_url=obj_name,
|
||||
width=w,
|
||||
height=h,
|
||||
)
|
||||
db.add(frame)
|
||||
frames_out.append(frame)
|
||||
|
||||
project.status = PROJECT_STATUS_READY
|
||||
db.commit()
|
||||
|
||||
result = {
|
||||
"project_id": project.id,
|
||||
"frames_extracted": len(frames_out),
|
||||
"status": PROJECT_STATUS_READY,
|
||||
"message": "Frame extraction completed successfully.",
|
||||
}
|
||||
_set_task_state(
|
||||
db,
|
||||
task,
|
||||
status=TASK_STATUS_SUCCESS,
|
||||
progress=100,
|
||||
message="解析完成",
|
||||
result=result,
|
||||
finished=True,
|
||||
)
|
||||
logger.info("Parsed %d frames for project_id=%s", len(frames_out), project.id)
|
||||
return result
|
||||
except Exception as exc: # noqa: BLE001
|
||||
project.status = PROJECT_STATUS_ERROR
|
||||
_set_task_state(
|
||||
db,
|
||||
task,
|
||||
status=TASK_STATUS_FAILED,
|
||||
progress=100,
|
||||
message="解析失败",
|
||||
error=str(exc),
|
||||
finished=True,
|
||||
)
|
||||
logger.error("Frame extraction failed: %s", exc)
|
||||
raise
|
||||
finally:
|
||||
shutil.rmtree(tmp_dir, ignore_errors=True)
|
||||
@@ -1,4 +1,4 @@
|
||||
"""SAM 2 engine wrapper with lazy loading and fallback stubs."""
|
||||
"""SAM 2 engine wrapper with lazy loading and explicit runtime status."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
@@ -11,10 +11,18 @@ from config import settings
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Attempt to import SAM 2; fall back to stubs if unavailable.
|
||||
# Attempt to import PyTorch and SAM 2; fall back to stubs if unavailable.
|
||||
# ---------------------------------------------------------------------------
|
||||
try:
|
||||
import torch
|
||||
|
||||
TORCH_AVAILABLE = True
|
||||
except Exception as exc: # noqa: BLE001
|
||||
TORCH_AVAILABLE = False
|
||||
torch = None # type: ignore[assignment]
|
||||
logger.warning("PyTorch import failed (%s). SAM2 will be unavailable.", exc)
|
||||
|
||||
try:
|
||||
from sam2.build_sam import build_sam2
|
||||
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
||||
|
||||
@@ -31,6 +39,8 @@ class SAM2Engine:
|
||||
def __init__(self) -> None:
|
||||
self._predictor: Optional[SAM2ImagePredictor] = None
|
||||
self._model_loaded = False
|
||||
self._loaded_device: str | None = None
|
||||
self._last_error: str | None = None
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# Internal helpers
|
||||
@@ -40,34 +50,87 @@ class SAM2Engine:
|
||||
if self._model_loaded:
|
||||
return
|
||||
|
||||
if not TORCH_AVAILABLE:
|
||||
self._last_error = "PyTorch is not installed."
|
||||
logger.warning("PyTorch not available; skipping SAM2 model load.")
|
||||
self._model_loaded = True
|
||||
return
|
||||
|
||||
if not SAM2_AVAILABLE:
|
||||
self._last_error = "sam2 package is not installed."
|
||||
logger.warning("SAM2 not available; skipping model load.")
|
||||
self._model_loaded = True
|
||||
return
|
||||
|
||||
if not os.path.isfile(settings.sam_model_path):
|
||||
self._last_error = f"SAM2 checkpoint not found: {settings.sam_model_path}"
|
||||
logger.error("SAM checkpoint not found at %s", settings.sam_model_path)
|
||||
self._model_loaded = True
|
||||
return
|
||||
|
||||
try:
|
||||
device = self._best_device()
|
||||
model = build_sam2(
|
||||
settings.sam_model_config,
|
||||
settings.sam_model_path,
|
||||
device="cuda",
|
||||
device=device,
|
||||
)
|
||||
self._predictor = SAM2ImagePredictor(model)
|
||||
self._model_loaded = True
|
||||
logger.info("SAM 2 model loaded from %s", settings.sam_model_path)
|
||||
self._loaded_device = device
|
||||
self._last_error = None
|
||||
logger.info("SAM 2 model loaded from %s on %s", settings.sam_model_path, device)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
self._last_error = str(exc)
|
||||
logger.error("Failed to load SAM 2 model: %s", exc)
|
||||
self._model_loaded = True # Prevent repeated load attempts
|
||||
|
||||
def _best_device(self) -> str:
|
||||
if TORCH_AVAILABLE and torch is not None and torch.cuda.is_available():
|
||||
return "cuda"
|
||||
return "cpu"
|
||||
|
||||
def _ensure_ready(self) -> bool:
|
||||
"""Ensure the model is loaded; return whether it is usable."""
|
||||
self._load_model()
|
||||
return SAM2_AVAILABLE and self._predictor is not None
|
||||
|
||||
def status(self) -> dict:
|
||||
"""Return lightweight, real runtime status without forcing model load."""
|
||||
checkpoint_exists = os.path.isfile(settings.sam_model_path)
|
||||
device = self._loaded_device or self._best_device()
|
||||
available = bool(TORCH_AVAILABLE and SAM2_AVAILABLE and checkpoint_exists)
|
||||
if self._predictor is not None:
|
||||
message = "SAM 2 model loaded and ready."
|
||||
elif available:
|
||||
message = "SAM 2 dependencies and checkpoint are present; model will load on first inference."
|
||||
else:
|
||||
missing = []
|
||||
if not TORCH_AVAILABLE:
|
||||
missing.append("PyTorch")
|
||||
if not SAM2_AVAILABLE:
|
||||
missing.append("sam2 package")
|
||||
if not checkpoint_exists:
|
||||
missing.append("checkpoint")
|
||||
message = f"SAM 2 unavailable: missing {', '.join(missing)}."
|
||||
if self._last_error and not self._predictor:
|
||||
message = self._last_error
|
||||
return {
|
||||
"id": "sam2",
|
||||
"label": "SAM 2",
|
||||
"available": available,
|
||||
"loaded": self._predictor is not None,
|
||||
"device": device,
|
||||
"supports": ["point", "box", "auto"],
|
||||
"message": message,
|
||||
"package_available": SAM2_AVAILABLE,
|
||||
"checkpoint_exists": checkpoint_exists,
|
||||
"checkpoint_path": settings.sam_model_path,
|
||||
"python_ok": True,
|
||||
"torch_ok": TORCH_AVAILABLE,
|
||||
"cuda_required": False,
|
||||
}
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# Public API
|
||||
# -----------------------------------------------------------------------
|
||||
|
||||
148
backend/services/sam3_engine.py
Normal file
148
backend/services/sam3_engine.py
Normal file
@@ -0,0 +1,148 @@
|
||||
"""SAM 3 engine adapter and runtime status.
|
||||
|
||||
The official facebookresearch/sam3 package currently targets Python 3.12+
|
||||
and CUDA-capable PyTorch. This adapter reports those requirements honestly and
|
||||
only performs inference when the local runtime can actually import and execute
|
||||
the package.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib.util
|
||||
import logging
|
||||
import sys
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from config import settings
|
||||
from services.sam2_engine import SAM2Engine
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
import torch
|
||||
|
||||
TORCH_AVAILABLE = True
|
||||
except Exception as exc: # noqa: BLE001
|
||||
TORCH_AVAILABLE = False
|
||||
torch = None # type: ignore[assignment]
|
||||
logger.warning("PyTorch import failed (%s). SAM3 will be unavailable.", exc)
|
||||
|
||||
SAM3_PACKAGE_AVAILABLE = importlib.util.find_spec("sam3") is not None
|
||||
|
||||
|
||||
class SAM3Engine:
|
||||
"""Lazy SAM 3 image inference adapter."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._model: Any | None = None
|
||||
self._processor: Any | None = None
|
||||
self._model_loaded = False
|
||||
self._last_error: str | None = None
|
||||
|
||||
def _python_ok(self) -> bool:
|
||||
return sys.version_info >= (3, 12)
|
||||
|
||||
def _gpu_ok(self) -> bool:
|
||||
return bool(TORCH_AVAILABLE and torch is not None and torch.cuda.is_available())
|
||||
|
||||
def _can_load(self) -> bool:
|
||||
return bool(SAM3_PACKAGE_AVAILABLE and TORCH_AVAILABLE and self._python_ok() and self._gpu_ok())
|
||||
|
||||
def _load_model(self) -> None:
|
||||
if self._model_loaded:
|
||||
return
|
||||
if not self._can_load():
|
||||
self._last_error = self._status_message()
|
||||
self._model_loaded = True
|
||||
return
|
||||
|
||||
try:
|
||||
from sam3.model.sam3_image_processor import Sam3Processor
|
||||
from sam3.model_builder import build_sam3_image_model
|
||||
|
||||
self._model = build_sam3_image_model()
|
||||
self._processor = Sam3Processor(self._model)
|
||||
self._model_loaded = True
|
||||
self._last_error = None
|
||||
logger.info("SAM 3 image model loaded with version setting %s", settings.sam3_model_version)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
self._last_error = str(exc)
|
||||
self._model_loaded = True
|
||||
logger.error("Failed to load SAM 3 model: %s", exc)
|
||||
|
||||
def _ensure_ready(self) -> bool:
|
||||
self._load_model()
|
||||
return self._processor is not None
|
||||
|
||||
def _status_message(self) -> str:
|
||||
missing = []
|
||||
if not SAM3_PACKAGE_AVAILABLE:
|
||||
missing.append("sam3 package")
|
||||
if not self._python_ok():
|
||||
missing.append("Python 3.12+ runtime")
|
||||
if not TORCH_AVAILABLE:
|
||||
missing.append("PyTorch")
|
||||
if not self._gpu_ok():
|
||||
missing.append("CUDA GPU")
|
||||
if missing:
|
||||
return f"SAM 3 unavailable: missing {', '.join(missing)}."
|
||||
return "SAM 3 dependencies are present; model will load on first inference."
|
||||
|
||||
def status(self) -> dict:
|
||||
available = self._can_load()
|
||||
return {
|
||||
"id": "sam3",
|
||||
"label": "SAM 3",
|
||||
"available": available,
|
||||
"loaded": self._processor is not None,
|
||||
"device": "cuda" if self._gpu_ok() else "unavailable",
|
||||
"supports": ["semantic"],
|
||||
"message": "SAM 3 model loaded and ready." if self._processor is not None else (self._last_error or self._status_message()),
|
||||
"package_available": SAM3_PACKAGE_AVAILABLE,
|
||||
"checkpoint_exists": SAM3_PACKAGE_AVAILABLE,
|
||||
"checkpoint_path": f"official/HuggingFace ({settings.sam3_model_version})",
|
||||
"python_ok": self._python_ok(),
|
||||
"torch_ok": TORCH_AVAILABLE,
|
||||
"cuda_required": True,
|
||||
}
|
||||
|
||||
def predict_semantic(self, image: np.ndarray, text: str) -> tuple[list[list[list[float]]], list[float]]:
|
||||
if not text.strip():
|
||||
raise ValueError("SAM 3 semantic prompt requires non-empty text.")
|
||||
if not self._ensure_ready():
|
||||
raise RuntimeError(self.status()["message"])
|
||||
|
||||
pil_image = Image.fromarray(image)
|
||||
with torch.inference_mode(): # type: ignore[union-attr]
|
||||
state = self._processor.set_image(pil_image)
|
||||
output = self._processor.set_text_prompt(state=state, prompt=text.strip())
|
||||
|
||||
masks = output.get("masks", [])
|
||||
scores = output.get("scores", [])
|
||||
polygons = []
|
||||
for mask in masks:
|
||||
if hasattr(mask, "detach"):
|
||||
mask = mask.detach().cpu().numpy()
|
||||
if mask.ndim == 3:
|
||||
mask = mask[0]
|
||||
poly = SAM2Engine._mask_to_polygon(mask)
|
||||
if poly:
|
||||
polygons.append(poly)
|
||||
|
||||
if hasattr(scores, "detach"):
|
||||
scores = scores.detach().cpu().tolist()
|
||||
elif hasattr(scores, "tolist"):
|
||||
scores = scores.tolist()
|
||||
return polygons, list(scores)
|
||||
|
||||
def predict_points(self, *_args: Any, **_kwargs: Any) -> tuple[list[list[list[float]]], list[float]]:
|
||||
raise NotImplementedError("This backend currently exposes SAM 3 semantic text inference; use SAM 2 for point prompts.")
|
||||
|
||||
def predict_box(self, *_args: Any, **_kwargs: Any) -> tuple[list[list[list[float]]], list[float]]:
|
||||
raise NotImplementedError("This backend currently exposes SAM 3 semantic text inference; use SAM 2 for box prompts.")
|
||||
|
||||
|
||||
sam3_engine = SAM3Engine()
|
||||
80
backend/services/sam_registry.py
Normal file
80
backend/services/sam_registry.py
Normal file
@@ -0,0 +1,80 @@
|
||||
"""Model registry for SAM runtimes and GPU status."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from config import settings
|
||||
from services.sam2_engine import TORCH_AVAILABLE, sam_engine as sam2_engine
|
||||
from services.sam3_engine import sam3_engine
|
||||
|
||||
try:
|
||||
import torch
|
||||
except Exception: # noqa: BLE001
|
||||
torch = None # type: ignore[assignment]
|
||||
|
||||
|
||||
class ModelUnavailableError(RuntimeError):
|
||||
"""Raised when a selected model cannot run in this environment."""
|
||||
|
||||
|
||||
class SAMRegistry:
|
||||
"""Dispatch predictions to the selected SAM backend."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._engines = {
|
||||
"sam2": sam2_engine,
|
||||
"sam3": sam3_engine,
|
||||
}
|
||||
|
||||
def normalize_model_id(self, model_id: str | None) -> str:
|
||||
selected = (model_id or settings.sam_default_model or "sam2").lower()
|
||||
if selected not in self._engines:
|
||||
raise ValueError(f"Unsupported model: {model_id}")
|
||||
return selected
|
||||
|
||||
def runtime_status(self, selected_model: str | None = None) -> dict[str, Any]:
|
||||
return {
|
||||
"selected_model": self.normalize_model_id(selected_model),
|
||||
"gpu": self.gpu_status(),
|
||||
"models": [engine.status() for engine in self._engines.values()],
|
||||
}
|
||||
|
||||
def gpu_status(self) -> dict[str, Any]:
|
||||
cuda_available = bool(TORCH_AVAILABLE and torch is not None and torch.cuda.is_available())
|
||||
return {
|
||||
"available": cuda_available,
|
||||
"device": "cuda" if cuda_available else "cpu",
|
||||
"name": torch.cuda.get_device_name(0) if cuda_available else None,
|
||||
"torch_available": bool(TORCH_AVAILABLE),
|
||||
"torch_version": getattr(torch, "__version__", None) if torch is not None else None,
|
||||
"cuda_version": getattr(torch.version, "cuda", None) if torch is not None else None,
|
||||
}
|
||||
|
||||
def _engine(self, model_id: str | None) -> Any:
|
||||
return self._engines[self.normalize_model_id(model_id)]
|
||||
|
||||
def _ensure_available(self, model_id: str | None) -> Any:
|
||||
engine = self._engine(model_id)
|
||||
status = engine.status()
|
||||
if not status["available"]:
|
||||
raise ModelUnavailableError(status["message"])
|
||||
return engine
|
||||
|
||||
def predict_points(self, model_id: str | None, image: Any, points: list[list[float]], labels: list[int]):
|
||||
return self._ensure_available(model_id).predict_points(image, points, labels)
|
||||
|
||||
def predict_box(self, model_id: str | None, image: Any, box: list[float]):
|
||||
return self._ensure_available(model_id).predict_box(image, box)
|
||||
|
||||
def predict_auto(self, model_id: str | None, image: Any):
|
||||
return self._ensure_available(model_id).predict_auto(image)
|
||||
|
||||
def predict_semantic(self, model_id: str | None, image: Any, text: str):
|
||||
model = self.normalize_model_id(model_id)
|
||||
if model == "sam3":
|
||||
return self._ensure_available(model).predict_semantic(image, text)
|
||||
return self._ensure_available(model).predict_auto(image)
|
||||
|
||||
|
||||
sam_registry = SAMRegistry()
|
||||
11
backend/statuses.py
Normal file
11
backend/statuses.py
Normal file
@@ -0,0 +1,11 @@
|
||||
"""Shared status constants used across backend project/task flows."""
|
||||
|
||||
PROJECT_STATUS_PENDING = "pending"
|
||||
PROJECT_STATUS_PARSING = "parsing"
|
||||
PROJECT_STATUS_READY = "ready"
|
||||
PROJECT_STATUS_ERROR = "error"
|
||||
|
||||
TASK_STATUS_QUEUED = "queued"
|
||||
TASK_STATUS_RUNNING = "running"
|
||||
TASK_STATUS_SUCCESS = "success"
|
||||
TASK_STATUS_FAILED = "failed"
|
||||
72
backend/tests/conftest.py
Normal file
72
backend/tests/conftest.py
Normal file
@@ -0,0 +1,72 @@
|
||||
"""Shared pytest fixtures for backend API tests."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Iterator
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
from sqlalchemy.pool import StaticPool
|
||||
|
||||
BACKEND_DIR = Path(__file__).resolve().parents[1]
|
||||
if str(BACKEND_DIR) not in sys.path:
|
||||
sys.path.insert(0, str(BACKEND_DIR))
|
||||
|
||||
from database import Base, get_db # noqa: E402
|
||||
from main import websocket_progress # noqa: E402
|
||||
from routers import ai, auth, dashboard, export, media, projects, tasks, templates # noqa: E402
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def db_session() -> Iterator[Session]:
|
||||
engine = create_engine(
|
||||
"sqlite://",
|
||||
connect_args={"check_same_thread": False},
|
||||
poolclass=StaticPool,
|
||||
)
|
||||
TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
Base.metadata.create_all(bind=engine)
|
||||
session = TestingSessionLocal()
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
session.close()
|
||||
Base.metadata.drop_all(bind=engine)
|
||||
engine.dispose()
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def app(db_session: Session) -> FastAPI:
|
||||
test_app = FastAPI()
|
||||
|
||||
def override_get_db() -> Iterator[Session]:
|
||||
yield db_session
|
||||
|
||||
test_app.dependency_overrides[get_db] = override_get_db
|
||||
test_app.include_router(auth.router)
|
||||
test_app.include_router(projects.router)
|
||||
test_app.include_router(templates.router)
|
||||
test_app.include_router(media.router)
|
||||
test_app.include_router(ai.router)
|
||||
test_app.include_router(export.router)
|
||||
test_app.include_router(dashboard.router)
|
||||
test_app.include_router(tasks.router)
|
||||
|
||||
@test_app.get("/health")
|
||||
def health_check() -> dict[str, str]:
|
||||
return {"status": "ok", "service": "SegServer"}
|
||||
|
||||
test_app.add_api_websocket_route("/ws/progress", websocket_progress)
|
||||
|
||||
return test_app
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def client(app: FastAPI) -> Iterator[TestClient]:
|
||||
with TestClient(app) as test_client:
|
||||
yield test_client
|
||||
248
backend/tests/test_ai.py
Normal file
248
backend/tests/test_ai.py
Normal file
@@ -0,0 +1,248 @@
|
||||
import numpy as np
|
||||
|
||||
|
||||
def _create_project_and_frame(client):
|
||||
project = client.post("/api/projects", json={"name": "AI Project"}).json()
|
||||
frame = client.post(f"/api/projects/{project['id']}/frames", json={
|
||||
"project_id": project["id"],
|
||||
"frame_index": 0,
|
||||
"image_url": "frames/0.jpg",
|
||||
"width": 640,
|
||||
"height": 360,
|
||||
}).json()
|
||||
template = client.post("/api/templates", json={
|
||||
"name": "Template",
|
||||
"color": "#06b6d4",
|
||||
"z_index": 0,
|
||||
"classes": [],
|
||||
"rules": [],
|
||||
}).json()
|
||||
return project, frame, template
|
||||
|
||||
|
||||
def test_predict_accepts_point_object_with_labels(client, monkeypatch):
|
||||
_, frame, _ = _create_project_and_frame(client)
|
||||
calls = {}
|
||||
|
||||
monkeypatch.setattr("routers.ai._load_frame_image", lambda frame: np.zeros((10, 10, 3), dtype=np.uint8))
|
||||
|
||||
def fake_predict_points(image, points, labels):
|
||||
calls["args"] = (points, labels)
|
||||
return (
|
||||
[[[0.1, 0.1], [0.9, 0.1], [0.9, 0.9]]],
|
||||
[0.95],
|
||||
)
|
||||
|
||||
monkeypatch.setattr("routers.ai.sam_registry.predict_points", lambda model, image, points, labels: fake_predict_points(image, points, labels))
|
||||
|
||||
response = client.post("/api/ai/predict", json={
|
||||
"image_id": frame["id"],
|
||||
"prompt_type": "point",
|
||||
"prompt_data": {"points": [[0.5, 0.5], [0.1, 0.1]], "labels": [1, 0]},
|
||||
})
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["scores"] == [0.95]
|
||||
assert calls["args"] == ([[0.5, 0.5], [0.1, 0.1]], [1, 0])
|
||||
|
||||
|
||||
def test_predict_box_and_semantic_fallback(client, monkeypatch):
|
||||
_, frame, _ = _create_project_and_frame(client)
|
||||
monkeypatch.setattr("routers.ai._load_frame_image", lambda frame: np.zeros((10, 10, 3), dtype=np.uint8))
|
||||
monkeypatch.setattr("routers.ai.sam_registry.predict_box", lambda model, image, box: (
|
||||
[[[0.2, 0.2], [0.8, 0.2], [0.8, 0.8]]],
|
||||
[0.8],
|
||||
))
|
||||
monkeypatch.setattr("routers.ai.sam_registry.predict_semantic", lambda model, image, text: (
|
||||
[[[0.0, 0.0], [1.0, 0.0], [1.0, 1.0]]],
|
||||
[0.5],
|
||||
))
|
||||
|
||||
box_response = client.post("/api/ai/predict", json={
|
||||
"image_id": frame["id"],
|
||||
"prompt_type": "box",
|
||||
"prompt_data": [0.2, 0.2, 0.8, 0.8],
|
||||
})
|
||||
semantic_response = client.post("/api/ai/predict", json={
|
||||
"image_id": frame["id"],
|
||||
"prompt_type": "semantic",
|
||||
"prompt_data": "胆囊",
|
||||
})
|
||||
|
||||
assert box_response.status_code == 200
|
||||
assert box_response.json()["scores"] == [0.8]
|
||||
assert semantic_response.status_code == 200
|
||||
assert semantic_response.json()["scores"] == [0.5]
|
||||
|
||||
|
||||
def test_model_status_reports_runtime(client, monkeypatch):
|
||||
monkeypatch.setattr("routers.ai.sam_registry.runtime_status", lambda selected_model=None: {
|
||||
"selected_model": selected_model or "sam2",
|
||||
"gpu": {
|
||||
"available": False,
|
||||
"device": "cpu",
|
||||
"name": None,
|
||||
"torch_available": True,
|
||||
"torch_version": "2.x",
|
||||
"cuda_version": None,
|
||||
},
|
||||
"models": [
|
||||
{
|
||||
"id": "sam2",
|
||||
"label": "SAM 2",
|
||||
"available": True,
|
||||
"loaded": False,
|
||||
"device": "cpu",
|
||||
"supports": ["point", "box", "auto"],
|
||||
"message": "ready",
|
||||
"package_available": True,
|
||||
"checkpoint_exists": True,
|
||||
"checkpoint_path": "model.pt",
|
||||
"python_ok": True,
|
||||
"torch_ok": True,
|
||||
"cuda_required": False,
|
||||
},
|
||||
{
|
||||
"id": "sam3",
|
||||
"label": "SAM 3",
|
||||
"available": False,
|
||||
"loaded": False,
|
||||
"device": "unavailable",
|
||||
"supports": ["semantic"],
|
||||
"message": "missing Python 3.12+ runtime",
|
||||
"package_available": False,
|
||||
"checkpoint_exists": False,
|
||||
"checkpoint_path": None,
|
||||
"python_ok": False,
|
||||
"torch_ok": True,
|
||||
"cuda_required": True,
|
||||
},
|
||||
],
|
||||
})
|
||||
|
||||
response = client.get("/api/ai/models/status?selected_model=sam3")
|
||||
|
||||
assert response.status_code == 200
|
||||
body = response.json()
|
||||
assert body["selected_model"] == "sam3"
|
||||
assert body["models"][1]["id"] == "sam3"
|
||||
assert body["models"][1]["available"] is False
|
||||
|
||||
|
||||
def test_predict_validation_errors(client, monkeypatch):
|
||||
project, _, _ = _create_project_and_frame(client)
|
||||
|
||||
assert client.post("/api/ai/predict", json={
|
||||
"image_id": 999,
|
||||
"prompt_type": "point",
|
||||
"prompt_data": [[0.5, 0.5]],
|
||||
}).status_code == 404
|
||||
|
||||
frame = client.post(f"/api/projects/{project['id']}/frames", json={
|
||||
"project_id": project["id"],
|
||||
"frame_index": 1,
|
||||
"image_url": "frames/1.jpg",
|
||||
}).json()
|
||||
monkeypatch.setattr("routers.ai._load_frame_image", lambda frame: np.zeros((10, 10, 3), dtype=np.uint8))
|
||||
assert client.post("/api/ai/predict", json={
|
||||
"image_id": frame["id"],
|
||||
"prompt_type": "box",
|
||||
"prompt_data": [0.1, 0.2],
|
||||
}).status_code == 400
|
||||
|
||||
|
||||
def test_save_annotation_validates_project_and_frame(client):
|
||||
project, frame, template = _create_project_and_frame(client)
|
||||
|
||||
saved = client.post("/api/ai/annotate", json={
|
||||
"project_id": project["id"],
|
||||
"frame_id": frame["id"],
|
||||
"template_id": template["id"],
|
||||
"mask_data": {"polygons": [[[0.1, 0.1], [0.9, 0.1], [0.9, 0.9]]]},
|
||||
"points": [[0.5, 0.5]],
|
||||
"bbox": [0.1, 0.1, 0.8, 0.8],
|
||||
})
|
||||
assert saved.status_code == 201
|
||||
assert saved.json()["project_id"] == project["id"]
|
||||
|
||||
listing = client.get(f"/api/ai/annotations?project_id={project['id']}")
|
||||
assert listing.status_code == 200
|
||||
assert listing.json()[0]["id"] == saved.json()["id"]
|
||||
|
||||
frame_listing = client.get(f"/api/ai/annotations?project_id={project['id']}&frame_id={frame['id']}")
|
||||
assert frame_listing.status_code == 200
|
||||
assert len(frame_listing.json()) == 1
|
||||
|
||||
missing_project = client.post("/api/ai/annotate", json={"project_id": 999})
|
||||
assert missing_project.status_code == 404
|
||||
|
||||
missing_frame = client.post("/api/ai/annotate", json={
|
||||
"project_id": project["id"],
|
||||
"frame_id": 999,
|
||||
})
|
||||
assert missing_frame.status_code == 404
|
||||
|
||||
missing_project_list = client.get("/api/ai/annotations?project_id=999")
|
||||
assert missing_project_list.status_code == 404
|
||||
|
||||
|
||||
def test_update_and_delete_annotation(client):
|
||||
project, frame, template = _create_project_and_frame(client)
|
||||
saved = client.post("/api/ai/annotate", json={
|
||||
"project_id": project["id"],
|
||||
"frame_id": frame["id"],
|
||||
"template_id": template["id"],
|
||||
"mask_data": {
|
||||
"polygons": [[[0.1, 0.1], [0.9, 0.1], [0.9, 0.9]]],
|
||||
"label": "AI Mask",
|
||||
"color": "#06b6d4",
|
||||
},
|
||||
"points": [[0.5, 0.5]],
|
||||
"bbox": [0.1, 0.1, 0.8, 0.8],
|
||||
}).json()
|
||||
|
||||
updated = client.patch(f"/api/ai/annotations/{saved['id']}", json={
|
||||
"template_id": template["id"],
|
||||
"mask_data": {
|
||||
"polygons": [[[0.2, 0.2], [0.8, 0.2], [0.8, 0.8]]],
|
||||
"label": "胆囊",
|
||||
"color": "#ff0000",
|
||||
"class": {"id": "c1", "name": "胆囊", "color": "#ff0000", "zIndex": 20},
|
||||
},
|
||||
"points": [[0.4, 0.4]],
|
||||
"bbox": [0.2, 0.2, 0.6, 0.6],
|
||||
})
|
||||
|
||||
assert updated.status_code == 200
|
||||
body = updated.json()
|
||||
assert body["mask_data"]["label"] == "胆囊"
|
||||
assert body["mask_data"]["class"]["id"] == "c1"
|
||||
assert body["points"] == [[0.4, 0.4]]
|
||||
assert body["bbox"] == [0.2, 0.2, 0.6, 0.6]
|
||||
|
||||
listing = client.get(f"/api/ai/annotations?project_id={project['id']}")
|
||||
assert listing.status_code == 200
|
||||
assert listing.json()[0]["mask_data"]["class"]["name"] == "胆囊"
|
||||
|
||||
deleted = client.delete(f"/api/ai/annotations/{saved['id']}")
|
||||
assert deleted.status_code == 204
|
||||
|
||||
empty_listing = client.get(f"/api/ai/annotations?project_id={project['id']}")
|
||||
assert empty_listing.status_code == 200
|
||||
assert empty_listing.json() == []
|
||||
|
||||
|
||||
def test_update_and_delete_annotation_validation(client):
|
||||
project, frame, template = _create_project_and_frame(client)
|
||||
saved = client.post("/api/ai/annotate", json={
|
||||
"project_id": project["id"],
|
||||
"frame_id": frame["id"],
|
||||
"template_id": template["id"],
|
||||
}).json()
|
||||
|
||||
assert client.patch("/api/ai/annotations/999", json={"bbox": [0, 0, 1, 1]}).status_code == 404
|
||||
assert client.delete("/api/ai/annotations/999").status_code == 404
|
||||
assert client.patch(
|
||||
f"/api/ai/annotations/{saved['id']}",
|
||||
json={"template_id": 999},
|
||||
).status_code == 404
|
||||
15
backend/tests/test_auth.py
Normal file
15
backend/tests/test_auth.py
Normal file
@@ -0,0 +1,15 @@
|
||||
def test_login_success(client):
|
||||
response = client.post("/api/auth/login", json={"username": "admin", "password": "123456"})
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {
|
||||
"token": "fake-jwt-token-for-admin",
|
||||
"username": "admin",
|
||||
}
|
||||
|
||||
|
||||
def test_login_rejects_invalid_credentials(client):
|
||||
response = client.post("/api/auth/login", json={"username": "admin", "password": "wrong"})
|
||||
|
||||
assert response.status_code == 401
|
||||
assert response.json()["detail"] == "Invalid credentials"
|
||||
69
backend/tests/test_dashboard.py
Normal file
69
backend/tests/test_dashboard.py
Normal file
@@ -0,0 +1,69 @@
|
||||
def test_dashboard_overview_uses_persisted_records(client, db_session):
|
||||
from models import ProcessingTask
|
||||
|
||||
project_pending = client.post("/api/projects", json={
|
||||
"name": "Pending Project",
|
||||
"status": "pending",
|
||||
}).json()
|
||||
project_ready = client.post("/api/projects", json={
|
||||
"name": "Ready Project",
|
||||
"status": "ready",
|
||||
}).json()
|
||||
frame = client.post(f"/api/projects/{project_pending['id']}/frames", json={
|
||||
"project_id": project_pending["id"],
|
||||
"frame_index": 0,
|
||||
"image_url": "frames/0.jpg",
|
||||
"width": 640,
|
||||
"height": 360,
|
||||
}).json()
|
||||
template = client.post("/api/templates", json={
|
||||
"name": "Dashboard Template",
|
||||
"color": "#06b6d4",
|
||||
"z_index": 0,
|
||||
"classes": [],
|
||||
"rules": [],
|
||||
}).json()
|
||||
annotation = client.post("/api/ai/annotate", json={
|
||||
"project_id": project_pending["id"],
|
||||
"frame_id": frame["id"],
|
||||
"template_id": template["id"],
|
||||
"mask_data": {"polygons": [[[0.1, 0.1], [0.9, 0.1], [0.9, 0.9]]]},
|
||||
})
|
||||
assert annotation.status_code == 201
|
||||
task = ProcessingTask(
|
||||
task_type="parse_video",
|
||||
status="running",
|
||||
progress=35,
|
||||
message="正在使用 FFmpeg/OpenCV 拆帧",
|
||||
project_id=project_pending["id"],
|
||||
payload={"source_type": "video"},
|
||||
)
|
||||
db_session.add(task)
|
||||
db_session.commit()
|
||||
db_session.refresh(task)
|
||||
|
||||
response = client.get("/api/dashboard/overview")
|
||||
|
||||
assert response.status_code == 200
|
||||
body = response.json()
|
||||
assert body["summary"]["project_count"] == 2
|
||||
assert body["summary"]["frame_count"] == 1
|
||||
assert body["summary"]["annotation_count"] == 1
|
||||
assert body["summary"]["template_count"] == 1
|
||||
assert body["summary"]["parsing_task_count"] == 1
|
||||
assert body["tasks"] == [
|
||||
{
|
||||
"id": f"task-{task.id}",
|
||||
"task_id": task.id,
|
||||
"project_id": project_pending["id"],
|
||||
"name": "Pending Project",
|
||||
"progress": 35,
|
||||
"status": "正在使用 FFmpeg/OpenCV 拆帧",
|
||||
"frame_count": 0,
|
||||
"updated_at": body["tasks"][0]["updated_at"],
|
||||
},
|
||||
]
|
||||
assert any(item["kind"] == "task" for item in body["activity"])
|
||||
assert any(item["kind"] == "annotation" for item in body["activity"])
|
||||
assert any(item["kind"] == "template" for item in body["activity"])
|
||||
assert all(item["name"] != "Ready Project" for item in body["tasks"])
|
||||
66
backend/tests/test_export.py
Normal file
66
backend/tests/test_export.py
Normal file
@@ -0,0 +1,66 @@
|
||||
import zipfile
|
||||
from io import BytesIO
|
||||
|
||||
|
||||
def _seed_export_data(client):
|
||||
project = client.post("/api/projects", json={"name": "Export Project"}).json()
|
||||
frame = client.post(f"/api/projects/{project['id']}/frames", json={
|
||||
"project_id": project["id"],
|
||||
"frame_index": 0,
|
||||
"image_url": "frames/0.jpg",
|
||||
"width": 100,
|
||||
"height": 50,
|
||||
}).json()
|
||||
template = client.post("/api/templates", json={
|
||||
"name": "Category",
|
||||
"color": "#06b6d4",
|
||||
"z_index": 0,
|
||||
"classes": [],
|
||||
"rules": [],
|
||||
}).json()
|
||||
annotation = client.post("/api/ai/annotate", json={
|
||||
"project_id": project["id"],
|
||||
"frame_id": frame["id"],
|
||||
"template_id": template["id"],
|
||||
"mask_data": {"polygons": [[[0.1, 0.2], [0.9, 0.2], [0.9, 0.8], [0.1, 0.8]]]},
|
||||
"points": [[0.5, 0.5]],
|
||||
"bbox": [0.1, 0.2, 0.8, 0.6],
|
||||
}).json()
|
||||
return project, frame, template, annotation
|
||||
|
||||
|
||||
def test_export_coco_json_structure(client):
|
||||
project, frame, _, _ = _seed_export_data(client)
|
||||
|
||||
response = client.get(f"/api/export/{project['id']}/coco")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.headers["content-type"].startswith("application/json")
|
||||
data = response.json()
|
||||
assert data["info"]["description"] == "Annotations for Export Project"
|
||||
assert data["images"][0] == {
|
||||
"id": frame["id"],
|
||||
"file_name": "frames/0.jpg",
|
||||
"width": 100,
|
||||
"height": 50,
|
||||
"frame_index": 0,
|
||||
}
|
||||
assert data["annotations"][0]["segmentation"] == [[10.0, 10.0, 90.0, 10.0, 90.0, 40.0, 10.0, 40.0]]
|
||||
assert data["annotations"][0]["bbox"] == [10.0, 10.0, 80.0, 30.000000000000004]
|
||||
assert data["categories"][0]["name"] == "Category"
|
||||
|
||||
|
||||
def test_export_masks_zip(client):
|
||||
project, _, _, annotation = _seed_export_data(client)
|
||||
|
||||
response = client.get(f"/api/export/{project['id']}/masks")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.headers["content-type"].startswith("application/zip")
|
||||
with zipfile.ZipFile(BytesIO(response.content)) as archive:
|
||||
assert archive.namelist() == [f"mask_{annotation['id']:06d}.png"]
|
||||
|
||||
|
||||
def test_export_missing_project_returns_404(client):
|
||||
assert client.get("/api/export/999/coco").status_code == 404
|
||||
assert client.get("/api/export/999/masks").status_code == 404
|
||||
15
backend/tests/test_main.py
Normal file
15
backend/tests/test_main.py
Normal file
@@ -0,0 +1,15 @@
|
||||
def test_health_endpoint(client):
|
||||
response = client.get("/health")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"status": "ok", "service": "SegServer"}
|
||||
|
||||
|
||||
def test_websocket_progress_heartbeat(client):
|
||||
with client.websocket_connect("/ws/progress") as websocket:
|
||||
websocket.send_text("ping")
|
||||
data = websocket.receive_json()
|
||||
|
||||
assert data["type"] == "status"
|
||||
assert data["status"] == "connected"
|
||||
assert data["message"] == "Progress stream active"
|
||||
142
backend/tests/test_media.py
Normal file
142
backend/tests/test_media.py
Normal file
@@ -0,0 +1,142 @@
|
||||
def test_upload_rejects_unsupported_file_type(client):
|
||||
response = client.post(
|
||||
"/api/media/upload",
|
||||
files={"file": ("notes.txt", b"text", "text/plain")},
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert "Unsupported file type" in response.json()["detail"]
|
||||
|
||||
|
||||
def test_upload_auto_creates_project(client, monkeypatch):
|
||||
uploaded = []
|
||||
monkeypatch.setattr("routers.media.upload_file", lambda object_name, data, content_type, length: uploaded.append(object_name))
|
||||
monkeypatch.setattr("routers.media.get_presigned_url", lambda object_name, expires=3600: f"http://storage/{object_name}")
|
||||
|
||||
response = client.post(
|
||||
"/api/media/upload",
|
||||
files={"file": ("clip.mp4", b"video", "video/mp4")},
|
||||
)
|
||||
|
||||
assert response.status_code == 201
|
||||
data = response.json()
|
||||
assert data["project_id"] is not None
|
||||
assert data["object_name"] == f"uploads/{data['project_id']}/clip.mp4"
|
||||
assert uploaded == ["uploads/general/clip.mp4", f"uploads/{data['project_id']}/clip.mp4"]
|
||||
|
||||
|
||||
def test_upload_links_existing_project(client, monkeypatch):
|
||||
project = client.post("/api/projects", json={"name": "Existing"}).json()
|
||||
monkeypatch.setattr("routers.media.upload_file", lambda *args, **kwargs: None)
|
||||
monkeypatch.setattr("routers.media.get_presigned_url", lambda object_name, expires=3600: f"http://storage/{object_name}")
|
||||
|
||||
response = client.post(
|
||||
"/api/media/upload",
|
||||
data={"project_id": str(project["id"])},
|
||||
files={"file": ("clip.mp4", b"video", "video/mp4")},
|
||||
)
|
||||
|
||||
assert response.status_code == 201
|
||||
detail = client.get(f"/api/projects/{project['id']}").json()
|
||||
assert detail["video_path"] == f"uploads/{project['id']}/clip.mp4"
|
||||
|
||||
|
||||
def test_upload_dicom_batch_filters_files_and_creates_project(client, monkeypatch):
|
||||
uploaded = []
|
||||
monkeypatch.setattr("routers.media.upload_file", lambda object_name, data, content_type, length: uploaded.append(object_name))
|
||||
|
||||
response = client.post(
|
||||
"/api/media/upload/dicom",
|
||||
files=[
|
||||
("files", ("a.dcm", b"dcm", "application/dicom")),
|
||||
("files", ("skip.txt", b"text", "text/plain")),
|
||||
],
|
||||
)
|
||||
|
||||
assert response.status_code == 201
|
||||
data = response.json()
|
||||
assert data["uploaded_count"] == 1
|
||||
assert uploaded == [f"uploads/{data['project_id']}/dicom/a.dcm"]
|
||||
|
||||
|
||||
def test_parse_media_queues_background_task(client, monkeypatch):
|
||||
project = client.post("/api/projects", json={
|
||||
"name": "Parse Me",
|
||||
"video_path": "uploads/1/clip.mp4",
|
||||
"source_type": "video",
|
||||
"parse_fps": 5,
|
||||
}).json()
|
||||
|
||||
class FakeAsyncResult:
|
||||
id = "celery-1"
|
||||
|
||||
queued = []
|
||||
monkeypatch.setattr("routers.media.parse_project_media.delay", lambda task_id: queued.append(task_id) or FakeAsyncResult())
|
||||
published = []
|
||||
monkeypatch.setattr("routers.media.publish_task_progress_event", lambda task: published.append(task.id))
|
||||
|
||||
response = client.post(f"/api/media/parse?project_id={project['id']}")
|
||||
|
||||
assert response.status_code == 202
|
||||
data = response.json()
|
||||
assert data["task_type"] == "parse_video"
|
||||
assert data["status"] == "queued"
|
||||
assert data["progress"] == 0
|
||||
assert data["project_id"] == project["id"]
|
||||
assert data["celery_task_id"] == "celery-1"
|
||||
assert queued == [data["id"]]
|
||||
assert published == [data["id"]]
|
||||
|
||||
detail = client.get(f"/api/tasks/{data['id']}")
|
||||
assert detail.status_code == 200
|
||||
assert detail.json()["status"] == "queued"
|
||||
project_detail = client.get(f"/api/projects/{project['id']}").json()
|
||||
assert project_detail["status"] == "parsing"
|
||||
|
||||
|
||||
def test_parse_task_runner_registers_frames(client, db_session, monkeypatch, tmp_path):
|
||||
from models import ProcessingTask
|
||||
from services.media_task_runner import run_parse_media_task
|
||||
|
||||
project = client.post("/api/projects", json={
|
||||
"name": "Parse Me",
|
||||
"video_path": "uploads/1/clip.mp4",
|
||||
"source_type": "video",
|
||||
"parse_fps": 5,
|
||||
}).json()
|
||||
task = ProcessingTask(
|
||||
task_type="parse_video",
|
||||
status="queued",
|
||||
progress=0,
|
||||
project_id=project["id"],
|
||||
payload={"source_type": "video"},
|
||||
)
|
||||
db_session.add(task)
|
||||
db_session.commit()
|
||||
db_session.refresh(task)
|
||||
frame_file = tmp_path / "frame_000001.jpg"
|
||||
frame_file.write_bytes(b"fake image")
|
||||
|
||||
monkeypatch.setattr("services.media_task_runner.download_file", lambda object_name: b"video")
|
||||
monkeypatch.setattr("services.media_task_runner.parse_video", lambda local_path, output_dir, fps: ([str(frame_file)], 25.0))
|
||||
monkeypatch.setattr("services.media_task_runner.extract_thumbnail", lambda local_path, thumbnail_path: open(thumbnail_path, "wb").write(b"thumb"))
|
||||
monkeypatch.setattr("services.media_task_runner.upload_file", lambda *args, **kwargs: None)
|
||||
monkeypatch.setattr("services.media_task_runner.upload_frames_to_minio", lambda frame_files, project_id: [f"projects/{project_id}/frames/frame_000001.jpg"])
|
||||
published = []
|
||||
monkeypatch.setattr(
|
||||
"services.media_task_runner.publish_task_progress_event",
|
||||
lambda event_task: published.append((event_task.status, event_task.progress, event_task.message)),
|
||||
)
|
||||
|
||||
result = run_parse_media_task(db_session, task.id)
|
||||
|
||||
assert result["frames_extracted"] == 1
|
||||
db_session.refresh(task)
|
||||
assert task.status == "success"
|
||||
assert task.progress == 100
|
||||
assert ("running", 5, "后台解析已启动") in published
|
||||
assert ("success", 100, "解析完成") in published
|
||||
project_detail = client.get(f"/api/projects/{project['id']}").json()
|
||||
assert project_detail["status"] == "ready"
|
||||
frames = client.get(f"/api/projects/{project['id']}/frames").json()
|
||||
assert "frame_000001.jpg" in frames[0]["image_url"]
|
||||
42
backend/tests/test_progress_events.py
Normal file
42
backend/tests/test_progress_events.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from types import SimpleNamespace
|
||||
|
||||
from progress_events import PROGRESS_CHANNEL, publish_progress_event, task_progress_payload
|
||||
|
||||
|
||||
def test_task_progress_payload_uses_dashboard_task_id_and_project_name():
|
||||
task = SimpleNamespace(
|
||||
id=12,
|
||||
project_id=7,
|
||||
project=SimpleNamespace(name="demo.mp4"),
|
||||
status="success",
|
||||
progress=100,
|
||||
message="解析完成",
|
||||
error=None,
|
||||
updated_at=None,
|
||||
)
|
||||
|
||||
payload = task_progress_payload(task)
|
||||
|
||||
assert payload["type"] == "complete"
|
||||
assert payload["taskId"] == "task-12"
|
||||
assert payload["task_id"] == 12
|
||||
assert payload["project_id"] == 7
|
||||
assert payload["filename"] == "demo.mp4"
|
||||
assert payload["projectName"] == "demo.mp4"
|
||||
assert payload["status"] == "解析完成"
|
||||
|
||||
|
||||
def test_publish_progress_event_writes_json_to_redis(monkeypatch):
|
||||
calls = []
|
||||
|
||||
class FakeRedis:
|
||||
def publish(self, channel, payload):
|
||||
calls.append((channel, payload))
|
||||
|
||||
monkeypatch.setattr("progress_events.get_redis_client", lambda: FakeRedis())
|
||||
|
||||
publish_progress_event({"type": "progress", "message": "正在下载媒体文件"})
|
||||
|
||||
assert calls
|
||||
assert calls[0][0] == PROGRESS_CHANNEL
|
||||
assert "正在下载媒体文件" in calls[0][1]
|
||||
56
backend/tests/test_projects.py
Normal file
56
backend/tests/test_projects.py
Normal file
@@ -0,0 +1,56 @@
|
||||
def test_project_crud_and_frames(client, monkeypatch):
|
||||
monkeypatch.setattr("routers.projects.get_presigned_url", lambda key, expires=3600: f"http://storage/{key}")
|
||||
|
||||
created = client.post("/api/projects", json={
|
||||
"name": "Demo",
|
||||
"description": "desc",
|
||||
"thumbnail_url": "thumb.jpg",
|
||||
"parse_fps": 12,
|
||||
})
|
||||
assert created.status_code == 201
|
||||
project_id = created.json()["id"]
|
||||
|
||||
frame = client.post(f"/api/projects/{project_id}/frames", json={
|
||||
"project_id": project_id,
|
||||
"frame_index": 0,
|
||||
"image_url": "frames/0.jpg",
|
||||
"width": 640,
|
||||
"height": 360,
|
||||
})
|
||||
assert frame.status_code == 201
|
||||
frame_id = frame.json()["id"]
|
||||
|
||||
listing = client.get("/api/projects")
|
||||
assert listing.status_code == 200
|
||||
assert listing.json()[0]["frame_count"] == 1
|
||||
assert listing.json()[0]["thumbnail_url"] == "http://storage/thumb.jpg"
|
||||
|
||||
frames = client.get(f"/api/projects/{project_id}/frames")
|
||||
assert frames.status_code == 200
|
||||
assert frames.json()[0]["image_url"] == "http://storage/frames/0.jpg"
|
||||
|
||||
single_frame = client.get(f"/api/projects/{project_id}/frames/{frame_id}")
|
||||
assert single_frame.status_code == 200
|
||||
assert single_frame.json()["frame_index"] == 0
|
||||
|
||||
updated = client.patch(f"/api/projects/{project_id}", json={"name": "Renamed", "status": "ready"})
|
||||
assert updated.status_code == 200
|
||||
assert updated.json()["name"] == "Renamed"
|
||||
assert updated.json()["status"] == "ready"
|
||||
|
||||
deleted = client.delete(f"/api/projects/{project_id}")
|
||||
assert deleted.status_code == 204
|
||||
assert client.get(f"/api/projects/{project_id}").status_code == 404
|
||||
|
||||
|
||||
def test_project_and_frame_404s(client):
|
||||
assert client.get("/api/projects/999").status_code == 404
|
||||
assert client.patch("/api/projects/999", json={"name": "x"}).status_code == 404
|
||||
assert client.delete("/api/projects/999").status_code == 404
|
||||
assert client.post("/api/projects/999/frames", json={
|
||||
"project_id": 999,
|
||||
"frame_index": 0,
|
||||
"image_url": "missing.jpg",
|
||||
}).status_code == 404
|
||||
assert client.get("/api/projects/999/frames").status_code == 404
|
||||
assert client.get("/api/projects/999/frames/1").status_code == 404
|
||||
39
backend/tests/test_templates.py
Normal file
39
backend/tests/test_templates.py
Normal file
@@ -0,0 +1,39 @@
|
||||
def test_template_crud_packs_and_unpacks_mapping_rules(client):
|
||||
payload = {
|
||||
"name": "Template",
|
||||
"color": "#06b6d4",
|
||||
"z_index": 0,
|
||||
"classes": [{"id": "c1", "name": "胆囊", "color": "#ff0000", "zIndex": 10}],
|
||||
"rules": [{"id": "r1", "name": "rule"}],
|
||||
}
|
||||
|
||||
created = client.post("/api/templates", json=payload)
|
||||
assert created.status_code == 201
|
||||
template_id = created.json()["id"]
|
||||
assert created.json()["classes"][0]["name"] == "胆囊"
|
||||
assert created.json()["rules"][0]["id"] == "r1"
|
||||
|
||||
listing = client.get("/api/templates")
|
||||
assert listing.status_code == 200
|
||||
assert listing.json()[0]["classes"][0]["name"] == "胆囊"
|
||||
|
||||
detail = client.get(f"/api/templates/{template_id}")
|
||||
assert detail.status_code == 200
|
||||
assert detail.json()["name"] == "Template"
|
||||
|
||||
updated = client.patch(f"/api/templates/{template_id}", json={
|
||||
"classes": [{"id": "c2", "name": "肝脏", "color": "#00ff00", "zIndex": 20}],
|
||||
"rules": [],
|
||||
})
|
||||
assert updated.status_code == 200
|
||||
assert updated.json()["classes"][0]["name"] == "肝脏"
|
||||
|
||||
deleted = client.delete(f"/api/templates/{template_id}")
|
||||
assert deleted.status_code == 204
|
||||
assert client.get(f"/api/templates/{template_id}").status_code == 404
|
||||
|
||||
|
||||
def test_template_404s(client):
|
||||
assert client.get("/api/templates/999").status_code == 404
|
||||
assert client.patch("/api/templates/999", json={"name": "x"}).status_code == 404
|
||||
assert client.delete("/api/templates/999").status_code == 404
|
||||
22
backend/worker_tasks.py
Normal file
22
backend/worker_tasks.py
Normal file
@@ -0,0 +1,22 @@
|
||||
"""Celery task definitions."""
|
||||
|
||||
import logging
|
||||
|
||||
from celery_app import celery_app
|
||||
from database import SessionLocal
|
||||
from services.media_task_runner import run_parse_media_task
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@celery_app.task(name="media.parse_project")
|
||||
def parse_project_media(task_id: int) -> dict:
|
||||
"""Run media parsing for one queued task."""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
return run_parse_media_task(db, task_id)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.exception("Parse media task failed: task_id=%s", task_id)
|
||||
raise exc
|
||||
finally:
|
||||
db.close()
|
||||
58
doc/01-purpose-and-word-summary.md
Normal file
58
doc/01-purpose-and-word-summary.md
Normal file
@@ -0,0 +1,58 @@
|
||||
# 目的与 Word 方案摘要
|
||||
|
||||
## 为什么要做这个系统
|
||||
|
||||
Word 文档《语义分割系统构建方案.docx》的核心目标是建设一个面向视频和连续帧的智能语义分割标注系统,解决传统标注工具在以下场景中的痛点:
|
||||
|
||||
- 视频或连续帧数量大,逐帧人工画 mask 成本高。
|
||||
- 高分辨率图像上同时存在底图、点、框、多边形和遮罩,DOM 渲染难以支撑重交互。
|
||||
- AI 分割需要低延迟点选/框选反馈,普通 REST 往返在密集交互场景下体验较差。
|
||||
- 语义分割要求一个像素只能归属一个类别,因此需要模板、颜色、z-index 和类别优先级来解决遮罩重叠。
|
||||
- 历史 GT mask 如果只是作为静态像素图层叠加,后续修改不灵活;Word 方案希望把 mask 降维成可编辑的点区域。
|
||||
|
||||
所以这个系统的业务目的不是单纯播放视频,而是把“视频/DICOM 数据接入、拆帧、AI 辅助分割、语义分类、标注导出”串成一个工作台。
|
||||
|
||||
## Word 中的目标架构
|
||||
|
||||
Word 方案描述的理想系统包含:
|
||||
|
||||
- React/Vue + Konva 的高性能 Canvas 工作台。
|
||||
- FastAPI 后端,使用 WebSocket 处理实时交互与任务进度。
|
||||
- Celery + Redis 处理视频拆帧等长任务。
|
||||
- FFmpeg/OpenCV 解析视频,pydicom 解析医学影像。
|
||||
- 本地 CUDA 上的 SAM 3 推理。
|
||||
- GT mask 导入后通过距离变换、骨架提取、聚类等算法降维为点区域。
|
||||
- 模板库管理分类、颜色和 z-index,用于语义分割遮罩重叠裁决。
|
||||
- PostgreSQL 存储项目、帧、模板和点区域数据。
|
||||
|
||||
## 当前代码已落地的部分
|
||||
|
||||
| 目标 | 当前代码状态 | 依据 |
|
||||
|------|--------------|------|
|
||||
| React 前端工作台 | 已落地 | `src/App.tsx`、`src/components/*.tsx` |
|
||||
| Konva Canvas | 已落地 | `CanvasArea.tsx`、`AISegmentation.tsx` 使用 `react-konva` |
|
||||
| FastAPI 后端 | 已落地 | `backend/main.py` |
|
||||
| PostgreSQL ORM | 已落地 | `backend/database.py`、`backend/models.py` |
|
||||
| MinIO 对象存储 | 已落地 | `backend/minio_client.py` |
|
||||
| Redis 连接 | 已落地 | 用于 Celery broker/result backend,并通过 `seg:progress` pub/sub 转发任务进度 |
|
||||
| 视频拆帧 | 已落地 | `backend/services/frame_parser.py`、`backend/routers/media.py` |
|
||||
| DICOM 批量导入 | 部分落地 | 上传和解析存在,项目级体验还需完善 |
|
||||
| WebSocket 进度 | 已落地 | 拆帧进度写入任务表后发布到 Redis `seg:progress`,FastAPI 广播到 `/ws/progress` |
|
||||
| SAM 推理 | 部分落地 | 后端已有 SAM 2 / SAM 3 选择和真实模型状态接口;SAM 3 依赖官方运行环境,当前环境不满足时会标为不可用 |
|
||||
| 模板库 | 部分落地 | 分类、颜色、z-index 能存储和编辑;重叠裁决算法未落地 |
|
||||
| 标注持久化 | 部分落地 | 后端有 `Annotation` 表,前端已接入新增、回显、分类更新和当前帧删除;逐点几何编辑未落地 |
|
||||
| COCO / Mask 导出 | 部分落地 | `backend/routers/export.py`;COCO JSON 前端按钮已接入,PNG mask ZIP 尚未提供前端按钮 |
|
||||
|
||||
## 当前代码尚未落地的目标
|
||||
|
||||
- SAM 3:当前已提供 `sam3_engine.py` 适配入口和状态检测;要实际运行仍需安装官方 `facebookresearch/sam3` 依赖并满足 Python 3.12+、PyTorch 2.7+、CUDA 12.6+。
|
||||
- Celery 异步任务队列:已注册 Celery app 和拆帧 worker task,`/api/media/parse` 会创建任务表记录并入队。
|
||||
- GT mask 导入:当前前端没有 GT Label 导入入口,后端也没有对应路由。
|
||||
- Mask 到点区域的拓扑降维:当前没有距离变换、骨架提取、HDBSCAN 等实现。
|
||||
- 类别优先级融合:模板有 z-index,但没有后端融合算法。
|
||||
- 撤销/重做:工具栏有按钮,但没有历史栈。
|
||||
- 结构化归档保存:工作区按钮已调用 `POST /api/ai/annotate` 保存当前未归档 mask,并通过 `PATCH /api/ai/annotations/{id}` 更新 dirty mask。
|
||||
|
||||
## 结论
|
||||
|
||||
当前项目已经从 UI 原型推进到“可上传、可异步拆帧、可实时查看任务进度、可浏览项目帧、可维护模板、可点/框 AI 推理、可保存标注、可导出 COCO、可查看 Dashboard 后端概览”的全栈雏形,但离 Word 中描述的完整智能标注系统还有明显差距。下一阶段最重要的是继续补齐手工绘制、撤销重做和真实语义文本分割。
|
||||
104
doc/02-current-implementation-map.md
Normal file
104
doc/02-current-implementation-map.md
Normal file
@@ -0,0 +1,104 @@
|
||||
# 当前实现地图
|
||||
|
||||
## 运行入口
|
||||
|
||||
### 前端入口
|
||||
|
||||
- React 挂载:`src/main.tsx`
|
||||
- 根组件:`src/App.tsx`
|
||||
- 前端服务:`server.ts`
|
||||
- 默认访问:`http://localhost:3000`
|
||||
|
||||
`server.ts` 的角色比较特殊:它既负责在开发模式下创建 Vite middleware,也在生产模式下服务 `dist/`。同时它还保留了旧版 mock API:`/api/login`、`/api/projects`、`/api/templates`。当前前端业务 API 主要不走这些 mock,而是走 `src/lib/api.ts` 指向的 FastAPI。
|
||||
|
||||
### 后端入口
|
||||
|
||||
- FastAPI 应用:`backend/main.py`
|
||||
- 默认访问:`http://localhost:8000`
|
||||
- API 文档:`http://localhost:8000/docs`
|
||||
- 健康检查:`GET /health`
|
||||
|
||||
后端启动时会通过 lifespan 执行:
|
||||
|
||||
- 创建数据库表。
|
||||
- 检查 MinIO bucket。
|
||||
- 测试 Redis。
|
||||
- Seed 默认模板。
|
||||
- 如果存在 `Data_MyVideo_1.mp4`,创建默认项目并拆前 100 帧。
|
||||
|
||||
## 前端模块切换
|
||||
|
||||
`App.tsx` 使用 Zustand 中的 `activeModule` 做模块切换,没有使用路由库。
|
||||
|
||||
| activeModule | 组件 | 页面 |
|
||||
|--------------|------|------|
|
||||
| `dashboard` | `Dashboard` | 系统概况 |
|
||||
| `projects` | `ProjectLibrary` | 项目库 |
|
||||
| `workspace` | `VideoWorkspace` | 分割工作区 |
|
||||
| `ai` | `AISegmentation` | AI 智能分割页 |
|
||||
| `templates` | `TemplateRegistry` | 模板库 |
|
||||
|
||||
未登录时,`App.tsx` 直接渲染 `Login`。
|
||||
|
||||
## 全局状态
|
||||
|
||||
全局状态在 `src/store/useStore.ts` 中,主要包括:
|
||||
|
||||
- 登录状态:`isAuthenticated`、`token`
|
||||
- 项目:`projects`、`currentProject`
|
||||
- 工作区:`activeModule`、`activeTool`、`frames`、`currentFrameIndex`
|
||||
- 标注与 mask:`annotations`、`masks`
|
||||
- 模板:`templates`、`activeTemplateId`
|
||||
- UI:`isLoading`、`error`
|
||||
|
||||
当前状态管理是前端内存状态,没有持久化到 localStorage,除了登录 token。
|
||||
|
||||
## 数据流
|
||||
|
||||
### 登录
|
||||
|
||||
1. `Login.tsx` 调用 `login()`。
|
||||
2. `src/lib/api.ts` 请求 `POST /api/auth/login`。
|
||||
3. FastAPI `backend/routers/auth.py` 校验 `admin / 123456`。
|
||||
4. 前端把返回 token 写入 localStorage。
|
||||
|
||||
### 项目与拆帧
|
||||
|
||||
1. `ProjectLibrary.tsx` 调用 `getProjects()` 获取项目。
|
||||
2. 上传视频时先 `createProject()`,再 `uploadMedia()`,再 `parseMedia()`。
|
||||
3. 后端 `media.py` 把原始文件上传到 MinIO。
|
||||
4. `parseMedia()` 创建 `processing_tasks` 记录并投递 Celery worker。
|
||||
5. Celery worker 下载 MinIO 文件,调用 `frame_parser.py` 拆帧。
|
||||
6. worker 把拆出的帧重新上传 MinIO,写入 `frames` 表,并更新任务状态。
|
||||
7. 工作区通过 `GET /api/tasks/{id}` 等待任务完成,再通过 `GET /api/projects/{id}/frames` 获取预签名图片 URL。
|
||||
|
||||
### 工作区浏览
|
||||
|
||||
1. `VideoWorkspace.tsx` 根据 `currentProject.id` 加载帧。
|
||||
2. `CanvasArea.tsx` 用当前帧 URL 加载底图。
|
||||
3. `FrameTimeline.tsx` 显示缩略图和当前帧索引。
|
||||
4. 播放按钮会推进 `currentFrameIndex`,从而更新画布底图。
|
||||
|
||||
### 模板管理
|
||||
|
||||
1. `TemplateRegistry.tsx` 调用模板 API。
|
||||
2. 后端 `templates.py` 把 `classes` 和 `rules` 打包进 `mapping_rules` JSON 字段。
|
||||
3. `OntologyInspector.tsx` 读取全局 `templates` 和 `activeTemplateId` 展示分类树。
|
||||
|
||||
## 后端数据模型
|
||||
|
||||
| 模型 | 表 | 用途 |
|
||||
|------|----|------|
|
||||
| `Project` | `projects` | 项目元数据,包含视频路径、缩略图、状态、fps |
|
||||
| `Frame` | `frames` | 拆帧后的图片记录 |
|
||||
| `Template` | `templates` | 模板、本体类别、颜色、z-index、mapping_rules |
|
||||
| `Annotation` | `annotations` | 标注数据、点、bbox、mask_data |
|
||||
| `Mask` | `masks` | mask 文件元数据 |
|
||||
|
||||
## 当前主要风险点
|
||||
|
||||
- 前端 API/WS 地址虽然已支持环境变量和 hostname 推导,但部署时仍需要确认浏览器可访问 `:8000` 后端。
|
||||
- AI 语义文本提示在选择 SAM 3 且运行环境满足官方依赖时走 SAM 3;当前环境若不满足会在模型状态中标明不可用。
|
||||
- 工作区顶部“导出 JSON 标注集”和“结构化归档保存”已接入导出、标注新增和 dirty 标注更新;清空当前帧遮罩会删除对应后端标注。撤销重做和手工绘制仍未持久化。
|
||||
- Dashboard 初始统计、队列和活动日志来自后端聚合接口;解析队列来自 `processing_tasks`,worker 进度通过 Redis `seg:progress` 转发到 WebSocket。
|
||||
- 后端路由大多未做真实鉴权。
|
||||
146
doc/03-frontend-element-audit.md
Normal file
146
doc/03-frontend-element-audit.md
Normal file
@@ -0,0 +1,146 @@
|
||||
# 前端逐元素审计
|
||||
|
||||
状态说明:
|
||||
|
||||
- 真实可用:接真实状态或后端接口,可以完成主要动作。
|
||||
- 部分可用:能展示或完成一部分,但存在关键缺口。
|
||||
- Mock / UI-only:只有展示或本地状态变化,没有真实业务效果。
|
||||
- 接口不通:前端调用与后端接口不一致,按当前代码大概率失败。
|
||||
|
||||
## App 与导航
|
||||
|
||||
| 元素 | 位置 | 状态 | 说明 |
|
||||
|------|------|------|------|
|
||||
| 登录拦截 | `App.tsx` | 真实可用 | 未登录显示 `Login`,登录后显示主界面 |
|
||||
| 模块切换 | `Sidebar.tsx` + `App.tsx` | 真实可用 | 切换 `dashboard/projects/workspace/ai/templates` |
|
||||
| Logo | `Sidebar.tsx` | 真实可用 | 使用 `/logo.png`,文件存在于 `public/logo.png` |
|
||||
| GPU 状态圆标 | `Sidebar.tsx` | 真实可用 | 通过 `GET /api/ai/models/status` 显示 GPU/CPU 和当前模型可用性 |
|
||||
|
||||
## 登录页
|
||||
|
||||
| 元素 | 状态 | 说明 |
|
||||
|------|------|------|
|
||||
| 用户名/密码输入 | 真实可用 | 默认填入 `admin / 123456` |
|
||||
| 安全登录按钮 | 真实可用 | 调用 `POST /api/auth/login` |
|
||||
| 错误提示 | 真实可用 | 捕获后端错误并显示 |
|
||||
| 安全审计说明文字 | Mock / UI-only | UI 文案,没有真实审计功能 |
|
||||
|
||||
## Dashboard 系统概况
|
||||
|
||||
| 元素 | 状态 | 说明 |
|
||||
|------|------|------|
|
||||
| WebSocket 连接状态 | 真实可用 | 前端通过 `src/lib/config.ts` 推导或读取 `VITE_WS_PROGRESS_URL`,后端有 `/ws/progress` |
|
||||
| 解析队列任务 | 真实可用 | 初始数据来自 `GET /api/dashboard/overview`,按 `processing_tasks` queued/running 任务生成 |
|
||||
| WebSocket 更新任务 | 真实可用 | Celery worker 更新 `processing_tasks` 后发布 Redis `seg:progress`,FastAPI 广播 progress/complete/error |
|
||||
| 项目、任务、标注、系统负载统计 | 真实可用 | 初始数据来自 `GET /api/dashboard/overview`,系统负载按主机 load average 估算 |
|
||||
| 近期实时流转记录 | 真实可用 | 初始数据来自任务、项目、标注和模板记录;WebSocket status/complete/error 会继续追加 |
|
||||
|
||||
## 项目库 ProjectLibrary
|
||||
|
||||
| 元素 | 状态 | 说明 |
|
||||
|------|------|------|
|
||||
| 项目列表 | 真实可用 | 调用 `GET /api/projects` |
|
||||
| 项目卡片缩略图 | 真实可用 | 后端返回 MinIO 预签名 `thumbnail_url` 时显示 |
|
||||
| 点击项目进入工作区 | 真实可用 | 设置 `currentProject` 后切到 `workspace` |
|
||||
| 新建项目 | 真实可用 | 调用 `POST /api/projects` |
|
||||
| 导入视频文件 | 真实可用 | 创建项目、上传文件、触发拆帧、刷新项目列表 |
|
||||
| 解析 FPS 滑块 | 真实可用 | 值传入 `createProject({ parse_fps })` |
|
||||
| 导入 DICOM 序列 | 部分可用 | 可上传 `.dcm` 并触发解析;体验和错误反馈较粗 |
|
||||
| 项目状态徽标 | 真实可用 | 项目状态统一为 `pending/parsing/ready/error`,前端兼容归一化旧状态值 |
|
||||
| 更多按钮 | Mock / UI-only | 有图标,没有菜单或事件 |
|
||||
| alert 成功/失败提示 | 真实可用但粗糙 | 使用浏览器 `alert` |
|
||||
|
||||
## 工作区 VideoWorkspace
|
||||
|
||||
| 元素 | 状态 | 说明 |
|
||||
|------|------|------|
|
||||
| 当前项目名 | 真实可用 | 读取 `currentProject.name` |
|
||||
| 自动加载项目帧 | 真实可用 | 调用 `GET /api/projects/{id}/frames` |
|
||||
| 无帧时触发解析 | 真实可用 | 如果 `video_path` 存在会调用 `parseMedia()` 创建异步任务,并轮询 `GET /api/tasks/{id}` 等待完成 |
|
||||
| SAM 模型状态徽标 | 真实可用 | 调用 `GET /api/ai/models/status`,显示当前选择的 SAM 2/SAM 3 是否可用 |
|
||||
| 已保存标注回显 | 真实可用 | 加载工作区帧后调用 `GET /api/ai/annotations` 并渲染已保存 mask |
|
||||
| “导出 JSON 标注集”按钮 | 真实可用 | 导出前会保存未归档 mask,然后调用 `exportCoco()` 下载 JSON |
|
||||
| “结构化归档保存”按钮 | 真实可用 | 未保存 mask 写入 `POST /api/ai/annotate`;dirty mask 写入 `PATCH /api/ai/annotations/{id}` |
|
||||
|
||||
## CanvasArea 画布
|
||||
|
||||
| 元素 | 状态 | 说明 |
|
||||
|------|------|------|
|
||||
| 当前帧底图显示 | 真实可用 | `useImage(frameUrl)` 加载当前帧 URL |
|
||||
| 滚轮缩放 | 真实可用 | 改变 Konva Stage scale |
|
||||
| 拖拽平移 | 真实可用 | activeTool 为 `move` 时 Stage draggable |
|
||||
| 光标坐标显示 | 真实可用 | 根据 pointer position 计算 |
|
||||
| 正向/反向选点 | 部分可用 | UI 能加点,并按当前帧 `frame.id` 调用 `/api/ai/predict`;需点击归档保存才持久化 |
|
||||
| 框选 | 部分可用 | UI 能画框,并把框坐标归一化后调用后端推理;需点击归档保存才持久化 |
|
||||
| AI 推理中提示 | 真实可用 | 请求期间会显示 |
|
||||
| Mask 渲染 | 部分可用 | 前端会把推理/已保存标注转成 Konva `pathData` 渲染 |
|
||||
| 应用分类 | 真实可用 | 将当前选择的模板分类应用到本帧 mask;已保存 mask 会标为 dirty,归档保存时更新后端 |
|
||||
| 清空遮罩 | 真实可用 | 工作区中会删除当前帧已保存标注并清空当前帧本地 mask |
|
||||
| 保存状态计数 | 真实可用 | 底部显示已保存、未保存、待更新数量 |
|
||||
| 当前图层树文字 | Mock / UI-only | 固定显示 `OBJECT_VEHICLE_01` |
|
||||
|
||||
## ToolsPalette 工具栏
|
||||
|
||||
| 元素 | 状态 | 说明 |
|
||||
|------|------|------|
|
||||
| 拖拽/选择 | 真实可用 | 控制 Canvas 是否可拖拽 |
|
||||
| 多边形/矩形/圆/点/线 | Mock / UI-only | 只切换 activeTool,没有对应绘制逻辑 |
|
||||
| 区域合并/去除 | Mock / UI-only | 只切换 activeTool,没有后端或前端算法 |
|
||||
| 正向选点/反向选点/框选 | 部分可用 | 会影响 Canvas 交互,并能触发已对齐的 AI 推理接口 |
|
||||
| 魔法棒 SAM 触发 | 部分可用 | 切到 AI 页面;不是直接执行推理 |
|
||||
| 撤销/重做 | Mock / UI-only | 按钮无事件 |
|
||||
|
||||
## FrameTimeline 时间轴
|
||||
|
||||
| 元素 | 状态 | 说明 |
|
||||
|------|------|------|
|
||||
| 帧缩略图 | 真实可用 | 使用 `frames[].url` |
|
||||
| 点击缩略图跳帧 | 真实可用 | 调用 `setCurrentFrame(idx)` |
|
||||
| 顶部 range 拖动 | 真实可用 | 改变当前帧 |
|
||||
| 播放/暂停 | 真实可用 | 当前代码按 `parse_fps/original_fps` 推进帧,最多 30fps |
|
||||
| 方向键切帧 | Mock / UI-only | Word 提到,但当前没有键盘监听 |
|
||||
|
||||
## OntologyInspector 本体面板
|
||||
|
||||
| 元素 | 状态 | 说明 |
|
||||
|------|------|------|
|
||||
| 模板选择 | 部分可用 | 读取全局 templates,可切换 activeTemplateId |
|
||||
| 分类树展示 | 真实可用 | 显示模板 classes 和本地 customClasses |
|
||||
| 添加自定义分类 | 部分可用 | 只存在组件本地状态,不保存到后端 |
|
||||
| 置信度条 | Mock / UI-only | 固定 `0.9412` |
|
||||
| 拓扑锚点数量 | Mock / UI-only | 固定 `12 节点` |
|
||||
| 重新提取骨架按钮 | Mock / UI-only | 无事件 |
|
||||
|
||||
## AISegmentation 独立 AI 页
|
||||
|
||||
| 元素 | 状态 | 说明 |
|
||||
|------|------|------|
|
||||
| 模型选择 SAM2/SAM3 | 真实可用 | 选择写入 Zustand,`predictMask()` 会把 `model` 传给后端 SAM registry |
|
||||
| 正向/反向点 | 部分可用 | 可在当前项目帧上加点,并可调用 AI 推理接口 |
|
||||
| 语义文本输入 | 部分可用 | 纯文本会以 `semantic` prompt 调用后端;选择 SAM 3 且运行环境满足官方依赖时走 SAM 3 文本语义推理,否则状态接口会标明不可用 |
|
||||
| 参数开关 | Mock / UI-only | `cropMode`、`autoDeleteBg` 只改本地状态 |
|
||||
| 执行高精度语义分割 | 部分可用 | 使用当前项目帧调用 `/api/ai/predict`;没有当前帧时按钮禁用 |
|
||||
| 上传替换底图 | Mock / UI-only | 按钮无事件 |
|
||||
| 清空全体锚点 | 部分可用 | 清空前端 points 和 masks |
|
||||
| 退档推送至工作区重组 | 部分可用 | 只切回工作区,共用 masks store,但没有保存/确认流程 |
|
||||
| 背景图 | 部分可用 | 优先显示当前项目帧;没有项目帧时仍回退到 Unsplash 演示图 |
|
||||
|
||||
## TemplateRegistry 模板库
|
||||
|
||||
| 元素 | 状态 | 说明 |
|
||||
|------|------|------|
|
||||
| 模板列表 | 真实可用 | 调用 `GET /api/templates` |
|
||||
| 新建方案 | 真实可用 | 调用 `POST /api/templates` |
|
||||
| 编辑模板 | 真实可用 | 调用 `PATCH /api/templates/{id}` |
|
||||
| 删除模板 | 真实可用 | 调用 `DELETE /api/templates/{id}` |
|
||||
| 添加/删除分类 | 真实可用 | 保存在模板 `mapping_rules.classes` |
|
||||
| 拖拽排序 | 真实可用 | 重算 zIndex,保存时写后端 |
|
||||
| JSON 批量导入 | 部分可用 | 前端解析 JSON 并加入编辑态,保存后才落库 |
|
||||
| 载入腹腔镜 35 分类 | 真实可用 | 前端内置数据;后端也 seed 默认模板 |
|
||||
| mapping rules | 部分可用 | 可存 `rules`,但无实际映射执行引擎 |
|
||||
|
||||
## 总体结论
|
||||
|
||||
当前前端真实可用的主链路是:登录、Dashboard 后端概览、项目列表、新建项目、上传视频/DICOM、拆帧、浏览帧、播放帧、工作区点/框 AI 推理、标注保存/回显、COCO 导出、模板 CRUD。
|
||||
|
||||
当前最主要的 Mock 或未打通链路是:撤销重做、手工几何绘制、GT 导入、mask 降维点区域、真正的文本语义分割和语义优先级融合。
|
||||
193
doc/04-api-contracts.md
Normal file
193
doc/04-api-contracts.md
Normal file
@@ -0,0 +1,193 @@
|
||||
# 接口契约清单
|
||||
|
||||
## 前端 API 基础配置
|
||||
|
||||
位置:`src/lib/config.ts`、`src/lib/api.ts`
|
||||
|
||||
```ts
|
||||
API_BASE_URL = import.meta.env.VITE_API_BASE_URL || 'http://<current-browser-host>:8000'
|
||||
timeout: 30000
|
||||
```
|
||||
|
||||
前端 request interceptor 会从 localStorage 读取 `token`,附加:
|
||||
|
||||
```http
|
||||
Authorization: Bearer <token>
|
||||
```
|
||||
|
||||
当前后端多数接口没有鉴权依赖,所以这个 header 主要是前端侧行为。
|
||||
|
||||
## 前端封装的 API
|
||||
|
||||
| 函数 | 方法与路径 | 状态 | 说明 |
|
||||
|------|------------|------|------|
|
||||
| `login(username, password)` | `POST /api/auth/login` | 对齐 | 后端返回 `{ token, username }`,前端只使用 token |
|
||||
| `getProjects()` | `GET /api/projects` | 对齐 | 前端映射 `frame_count`、`thumbnail_url` 等字段 |
|
||||
| `createProject(payload)` | `POST /api/projects` | 对齐 | 支持 `name`、`description`、`parse_fps` |
|
||||
| `updateProject(id, payload)` | `PATCH /api/projects/{id}` | 对齐 | 后端是 `PATCH /api/projects/{id}` |
|
||||
| `deleteProject(id)` | `DELETE /api/projects/{id}` | 对齐 | 当前 UI 未明显接入 |
|
||||
| `getTemplates()` | `GET /api/templates` | 对齐 | 前端从 `mapping_rules` 取 classes/rules |
|
||||
| `createTemplate(payload)` | `POST /api/templates` | 对齐 | 后端会打包 classes/rules 到 mapping_rules |
|
||||
| `updateTemplate(id, payload)` | `PATCH /api/templates/{id}` | 对齐 | 模板编辑页使用 |
|
||||
| `deleteTemplate(id)` | `DELETE /api/templates/{id}` | 对齐 | 模板编辑页使用 |
|
||||
| `uploadMedia(file, projectId)` | `POST /api/media/upload` | 对齐 | multipart form-data |
|
||||
| `uploadDicomBatch(files, projectId)` | `POST /api/media/upload/dicom` | 对齐 | multipart form-data |
|
||||
| `parseMedia(projectId)` | `POST /api/media/parse?project_id=...` | 对齐 | 创建异步拆帧任务并返回 task |
|
||||
| `getTask(taskId)` | `GET /api/tasks/{task_id}` | 对齐 | 查询异步任务状态 |
|
||||
| `getProjectFrames(projectId)` | `GET /api/projects/{id}/frames` | 对齐 | 后端返回预签名 image_url |
|
||||
| `predictMask(payload)` | `POST /api/ai/predict` | 对齐 | 前端发送 `image_id/prompt_type/prompt_data/model`,并把后端 `polygons` 转为 `masks[].pathData` |
|
||||
| `getAiModelStatus(selectedModel?)` | `GET /api/ai/models/status` | 对齐 | 返回 GPU、SAM 2、SAM 3 的真实运行状态 |
|
||||
| `getProjectAnnotations(projectId, frameId?)` | `GET /api/ai/annotations` | 对齐 | 前端加载工作区时用于回显已保存标注 |
|
||||
| `saveAnnotation(payload)` | `POST /api/ai/annotate` | 对齐 | 工作区归档保存当前项目未保存 mask |
|
||||
| `updateAnnotation(annotationId, payload)` | `PATCH /api/ai/annotations/{annotation_id}` | 对齐 | 工作区归档保存 dirty mask |
|
||||
| `deleteAnnotation(annotationId)` | `DELETE /api/ai/annotations/{annotation_id}` | 对齐 | 工作区清空当前帧已保存标注 |
|
||||
| `getDashboardOverview()` | `GET /api/dashboard/overview` | 对齐 | Dashboard 初始统计、队列和活动日志 |
|
||||
| `exportCoco(projectId)` | `GET /api/export/{projectId}/coco` | 对齐 | 后端实际是 `GET /api/export/{project_id}/coco` |
|
||||
|
||||
## 后端 FastAPI 接口
|
||||
|
||||
以下列表来自当前运行的 OpenAPI:
|
||||
|
||||
| 方法 | 路径 | 用途 |
|
||||
|------|------|------|
|
||||
| POST | `/api/auth/login` | 登录 |
|
||||
| POST | `/api/projects` | 创建项目 |
|
||||
| GET | `/api/projects` | 项目列表 |
|
||||
| GET | `/api/projects/{project_id}` | 项目详情 |
|
||||
| PATCH | `/api/projects/{project_id}` | 更新项目 |
|
||||
| DELETE | `/api/projects/{project_id}` | 删除项目 |
|
||||
| POST | `/api/projects/{project_id}/frames` | 添加帧记录 |
|
||||
| GET | `/api/projects/{project_id}/frames` | 项目帧列表 |
|
||||
| GET | `/api/projects/{project_id}/frames/{frame_id}` | 单帧详情 |
|
||||
| POST | `/api/templates` | 创建模板 |
|
||||
| GET | `/api/templates` | 模板列表 |
|
||||
| GET | `/api/templates/{template_id}` | 模板详情 |
|
||||
| PATCH | `/api/templates/{template_id}` | 更新模板 |
|
||||
| DELETE | `/api/templates/{template_id}` | 删除模板 |
|
||||
| POST | `/api/media/upload` | 上传视频/图片/DICOM 单文件 |
|
||||
| POST | `/api/media/upload/dicom` | 批量上传 DICOM |
|
||||
| POST | `/api/media/parse` | 创建 Celery 拆帧任务 |
|
||||
| GET | `/api/tasks` | 查询后台任务列表 |
|
||||
| GET | `/api/tasks/{task_id}` | 查询单个后台任务 |
|
||||
| POST | `/api/ai/predict` | SAM 2 / SAM 3 可选推理 |
|
||||
| GET | `/api/ai/models/status` | GPU 和 SAM 模型状态 |
|
||||
| POST | `/api/ai/auto` | 自动分割 |
|
||||
| POST | `/api/ai/annotate` | 保存 AI 标注 |
|
||||
| GET | `/api/ai/annotations` | 查询项目标注,可选按帧过滤 |
|
||||
| PATCH | `/api/ai/annotations/{annotation_id}` | 更新已保存标注 |
|
||||
| DELETE | `/api/ai/annotations/{annotation_id}` | 删除已保存标注 |
|
||||
| GET | `/api/dashboard/overview` | Dashboard 聚合快照 |
|
||||
| GET | `/api/export/{project_id}/coco` | 导出 COCO JSON |
|
||||
| GET | `/api/export/{project_id}/masks` | 导出 PNG mask ZIP |
|
||||
| GET | `/health` | 健康检查 |
|
||||
| WS | `/ws/progress` | WebSocket 进度通道,未出现在 OpenAPI paths 中 |
|
||||
|
||||
## 关键请求体
|
||||
|
||||
### 登录
|
||||
|
||||
```json
|
||||
{
|
||||
"username": "admin",
|
||||
"password": "123456"
|
||||
}
|
||||
```
|
||||
|
||||
### 创建项目
|
||||
|
||||
```json
|
||||
{
|
||||
"name": "example.mp4",
|
||||
"description": "导入说明",
|
||||
"parse_fps": 30
|
||||
}
|
||||
```
|
||||
|
||||
### 创建/更新模板
|
||||
|
||||
```json
|
||||
{
|
||||
"name": "腹腔镜胆囊切除术",
|
||||
"color": "#06b6d4",
|
||||
"z_index": 0,
|
||||
"classes": [
|
||||
{
|
||||
"id": "cls-1",
|
||||
"name": "胆囊",
|
||||
"color": "#ffae00",
|
||||
"zIndex": 280,
|
||||
"category": "腹腔镜胆囊切除术"
|
||||
}
|
||||
],
|
||||
"rules": []
|
||||
}
|
||||
```
|
||||
|
||||
### AI 推理请求体
|
||||
|
||||
前端 `predictMask()` 当前已适配后端 `PredictRequest`:
|
||||
|
||||
```json
|
||||
{
|
||||
"image_id": 123,
|
||||
"model": "sam2",
|
||||
"prompt_type": "point",
|
||||
"prompt_data": {
|
||||
"points": [[0.5, 0.5]],
|
||||
"labels": [1]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
`prompt_type` 支持:
|
||||
|
||||
- `point`
|
||||
- `box`
|
||||
- `semantic`,选择 `sam3` 时进入 SAM 3 文本语义推理;选择 `sam2` 时仍回退到 auto segmentation
|
||||
|
||||
后端响应:
|
||||
|
||||
```json
|
||||
{
|
||||
"polygons": [
|
||||
[[0.25, 0.25], [0.75, 0.25], [0.75, 0.75], [0.25, 0.75]]
|
||||
],
|
||||
"scores": [0.5]
|
||||
}
|
||||
```
|
||||
|
||||
前端会把上面的 `polygons` 转成:
|
||||
|
||||
```json
|
||||
{
|
||||
"masks": [
|
||||
{
|
||||
"pathData": "M 160 90 L 480 90 L 480 270 L 160 270 Z",
|
||||
"segmentation": [[160, 90, 480, 90, 480, 270, 160, 270]],
|
||||
"bbox": [160, 90, 320, 180]
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
## 已完成的接口对齐
|
||||
|
||||
- `updateProject()` 已从 `PUT` 改为 `PATCH`。
|
||||
- `exportCoco()` 已从 `/api/export/coco/{projectId}` 改为 `/api/export/{projectId}/coco`。
|
||||
- Canvas 已使用真实 `frame.id` 作为 `image_id`。
|
||||
- 点和框坐标已转成后端需要的归一化坐标。
|
||||
- 后端 `polygons` 已在前端转成 Konva 可渲染的 path。
|
||||
- `saveAnnotation()` 已接入 `POST /api/ai/annotate`。
|
||||
- `getProjectAnnotations()` 已接入 `GET /api/ai/annotations`。
|
||||
- `updateAnnotation()` 已接入 `PATCH /api/ai/annotations/{annotationId}`。
|
||||
- `deleteAnnotation()` 已接入 `DELETE /api/ai/annotations/{annotationId}`。
|
||||
- `parseMedia()` 已改为创建 Celery 后台任务,并返回 `ProcessingTask`。
|
||||
- `getTask()` 已接入 `GET /api/tasks/{taskId}`。
|
||||
- `getDashboardOverview()` 已从 `processing_tasks` 聚合解析队列。
|
||||
- 工作区导出按钮已调用 `exportCoco()`,并会先保存未归档 mask。
|
||||
|
||||
## 仍需处理的接口问题
|
||||
|
||||
- WebSocket 地址已从 `VITE_WS_PROGRESS_URL` 读取,未配置时从 `API_BASE_URL` 推导;部署时仍要确认浏览器能访问该地址。
|
||||
- Celery worker 进度会写 PostgreSQL 任务表,同时发布到 Redis `seg:progress`;FastAPI 订阅后广播到 `/ws/progress`。
|
||||
- 已保存标注目前支持分类级更新和整帧清空删除;逐点几何编辑器尚未实现。
|
||||
115
doc/05-implementation-plan.md
Normal file
115
doc/05-implementation-plan.md
Normal file
@@ -0,0 +1,115 @@
|
||||
# 后续实施建议
|
||||
|
||||
目标是把当前“能看、能上传、能拆帧”的系统推进到“能真实完成标注闭环”的系统。
|
||||
|
||||
## 阶段 1:先修接口契约(已完成基础对齐)
|
||||
|
||||
优先级最高。AI 点/框推理和 COCO 导出的基础契约已经按当前代码完成对齐。
|
||||
|
||||
已完成:
|
||||
|
||||
1. `src/lib/api.ts` 的 `updateProject()` 已改为 `PATCH`。
|
||||
2. `exportCoco()` 路径已改为 `/api/export/{projectId}/coco`。
|
||||
3. Canvas 调 AI 时已使用当前帧真实 `frame.id` 作为 `image_id`。
|
||||
4. Canvas 点/框坐标已转成后端需要的归一化坐标。
|
||||
5. 后端 `polygons` 已转成前端可渲染的 Konva path。
|
||||
|
||||
剩余边界:
|
||||
|
||||
1. SAM 3 真实推理需要独立满足官方 Python 3.12+、PyTorch 2.7+、CUDA 12.6+ 环境。
|
||||
2. 标注删除/更新接口已打通基础能力;逐点几何编辑器尚未实现。
|
||||
|
||||
## 阶段 2:打通标注保存(已完成基础闭环)
|
||||
|
||||
当前工作区可将未保存 mask 写入后端标注表,并在加载项目帧后回显。
|
||||
|
||||
已完成:
|
||||
|
||||
1. 前端根据 `Mask.segmentation` 构造后端需要的 normalized `mask_data.polygons`。
|
||||
2. 用户点击“结构化归档保存”后,未保存 mask 调用 `POST /api/ai/annotate`,dirty mask 调用 `PATCH /api/ai/annotations/{annotation_id}`。
|
||||
3. 后端保存或更新 `project_id`、`frame_id`、`template_id`、`mask_data`、`bbox`;具体分类写入 `mask_data.class`。
|
||||
4. 工作区加载帧后调用 `GET /api/ai/annotations` 回显已保存标注。
|
||||
5. 工作区“清空遮罩”调用 `DELETE /api/ai/annotations/{annotation_id}` 删除当前帧已保存标注。
|
||||
|
||||
剩余建议:
|
||||
|
||||
1. 加入保存冲突处理和批量保存错误提示。
|
||||
2. 增加逐点几何编辑器,让已保存 mask 的 polygon 本身可以被修改后 PATCH。
|
||||
|
||||
## 阶段 3:接入导出按钮(已完成 COCO JSON)
|
||||
|
||||
当前工作区“导出 JSON 标注集”会先保存未归档 mask,再调用 COCO 导出接口。
|
||||
|
||||
建议:
|
||||
|
||||
1. 增加“导出 PNG Mask ZIP”按钮,调用 `/api/export/{projectId}/masks`。
|
||||
2. 无标注时给出更明确的空导出提示。
|
||||
|
||||
## 阶段 4:替换 Dashboard mock
|
||||
|
||||
当前 Dashboard 已通过 `GET /api/dashboard/overview` 读取后端聚合快照,不再使用硬编码初始统计、队列或活动日志。
|
||||
|
||||
已完成:
|
||||
|
||||
- 聚合项目、帧、标注、模板数量和主机 load average。
|
||||
- 按 `processing_tasks` queued/running 任务生成解析队列。
|
||||
- 按最近任务、项目、标注、模板记录生成活动流。
|
||||
|
||||
剩余建议:
|
||||
|
||||
1. 为任务增加取消、重试和失败详情 UI。
|
||||
2. 为 Dashboard 增加任务历史筛选和失败详情入口。
|
||||
|
||||
## 阶段 5:异步拆帧和进度
|
||||
|
||||
Word 方案中提到 Celery + Redis。当前已经有 Celery app、worker task 和 `processing_tasks` 表。
|
||||
|
||||
已完成:
|
||||
|
||||
1. 新建 Celery app。
|
||||
2. `POST /api/media/parse` 只创建任务并立即返回 task id。
|
||||
3. worker 执行 FFmpeg/OpenCV/pydicom。
|
||||
4. worker 写 PostgreSQL 任务进度。
|
||||
5. worker 发布 Redis `seg:progress`,FastAPI 广播到 `/ws/progress`。
|
||||
|
||||
剩余建议:
|
||||
|
||||
1. 为任务增加取消、重试和失败详情接口。
|
||||
2. 前端 Dashboard 保留轮询兜底,并补充失败详情 UI。
|
||||
|
||||
Dashboard 的解析队列现在已经从“项目状态派生”升级为任务表驱动,实时推送也已通过 Redis/WebSocket 打通;剩余重点是任务控制。
|
||||
|
||||
## 阶段 6:GT 导入与点区域
|
||||
|
||||
这是 Word 方案中最复杂的部分,当前完全未实现。
|
||||
|
||||
建议拆成小步:
|
||||
|
||||
1. 先支持上传二值/多类别 mask。
|
||||
2. 后端按类别提取 connected components。
|
||||
3. 用 OpenCV distance transform 找正向点。
|
||||
4. 暂时不做骨架/HDBSCAN,先生成最小可用点集。
|
||||
5. 前端以可拖拽点显示并保存。
|
||||
6. 后续再做骨架和聚类增强。
|
||||
|
||||
## 阶段 7:模板优先级融合
|
||||
|
||||
当前模板有 z-index,但没有真正用于语义冲突裁决。
|
||||
|
||||
建议:
|
||||
|
||||
1. 标注保存时记录 template class id / name / zIndex。
|
||||
2. 导出 mask 时按 zIndex 从低到高覆盖。
|
||||
3. 同类 mask 做 union。
|
||||
4. 跨类重叠由高 zIndex 覆盖低 zIndex。
|
||||
|
||||
这一步完成后,系统才真正符合“语义分割一个像素一个类别”的目标。
|
||||
|
||||
## 阶段 8:清理 UI 文案与 Mock
|
||||
|
||||
建议统一这些文案和真实能力:
|
||||
|
||||
- SAM/GPU 状态已改为 `GET /api/ai/models/status` 驱动。
|
||||
- 撤销/重做按钮接历史栈,否则隐藏。
|
||||
- “重新提取内侧中轴树骨架”接真实接口,否则标为未实现。
|
||||
- AI 独立页不要固定 Unsplash 图,应从当前项目帧或上传文件进入。
|
||||
103
doc/06-fastapi-docs-explained.md
Normal file
103
doc/06-fastapi-docs-explained.md
Normal file
@@ -0,0 +1,103 @@
|
||||
# `/docs` 是什么
|
||||
|
||||
地址:
|
||||
|
||||
- 本机:`http://localhost:8000/docs`
|
||||
- 局域网:`http://192.168.3.11:8000/docs`
|
||||
|
||||
这个页面不是文件列表,也不是项目文档目录。它是 FastAPI 自动生成的 Swagger UI,用来展示和调试后端 HTTP API。
|
||||
|
||||
## 为什么会自动出现
|
||||
|
||||
FastAPI 会根据代码里的路由和 Pydantic schema 自动生成 OpenAPI 描述,然后用 Swagger UI 展示出来。
|
||||
|
||||
相关代码在:
|
||||
|
||||
- `backend/main.py` 创建 `FastAPI(...)`
|
||||
- `backend/routers/*.py` 定义 `@router.get(...)`、`@router.post(...)` 等接口
|
||||
- `backend/schemas.py` 定义请求体和响应体
|
||||
|
||||
## 页面上 GET / POST / PATCH / DELETE 是什么
|
||||
|
||||
这些是 HTTP 方法,不是文件。
|
||||
|
||||
| 方法 | 含义 | 例子 |
|
||||
|------|------|------|
|
||||
| GET | 读取数据 | `GET /api/projects` 获取项目列表 |
|
||||
| POST | 创建或触发动作 | `POST /api/media/upload` 上传文件 |
|
||||
| PATCH | 局部更新 | `PATCH /api/templates/{template_id}` 更新模板 |
|
||||
| DELETE | 删除 | `DELETE /api/projects/{project_id}` 删除项目 |
|
||||
|
||||
你看到的每一行,都是后端暴露给前端调用的一个接口。
|
||||
|
||||
## `/docs` 能做什么
|
||||
|
||||
可以:
|
||||
|
||||
- 查看后端目前有哪些接口。
|
||||
- 展开接口查看参数、请求体和响应格式。
|
||||
- 点击 `Try it out` 直接发请求测试后端。
|
||||
- 检查接口返回错误,比如 400、401、404、500。
|
||||
|
||||
不能:
|
||||
|
||||
- 查看前端页面源码。
|
||||
- 直接代表某个功能已经完整可用。
|
||||
- 展示 WebSocket 的完整交互,因为 OpenAPI 主要描述 HTTP 接口。
|
||||
|
||||
## 和前端有什么关系
|
||||
|
||||
前端的 `src/lib/api.ts` 会调用这些接口。例如:
|
||||
|
||||
- 登录页调用 `/api/auth/login`
|
||||
- 项目库调用 `/api/projects`
|
||||
- 上传视频调用 `/api/media/upload`
|
||||
- 拆帧调用 `/api/media/parse`
|
||||
- 模板库调用 `/api/templates`
|
||||
|
||||
所以 `/docs` 是检查“后端提供了什么”的地方;前端是否真的用对了,还要对照 `src/lib/api.ts`。
|
||||
|
||||
## 目前通过 `/docs` 能看到的接口
|
||||
|
||||
当前后端接口包括:
|
||||
|
||||
- Auth:登录
|
||||
- Projects:项目 CRUD、项目帧 CRUD
|
||||
- Templates:模板 CRUD
|
||||
- Media:上传视频/DICOM、触发拆帧
|
||||
- AI:SAM 2 / SAM 3 可选推理、模型状态、自动分割、保存标注
|
||||
- Export:导出 COCO JSON、导出 PNG masks
|
||||
- Health:健康检查
|
||||
|
||||
## 为什么看起来像“列举文件和请求”
|
||||
|
||||
因为 Swagger UI 默认按接口分组,把每个 endpoint 展开成一行。它列举的是“后端可被调用的功能入口”,不是项目文件。
|
||||
|
||||
真正的项目文件在本地目录里,例如:
|
||||
|
||||
- 前端:`src/components/*.tsx`
|
||||
- 后端路由:`backend/routers/*.py`
|
||||
- 后端模型:`backend/models.py`
|
||||
|
||||
## 如何用 `/docs` 验证一个接口
|
||||
|
||||
以项目列表为例:
|
||||
|
||||
1. 打开 `/docs`。
|
||||
2. 找到 `GET /api/projects`。
|
||||
3. 点开。
|
||||
4. 点击 `Try it out`。
|
||||
5. 点击 `Execute`。
|
||||
6. 查看 Response body。
|
||||
|
||||
如果这里能返回数据,但前端项目库加载失败,那问题多半在前端 API 地址、CORS、字段映射或浏览器网络请求。
|
||||
|
||||
## 另一个机器可读入口
|
||||
|
||||
OpenAPI JSON 在:
|
||||
|
||||
```text
|
||||
http://localhost:8000/openapi.json
|
||||
```
|
||||
|
||||
这是给工具读取的接口描述,Swagger UI 就是基于它渲染出来的。
|
||||
120
doc/07-current-requirements-freeze.md
Normal file
120
doc/07-current-requirements-freeze.md
Normal file
@@ -0,0 +1,120 @@
|
||||
# 当前需求冻结文档
|
||||
|
||||
冻结日期:2026-05-01
|
||||
|
||||
本文档描述当前仓库已经实现或明确保留为占位的需求。测试用例以本文档为准,不把早期设想或 Word 文档中的远期能力当作当前版本必须实现的功能。
|
||||
|
||||
## R1 登录与会话
|
||||
|
||||
- 系统提供登录页。
|
||||
- 默认开发凭证为 `admin / 123456`。
|
||||
- 登录成功后前端保存 token,并进入主应用。
|
||||
- 登录失败时显示错误信息。
|
||||
- 当前 token 是开发用固定 token,不做真实 JWT 校验。
|
||||
|
||||
## R2 项目管理
|
||||
|
||||
- 前端展示项目库,并从 `GET /api/projects` 获取项目列表。
|
||||
- 用户可以新建项目,前端调用 `POST /api/projects`。
|
||||
- 用户可以选择项目,进入工作区。
|
||||
- 用户可以导入视频文件,前端创建项目、上传文件、触发拆帧、刷新项目列表。
|
||||
- 用户可以导入 DICOM 序列,前端上传 DICOM、触发拆帧、刷新项目列表。
|
||||
- 后端支持项目创建、列表、详情、局部更新和删除。
|
||||
- 后端支持项目帧创建、列表和单帧查询。
|
||||
|
||||
## R3 媒体上传与拆帧
|
||||
|
||||
- 后端允许上传视频、图片、DICOM 文件,其他扩展名返回 400。
|
||||
- 未提供项目 ID 上传时,后端自动创建项目。
|
||||
- 提供项目 ID 上传时,后端把上传对象关联到该项目。
|
||||
- 拆帧接口根据项目 `source_type` 处理视频或 DICOM。
|
||||
- 拆帧完成后写入 `frames` 记录,并把项目状态设为 `ready`。
|
||||
- 拆帧接口会创建 `processing_tasks` 记录并投递 Celery worker。
|
||||
- 前端可通过 `GET /api/tasks/{task_id}` 查询任务状态。
|
||||
|
||||
## R4 工作区与帧浏览
|
||||
|
||||
- 工作区根据当前项目加载帧列表。
|
||||
- 若项目有媒体但无帧,工作区会尝试触发拆帧后重新加载。
|
||||
- Canvas 显示当前帧图片。
|
||||
- Canvas 支持滚轮缩放、移动工具拖拽、鼠标坐标显示。
|
||||
- 时间轴支持缩略图点击切帧、range 拖动切帧、播放/暂停顺序推进帧。
|
||||
- 播放帧率使用项目 `parse_fps` 或 `original_fps`,限制在 1 到 30 FPS。
|
||||
|
||||
## R5 工具栏
|
||||
|
||||
- 工具栏可以切换当前 active tool。
|
||||
- 正向点、反向点、框选工具会影响 Canvas 交互。
|
||||
- 魔法棒按钮切换到 AI 页面。
|
||||
- 多边形、矩形、圆、点、线、合并、去除、撤销、重做当前只提供 UI 状态或占位按钮,不完成真实绘制/算法。
|
||||
|
||||
## R6 AI 推理
|
||||
|
||||
- 前端可以在 AI 页面选择 `sam2` 或 `sam3`,选择结果存放在全局 store。
|
||||
- 前端和工作区通过 `GET /api/ai/models/status` 展示 GPU、SAM 2 和 SAM 3 的真实运行状态。
|
||||
- 前端 `predictMask()` 调用 `POST /api/ai/predict`。
|
||||
- 前端发送后端契约:`image_id`、`prompt_type`、`prompt_data`、`model`。
|
||||
- 点提示传 `{ points, labels }`,正向点 label 为 1,反向点 label 为 0。
|
||||
- 框选提示传归一化 `[x1, y1, x2, y2]`。
|
||||
- 语义文本提示传 `semantic`;选择 `sam3` 且环境满足依赖时走 SAM 3 文本语义推理,选择 `sam2` 时回退到自动分割。
|
||||
- 后端返回 `polygons` 和 `scores`。
|
||||
- 前端把后端 `polygons` 转成 Konva `pathData`、`segmentation`、`bbox`、`area`。
|
||||
- AI 推理结果先存放在前端 store 的 `masks` 中,点击“结构化归档保存”后持久化到后端标注表。
|
||||
|
||||
## R7 标注保存
|
||||
|
||||
- 后端提供 `POST /api/ai/annotate` 保存标注。
|
||||
- 保存时必须存在项目;如果传入 `frame_id`,帧也必须存在。
|
||||
- 后端提供 `GET /api/ai/annotations` 查询项目标注,可选按 `frame_id` 过滤。
|
||||
- 后端提供 `PATCH /api/ai/annotations/{annotation_id}` 更新已保存标注的 `mask_data`、`points`、`bbox` 和 `template_id`。
|
||||
- 后端提供 `DELETE /api/ai/annotations/{annotation_id}` 删除已保存标注。
|
||||
- 当前前端“结构化归档保存”会保存当前项目未保存 mask,并会更新已标记为 dirty 的已保存 mask。
|
||||
- 工作区“清空遮罩”会删除当前帧已保存标注,并清空当前帧未保存 mask。
|
||||
- 工作区加载项目帧后会查询已保存标注并回显。
|
||||
|
||||
## R8 模板库
|
||||
|
||||
- 前端展示模板列表,调用 `GET /api/templates`。
|
||||
- 用户可以新建、编辑、删除模板。
|
||||
- 模板分类存放在 `mapping_rules.classes`,规则存放在 `mapping_rules.rules`。
|
||||
- 前端支持添加/删除分类、拖拽排序后重算 `zIndex`、JSON 批量导入、加载腹腔镜默认分类。
|
||||
- 后端支持模板创建、列表、详情、局部更新和删除。
|
||||
|
||||
## R9 本体检查面板
|
||||
|
||||
- 工作区右侧可以选择模板。
|
||||
- 面板显示模板分类和组件本地自定义分类。
|
||||
- 用户可以选择具体分类;新 AI mask 会记录 `classId`、`className`、`classZIndex`,并在保存时写入 `mask_data.class`。
|
||||
- 添加自定义分类只存在组件本地状态,不保存到后端。
|
||||
- 置信度、拓扑锚点和重新提取骨架按钮当前为展示/占位。
|
||||
|
||||
## R10 Dashboard 与 WebSocket
|
||||
|
||||
- Dashboard 显示基础统计、解析队列和活动日志。
|
||||
- Dashboard 初始数据来自 `GET /api/dashboard/overview`。
|
||||
- 后端聚合项目数、处理中任务数、标注数、帧数、模板数和主机 load average。
|
||||
- 解析队列由 `processing_tasks` 中的 queued/running 任务生成;活动日志由最近任务、项目、标注和模板记录生成。
|
||||
- Dashboard 会连接 `/ws/progress`。
|
||||
- 收到 progress、complete、error、status 消息时,前端会更新队列或日志。
|
||||
- Celery worker 每次更新 `processing_tasks` 后会发布 Redis `seg:progress` 事件,FastAPI 订阅并广播给 `/ws/progress` 客户端。
|
||||
- 后端 WebSocket 接收到客户端消息后返回 status heartbeat。
|
||||
|
||||
## R11 导出
|
||||
|
||||
- 后端支持 `GET /api/export/{project_id}/coco` 导出 COCO JSON。
|
||||
- 后端支持 `GET /api/export/{project_id}/masks` 导出 PNG mask ZIP。
|
||||
- 当前前端 `exportCoco()` API 封装已对齐后端路径。
|
||||
- 工作区“导出 JSON 标注集”按钮已绑定下载事件;导出前会先保存当前未归档 mask。
|
||||
|
||||
## R12 配置
|
||||
|
||||
- 前端 API 地址由 `src/lib/config.ts` 统一推导。
|
||||
- `VITE_API_BASE_URL` 优先级高于自动推导。
|
||||
- `VITE_WS_PROGRESS_URL` 优先级高于从 API 地址推导 WebSocket 地址。
|
||||
- 未设置环境变量时,前端按当前浏览器 hostname 推导 `http://<host>:8000`。
|
||||
|
||||
## R13 文档与测试
|
||||
|
||||
- `doc/` 目录保存当前实现审计、接口契约、需求冻结、设计冻结和测试计划。
|
||||
- 测试应覆盖当前冻结需求中的真实功能、半可用行为和明确占位行为。
|
||||
- 对外部服务依赖 PostgreSQL、MinIO、Redis、SAM 模型的测试应使用 mock 或测试替身,不依赖真实服务可用性。
|
||||
155
doc/08-current-design-freeze.md
Normal file
155
doc/08-current-design-freeze.md
Normal file
@@ -0,0 +1,155 @@
|
||||
# 当前设计冻结文档
|
||||
|
||||
冻结日期:2026-05-01
|
||||
|
||||
本文档描述当前代码结构、数据流、接口契约和测试边界。后续实现如果改变这些设计,应同步更新本文档和测试。
|
||||
|
||||
## 总体架构
|
||||
|
||||
当前系统由三层组成:
|
||||
|
||||
- React + TypeScript 前端 SPA。
|
||||
- FastAPI 后端 API。
|
||||
- PostgreSQL、MinIO、Redis、SAM 2 / SAM 3 等外部基础设施。
|
||||
|
||||
开发时前端通过 `server.ts` 启动 Express + Vite middleware;后端通过 `backend/main.py` 启动 FastAPI。前端业务接口主要访问 FastAPI,不依赖 `server.ts` 中保留的旧 mock API。
|
||||
|
||||
## 前端模块
|
||||
|
||||
| 模块 | 文件 | 设计职责 |
|
||||
|------|------|----------|
|
||||
| 应用入口 | `src/App.tsx` | 根据登录状态和 `activeModule` 切换页面 |
|
||||
| 全局状态 | `src/store/useStore.ts` | Zustand store,保存项目、帧、模板、mask、工具状态 |
|
||||
| API 封装 | `src/lib/api.ts` | Axios 客户端、字段映射、AI 响应转换 |
|
||||
| 配置 | `src/lib/config.ts` | 推导 API 和 WebSocket 地址 |
|
||||
| WebSocket | `src/lib/websocket.ts` | 进度流连接、订阅和重连 |
|
||||
| 模型状态 | `src/components/ModelStatusBadge.tsx` | 展示 GPU 与当前 SAM 模型真实可用状态 |
|
||||
| 登录页 | `src/components/Login.tsx` | 调用登录 API,写入 store |
|
||||
| Dashboard | `src/components/Dashboard.tsx` | 展示统计和 WebSocket 进度消息 |
|
||||
| 项目库 | `src/components/ProjectLibrary.tsx` | 项目列表、新建、导入视频/DICOM |
|
||||
| 工作区 | `src/components/VideoWorkspace.tsx` | 加载帧和模板,组织工具栏、Canvas、本体面板、时间轴 |
|
||||
| Canvas | `src/components/CanvasArea.tsx` | 显示帧、缩放平移、点/框提示、渲染 mask |
|
||||
| 工具栏 | `src/components/ToolsPalette.tsx` | 切换工具和跳转 AI 页面 |
|
||||
| 时间轴 | `src/components/FrameTimeline.tsx` | 帧导航和播放 |
|
||||
| 本体面板 | `src/components/OntologyInspector.tsx` | 模板选择、分类树、本地自定义分类 |
|
||||
| AI 页面 | `src/components/AISegmentation.tsx` | 独立 AI 推理视图,使用当前项目帧 |
|
||||
| 模板库 | `src/components/TemplateRegistry.tsx` | 模板 CRUD、分类编辑、导入、排序 |
|
||||
|
||||
## 后端模块
|
||||
|
||||
| 模块 | 文件 | 设计职责 |
|
||||
|------|------|----------|
|
||||
| 应用入口 | `backend/main.py` | FastAPI app、CORS、路由注册、健康检查、WebSocket |
|
||||
| 配置 | `backend/config.py` | Pydantic settings |
|
||||
| 数据库 | `backend/database.py` | SQLAlchemy engine、session、Base |
|
||||
| 模型 | `backend/models.py` | Project、Frame、Template、Annotation、Mask、ProcessingTask |
|
||||
| Schema | `backend/schemas.py` | Pydantic 请求/响应模型 |
|
||||
| Auth | `backend/routers/auth.py` | 开发登录 |
|
||||
| Projects | `backend/routers/projects.py` | 项目与帧 CRUD |
|
||||
| Templates | `backend/routers/templates.py` | 模板 CRUD 和 mapping_rules 打包/解包 |
|
||||
| Media | `backend/routers/media.py` | 上传媒体和拆帧 |
|
||||
| AI | `backend/routers/ai.py` | SAM 2 / SAM 3 可选推理、模型状态和标注保存 |
|
||||
| Export | `backend/routers/export.py` | COCO 和 PNG mask 导出 |
|
||||
| SAM 2 | `backend/services/sam2_engine.py` | SAM 2 懒加载、状态检测和点/框/自动推理 |
|
||||
| SAM 3 | `backend/services/sam3_engine.py` | SAM 3 状态检测和文本语义推理适配 |
|
||||
| SAM Registry | `backend/services/sam_registry.py` | 模型选择、GPU 状态和推理分发 |
|
||||
|
||||
## 状态模型
|
||||
|
||||
前端 store 的核心对象:
|
||||
|
||||
- `Project`:项目基本信息、状态、帧数、fps、媒体路径。
|
||||
- `Frame`:帧 ID、项目 ID、索引、图片 URL、宽高。
|
||||
- `Template` / `TemplateClass`:模板和分类定义。
|
||||
- `Mask`:前端渲染用 mask,包含 `pathData`、`segmentation`、`bbox`、`area`。
|
||||
- `activeModule`:当前页面。
|
||||
- `activeTool`:当前工具。
|
||||
- `aiModel`:当前选择的 AI 模型,取值为 `sam2` 或 `sam3`。
|
||||
|
||||
## 关键数据流
|
||||
|
||||
### 登录
|
||||
|
||||
1. `Login` 收集用户名和密码。
|
||||
2. `login()` 调用 `POST /api/auth/login`。
|
||||
3. 成功后 store 写入 token,App 渲染主界面。
|
||||
|
||||
### 项目导入
|
||||
|
||||
1. `ProjectLibrary` 创建项目。
|
||||
2. 上传视频或 DICOM 到 `/api/media/upload` 或 `/api/media/upload/dicom`。
|
||||
3. 调用 `/api/media/parse` 创建异步拆帧任务。
|
||||
4. Celery worker 执行 FFmpeg/OpenCV/pydicom 拆帧,持续更新 `processing_tasks`,并发布 Redis `seg:progress`。
|
||||
5. 刷新项目列表。
|
||||
|
||||
### 工作区加载
|
||||
|
||||
1. `VideoWorkspace` 根据 `currentProject.id` 调用 `getProjectFrames()`。
|
||||
2. 若无帧但项目有 `video_path`,触发 `parseMedia()`,通过 `getTask()` 轮询任务完成后重新取帧。
|
||||
3. 帧数据映射为 store `Frame[]`。
|
||||
4. 当前帧传入 `CanvasArea`。
|
||||
|
||||
### AI 点/框推理
|
||||
|
||||
1. 用户在 Canvas 选择正向点、反向点或框选。
|
||||
2. `CanvasArea` 读取当前帧 ID 和宽高。
|
||||
3. `predictMask()` 归一化坐标并携带当前 `model` 调用 `/api/ai/predict`。
|
||||
4. 后端加载帧图片并通过 SAM registry 分发到 SAM 2 或 SAM 3。
|
||||
5. 前端把 `polygons` 转为 mask,写入 store。
|
||||
6. Canvas 按当前帧过滤并渲染 mask。
|
||||
7. 新 mask 会带上当前选择的模板分类元数据,包括 `classId`、`className`、`classZIndex` 和保存状态 `draft`。
|
||||
8. 用户点击“结构化归档保存”后,前端将像素 `segmentation` 转成 normalized `mask_data.polygons`;未保存 mask 调用 `POST /api/ai/annotate`,dirty mask 调用 `PATCH /api/ai/annotations/{annotation_id}`。
|
||||
9. 工作区加载项目帧后通过 `GET /api/ai/annotations` 取回已保存标注并转成前端 mask。
|
||||
10. 工作区“清空遮罩”删除当前帧已保存标注,并清除当前帧本地 mask。
|
||||
|
||||
### 模板管理
|
||||
|
||||
1. `TemplateRegistry` 从后端读取模板。
|
||||
2. 编辑态在组件本地维护分类列表。
|
||||
3. 保存时调用 `createTemplate()` 或 `updateTemplate()`。
|
||||
4. 后端把 `classes`、`rules` 打包进 `mapping_rules`。
|
||||
5. 返回时再解包给前端。
|
||||
6. `OntologyInspector` 可以选择具体分类;选择结果进入全局 store,供 `CanvasArea` 和 `AISegmentation` 新建/更新 mask 时使用。
|
||||
|
||||
### 导出
|
||||
|
||||
1. 后端根据项目、帧、标注和模板生成 COCO JSON。
|
||||
2. PNG mask 导出会把 normalized polygon 渲染为二值 mask 并打包 ZIP。
|
||||
3. 前端“导出 JSON 标注集”按钮会在导出前保存待归档标注,然后下载 COCO JSON。
|
||||
|
||||
## 接口契约
|
||||
|
||||
接口详情见 `doc/04-api-contracts.md`。测试中重点固定以下契约:
|
||||
|
||||
- `updateProject()` 使用 `PATCH /api/projects/{id}`。
|
||||
- `exportCoco()` 使用 `GET /api/export/{projectId}/coco`。
|
||||
- `predictMask()` 使用 `POST /api/ai/predict`,请求体为 `image_id`、`prompt_type`、`prompt_data`、`model`。
|
||||
- `saveAnnotation()` 使用 `POST /api/ai/annotate`。
|
||||
- `getProjectAnnotations()` 使用 `GET /api/ai/annotations`。
|
||||
- `updateAnnotation()` 使用 `PATCH /api/ai/annotations/{annotationId}`。
|
||||
- `deleteAnnotation()` 使用 `DELETE /api/ai/annotations/{annotationId}`。
|
||||
- 后端 `/api/ai/predict` 支持 point、box、semantic 三种 prompt_type,并通过 `model` 选择 SAM 2 或 SAM 3。
|
||||
- 后端 `/api/ai/models/status` 返回 GPU、SAM 2、SAM 3 的真实运行状态。
|
||||
- point prompt 支持旧数组形式和 `{ points, labels }` 对象形式。
|
||||
|
||||
## 外部依赖边界
|
||||
|
||||
测试不直接依赖以下真实服务:
|
||||
|
||||
- PostgreSQL:后端测试使用内存 SQLite。
|
||||
- MinIO:上传、下载、预签名 URL 使用 monkeypatch。
|
||||
- Redis:单测使用 monkeypatch 验证进度事件发布,不依赖真实 Redis 服务。
|
||||
- SAM:AI 推理测试使用 fake registry。
|
||||
- 浏览器 Canvas/Konva 图片加载:前端测试 mock `react-konva` 和 `use-image`。
|
||||
|
||||
## 已知占位设计
|
||||
|
||||
以下能力属于当前冻结版本的占位或半可用功能:
|
||||
|
||||
- Dashboard 初始快照来自 `GET /api/dashboard/overview`;解析队列由 `processing_tasks` queued/running 任务生成。
|
||||
- 多边形、矩形、圆、点、线手工绘制未实现。
|
||||
- 合并、去除、撤销、重做未实现。
|
||||
- 工作区导出 PNG mask ZIP 按钮尚未提供。
|
||||
- 已保存标注支持通过“应用分类”进入 dirty 状态并归档更新;暂未提供逐点几何编辑器。
|
||||
- SAM 3 文本语义分割取决于官方依赖和 GPU 运行环境;状态接口会暴露真实可用性。
|
||||
- 自定义分类只存在本地组件状态。
|
||||
49
doc/09-test-plan.md
Normal file
49
doc/09-test-plan.md
Normal file
@@ -0,0 +1,49 @@
|
||||
# 当前测试计划
|
||||
|
||||
本文档把 `doc/07-current-requirements-freeze.md` 中的冻结需求映射到测试。测试目标是覆盖当前真实行为和明确占位行为。
|
||||
|
||||
## 测试分层
|
||||
|
||||
| 层级 | 工具 | 覆盖范围 |
|
||||
|------|------|----------|
|
||||
| 前端单元/组件 | Vitest + Testing Library | API 封装、store、组件交互、Mock/UI-only 状态 |
|
||||
| 后端路由 | pytest + FastAPI TestClient | Auth、Projects、Templates、AI、Export、Media 的接口契约 |
|
||||
| 静态契约 | TypeScript / py_compile | 类型和 Python 语法 |
|
||||
|
||||
## 覆盖矩阵
|
||||
|
||||
| 需求 | 测试文件 | 覆盖点 |
|
||||
|------|----------|--------|
|
||||
| R1 登录与会话 | `src/components/Login.test.tsx`, `backend/tests/test_auth.py` | 成功登录、失败提示、后端 401 |
|
||||
| R2 项目管理 | `src/lib/api.test.ts`, `backend/tests/test_projects.py` | 前端字段映射、PATCH 更新、后端 CRUD、帧列表 |
|
||||
| R3 媒体上传与拆帧 | `backend/tests/test_media.py` | 扩展名校验、自动建项目、关联项目、创建异步任务、worker 注册帧 |
|
||||
| R4 工作区与帧浏览 | `src/components/VideoWorkspace.test.tsx`, `src/components/FrameTimeline.test.tsx` | 加载帧、无帧触发解析、切帧、播放 |
|
||||
| R5 工具栏 | `src/components/ToolsPalette.test.tsx` | 工具切换、AI 跳转、占位按钮存在 |
|
||||
| R6 AI 推理 | `src/lib/api.test.ts`, `src/components/CanvasArea.test.tsx`, `src/components/AISegmentation.test.tsx`, `src/components/ModelStatusBadge.test.tsx`, `backend/tests/test_ai.py` | 点/框/semantic 契约、模型选择、GPU/SAM 状态、状态徽标、坐标归一化、正负点 labels、polygons 转 path、后端 fake registry |
|
||||
| R7 标注保存 | `src/components/VideoWorkspace.test.tsx`, `src/components/CanvasArea.test.tsx`, `src/lib/api.test.ts`, `backend/tests/test_ai.py` | 保存标注、加载回显、更新 dirty 标注、清空删除已保存标注、项目不存在、帧不存在 |
|
||||
| R8 模板库 | `src/lib/api.test.ts`, `backend/tests/test_templates.py` | mapping_rules 解包/打包、模板 CRUD |
|
||||
| R9 本体检查面板 | `src/components/OntologyInspector.test.tsx` | 模板选择、分类展示、具体分类选择、自定义分类本地添加 |
|
||||
| R10 Dashboard 与 WebSocket | `src/lib/api.test.ts`, `src/lib/websocket.test.ts`, `src/components/Dashboard.test.tsx`, `backend/tests/test_dashboard.py`, `backend/tests/test_main.py`, `backend/tests/test_progress_events.py` | 后端概览接口、任务表驱动队列、Redis 进度事件 payload/发布、地址推导、消息订阅、队列更新、heartbeat |
|
||||
| R11 导出 | `src/components/VideoWorkspace.test.tsx`, `src/lib/api.test.ts`, `backend/tests/test_export.py` | COCO 按钮下载、导出前自动保存、COCO 路径、JSON 结构、mask ZIP |
|
||||
| R12 配置 | `src/lib/config.test.ts` | env 优先、hostname 推导、WS 推导 |
|
||||
| R13 文档与测试 | `doc/09-test-plan.md` | 测试覆盖矩阵 |
|
||||
|
||||
## 运行命令
|
||||
|
||||
```bash
|
||||
npm run test
|
||||
npm run test:run
|
||||
npm run lint
|
||||
npm run build
|
||||
|
||||
pip install -r backend/requirements-dev.txt
|
||||
pytest backend/tests
|
||||
python -m py_compile backend/routers/ai.py backend/routers/templates.py backend/schemas.py
|
||||
```
|
||||
|
||||
## 当前不做的测试
|
||||
|
||||
- 不启动真实 PostgreSQL、MinIO、Redis 或 SAM 模型。
|
||||
- 不做真实视频大文件拆帧性能测试。
|
||||
- 不用浏览器 E2E 验证视觉细节。
|
||||
- 不把当前明确 Mock/UI-only 的按钮当成真实业务成功路径测试。
|
||||
32
doc/README.md
Normal file
32
doc/README.md
Normal file
@@ -0,0 +1,32 @@
|
||||
# 项目文档索引
|
||||
|
||||
本目录用于记录当前代码库的真实状态、目标设计与实现差距。文档依据包括:
|
||||
|
||||
- 根目录 Word 文档:`语义分割系统构建方案.docx`
|
||||
- 前端源码:`src/App.tsx`、`src/components/*.tsx`、`src/lib/api.ts`、`src/store/useStore.ts`
|
||||
- 后端源码:`backend/main.py`、`backend/routers/*.py`、`backend/schemas.py`、`backend/models.py`
|
||||
- 运行时 OpenAPI:`http://localhost:8000/openapi.json`
|
||||
|
||||
## 文档结构
|
||||
|
||||
| 文档 | 内容 |
|
||||
|------|------|
|
||||
| [01-purpose-and-word-summary.md](./01-purpose-and-word-summary.md) | 为什么要做这个系统,Word 方案中的目标,以及当前代码的落地程度 |
|
||||
| [02-current-implementation-map.md](./02-current-implementation-map.md) | 当前系统怎么运行,前后端、存储、数据流具体怎么串起来 |
|
||||
| [03-frontend-element-audit.md](./03-frontend-element-audit.md) | 前端逐页面/逐元素审计:真实可用、半可用、Mock/UI-only、接口不通 |
|
||||
| [04-api-contracts.md](./04-api-contracts.md) | 前端 API 封装、后端 FastAPI 接口、已完成对齐项和剩余接口问题 |
|
||||
| [05-implementation-plan.md](./05-implementation-plan.md) | 后续要把 Mock 变成真实功能的建议实施顺序 |
|
||||
| [06-fastapi-docs-explained.md](./06-fastapi-docs-explained.md) | `http://192.168.3.11:8000/docs` 是什么,怎么看和怎么用 |
|
||||
| [07-current-requirements-freeze.md](./07-current-requirements-freeze.md) | 当前版本需求冻结,测试以此为准 |
|
||||
| [08-current-design-freeze.md](./08-current-design-freeze.md) | 当前版本设计冻结,记录模块、数据流和接口边界 |
|
||||
| [09-test-plan.md](./09-test-plan.md) | 需求到测试文件的覆盖矩阵和运行命令 |
|
||||
|
||||
## 状态标记
|
||||
|
||||
| 标记 | 含义 |
|
||||
|------|------|
|
||||
| 真实可用 | 已接真实前端状态或后端 API,按当前代码能完成主要动作 |
|
||||
| 部分可用 | 有真实数据或真实 UI,但存在关键缺口,例如只读、不能持久化、缺少错误处理 |
|
||||
| Mock / UI-only | 只有展示或本地状态变化,没有真实业务效果 |
|
||||
| 接口不通 | 前端调用和后端接口契约不一致,按当前代码大概率失败 |
|
||||
| 目标设计 | Word 方案中提出,但当前代码尚未实现 |
|
||||
1154
package-lock.json
generated
1154
package-lock.json
generated
File diff suppressed because it is too large
Load Diff
11
package.json
11
package.json
@@ -9,7 +9,9 @@
|
||||
"preview": "vite preview",
|
||||
"start": "node server.ts",
|
||||
"clean": "rm -rf dist",
|
||||
"lint": "tsc --noEmit"
|
||||
"lint": "tsc --noEmit",
|
||||
"test": "vitest",
|
||||
"test:run": "vitest run"
|
||||
},
|
||||
"dependencies": {
|
||||
"@google/genai": "^1.29.0",
|
||||
@@ -31,12 +33,17 @@
|
||||
"zustand": "^5.0.12"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@testing-library/jest-dom": "^6.9.1",
|
||||
"@testing-library/react": "^16.3.2",
|
||||
"@testing-library/user-event": "^14.6.1",
|
||||
"@types/express": "^4.17.21",
|
||||
"@types/node": "^22.14.0",
|
||||
"autoprefixer": "^10.4.21",
|
||||
"jsdom": "^29.1.1",
|
||||
"tailwindcss": "^4.1.14",
|
||||
"tsx": "^4.21.0",
|
||||
"typescript": "~5.8.2",
|
||||
"vite": "^6.2.0"
|
||||
"vite": "^6.2.0",
|
||||
"vitest": "^4.1.5"
|
||||
}
|
||||
}
|
||||
|
||||
43
src/components/AISegmentation.test.tsx
Normal file
43
src/components/AISegmentation.test.tsx
Normal file
@@ -0,0 +1,43 @@
|
||||
import { fireEvent, render, screen } from '@testing-library/react';
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest';
|
||||
import { resetStore } from '../test/storeTestUtils';
|
||||
import { useStore } from '../store/useStore';
|
||||
import { AISegmentation } from './AISegmentation';
|
||||
|
||||
const apiMock = vi.hoisted(() => ({
|
||||
getAiModelStatus: vi.fn(),
|
||||
predictMask: vi.fn(),
|
||||
}));
|
||||
|
||||
vi.mock('../lib/api', () => ({
|
||||
getAiModelStatus: apiMock.getAiModelStatus,
|
||||
predictMask: apiMock.predictMask,
|
||||
}));
|
||||
|
||||
describe('AISegmentation', () => {
|
||||
beforeEach(() => {
|
||||
resetStore();
|
||||
vi.clearAllMocks();
|
||||
useStore.setState({
|
||||
frames: [{ id: 'frame-1', projectId: 'project-1', index: 0, url: '/frame.jpg', width: 640, height: 360 }],
|
||||
});
|
||||
apiMock.getAiModelStatus.mockResolvedValue({
|
||||
selected_model: 'sam2',
|
||||
gpu: { available: true, device: 'cuda', name: 'RTX 4090', torch_available: true },
|
||||
models: [
|
||||
{ id: 'sam2', label: 'SAM 2', available: true, loaded: false, device: 'cuda', supports: ['point', 'box'], message: 'SAM 2 ready', package_available: true, checkpoint_exists: true, python_ok: true, torch_ok: true, cuda_required: false },
|
||||
{ id: 'sam3', label: 'SAM 3', available: false, loaded: false, device: 'cuda', supports: ['semantic'], message: 'SAM 3 missing runtime', package_available: false, checkpoint_exists: false, python_ok: false, torch_ok: true, cuda_required: true },
|
||||
],
|
||||
});
|
||||
});
|
||||
|
||||
it('lets the user choose SAM3 for subsequent predictions', async () => {
|
||||
render(<AISegmentation onSendToWorkspace={vi.fn()} />);
|
||||
|
||||
const sam3Button = (await screen.findByText('SAM3')).closest('button')!;
|
||||
fireEvent.click(sam3Button);
|
||||
|
||||
expect(useStore.getState().aiModel).toBe('sam3');
|
||||
expect(await screen.findByText('SAM 3 missing runtime')).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
@@ -1,11 +1,11 @@
|
||||
import React, { useState, useCallback } from 'react';
|
||||
import React, { useState, useCallback, useEffect } from 'react';
|
||||
import { Target, PlusCircle, MinusCircle, SquareDashed, Sparkles, SendToBack, Image as ImageIcon, Undo, Redo, Loader2 } from 'lucide-react';
|
||||
import { cn } from '../lib/utils';
|
||||
import { Stage, Layer, Image as KonvaImage, Circle, Path, Group } from 'react-konva';
|
||||
import useImage from 'use-image';
|
||||
import { OntologyInspector } from './OntologyInspector';
|
||||
import { useStore } from '../store/useStore';
|
||||
import { predictMask } from '../lib/api';
|
||||
import { getAiModelStatus, predictMask, type AiRuntimeStatus } from '../lib/api';
|
||||
|
||||
interface AISegmentationProps {
|
||||
onSendToWorkspace: () => void;
|
||||
@@ -17,9 +17,15 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
|
||||
const masks = useStore((state) => state.masks);
|
||||
const addMask = useStore((state) => state.addMask);
|
||||
const clearMasks = useStore((state) => state.clearMasks);
|
||||
const frames = useStore((state) => state.frames);
|
||||
const currentFrameIndex = useStore((state) => state.currentFrameIndex);
|
||||
const activeTemplateId = useStore((state) => state.activeTemplateId);
|
||||
const activeClass = useStore((state) => state.activeClass);
|
||||
const aiModel = useStore((state) => state.aiModel);
|
||||
const setAiModel = useStore((state) => state.setAiModel);
|
||||
|
||||
const [modelSize, setModelSize] = useState('vit_l');
|
||||
const [semanticText, setSemanticText] = useState('');
|
||||
const [modelStatus, setModelStatus] = useState<AiRuntimeStatus | null>(null);
|
||||
const [autoDeleteBg, setAutoDeleteBg] = useState(true);
|
||||
const [cropMode, setCropMode] = useState(false);
|
||||
const [isInferencing, setIsInferencing] = useState(false);
|
||||
@@ -29,10 +35,29 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
|
||||
const [position, setPosition] = useState({ x: 0, y: 0 });
|
||||
const [points, setPoints] = useState<{ x: number, y: number, type: 'pos'|'neg' }[]>([]);
|
||||
const [cursorPos, setCursorPos] = useState({ x: 0, y: 0 });
|
||||
const [image] = useImage('https://images.unsplash.com/photo-1549317661-bd32c8ce0be2?q=80&w=2070&auto=format&fit=crop');
|
||||
const currentFrame = frames[currentFrameIndex] || null;
|
||||
const previewUrl = currentFrame?.url || 'https://images.unsplash.com/photo-1549317661-bd32c8ce0be2?q=80&w=2070&auto=format&fit=crop';
|
||||
const [image] = useImage(previewUrl);
|
||||
const frameMasks = currentFrame ? masks.filter((mask) => mask.frameId === currentFrame.id) : masks;
|
||||
const selectedModelStatus = modelStatus?.models.find((model) => model.id === aiModel);
|
||||
const modelCanInfer = selectedModelStatus?.available ?? true;
|
||||
|
||||
const effectiveTool = storeActiveTool;
|
||||
|
||||
useEffect(() => {
|
||||
let cancelled = false;
|
||||
getAiModelStatus(aiModel)
|
||||
.then((status) => {
|
||||
if (!cancelled) setModelStatus(status);
|
||||
})
|
||||
.catch(() => {
|
||||
if (!cancelled) setModelStatus(null);
|
||||
});
|
||||
return () => {
|
||||
cancelled = true;
|
||||
};
|
||||
}, [aiModel]);
|
||||
|
||||
const handleWheel = (e: any) => {
|
||||
e.evt.preventDefault();
|
||||
const scaleBy = 1.1;
|
||||
@@ -63,22 +88,44 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
|
||||
|
||||
const runInference = useCallback(async () => {
|
||||
if (points.length === 0 && !semanticText.trim()) return;
|
||||
if (!currentFrame?.id) {
|
||||
console.warn('AI inference skipped: no project frame is selected');
|
||||
return;
|
||||
}
|
||||
|
||||
const imageWidth = currentFrame.width || image?.naturalWidth || image?.width || 0;
|
||||
const imageHeight = currentFrame.height || image?.naturalHeight || image?.height || 0;
|
||||
if (imageWidth <= 0 || imageHeight <= 0) {
|
||||
console.warn('AI inference skipped: active frame dimensions are unavailable');
|
||||
return;
|
||||
}
|
||||
|
||||
setIsInferencing(true);
|
||||
try {
|
||||
const result = await predictMask({
|
||||
imageUrl: 'https://images.unsplash.com/photo-1549317661-bd32c8ce0be2?q=80&w=2070&auto=format&fit=crop',
|
||||
imageId: currentFrame.id,
|
||||
imageWidth,
|
||||
imageHeight,
|
||||
model: aiModel,
|
||||
points: points.map((p) => ({ x: p.x, y: p.y, type: p.type })),
|
||||
text: semanticText.trim() || undefined,
|
||||
modelSize,
|
||||
});
|
||||
|
||||
result.masks.forEach((m) => {
|
||||
const label = activeClass?.name || m.label;
|
||||
const color = activeClass?.color || m.color;
|
||||
addMask({
|
||||
id: m.id,
|
||||
frameId: 'frame-ai-1',
|
||||
frameId: currentFrame.id,
|
||||
templateId: activeTemplateId || undefined,
|
||||
classId: activeClass?.id,
|
||||
className: activeClass?.name,
|
||||
classZIndex: activeClass?.zIndex,
|
||||
saveStatus: 'draft',
|
||||
saved: false,
|
||||
pathData: m.pathData,
|
||||
label: m.label,
|
||||
color: m.color,
|
||||
label,
|
||||
color,
|
||||
segmentation: m.segmentation,
|
||||
bbox: m.bbox,
|
||||
area: m.area,
|
||||
@@ -89,7 +136,7 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
|
||||
} finally {
|
||||
setIsInferencing(false);
|
||||
}
|
||||
}, [points, semanticText, modelSize, addMask]);
|
||||
}, [activeClass, activeTemplateId, addMask, aiModel, currentFrame?.height, currentFrame?.id, currentFrame?.width, image?.height, image?.naturalHeight, image?.naturalWidth, image?.width, points, semanticText]);
|
||||
|
||||
const handleStageClick = (e: any) => {
|
||||
if (effectiveTool === 'move') return;
|
||||
@@ -117,17 +164,26 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
|
||||
{/* Model Select */}
|
||||
<div>
|
||||
<h3 className="text-[10px] font-bold text-gray-500 uppercase tracking-widest mb-3">视觉基础模型选型</h3>
|
||||
<div className="bg-[#111] border border-white/5 flex p-1 rounded-lg">
|
||||
{['vit_b', 'vit_l', 'vit_h'].map(m => (
|
||||
<div className="bg-[#111] border border-white/5 grid grid-cols-2 gap-1 p-1 rounded-lg">
|
||||
{(modelStatus?.models || [
|
||||
{ id: 'sam2' as const, label: 'SAM 2', available: true, message: '正在读取 SAM 2 状态' },
|
||||
{ id: 'sam3' as const, label: 'SAM 3', available: false, message: '正在读取 SAM 3 状态' },
|
||||
]).map((m) => (
|
||||
<button
|
||||
key={m}
|
||||
className={cn("flex-1 text-xs py-2 rounded-md transition-colors text-center uppercase tracking-wider font-mono", modelSize === m ? "bg-white/10 text-white font-medium shadow-sm" : "text-gray-500 hover:text-gray-300 hover:bg-white/5")}
|
||||
onClick={() => setModelSize(m)}
|
||||
key={m.id}
|
||||
className={cn("text-xs py-2 rounded-md transition-colors text-center uppercase tracking-wider font-mono", aiModel === m.id ? "bg-white/10 text-white font-medium shadow-sm" : "text-gray-500 hover:text-gray-300 hover:bg-white/5")}
|
||||
onClick={() => setAiModel(m.id)}
|
||||
title={m.message}
|
||||
>
|
||||
{m.split('_')[1]}
|
||||
{m.label.replace(' ', '')}
|
||||
<span className={cn("ml-1", m.available ? "text-emerald-400" : "text-amber-400")}>●</span>
|
||||
</button>
|
||||
))}
|
||||
</div>
|
||||
<div className="mt-2 text-[10px] text-gray-500 leading-relaxed">
|
||||
<div>{selectedModelStatus?.message || '正在读取模型状态...'}</div>
|
||||
<div>GPU: {modelStatus?.gpu.available ? `${modelStatus.gpu.name || 'CUDA'} 可用` : '不可用或未检测到 CUDA'}</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Prompt Tools */}
|
||||
@@ -206,16 +262,16 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
|
||||
<div className="p-6 bg-[#0a0a0a] border-t border-white/5 shrink-0 flex flex-col gap-3">
|
||||
<button
|
||||
onClick={runInference}
|
||||
disabled={isInferencing}
|
||||
disabled={isInferencing || !currentFrame || !modelCanInfer}
|
||||
className={cn(
|
||||
"w-full py-3.5 rounded-lg flex items-center justify-center gap-2 transition-all shadow-lg font-medium tracking-wide text-xs uppercase",
|
||||
isInferencing
|
||||
isInferencing || !currentFrame || !modelCanInfer
|
||||
? "bg-cyan-500/50 text-black/70 cursor-not-allowed"
|
||||
: "bg-cyan-500 hover:bg-cyan-400 text-black shadow-cyan-500/20 hover:shadow-cyan-500/40"
|
||||
)}
|
||||
>
|
||||
{isInferencing ? <Loader2 size={16} className="animate-spin" /> : <Sparkles size={16} />}
|
||||
{isInferencing ? '推理中...' : '执行高精度语义分割'}
|
||||
{isInferencing ? '推理中...' : modelCanInfer ? '执行高精度语义分割' : '当前模型不可用'}
|
||||
</button>
|
||||
<button
|
||||
onClick={onSendToWorkspace}
|
||||
@@ -231,7 +287,7 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
|
||||
<header className="h-16 border-b border-white/5 bg-[#111] flex items-center justify-between px-6 shrink-0">
|
||||
<div className="flex flex-col">
|
||||
<h2 className="text-sm font-semibold tracking-wide text-white">模型端推理侧可视化 (Visualizer)</h2>
|
||||
<span className="text-[10px] text-gray-500 uppercase tracking-widest font-mono">SAM 3 内核级动态即时渲染</span>
|
||||
<span className="text-[10px] text-gray-500 uppercase tracking-widest font-mono">{aiModel.toUpperCase()} 动态推理渲染</span>
|
||||
</div>
|
||||
<div className="flex items-center gap-4">
|
||||
<button className="w-8 h-8 rounded text-gray-400 hover:bg-white/5 hover:text-white flex items-center justify-center transition-colors" title="撤销操作 (Ctrl+Z)">
|
||||
@@ -276,7 +332,7 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
|
||||
)}
|
||||
|
||||
{/* AI Returned Masks */}
|
||||
{masks.map((mask) => (
|
||||
{frameMasks.map((mask) => (
|
||||
<Group key={mask.id} opacity={0.45}>
|
||||
<Path
|
||||
data={mask.pathData}
|
||||
@@ -309,7 +365,7 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
|
||||
<div className="absolute bottom-4 left-4 flex gap-4 text-[10px] font-mono text-gray-500 pointer-events-none">
|
||||
<span>光标坐标: {cursorPos.x.toFixed(2)}, {cursorPos.y.toFixed(2)}</span>
|
||||
<span>缩放比率: {(scale * 100).toFixed(0)}%</span>
|
||||
<span>遮罩数: {masks.length}</span>
|
||||
<span>遮罩数: {frameMasks.length}</span>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
130
src/components/CanvasArea.test.tsx
Normal file
130
src/components/CanvasArea.test.tsx
Normal file
@@ -0,0 +1,130 @@
|
||||
import { fireEvent, render, screen, waitFor } from '@testing-library/react';
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest';
|
||||
import { resetStore } from '../test/storeTestUtils';
|
||||
import { useStore } from '../store/useStore';
|
||||
import { CanvasArea } from './CanvasArea';
|
||||
|
||||
const apiMock = vi.hoisted(() => ({
|
||||
predictMask: vi.fn(),
|
||||
}));
|
||||
|
||||
vi.mock('../lib/api', () => ({
|
||||
predictMask: apiMock.predictMask,
|
||||
}));
|
||||
|
||||
describe('CanvasArea', () => {
|
||||
const frame = { id: 'frame-1', projectId: 'project-1', index: 0, url: '/frame.jpg', width: 640, height: 360 };
|
||||
|
||||
beforeEach(() => {
|
||||
resetStore();
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
it('calls AI prediction with the active frame when a point prompt is placed', async () => {
|
||||
useStore.setState({
|
||||
activeTemplateId: '2',
|
||||
activeClass: { id: 'c1', name: '胆囊', color: '#ff0000', zIndex: 20 },
|
||||
activeClassId: 'c1',
|
||||
});
|
||||
apiMock.predictMask.mockResolvedValueOnce({
|
||||
masks: [
|
||||
{
|
||||
id: 'mask-1',
|
||||
pathData: 'M 0 0 L 10 0 L 10 10 Z',
|
||||
label: 'AI Mask',
|
||||
color: '#06b6d4',
|
||||
segmentation: [[0, 0, 10, 0, 10, 10]],
|
||||
bbox: [0, 0, 10, 10],
|
||||
area: 100,
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
render(<CanvasArea activeTool="point_pos" frame={frame} />);
|
||||
fireEvent.click(screen.getByTestId('konva-stage'));
|
||||
|
||||
await waitFor(() => expect(apiMock.predictMask).toHaveBeenCalledWith({
|
||||
imageId: 'frame-1',
|
||||
imageWidth: 640,
|
||||
imageHeight: 360,
|
||||
model: 'sam2',
|
||||
points: [{ x: 120, y: 80, type: 'pos' }],
|
||||
box: undefined,
|
||||
}));
|
||||
expect(useStore.getState().masks[0]).toEqual(expect.objectContaining({
|
||||
id: 'mask-1',
|
||||
frameId: 'frame-1',
|
||||
pathData: 'M 0 0 L 10 0 L 10 10 Z',
|
||||
templateId: '2',
|
||||
classId: 'c1',
|
||||
className: '胆囊',
|
||||
classZIndex: 20,
|
||||
label: '胆囊',
|
||||
color: '#ff0000',
|
||||
saveStatus: 'draft',
|
||||
}));
|
||||
});
|
||||
|
||||
it('renders only masks that belong to the current frame', () => {
|
||||
useStore.setState({
|
||||
masks: [
|
||||
{ id: 'm1', frameId: 'frame-1', pathData: 'M 0 0 Z', label: 'A', color: '#fff' },
|
||||
{ id: 'm2', frameId: 'frame-2', pathData: 'M 1 1 Z', label: 'B', color: '#000' },
|
||||
],
|
||||
});
|
||||
|
||||
render(<CanvasArea activeTool="move" frame={frame} />);
|
||||
|
||||
expect(screen.getAllByTestId('konva-path')).toHaveLength(1);
|
||||
expect(screen.getByText('遮罩数: 1')).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('applies the selected class to current-frame masks and marks saved masks dirty', () => {
|
||||
useStore.setState({
|
||||
activeTemplateId: '2',
|
||||
activeClass: { id: 'c1', name: '胆囊', color: '#ff0000', zIndex: 20 },
|
||||
activeClassId: 'c1',
|
||||
masks: [
|
||||
{
|
||||
id: 'm1',
|
||||
frameId: 'frame-1',
|
||||
annotationId: '99',
|
||||
pathData: 'M 0 0 Z',
|
||||
label: '旧标签',
|
||||
color: '#06b6d4',
|
||||
saved: true,
|
||||
saveStatus: 'saved',
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
render(<CanvasArea activeTool="move" frame={frame} />);
|
||||
fireEvent.click(screen.getByRole('button', { name: '应用分类' }));
|
||||
|
||||
expect(useStore.getState().masks[0]).toEqual(expect.objectContaining({
|
||||
templateId: '2',
|
||||
classId: 'c1',
|
||||
className: '胆囊',
|
||||
classZIndex: 20,
|
||||
label: '胆囊',
|
||||
color: '#ff0000',
|
||||
saveStatus: 'dirty',
|
||||
saved: false,
|
||||
}));
|
||||
});
|
||||
|
||||
it('delegates clear to the workspace handler so saved annotations can be deleted', () => {
|
||||
const onClearMasks = vi.fn();
|
||||
useStore.setState({
|
||||
masks: [
|
||||
{ id: 'm1', frameId: 'frame-1', pathData: 'M 0 0 Z', label: 'A', color: '#fff' },
|
||||
],
|
||||
});
|
||||
|
||||
render(<CanvasArea activeTool="move" frame={frame} onClearMasks={onClearMasks} />);
|
||||
fireEvent.click(screen.getByRole('button', { name: '清空遮罩' }));
|
||||
|
||||
expect(onClearMasks).toHaveBeenCalled();
|
||||
expect(useStore.getState().masks).toHaveLength(1);
|
||||
});
|
||||
});
|
||||
@@ -3,14 +3,15 @@ import { Stage, Layer, Image as KonvaImage, Circle, Rect, Path, Group } from 're
|
||||
import useImage from 'use-image';
|
||||
import { useStore } from '../store/useStore';
|
||||
import { predictMask } from '../lib/api';
|
||||
import { cn } from '../lib/utils';
|
||||
import type { Frame } from '../store/useStore';
|
||||
|
||||
interface CanvasAreaProps {
|
||||
activeTool: string;
|
||||
frameUrl: string;
|
||||
frame: Frame | null;
|
||||
onClearMasks?: () => void;
|
||||
}
|
||||
|
||||
export function CanvasArea({ activeTool, frameUrl }: CanvasAreaProps) {
|
||||
export function CanvasArea({ activeTool, frame, onClearMasks }: CanvasAreaProps) {
|
||||
const containerRef = useRef<HTMLDivElement>(null);
|
||||
const [stageSize, setStageSize] = useState({ width: 800, height: 600 });
|
||||
const [scale, setScale] = useState(1);
|
||||
@@ -24,13 +25,20 @@ export function CanvasArea({ activeTool, frameUrl }: CanvasAreaProps) {
|
||||
const masks = useStore((state) => state.masks);
|
||||
const addMask = useStore((state) => state.addMask);
|
||||
const clearMasks = useStore((state) => state.clearMasks);
|
||||
const setMasks = useStore((state) => state.setMasks);
|
||||
const storeActiveTool = useStore((state) => state.activeTool);
|
||||
const setActiveTool = useStore((state) => state.setActiveTool);
|
||||
const aiModel = useStore((state) => state.aiModel);
|
||||
const activeTemplateId = useStore((state) => state.activeTemplateId);
|
||||
const activeClass = useStore((state) => state.activeClass);
|
||||
|
||||
const effectiveTool = activeTool || storeActiveTool;
|
||||
|
||||
// Load the actual frame image
|
||||
const [image] = useImage(frameUrl || '');
|
||||
const [image] = useImage(frame?.url || '');
|
||||
const frameMasks = masks.filter((mask) => mask.frameId === frame?.id);
|
||||
const savedMaskCount = frameMasks.filter((mask) => mask.saveStatus === 'saved' || mask.saved).length;
|
||||
const draftMaskCount = frameMasks.filter((mask) => !mask.annotationId).length;
|
||||
const dirtyMaskCount = frameMasks.filter((mask) => mask.saveStatus === 'dirty').length;
|
||||
|
||||
useEffect(() => {
|
||||
const handleResize = () => {
|
||||
@@ -85,21 +93,44 @@ export function CanvasArea({ activeTool, frameUrl }: CanvasAreaProps) {
|
||||
};
|
||||
|
||||
const runInference = useCallback(async (promptPoints?: typeof points, promptBox?: { x1: number, y1: number, x2: number, y2: number }) => {
|
||||
if (!frame?.id) {
|
||||
console.warn('Inference skipped: no active frame');
|
||||
return;
|
||||
}
|
||||
|
||||
const imageWidth = frame.width || image?.naturalWidth || image?.width || 0;
|
||||
const imageHeight = frame.height || image?.naturalHeight || image?.height || 0;
|
||||
if (imageWidth <= 0 || imageHeight <= 0) {
|
||||
console.warn('Inference skipped: active frame dimensions are unavailable');
|
||||
return;
|
||||
}
|
||||
|
||||
setIsInferencing(true);
|
||||
try {
|
||||
const result = await predictMask({
|
||||
imageUrl: frameUrl || '',
|
||||
imageId: frame.id,
|
||||
imageWidth,
|
||||
imageHeight,
|
||||
model: aiModel,
|
||||
points: promptPoints?.map((p) => ({ x: p.x, y: p.y, type: p.type })),
|
||||
box: promptBox,
|
||||
});
|
||||
|
||||
result.masks.forEach((m) => {
|
||||
const label = activeClass?.name || m.label;
|
||||
const color = activeClass?.color || m.color;
|
||||
addMask({
|
||||
id: m.id,
|
||||
frameId: 'frame-1',
|
||||
frameId: frame.id,
|
||||
templateId: activeTemplateId || undefined,
|
||||
classId: activeClass?.id,
|
||||
className: activeClass?.name,
|
||||
classZIndex: activeClass?.zIndex,
|
||||
saveStatus: 'draft',
|
||||
saved: false,
|
||||
pathData: m.pathData,
|
||||
label: m.label,
|
||||
color: m.color,
|
||||
label,
|
||||
color,
|
||||
segmentation: m.segmentation,
|
||||
bbox: m.bbox,
|
||||
area: m.area,
|
||||
@@ -110,7 +141,33 @@ export function CanvasArea({ activeTool, frameUrl }: CanvasAreaProps) {
|
||||
} finally {
|
||||
setIsInferencing(false);
|
||||
}
|
||||
}, [addMask]);
|
||||
}, [activeClass, activeTemplateId, addMask, aiModel, frame?.height, frame?.id, frame?.width, image?.height, image?.naturalHeight, image?.naturalWidth, image?.width]);
|
||||
|
||||
const handleApplyActiveClass = () => {
|
||||
if (!frame?.id || !activeClass) return;
|
||||
setMasks(masks.map((mask) => {
|
||||
if (mask.frameId !== frame.id) return mask;
|
||||
return {
|
||||
...mask,
|
||||
templateId: activeTemplateId || mask.templateId,
|
||||
classId: activeClass.id,
|
||||
className: activeClass.name,
|
||||
classZIndex: activeClass.zIndex,
|
||||
label: activeClass.name,
|
||||
color: activeClass.color,
|
||||
saveStatus: mask.annotationId ? 'dirty' : 'draft',
|
||||
saved: Boolean(mask.annotationId) ? false : mask.saved,
|
||||
};
|
||||
}));
|
||||
};
|
||||
|
||||
const handleClearMasks = () => {
|
||||
if (onClearMasks) {
|
||||
onClearMasks();
|
||||
return;
|
||||
}
|
||||
clearMasks();
|
||||
};
|
||||
|
||||
const handleStageMouseDown = (e: any) => {
|
||||
if (effectiveTool === 'box_select') {
|
||||
@@ -199,7 +256,7 @@ export function CanvasArea({ activeTool, frameUrl }: CanvasAreaProps) {
|
||||
)}
|
||||
|
||||
{/* AI Returned Masks */}
|
||||
{masks.map((mask) => (
|
||||
{frameMasks.map((mask) => (
|
||||
<Group key={mask.id} opacity={0.5}>
|
||||
<Path
|
||||
data={mask.pathData}
|
||||
@@ -248,16 +305,29 @@ export function CanvasArea({ activeTool, frameUrl }: CanvasAreaProps) {
|
||||
<span>光标: {cursorPos.x.toFixed(2)}, {cursorPos.y.toFixed(2)}</span>
|
||||
<span>当前图层树: OBJECT_VEHICLE_01</span>
|
||||
<span>缩放比: {(scale * 100).toFixed(0)}%</span>
|
||||
<span>遮罩数: {masks.length}</span>
|
||||
<span>遮罩数: {frameMasks.length}</span>
|
||||
<span>已保存: {savedMaskCount}</span>
|
||||
<span>未保存: {draftMaskCount}</span>
|
||||
<span>待更新: {dirtyMaskCount}</span>
|
||||
</div>
|
||||
|
||||
{masks.length > 0 && (
|
||||
<button
|
||||
onClick={clearMasks}
|
||||
className="absolute bottom-4 right-4 text-xs bg-red-500/10 hover:bg-red-500/20 text-red-400 border border-red-500/20 px-3 py-1.5 rounded transition-colors"
|
||||
>
|
||||
清空遮罩
|
||||
</button>
|
||||
{frameMasks.length > 0 && (
|
||||
<div className="absolute bottom-4 right-4 flex gap-2">
|
||||
{activeClass && (
|
||||
<button
|
||||
onClick={handleApplyActiveClass}
|
||||
className="text-xs bg-cyan-500/10 hover:bg-cyan-500/20 text-cyan-300 border border-cyan-500/20 px-3 py-1.5 rounded transition-colors"
|
||||
>
|
||||
应用分类
|
||||
</button>
|
||||
)}
|
||||
<button
|
||||
onClick={handleClearMasks}
|
||||
className="text-xs bg-red-500/10 hover:bg-red-500/20 text-red-400 border border-red-500/20 px-3 py-1.5 rounded transition-colors"
|
||||
>
|
||||
清空遮罩
|
||||
</button>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
|
||||
115
src/components/Dashboard.test.tsx
Normal file
115
src/components/Dashboard.test.tsx
Normal file
@@ -0,0 +1,115 @@
|
||||
import { act, render, screen, waitFor } from '@testing-library/react';
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest';
|
||||
import { Dashboard } from './Dashboard';
|
||||
|
||||
const apiMock = vi.hoisted(() => ({
|
||||
getDashboardOverview: vi.fn(),
|
||||
}));
|
||||
|
||||
const wsMock = vi.hoisted(() => {
|
||||
const state = {
|
||||
callback: undefined as undefined | ((data: any) => void),
|
||||
connected: false,
|
||||
};
|
||||
return {
|
||||
state,
|
||||
progressWS: {
|
||||
connect: vi.fn(() => { state.connected = true; }),
|
||||
disconnect: vi.fn(() => { state.connected = false; }),
|
||||
isConnected: vi.fn(() => state.connected),
|
||||
onProgress: vi.fn((cb: (data: any) => void) => {
|
||||
state.callback = cb;
|
||||
return vi.fn();
|
||||
}),
|
||||
},
|
||||
};
|
||||
});
|
||||
|
||||
vi.mock('../lib/websocket', () => ({
|
||||
progressWS: wsMock.progressWS,
|
||||
}));
|
||||
|
||||
vi.mock('../lib/api', () => ({
|
||||
getDashboardOverview: apiMock.getDashboardOverview,
|
||||
}));
|
||||
|
||||
describe('Dashboard', () => {
|
||||
beforeEach(() => {
|
||||
vi.useRealTimers();
|
||||
vi.clearAllMocks();
|
||||
wsMock.state.connected = false;
|
||||
wsMock.state.callback = undefined;
|
||||
apiMock.getDashboardOverview.mockResolvedValue({
|
||||
summary: {
|
||||
project_count: 2,
|
||||
parsing_task_count: 1,
|
||||
annotation_count: 5,
|
||||
frame_count: 100,
|
||||
template_count: 3,
|
||||
system_load_percent: 12,
|
||||
},
|
||||
tasks: [
|
||||
{
|
||||
id: 'project-1',
|
||||
project_id: 1,
|
||||
name: '真实项目.mp4',
|
||||
progress: 60,
|
||||
status: 'pending',
|
||||
frame_count: 10,
|
||||
updated_at: '2026-05-01T00:00:00Z',
|
||||
},
|
||||
],
|
||||
activity: [
|
||||
{
|
||||
id: 'activity-1',
|
||||
kind: 'project',
|
||||
time: '2026-05-01T00:00:00Z',
|
||||
message: '项目状态: pending',
|
||||
project: '真实项目.mp4',
|
||||
},
|
||||
],
|
||||
});
|
||||
});
|
||||
|
||||
it('loads dashboard stats, tasks, and activity from the backend overview endpoint', async () => {
|
||||
render(<Dashboard />);
|
||||
|
||||
await waitFor(() => expect(apiMock.getDashboardOverview).toHaveBeenCalled());
|
||||
expect(screen.getByText('项目总数')).toBeInTheDocument();
|
||||
expect(screen.getByText('已存标注')).toBeInTheDocument();
|
||||
expect(screen.getByText('真实项目.mp4')).toBeInTheDocument();
|
||||
expect(screen.getByText('项目状态: pending')).toBeInTheDocument();
|
||||
expect(screen.queryByText('City_Driving_Dataset_004.mp4')).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('connects to the progress stream and updates progress tasks', async () => {
|
||||
render(<Dashboard />);
|
||||
|
||||
await waitFor(() => expect(wsMock.progressWS.connect).toHaveBeenCalled());
|
||||
|
||||
act(() => {
|
||||
wsMock.state.callback?.({
|
||||
type: 'progress',
|
||||
taskId: 'task-1',
|
||||
projectName: 'demo.mp4',
|
||||
progress: 44,
|
||||
status: '正在截取帧',
|
||||
});
|
||||
});
|
||||
|
||||
expect(await screen.findByText('demo.mp4')).toBeInTheDocument();
|
||||
expect(screen.getByText('44%')).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('adds activity logs for complete and status messages', async () => {
|
||||
render(<Dashboard />);
|
||||
|
||||
act(() => {
|
||||
wsMock.state.callback?.({ type: 'status', message: 'Progress stream active' });
|
||||
wsMock.state.callback?.({ type: 'complete', taskId: '1', filename: 'done.mp4' });
|
||||
});
|
||||
|
||||
await waitFor(() => expect(screen.getByText('Progress stream active')).toBeInTheDocument());
|
||||
expect(screen.getByText('解析完成: done.mp4')).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
@@ -2,30 +2,68 @@ import React, { useState, useEffect } from 'react';
|
||||
import { Activity, Clock, Folders, CheckCircle2, Loader2 } from 'lucide-react';
|
||||
import { progressWS, type ProgressMessage } from '../lib/websocket';
|
||||
import { cn } from '../lib/utils';
|
||||
import { getDashboardOverview, type DashboardActivity, type DashboardOverview, type DashboardTask } from '../lib/api';
|
||||
|
||||
interface QueueTask {
|
||||
id: string;
|
||||
name: string;
|
||||
progress: number;
|
||||
status: string;
|
||||
}
|
||||
const emptySummary: DashboardOverview['summary'] = {
|
||||
project_count: 0,
|
||||
parsing_task_count: 0,
|
||||
annotation_count: 0,
|
||||
frame_count: 0,
|
||||
template_count: 0,
|
||||
system_load_percent: 0,
|
||||
};
|
||||
|
||||
export function Dashboard() {
|
||||
const [tasks, setTasks] = useState<QueueTask[]>([
|
||||
{ id: '1', name: 'City_Driving_Dataset_004.mp4', progress: 85, status: '正在截取帧 (30fps)' },
|
||||
{ id: '2', name: 'Pedestrian_Night_Vision_02.mkv', progress: 32, status: '正在截取帧 (60fps)' },
|
||||
{ id: '3', name: 'Drone_Mapping_Sector_7.avi', progress: 0, status: '队列排队等待中' },
|
||||
]);
|
||||
const [summary, setSummary] = useState<DashboardOverview['summary']>(emptySummary);
|
||||
const [tasks, setTasks] = useState<DashboardTask[]>([]);
|
||||
const [isConnected, setIsConnected] = useState(false);
|
||||
const [activityLog, setActivityLog] = useState<Array<{ time: string; message: string; project?: string }>>([
|
||||
{ time: '10 分钟前', message: '语义归档完成 54 帧', project: 'Highway_Data' },
|
||||
{ time: '25 分钟前', message: '项目解析开始', project: 'City_Driving_Dataset_004' },
|
||||
{ time: '1 小时前', message: '模板库更新: Cityscapes_v2', project: '系统' },
|
||||
{ time: '2 小时前', message: 'AI 推理完成 12 个实例', project: 'Nav_Cam_Left' },
|
||||
]);
|
||||
const [activityLog, setActivityLog] = useState<DashboardActivity[]>([]);
|
||||
const [isLoading, setIsLoading] = useState(true);
|
||||
const [loadError, setLoadError] = useState('');
|
||||
|
||||
useEffect(() => {
|
||||
let cancelled = false;
|
||||
|
||||
const loadOverview = () => {
|
||||
getDashboardOverview()
|
||||
.then((overview) => {
|
||||
if (cancelled) return;
|
||||
setSummary(overview.summary);
|
||||
setTasks((prev) => {
|
||||
if (prev.length === 0) return overview.tasks;
|
||||
const overviewIds = new Set(overview.tasks.map((task) => task.id));
|
||||
const wsOnly = prev.filter((task) => !task.id.startsWith('task-') && !overviewIds.has(task.id) && task.progress < 100);
|
||||
return [...overview.tasks, ...wsOnly];
|
||||
});
|
||||
setActivityLog((prev) => {
|
||||
if (prev.length === 0) return overview.activity;
|
||||
const byId = new Map(prev.map((item) => [item.id, item]));
|
||||
overview.activity.forEach((item) => byId.set(item.id, item));
|
||||
return Array.from(byId.values()).slice(0, 10);
|
||||
});
|
||||
setLoadError('');
|
||||
})
|
||||
.catch((err) => {
|
||||
console.error('Failed to load dashboard overview:', err);
|
||||
if (!cancelled) setLoadError('Dashboard 数据加载失败');
|
||||
})
|
||||
.finally(() => {
|
||||
if (!cancelled) setIsLoading(false);
|
||||
});
|
||||
};
|
||||
|
||||
loadOverview();
|
||||
const overviewInterval = setInterval(loadOverview, 5000);
|
||||
|
||||
return () => {
|
||||
cancelled = true;
|
||||
clearInterval(overviewInterval);
|
||||
};
|
||||
}, []);
|
||||
|
||||
useEffect(() => {
|
||||
let mounted = true;
|
||||
const taskTitle = (data: ProgressMessage) => data.filename || data.projectName || data.taskId || '后台任务';
|
||||
const timer = setTimeout(() => {
|
||||
if (mounted) progressWS.connect();
|
||||
}, 500);
|
||||
@@ -34,7 +72,7 @@ export function Dashboard() {
|
||||
if (!mounted) return;
|
||||
setIsConnected(progressWS.isConnected());
|
||||
|
||||
if (data.type === 'progress' && data.taskId && data.filename) {
|
||||
if (data.type === 'progress' && data.taskId) {
|
||||
setTasks((prev) => {
|
||||
const exists = prev.find((t) => t.id === data.taskId);
|
||||
if (exists) {
|
||||
@@ -48,9 +86,12 @@ export function Dashboard() {
|
||||
...prev,
|
||||
{
|
||||
id: data.taskId!,
|
||||
name: data.filename!,
|
||||
project_id: data.project_id ?? Number(data.task_id || 0),
|
||||
name: taskTitle(data),
|
||||
progress: data.progress ?? 0,
|
||||
status: data.status ?? '处理中',
|
||||
frame_count: 0,
|
||||
updated_at: new Date().toISOString(),
|
||||
},
|
||||
];
|
||||
});
|
||||
@@ -63,7 +104,7 @@ export function Dashboard() {
|
||||
)
|
||||
);
|
||||
setActivityLog((prev) => [
|
||||
{ time: '刚刚', message: `解析完成: ${data.filename || data.taskId}`, project: '系统' },
|
||||
{ id: `ws-complete-${Date.now()}`, kind: 'websocket', time: new Date().toISOString(), message: data.message || `解析完成: ${taskTitle(data)}`, project: data.projectName || '系统' },
|
||||
...prev.slice(0, 9),
|
||||
]);
|
||||
}
|
||||
@@ -71,14 +112,18 @@ export function Dashboard() {
|
||||
if (data.type === 'error' && data.taskId) {
|
||||
setTasks((prev) =>
|
||||
prev.map((t) =>
|
||||
t.id === data.taskId ? { ...t, status: `错误: ${data.message || '未知错误'}` } : t
|
||||
t.id === data.taskId ? { ...t, progress: data.progress ?? t.progress, status: `错误: ${data.error || data.message || '未知错误'}` } : t
|
||||
)
|
||||
);
|
||||
setActivityLog((prev) => [
|
||||
{ id: `ws-error-${Date.now()}`, kind: 'websocket', time: new Date().toISOString(), message: data.message || `解析失败: ${taskTitle(data)}`, project: data.projectName || '系统' },
|
||||
...prev.slice(0, 9),
|
||||
]);
|
||||
}
|
||||
|
||||
if (data.type === 'status') {
|
||||
setActivityLog((prev) => [
|
||||
{ time: '刚刚', message: data.message || '状态更新', project: '系统' },
|
||||
{ id: `ws-status-${Date.now()}`, kind: 'websocket', time: new Date().toISOString(), message: data.message || '状态更新', project: '系统' },
|
||||
...prev.slice(0, 9),
|
||||
]);
|
||||
}
|
||||
@@ -97,12 +142,24 @@ export function Dashboard() {
|
||||
}, []);
|
||||
|
||||
const stats = [
|
||||
{ label: '运行中项目', value: '14', icon: Folders, color: 'text-blue-400', bg: 'bg-blue-400/10' },
|
||||
{ label: '排队处理任务', value: tasks.length.toString(), icon: Clock, color: 'text-orange-400', bg: 'bg-orange-400/10' },
|
||||
{ label: '已归档批次', value: '128', icon: CheckCircle2, color: 'text-emerald-400', bg: 'bg-emerald-400/10' },
|
||||
{ label: '系统负载', value: '78%', icon: Activity, color: 'text-cyan-400', bg: 'bg-cyan-400/10' },
|
||||
{ label: '项目总数', value: summary.project_count.toString(), icon: Folders, color: 'text-blue-400', bg: 'bg-blue-400/10' },
|
||||
{ label: '处理任务', value: summary.parsing_task_count.toString(), icon: Clock, color: 'text-orange-400', bg: 'bg-orange-400/10' },
|
||||
{ label: '已存标注', value: summary.annotation_count.toString(), icon: CheckCircle2, color: 'text-emerald-400', bg: 'bg-emerald-400/10' },
|
||||
{ label: '系统负载', value: `${summary.system_load_percent}%`, icon: Activity, color: 'text-cyan-400', bg: 'bg-cyan-400/10' },
|
||||
];
|
||||
|
||||
function formatActivityTime(value: string | null): string {
|
||||
if (!value) return '未知时间';
|
||||
const date = new Date(value);
|
||||
if (Number.isNaN(date.getTime())) return value;
|
||||
return date.toLocaleString('zh-CN', {
|
||||
month: '2-digit',
|
||||
day: '2-digit',
|
||||
hour: '2-digit',
|
||||
minute: '2-digit',
|
||||
});
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="p-8 w-full h-full overflow-y-auto bg-[#0a0a0a]">
|
||||
<header className="mb-8">
|
||||
@@ -119,6 +176,7 @@ export function Dashboard() {
|
||||
</div>
|
||||
</div>
|
||||
<p className="text-gray-400 text-sm mt-1">系统全局数据吞吐状态与所有接入项目进度实时洞察驾驶舱。</p>
|
||||
{loadError && <p className="text-red-400 text-xs mt-2">{loadError}</p>}
|
||||
</header>
|
||||
|
||||
<div className="grid grid-cols-1 md:grid-cols-2 lg:grid-cols-4 gap-4 mb-8">
|
||||
@@ -140,8 +198,11 @@ export function Dashboard() {
|
||||
|
||||
<div className="grid grid-cols-1 lg:grid-cols-3 gap-6">
|
||||
<div className="lg:col-span-2 bg-[#111] border border-white/5 rounded-xl p-6 min-h-[400px]">
|
||||
<h2 className="text-sm font-medium text-gray-400 uppercase tracking-widest mb-6">解析队列 (FFmpeg 挂起任务)</h2>
|
||||
<h2 className="text-sm font-medium text-gray-400 uppercase tracking-widest mb-6">解析队列 (后台任务)</h2>
|
||||
<div className="space-y-4">
|
||||
{isLoading && (
|
||||
<div className="text-sm text-gray-500 text-center py-12">正在读取后端 Dashboard 数据...</div>
|
||||
)}
|
||||
{tasks.map((task) => (
|
||||
<div key={task.id} className="bg-[#0d0d0d] border border-white/5 p-4 rounded-lg">
|
||||
<div className="flex justify-between items-center mb-2">
|
||||
@@ -152,7 +213,7 @@ export function Dashboard() {
|
||||
<div className="h-full bg-gradient-to-r from-cyan-600 to-cyan-400 rounded-full transition-all duration-500" style={{ width: `${task.progress}%` }} />
|
||||
</div>
|
||||
<div className="text-xs text-gray-500 flex items-center gap-2">
|
||||
{task.status === '已完成' ? (
|
||||
{task.status === '已完成' || task.progress >= 100 ? (
|
||||
<CheckCircle2 size={12} className="text-emerald-400" />
|
||||
) : task.status.includes('错误') ? (
|
||||
<span className="text-red-400">●</span>
|
||||
@@ -160,10 +221,11 @@ export function Dashboard() {
|
||||
<Loader2 size={12} className="text-cyan-400 animate-spin" />
|
||||
)}
|
||||
{task.status}
|
||||
<span className="text-gray-600">帧: {task.frame_count}</span>
|
||||
</div>
|
||||
</div>
|
||||
))}
|
||||
{tasks.length === 0 && (
|
||||
{!isLoading && tasks.length === 0 && (
|
||||
<div className="text-sm text-gray-500 text-center py-12">当前无处理任务</div>
|
||||
)}
|
||||
</div>
|
||||
@@ -172,16 +234,22 @@ export function Dashboard() {
|
||||
<div className="bg-[#111] border border-white/5 rounded-xl p-6 min-h-[400px]">
|
||||
<h2 className="text-sm font-medium text-gray-400 uppercase tracking-widest mb-6">近期实时流转记录</h2>
|
||||
<div className="space-y-6 relative before:absolute before:inset-0 before:ml-[11px] before:-translate-x-px md:before:mx-auto md:before:translate-x-0 before:h-full before:w-0.5 before:bg-gradient-to-b before:from-transparent before:via-white/10 before:to-transparent">
|
||||
{activityLog.map((log, i) => (
|
||||
<div key={i} className="relative flex items-center justify-between md:justify-normal md:odd:flex-row-reverse group is-active">
|
||||
{isLoading && (
|
||||
<div className="text-sm text-gray-500 text-center py-12">正在读取近期流转记录...</div>
|
||||
)}
|
||||
{activityLog.map((log) => (
|
||||
<div key={log.id} className="relative flex items-center justify-between md:justify-normal md:odd:flex-row-reverse group is-active">
|
||||
<div className="flex items-center justify-center w-6 h-6 rounded-full border border-white/10 bg-[#111] group-[.is-active]:bg-cyan-500 group-[.is-active]:border-cyan-400 text-slate-500 group-[.is-active]:text-black shadow shrink-0 md:order-1 md:group-odd:-translate-x-1/2 md:group-even:translate-x-1/2 z-10" />
|
||||
<div className="w-[calc(100%-4rem)] md:w-[calc(50%-2.5rem)] bg-[#0d0d0d] p-3 rounded border border-white/5">
|
||||
<div className="text-xs text-gray-400 mb-1">{log.time}</div>
|
||||
<div className="text-xs text-gray-400 mb-1">{formatActivityTime(log.time)}</div>
|
||||
<div className="text-sm font-medium text-gray-200">{log.message}</div>
|
||||
<div className="text-xs text-gray-500">归属项目: {log.project}</div>
|
||||
</div>
|
||||
</div>
|
||||
))}
|
||||
{!isLoading && activityLog.length === 0 && (
|
||||
<div className="text-sm text-gray-500 text-center py-12">暂无近期流转记录</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
62
src/components/FrameTimeline.test.tsx
Normal file
62
src/components/FrameTimeline.test.tsx
Normal file
@@ -0,0 +1,62 @@
|
||||
import { act, fireEvent, render, screen } from '@testing-library/react';
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest';
|
||||
import { resetStore } from '../test/storeTestUtils';
|
||||
import { useStore } from '../store/useStore';
|
||||
import { FrameTimeline } from './FrameTimeline';
|
||||
|
||||
describe('FrameTimeline', () => {
|
||||
beforeEach(() => {
|
||||
resetStore();
|
||||
vi.useRealTimers();
|
||||
});
|
||||
|
||||
it('renders empty state when no frames are loaded', () => {
|
||||
render(<FrameTimeline />);
|
||||
|
||||
expect(screen.getByText('暂无帧数据')).toBeInTheDocument();
|
||||
expect(screen.getByText('0')).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('changes the current frame through thumbnails and range input', () => {
|
||||
useStore.setState({
|
||||
frames: [
|
||||
{ id: 'f1', projectId: 'p1', index: 0, url: '/1.jpg', width: 640, height: 360 },
|
||||
{ id: 'f2', projectId: 'p1', index: 1, url: '/2.jpg', width: 640, height: 360 },
|
||||
{ id: 'f3', projectId: 'p1', index: 2, url: '/3.jpg', width: 640, height: 360 },
|
||||
],
|
||||
});
|
||||
|
||||
render(<FrameTimeline />);
|
||||
fireEvent.click(screen.getByAltText('frame-1'));
|
||||
expect(useStore.getState().currentFrameIndex).toBe(1);
|
||||
|
||||
fireEvent.change(screen.getByRole('slider'), { target: { value: '3' } });
|
||||
expect(useStore.getState().currentFrameIndex).toBe(2);
|
||||
});
|
||||
|
||||
it('plays forward using the project parse fps and stops at the end', () => {
|
||||
vi.useFakeTimers();
|
||||
useStore.setState({
|
||||
currentProject: { id: 'p1', name: 'P', status: 'ready', parse_fps: 10 },
|
||||
frames: [
|
||||
{ id: 'f1', projectId: 'p1', index: 0, url: '/1.jpg', width: 640, height: 360 },
|
||||
{ id: 'f2', projectId: 'p1', index: 1, url: '/2.jpg', width: 640, height: 360 },
|
||||
],
|
||||
});
|
||||
|
||||
const { container } = render(<FrameTimeline />);
|
||||
fireEvent.click(container.querySelector('button')!);
|
||||
|
||||
act(() => {
|
||||
vi.advanceTimersByTime(100);
|
||||
});
|
||||
|
||||
expect(useStore.getState().currentFrameIndex).toBe(1);
|
||||
|
||||
act(() => {
|
||||
vi.advanceTimersByTime(100);
|
||||
});
|
||||
|
||||
expect(screen.getByText('播放序列 (F5)')).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
@@ -1,16 +1,42 @@
|
||||
import React, { useState } from 'react';
|
||||
import React, { useEffect, useMemo, useState } from 'react';
|
||||
import { Play, Pause } from 'lucide-react';
|
||||
import { cn } from '../lib/utils';
|
||||
import { useStore } from '../store/useStore';
|
||||
|
||||
export function FrameTimeline() {
|
||||
const frames = useStore((state) => state.frames);
|
||||
const currentProject = useStore((state) => state.currentProject);
|
||||
const currentFrameIndex = useStore((state) => state.currentFrameIndex);
|
||||
const setCurrentFrame = useStore((state) => state.setCurrentFrame);
|
||||
const [isPlaying, setIsPlaying] = useState(false);
|
||||
|
||||
const totalFrames = frames.length;
|
||||
const currentFrame = totalFrames > 0 ? currentFrameIndex + 1 : 0;
|
||||
const playbackFps = useMemo(() => {
|
||||
const fps = currentProject?.parse_fps || currentProject?.original_fps || 12;
|
||||
return Math.min(Math.max(fps, 1), 30);
|
||||
}, [currentProject?.original_fps, currentProject?.parse_fps]);
|
||||
|
||||
useEffect(() => {
|
||||
if (!isPlaying || totalFrames <= 1) return;
|
||||
|
||||
const timer = window.setTimeout(() => {
|
||||
if (currentFrameIndex >= totalFrames - 1) {
|
||||
setIsPlaying(false);
|
||||
return;
|
||||
}
|
||||
|
||||
setCurrentFrame(currentFrameIndex + 1);
|
||||
}, 1000 / playbackFps);
|
||||
|
||||
return () => window.clearTimeout(timer);
|
||||
}, [currentFrameIndex, isPlaying, playbackFps, setCurrentFrame, totalFrames]);
|
||||
|
||||
useEffect(() => {
|
||||
if (totalFrames === 0) {
|
||||
setIsPlaying(false);
|
||||
}
|
||||
}, [totalFrames]);
|
||||
|
||||
// show frames around current frame
|
||||
const frameWindow = 20;
|
||||
@@ -45,8 +71,14 @@ export function FrameTimeline() {
|
||||
<div className="flex-1 flex items-center px-4 gap-6">
|
||||
<div className="flex flex-col items-center gap-2 px-4 border-r border-white/10 shrink-0">
|
||||
<button
|
||||
className="p-2 rounded-full bg-white/5 text-white hover:bg-white/10"
|
||||
onClick={() => setIsPlaying(!isPlaying)}
|
||||
className="p-2 rounded-full bg-white/5 text-white hover:bg-white/10 disabled:opacity-40 disabled:cursor-not-allowed"
|
||||
disabled={totalFrames <= 1}
|
||||
onClick={() => {
|
||||
if (currentFrameIndex >= totalFrames - 1) {
|
||||
setCurrentFrame(0);
|
||||
}
|
||||
setIsPlaying(!isPlaying);
|
||||
}}
|
||||
>
|
||||
{isPlaying ? <Pause size={20} fill="currentColor" /> : <Play size={20} fill="currentColor" />}
|
||||
</button>
|
||||
|
||||
42
src/components/Login.test.tsx
Normal file
42
src/components/Login.test.tsx
Normal file
@@ -0,0 +1,42 @@
|
||||
import { fireEvent, render, screen, waitFor } from '@testing-library/react';
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest';
|
||||
import { resetStore } from '../test/storeTestUtils';
|
||||
import { useStore } from '../store/useStore';
|
||||
import { Login } from './Login';
|
||||
|
||||
const apiMock = vi.hoisted(() => ({
|
||||
login: vi.fn(),
|
||||
}));
|
||||
|
||||
vi.mock('../lib/api', () => ({
|
||||
login: apiMock.login,
|
||||
}));
|
||||
|
||||
describe('Login', () => {
|
||||
beforeEach(() => {
|
||||
resetStore();
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
it('logs in with the development credentials and stores the token', async () => {
|
||||
apiMock.login.mockResolvedValueOnce({ token: 'fake-jwt-token-for-admin' });
|
||||
|
||||
render(<Login />);
|
||||
fireEvent.click(screen.getByRole('button', { name: '安全登录' }));
|
||||
|
||||
await waitFor(() => expect(apiMock.login).toHaveBeenCalledWith('admin', '123456'));
|
||||
expect(useStore.getState().isAuthenticated).toBe(true);
|
||||
expect(localStorage.getItem('token')).toBe('fake-jwt-token-for-admin');
|
||||
});
|
||||
|
||||
it('shows backend login errors', async () => {
|
||||
apiMock.login.mockRejectedValueOnce({ response: { data: { detail: 'Invalid credentials' } } });
|
||||
|
||||
render(<Login />);
|
||||
fireEvent.change(screen.getByDisplayValue('admin'), { target: { value: 'bad' } });
|
||||
fireEvent.click(screen.getByRole('button', { name: '安全登录' }));
|
||||
|
||||
expect(await screen.findByText('Invalid credentials')).toBeInTheDocument();
|
||||
expect(useStore.getState().isAuthenticated).toBe(false);
|
||||
});
|
||||
});
|
||||
45
src/components/ModelStatusBadge.test.tsx
Normal file
45
src/components/ModelStatusBadge.test.tsx
Normal file
@@ -0,0 +1,45 @@
|
||||
import { render, screen, waitFor } from '@testing-library/react';
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest';
|
||||
import { resetStore } from '../test/storeTestUtils';
|
||||
import { useStore } from '../store/useStore';
|
||||
import { ModelStatusBadge } from './ModelStatusBadge';
|
||||
|
||||
const apiMock = vi.hoisted(() => ({
|
||||
getAiModelStatus: vi.fn(),
|
||||
}));
|
||||
|
||||
vi.mock('../lib/api', () => ({
|
||||
getAiModelStatus: apiMock.getAiModelStatus,
|
||||
}));
|
||||
|
||||
describe('ModelStatusBadge', () => {
|
||||
beforeEach(() => {
|
||||
resetStore();
|
||||
vi.clearAllMocks();
|
||||
apiMock.getAiModelStatus.mockResolvedValue({
|
||||
selected_model: 'sam2',
|
||||
gpu: { available: true, device: 'cuda', name: 'RTX 4090', torch_available: true },
|
||||
models: [
|
||||
{ id: 'sam2', label: 'SAM 2', available: true, loaded: false, device: 'cuda', supports: ['point', 'box'], message: 'SAM 2 ready', package_available: true, checkpoint_exists: true, python_ok: true, torch_ok: true, cuda_required: false },
|
||||
{ id: 'sam3', label: 'SAM 3', available: false, loaded: false, device: 'cuda', supports: ['semantic'], message: 'SAM 3 missing runtime', package_available: false, checkpoint_exists: false, python_ok: false, torch_ok: true, cuda_required: true },
|
||||
],
|
||||
});
|
||||
});
|
||||
|
||||
it('loads real model status for the selected model', async () => {
|
||||
render(<ModelStatusBadge />);
|
||||
|
||||
expect(await screen.findByText('SAM 2 可用')).toBeInTheDocument();
|
||||
expect(apiMock.getAiModelStatus).toHaveBeenCalledWith('sam2');
|
||||
});
|
||||
|
||||
it('shows unavailable state when SAM3 is selected but not runnable', async () => {
|
||||
useStore.getState().setAiModel('sam3');
|
||||
|
||||
render(<ModelStatusBadge />);
|
||||
|
||||
await waitFor(() => expect(apiMock.getAiModelStatus).toHaveBeenCalledWith('sam3'));
|
||||
expect(await screen.findByText('SAM 3 不可用')).toBeInTheDocument();
|
||||
expect(screen.getByTitle('SAM 3 missing runtime')).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
56
src/components/ModelStatusBadge.tsx
Normal file
56
src/components/ModelStatusBadge.tsx
Normal file
@@ -0,0 +1,56 @@
|
||||
import React, { useEffect, useState } from 'react';
|
||||
import { Cpu, Loader2 } from 'lucide-react';
|
||||
import { getAiModelStatus, type AiRuntimeStatus } from '../lib/api';
|
||||
import { cn } from '../lib/utils';
|
||||
import { useStore } from '../store/useStore';
|
||||
|
||||
interface ModelStatusBadgeProps {
|
||||
compact?: boolean;
|
||||
}
|
||||
|
||||
export function ModelStatusBadge({ compact = false }: ModelStatusBadgeProps) {
|
||||
const aiModel = useStore((state) => state.aiModel);
|
||||
const [status, setStatus] = useState<AiRuntimeStatus | null>(null);
|
||||
const [isLoading, setIsLoading] = useState(true);
|
||||
|
||||
useEffect(() => {
|
||||
let cancelled = false;
|
||||
setIsLoading(true);
|
||||
getAiModelStatus(aiModel)
|
||||
.then((data) => {
|
||||
if (!cancelled) setStatus(data);
|
||||
})
|
||||
.catch(() => {
|
||||
if (!cancelled) setStatus(null);
|
||||
})
|
||||
.finally(() => {
|
||||
if (!cancelled) setIsLoading(false);
|
||||
});
|
||||
return () => {
|
||||
cancelled = true;
|
||||
};
|
||||
}, [aiModel]);
|
||||
|
||||
const model = status?.models.find((item) => item.id === aiModel);
|
||||
const ready = Boolean(model?.available);
|
||||
const gpuReady = Boolean(status?.gpu.available);
|
||||
const label = compact
|
||||
? (gpuReady ? 'GPU' : 'CPU')
|
||||
: `${model?.label || aiModel.toUpperCase()} ${ready ? '可用' : '不可用'}`;
|
||||
|
||||
return (
|
||||
<div
|
||||
className={cn(
|
||||
"inline-flex items-center gap-1.5 rounded border font-mono uppercase",
|
||||
compact ? "w-8 h-8 justify-center text-[9px]" : "px-2 py-0.5 text-[10px]",
|
||||
ready
|
||||
? "bg-emerald-500/10 text-emerald-400 border-emerald-500/20"
|
||||
: "bg-amber-500/10 text-amber-400 border-amber-500/20"
|
||||
)}
|
||||
title={model?.message || 'AI 模型状态读取中'}
|
||||
>
|
||||
{isLoading ? <Loader2 size={compact ? 12 : 10} className="animate-spin" /> : <Cpu size={compact ? 12 : 10} />}
|
||||
<span>{label}</span>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
60
src/components/OntologyInspector.test.tsx
Normal file
60
src/components/OntologyInspector.test.tsx
Normal file
@@ -0,0 +1,60 @@
|
||||
import { fireEvent, render, screen, within } from '@testing-library/react';
|
||||
import { beforeEach, describe, expect, it } from 'vitest';
|
||||
import { resetStore } from '../test/storeTestUtils';
|
||||
import { useStore } from '../store/useStore';
|
||||
import { OntologyInspector } from './OntologyInspector';
|
||||
|
||||
describe('OntologyInspector', () => {
|
||||
beforeEach(() => {
|
||||
resetStore();
|
||||
useStore.setState({
|
||||
templates: [
|
||||
{
|
||||
id: 't1',
|
||||
name: '腹腔镜模板',
|
||||
classes: [
|
||||
{ id: 'c1', name: '胆囊', color: '#ff0000', zIndex: 20, category: '器官' },
|
||||
{ id: 'c2', name: '肝脏', color: '#00ff00', zIndex: 10, category: '器官' },
|
||||
],
|
||||
rules: [],
|
||||
},
|
||||
],
|
||||
});
|
||||
});
|
||||
|
||||
it('shows template classes and changes the active template', () => {
|
||||
render(<OntologyInspector />);
|
||||
|
||||
fireEvent.change(screen.getByRole('combobox'), { target: { value: 't1' } });
|
||||
|
||||
expect(useStore.getState().activeTemplateId).toBe('t1');
|
||||
expect(screen.getByText('胆囊')).toBeInTheDocument();
|
||||
expect(screen.getByText('肝脏')).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('selects a concrete class for subsequent masks', () => {
|
||||
render(<OntologyInspector />);
|
||||
|
||||
fireEvent.click(screen.getByText('胆囊'));
|
||||
|
||||
expect(useStore.getState().activeClassId).toBe('c1');
|
||||
expect(useStore.getState().activeClass).toEqual(expect.objectContaining({
|
||||
id: 'c1',
|
||||
name: '胆囊',
|
||||
zIndex: 20,
|
||||
}));
|
||||
});
|
||||
|
||||
it('adds custom classes locally without backend persistence', () => {
|
||||
const { container } = render(<OntologyInspector />);
|
||||
const customSection = screen.getByText('自定义分类').parentElement!;
|
||||
fireEvent.click(within(customSection).getByRole('button'));
|
||||
fireEvent.change(screen.getByPlaceholderText('分类名称'), { target: { value: '新局部分类' } });
|
||||
fireEvent.keyDown(screen.getByPlaceholderText('分类名称'), { key: 'Enter' });
|
||||
|
||||
expect(screen.getAllByText('新局部分类')).toHaveLength(2);
|
||||
expect(useStore.getState().activeClass).toEqual(expect.objectContaining({ name: '新局部分类' }));
|
||||
expect(useStore.getState().templates[0].classes).toHaveLength(2);
|
||||
expect(container).toHaveTextContent('2 个分类来自模板 + 1 个自定义');
|
||||
});
|
||||
});
|
||||
@@ -2,11 +2,16 @@ import React, { useState } from 'react';
|
||||
import { Layers, ChevronDown, Tag, Eye, Plus, X } from 'lucide-react';
|
||||
import { useStore } from '../store/useStore';
|
||||
import type { TemplateClass } from '../store/useStore';
|
||||
import { cn } from '../lib/utils';
|
||||
import { getActiveTemplate } from '../lib/templateSelection';
|
||||
|
||||
export function OntologyInspector() {
|
||||
const templates = useStore((state) => state.templates);
|
||||
const activeTemplateId = useStore((state) => state.activeTemplateId);
|
||||
const activeClassId = useStore((state) => state.activeClassId);
|
||||
const activeClass = useStore((state) => state.activeClass);
|
||||
const setActiveTemplateId = useStore((state) => state.setActiveTemplateId);
|
||||
const setActiveClass = useStore((state) => state.setActiveClass);
|
||||
|
||||
// Project-level custom classes (in addition to template classes)
|
||||
const [customClasses, setCustomClasses] = useState<TemplateClass[]>([]);
|
||||
@@ -14,10 +19,17 @@ export function OntologyInspector() {
|
||||
const [newClassName, setNewClassName] = useState('');
|
||||
const [newClassColor, setNewClassColor] = useState('#06b6d4');
|
||||
|
||||
const activeTemplate = templates.find((t) => t.id === activeTemplateId) || templates[0] || null;
|
||||
const activeTemplate = getActiveTemplate(templates, activeTemplateId);
|
||||
const templateClasses = activeTemplate?.classes || [];
|
||||
const allClasses = [...templateClasses, ...customClasses].sort((a, b) => b.zIndex - a.zIndex);
|
||||
|
||||
const handleSelectClass = (templateClass: TemplateClass) => {
|
||||
if (activeTemplate && !activeTemplateId) {
|
||||
setActiveTemplateId(activeTemplate.id);
|
||||
}
|
||||
setActiveClass(templateClass);
|
||||
};
|
||||
|
||||
const handleAddCustom = () => {
|
||||
if (!newClassName.trim()) return;
|
||||
const maxZ = allClasses.length > 0 ? Math.max(...allClasses.map((c) => c.zIndex)) : 0;
|
||||
@@ -29,6 +41,7 @@ export function OntologyInspector() {
|
||||
category: '自定义',
|
||||
};
|
||||
setCustomClasses([...customClasses, newClass]);
|
||||
handleSelectClass(newClass);
|
||||
setNewClassName('');
|
||||
setShowAddForm(false);
|
||||
};
|
||||
@@ -47,7 +60,10 @@ export function OntologyInspector() {
|
||||
<div className="relative">
|
||||
<select
|
||||
value={activeTemplate?.id || ''}
|
||||
onChange={(e) => setActiveTemplateId(e.target.value || null)}
|
||||
onChange={(e) => {
|
||||
setActiveTemplateId(e.target.value || null);
|
||||
setActiveClass(null);
|
||||
}}
|
||||
className="w-full bg-[#1a1a1a] border border-white/10 rounded-lg px-3 py-2 text-xs text-gray-300 appearance-none cursor-pointer focus:outline-none focus:border-cyan-500/50"
|
||||
>
|
||||
<option value="">-- 选择模板 --</option>
|
||||
@@ -73,7 +89,14 @@ export function OntologyInspector() {
|
||||
<div className="space-y-2">
|
||||
{allClasses.map(cls => (
|
||||
<div key={cls.id} className="flex flex-col gap-1">
|
||||
<div className="flex items-center justify-between p-2 rounded bg-white/5 hover:bg-white/10 cursor-pointer group transition-colors">
|
||||
<button
|
||||
type="button"
|
||||
onClick={() => handleSelectClass(cls)}
|
||||
className={cn(
|
||||
'flex items-center justify-between p-2 rounded bg-white/5 hover:bg-white/10 cursor-pointer group transition-colors text-left border',
|
||||
activeClassId === cls.id ? 'border-cyan-500/50 bg-cyan-500/10' : 'border-transparent',
|
||||
)}
|
||||
>
|
||||
<div className="flex items-center gap-2">
|
||||
<span className="w-2.5 h-2.5 rounded-sm" style={{ backgroundColor: cls.color }} />
|
||||
<span className="text-xs font-medium text-gray-200">{cls.name}</span>
|
||||
@@ -82,7 +105,7 @@ export function OntologyInspector() {
|
||||
<span className="text-[10px] text-gray-500 font-mono">z:{cls.zIndex}</span>
|
||||
<Eye size={14} className="text-gray-500 group-hover:text-gray-300" />
|
||||
</div>
|
||||
</div>
|
||||
</button>
|
||||
</div>
|
||||
))}
|
||||
{allClasses.length === 0 && (
|
||||
@@ -136,7 +159,9 @@ export function OntologyInspector() {
|
||||
<div className="bg-white/5 rounded-lg p-3">
|
||||
<div className="flex items-center gap-2 mb-3">
|
||||
<Tag size={12} className="text-cyan-400" />
|
||||
<span className="text-xs font-semibold text-gray-200">{activeTemplate?.name || '未选择'}</span>
|
||||
<span className="text-xs font-semibold text-gray-200">
|
||||
{activeClass?.name || activeTemplate?.name || '未选择'}
|
||||
</span>
|
||||
</div>
|
||||
<div className="space-y-3">
|
||||
<div className="space-y-1">
|
||||
|
||||
92
src/components/ProjectLibrary.test.tsx
Normal file
92
src/components/ProjectLibrary.test.tsx
Normal file
@@ -0,0 +1,92 @@
|
||||
import { fireEvent, render, screen, waitFor } from '@testing-library/react';
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest';
|
||||
import { resetStore } from '../test/storeTestUtils';
|
||||
import { useStore } from '../store/useStore';
|
||||
import { ProjectLibrary } from './ProjectLibrary';
|
||||
|
||||
const apiMock = vi.hoisted(() => ({
|
||||
getProjects: vi.fn(),
|
||||
createProject: vi.fn(),
|
||||
uploadMedia: vi.fn(),
|
||||
parseMedia: vi.fn(),
|
||||
uploadDicomBatch: vi.fn(),
|
||||
}));
|
||||
|
||||
vi.mock('../lib/api', () => ({
|
||||
getProjects: apiMock.getProjects,
|
||||
createProject: apiMock.createProject,
|
||||
uploadMedia: apiMock.uploadMedia,
|
||||
parseMedia: apiMock.parseMedia,
|
||||
uploadDicomBatch: apiMock.uploadDicomBatch,
|
||||
}));
|
||||
|
||||
describe('ProjectLibrary', () => {
|
||||
beforeEach(() => {
|
||||
resetStore();
|
||||
vi.clearAllMocks();
|
||||
apiMock.getProjects.mockResolvedValue([]);
|
||||
});
|
||||
|
||||
it('loads projects and selects one into the workspace', async () => {
|
||||
const onProjectSelect = vi.fn();
|
||||
apiMock.getProjects.mockResolvedValueOnce([
|
||||
{ id: 'p1', name: 'Demo Project', status: 'ready', frames: 3, fps: '30FPS' },
|
||||
]);
|
||||
|
||||
render(<ProjectLibrary onProjectSelect={onProjectSelect} />);
|
||||
|
||||
fireEvent.click(await screen.findByText('Demo Project'));
|
||||
expect(useStore.getState().currentProject?.id).toBe('p1');
|
||||
expect(onProjectSelect).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('creates a new project from the modal', async () => {
|
||||
apiMock.createProject.mockResolvedValueOnce({ id: 'p2', name: 'New Project', status: 'pending' });
|
||||
|
||||
render(<ProjectLibrary onProjectSelect={vi.fn()} />);
|
||||
fireEvent.click(screen.getByText('新建项目'));
|
||||
fireEvent.change(screen.getByPlaceholderText('输入项目名称'), { target: { value: 'New Project' } });
|
||||
fireEvent.change(screen.getByPlaceholderText('输入项目描述'), { target: { value: 'desc' } });
|
||||
fireEvent.click(screen.getByRole('button', { name: '创建' }));
|
||||
|
||||
await waitFor(() => expect(apiMock.createProject).toHaveBeenCalledWith({
|
||||
name: 'New Project',
|
||||
description: 'desc',
|
||||
}));
|
||||
expect(useStore.getState().projects[0]).toEqual(expect.objectContaining({ id: 'p2' }));
|
||||
});
|
||||
|
||||
it('imports video by creating a project, uploading media, parsing frames and refreshing projects', async () => {
|
||||
apiMock.createProject.mockResolvedValueOnce({ id: 'p3', name: 'clip.mp4', status: 'pending' });
|
||||
apiMock.uploadMedia.mockResolvedValueOnce({ url: 'http://file', id: 'object' });
|
||||
apiMock.parseMedia.mockResolvedValueOnce({ frames_extracted: 1 });
|
||||
apiMock.getProjects.mockResolvedValue([]);
|
||||
|
||||
const { container } = render(<ProjectLibrary onProjectSelect={vi.fn()} />);
|
||||
const input = container.querySelector('input[accept="video/*"]') as HTMLInputElement;
|
||||
const file = new File(['video'], 'clip.mp4', { type: 'video/mp4' });
|
||||
fireEvent.change(input, { target: { files: [file] } });
|
||||
fireEvent.click(await screen.findByRole('button', { name: '开始导入' }));
|
||||
|
||||
await waitFor(() => expect(apiMock.createProject).toHaveBeenCalledWith(expect.objectContaining({
|
||||
name: 'clip.mp4',
|
||||
parse_fps: 30,
|
||||
})));
|
||||
expect(apiMock.uploadMedia).toHaveBeenCalledWith(file, 'p3');
|
||||
expect(apiMock.parseMedia).toHaveBeenCalledWith('p3');
|
||||
});
|
||||
|
||||
it('imports only valid DICOM files and parses the returned project', async () => {
|
||||
apiMock.uploadDicomBatch.mockResolvedValueOnce({ project_id: 77, uploaded_count: 1, message: 'ok' });
|
||||
apiMock.parseMedia.mockResolvedValueOnce({ frames_extracted: 1 });
|
||||
|
||||
const { container } = render(<ProjectLibrary onProjectSelect={vi.fn()} />);
|
||||
const input = container.querySelector('input[accept=".dcm"]') as HTMLInputElement;
|
||||
const dcm = new File(['dcm'], 'scan.dcm', { type: 'application/dicom' });
|
||||
const ignored = new File(['txt'], 'notes.txt', { type: 'text/plain' });
|
||||
fireEvent.change(input, { target: { files: [dcm, ignored] } });
|
||||
|
||||
await waitFor(() => expect(apiMock.uploadDicomBatch).toHaveBeenCalledWith([dcm]));
|
||||
expect(apiMock.parseMedia).toHaveBeenCalledWith('77');
|
||||
});
|
||||
});
|
||||
@@ -212,11 +212,11 @@ export function ProjectLibrary({ onProjectSelect }: ProjectLibraryProps) {
|
||||
{proj.source_type === 'dicom' ? 'DICOM' : (proj.fps || '30FPS')}
|
||||
</span>
|
||||
<span className="backdrop-blur-md bg-black/40 text-gray-200 text-[10px] px-2 py-1 rounded border border-white/10 flex items-center gap-1 uppercase tracking-widest">
|
||||
{proj.status === 'Ready' ? (
|
||||
{proj.status === 'ready' ? (
|
||||
<><div className="w-1.5 h-1.5 bg-emerald-500 rounded-full" /> 已就绪</>
|
||||
) : proj.status === 'Parsing' ? (
|
||||
) : proj.status === 'parsing' ? (
|
||||
<><div className="w-1.5 h-1.5 bg-amber-500 rounded-full animate-pulse" /> 解析拆帧中</>
|
||||
) : proj.status === 'Error' ? (
|
||||
) : proj.status === 'error' ? (
|
||||
<><div className="w-1.5 h-1.5 bg-red-500 rounded-full" /> 异常</>
|
||||
) : (
|
||||
<><div className="w-1.5 h-1.5 bg-blue-500 rounded-full" /> 待处理</>
|
||||
|
||||
@@ -2,6 +2,7 @@ import React from 'react';
|
||||
import { Home, FolderOpen, Edit3, LayoutTemplate, BrainCircuit } from 'lucide-react';
|
||||
import { cn } from '../lib/utils';
|
||||
import type { ActiveModule } from '../App';
|
||||
import { ModelStatusBadge } from './ModelStatusBadge';
|
||||
|
||||
interface SidebarProps {
|
||||
activeModule: ActiveModule;
|
||||
@@ -47,9 +48,7 @@ export function Sidebar({ activeModule, setActiveModule }: SidebarProps) {
|
||||
})}
|
||||
</nav>
|
||||
<div className="mt-auto mb-4 flex flex-col gap-4">
|
||||
<div className="w-8 h-8 rounded-full border border-cyan-500/50 flex items-center justify-center text-[10px] text-cyan-400 font-bold cursor-pointer transition-all hover:bg-cyan-500/10">
|
||||
GPU
|
||||
</div>
|
||||
<ModelStatusBadge compact />
|
||||
</div>
|
||||
</aside>
|
||||
);
|
||||
|
||||
85
src/components/TemplateRegistry.test.tsx
Normal file
85
src/components/TemplateRegistry.test.tsx
Normal file
@@ -0,0 +1,85 @@
|
||||
import { fireEvent, render, screen, waitFor } from '@testing-library/react';
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest';
|
||||
import { resetStore } from '../test/storeTestUtils';
|
||||
import { useStore } from '../store/useStore';
|
||||
import { TemplateRegistry } from './TemplateRegistry';
|
||||
|
||||
const apiMock = vi.hoisted(() => ({
|
||||
getTemplates: vi.fn(),
|
||||
createTemplate: vi.fn(),
|
||||
updateTemplate: vi.fn(),
|
||||
deleteTemplate: vi.fn(),
|
||||
}));
|
||||
|
||||
vi.mock('../lib/api', () => ({
|
||||
getTemplates: apiMock.getTemplates,
|
||||
createTemplate: apiMock.createTemplate,
|
||||
updateTemplate: apiMock.updateTemplate,
|
||||
deleteTemplate: apiMock.deleteTemplate,
|
||||
}));
|
||||
|
||||
describe('TemplateRegistry', () => {
|
||||
beforeEach(() => {
|
||||
resetStore();
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
it('loads and displays templates with unpacked classes', async () => {
|
||||
apiMock.getTemplates.mockResolvedValueOnce([
|
||||
{
|
||||
id: 't1',
|
||||
name: '腹腔镜胆囊切除术',
|
||||
description: 'desc',
|
||||
classes: [{ id: 'c1', name: '胆囊', color: '#ff0000', zIndex: 10, category: '器官' }],
|
||||
rules: [],
|
||||
},
|
||||
]);
|
||||
|
||||
render(<TemplateRegistry />);
|
||||
|
||||
expect(await screen.findAllByText('腹腔镜胆囊切除术')).toHaveLength(2);
|
||||
expect(screen.getByText('胆囊')).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('creates a template and stores it globally', async () => {
|
||||
apiMock.getTemplates.mockResolvedValueOnce([]);
|
||||
apiMock.createTemplate.mockResolvedValueOnce({
|
||||
id: 't2',
|
||||
name: 'New Template',
|
||||
description: 'desc',
|
||||
classes: [],
|
||||
rules: [],
|
||||
});
|
||||
|
||||
render(<TemplateRegistry />);
|
||||
fireEvent.click(screen.getByText('新建方案'));
|
||||
fireEvent.change(screen.getAllByRole('textbox')[0], { target: { value: 'New Template' } });
|
||||
fireEvent.change(screen.getAllByRole('textbox')[1], { target: { value: 'desc' } });
|
||||
fireEvent.click(screen.getByRole('button', { name: '保存' }));
|
||||
|
||||
await waitFor(() => expect(apiMock.createTemplate).toHaveBeenCalledWith(expect.objectContaining({
|
||||
name: 'New Template',
|
||||
description: 'desc',
|
||||
classes: [],
|
||||
rules: [],
|
||||
color: '#06b6d4',
|
||||
z_index: 0,
|
||||
})));
|
||||
expect(useStore.getState().templates[0]).toEqual(expect.objectContaining({ id: 't2' }));
|
||||
});
|
||||
|
||||
it('imports JSON classes into the edit modal before saving', async () => {
|
||||
apiMock.getTemplates.mockResolvedValueOnce([]);
|
||||
|
||||
render(<TemplateRegistry />);
|
||||
fireEvent.click(screen.getByText('新建方案'));
|
||||
fireEvent.change(screen.getAllByRole('textbox')[0], { target: { value: 'With Classes' } });
|
||||
fireEvent.click(screen.getByText('批量导入'));
|
||||
fireEvent.change(screen.getByPlaceholderText('[[[255,0,0], [0,255,0]], ["分类A", "分类B"]]'), {
|
||||
target: { value: '{"colors":[[255,0,0]],"names":["分类A"]}' },
|
||||
});
|
||||
fireEvent.click(screen.getByRole('button', { name: '导入' }));
|
||||
|
||||
expect(screen.getByText('分类A')).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
30
src/components/ToolsPalette.test.tsx
Normal file
30
src/components/ToolsPalette.test.tsx
Normal file
@@ -0,0 +1,30 @@
|
||||
import { fireEvent, render, screen } from '@testing-library/react';
|
||||
import { describe, expect, it, vi } from 'vitest';
|
||||
import { ToolsPalette } from './ToolsPalette';
|
||||
|
||||
describe('ToolsPalette', () => {
|
||||
it('switches tools and exposes UI-only placeholder buttons', () => {
|
||||
const setActiveTool = vi.fn();
|
||||
|
||||
render(<ToolsPalette activeTool="move" setActiveTool={setActiveTool} />);
|
||||
|
||||
fireEvent.click(screen.getByTitle('创建多边形 (P)'));
|
||||
fireEvent.click(screen.getByTitle('正向选点 (SAM)'));
|
||||
|
||||
expect(setActiveTool).toHaveBeenNthCalledWith(1, 'create_polygon');
|
||||
expect(setActiveTool).toHaveBeenNthCalledWith(2, 'point_pos');
|
||||
expect(screen.getByTitle('撤销操作 (Ctrl+Z)')).toBeInTheDocument();
|
||||
expect(screen.getByTitle('重做操作 (Ctrl+Shift+Z)')).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('switches to SAM trigger and calls the AI navigation hook', () => {
|
||||
const setActiveTool = vi.fn();
|
||||
const onTriggerAI = vi.fn();
|
||||
|
||||
render(<ToolsPalette activeTool="move" setActiveTool={setActiveTool} onTriggerAI={onTriggerAI} />);
|
||||
fireEvent.click(screen.getByTitle('触发 SAM 推理 (Enter)'));
|
||||
|
||||
expect(setActiveTool).toHaveBeenCalledWith('sam_trigger');
|
||||
expect(onTriggerAI).toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
@@ -78,7 +78,7 @@ export function ToolsPalette({ activeTool, setActiveTool, onTriggerAI }: ToolsPa
|
||||
setActiveTool('sam_trigger');
|
||||
if (onTriggerAI) onTriggerAI();
|
||||
}}
|
||||
title="触发 SAM 3 推理 (Enter)"
|
||||
title="触发 SAM 推理 (Enter)"
|
||||
className={cn(
|
||||
"w-10 h-10 rounded-lg flex items-center justify-center transition-all",
|
||||
activeTool === 'sam_trigger'
|
||||
|
||||
259
src/components/VideoWorkspace.test.tsx
Normal file
259
src/components/VideoWorkspace.test.tsx
Normal file
@@ -0,0 +1,259 @@
|
||||
import { act, fireEvent, render, screen, waitFor } from '@testing-library/react';
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest';
|
||||
import { resetStore } from '../test/storeTestUtils';
|
||||
import { useStore } from '../store/useStore';
|
||||
import { VideoWorkspace } from './VideoWorkspace';
|
||||
|
||||
const apiMock = vi.hoisted(() => ({
|
||||
getProjectFrames: vi.fn(),
|
||||
parseMedia: vi.fn(),
|
||||
getTask: vi.fn(),
|
||||
getTemplates: vi.fn(),
|
||||
getProjectAnnotations: vi.fn(),
|
||||
saveAnnotation: vi.fn(),
|
||||
updateAnnotation: vi.fn(),
|
||||
deleteAnnotation: vi.fn(),
|
||||
exportCoco: vi.fn(),
|
||||
annotationToMask: vi.fn(),
|
||||
buildAnnotationPayload: vi.fn(),
|
||||
getAiModelStatus: vi.fn(),
|
||||
}));
|
||||
|
||||
vi.mock('../lib/api', () => ({
|
||||
getProjectFrames: apiMock.getProjectFrames,
|
||||
parseMedia: apiMock.parseMedia,
|
||||
getTask: apiMock.getTask,
|
||||
getTemplates: apiMock.getTemplates,
|
||||
getProjectAnnotations: apiMock.getProjectAnnotations,
|
||||
saveAnnotation: apiMock.saveAnnotation,
|
||||
updateAnnotation: apiMock.updateAnnotation,
|
||||
deleteAnnotation: apiMock.deleteAnnotation,
|
||||
exportCoco: apiMock.exportCoco,
|
||||
annotationToMask: apiMock.annotationToMask,
|
||||
buildAnnotationPayload: apiMock.buildAnnotationPayload,
|
||||
getAiModelStatus: apiMock.getAiModelStatus,
|
||||
}));
|
||||
|
||||
describe('VideoWorkspace', () => {
|
||||
beforeEach(() => {
|
||||
resetStore();
|
||||
vi.clearAllMocks();
|
||||
useStore.setState({ currentProject: { id: '1', name: 'Demo', status: 'ready', video_path: 'uploads/demo.mp4' } });
|
||||
apiMock.getTemplates.mockResolvedValue([]);
|
||||
apiMock.getProjectAnnotations.mockResolvedValue([]);
|
||||
apiMock.annotationToMask.mockReturnValue(null);
|
||||
apiMock.getTask.mockResolvedValue({ id: 1, status: 'success', progress: 100, message: '解析完成' });
|
||||
apiMock.getAiModelStatus.mockResolvedValue({
|
||||
selected_model: 'sam2',
|
||||
gpu: { available: false, device: 'cpu', name: null, torch_available: true },
|
||||
models: [
|
||||
{ id: 'sam2', label: 'SAM 2', available: true, loaded: false, device: 'cpu', supports: [], message: 'ready', package_available: true, checkpoint_exists: true, python_ok: true, torch_ok: true, cuda_required: false },
|
||||
{ id: 'sam3', label: 'SAM 3', available: false, loaded: false, device: 'unavailable', supports: [], message: 'missing', package_available: false, checkpoint_exists: false, python_ok: false, torch_ok: true, cuda_required: true },
|
||||
],
|
||||
});
|
||||
});
|
||||
|
||||
it('loads project frames into the workspace store', async () => {
|
||||
apiMock.getProjectFrames.mockResolvedValueOnce([
|
||||
{ id: 10, project_id: 1, frame_index: 0, image_url: '/frame.jpg', width: 640, height: 360 },
|
||||
]);
|
||||
|
||||
render(<VideoWorkspace />);
|
||||
|
||||
await waitFor(() => expect(useStore.getState().frames).toEqual([
|
||||
{ id: '10', projectId: '1', index: 0, url: '/frame.jpg', width: 640, height: 360 },
|
||||
]));
|
||||
expect(screen.getByText('Demo')).toBeInTheDocument();
|
||||
expect(apiMock.parseMedia).not.toHaveBeenCalled();
|
||||
expect(apiMock.getProjectAnnotations).toHaveBeenCalledWith('1');
|
||||
});
|
||||
|
||||
it('triggers parsing when a media project has no frames yet', async () => {
|
||||
apiMock.getProjectFrames
|
||||
.mockResolvedValueOnce([])
|
||||
.mockResolvedValueOnce([
|
||||
{ id: 11, project_id: 1, frame_index: 0, image_url: '/parsed.jpg', width: 320, height: 240 },
|
||||
]);
|
||||
apiMock.parseMedia.mockResolvedValueOnce({ id: 7, status: 'queued', progress: 0 });
|
||||
apiMock.getTask.mockResolvedValueOnce({ id: 7, status: 'success', progress: 100, message: '解析完成' });
|
||||
|
||||
render(<VideoWorkspace />);
|
||||
|
||||
await waitFor(() => expect(apiMock.parseMedia).toHaveBeenCalledWith('1'));
|
||||
expect(apiMock.getTask).toHaveBeenCalledWith(7);
|
||||
await waitFor(() => expect(useStore.getState().frames[0]).toEqual(expect.objectContaining({
|
||||
id: '11',
|
||||
url: '/parsed.jpg',
|
||||
})));
|
||||
});
|
||||
|
||||
it('hydrates saved annotations after loading frames', async () => {
|
||||
apiMock.getProjectFrames.mockResolvedValueOnce([
|
||||
{ id: 10, project_id: 1, frame_index: 0, image_url: '/frame.jpg', width: 640, height: 360 },
|
||||
]);
|
||||
apiMock.getProjectAnnotations.mockResolvedValueOnce([{ id: 99, frame_id: 10 }]);
|
||||
apiMock.annotationToMask.mockReturnValueOnce({
|
||||
id: 'annotation-99',
|
||||
annotationId: '99',
|
||||
frameId: '10',
|
||||
saved: true,
|
||||
pathData: 'M 0 0 Z',
|
||||
label: 'Saved',
|
||||
color: '#06b6d4',
|
||||
});
|
||||
|
||||
render(<VideoWorkspace />);
|
||||
|
||||
await waitFor(() => expect(useStore.getState().masks).toEqual([
|
||||
expect.objectContaining({ id: 'annotation-99', saved: true }),
|
||||
]));
|
||||
});
|
||||
|
||||
it('saves pending masks through the archive button', async () => {
|
||||
apiMock.getProjectFrames.mockResolvedValueOnce([
|
||||
{ id: 10, project_id: 1, frame_index: 0, image_url: '/frame.jpg', width: 640, height: 360 },
|
||||
]);
|
||||
apiMock.buildAnnotationPayload.mockReturnValueOnce({ project_id: 1, frame_id: 10, mask_data: { polygons: [] } });
|
||||
apiMock.saveAnnotation.mockResolvedValueOnce({ id: 5 });
|
||||
|
||||
render(<VideoWorkspace />);
|
||||
await waitFor(() => expect(useStore.getState().frames).toHaveLength(1));
|
||||
act(() => {
|
||||
useStore.setState({
|
||||
activeTemplateId: '2',
|
||||
masks: [{
|
||||
id: 'mask-1',
|
||||
frameId: '10',
|
||||
pathData: 'M 0 0 Z',
|
||||
label: 'AI Mask',
|
||||
color: '#06b6d4',
|
||||
segmentation: [[0, 0, 10, 0, 10, 10]],
|
||||
bbox: [0, 0, 10, 10],
|
||||
}],
|
||||
});
|
||||
});
|
||||
|
||||
fireEvent.click(screen.getByRole('button', { name: '结构化归档保存' }));
|
||||
|
||||
await waitFor(() => expect(apiMock.saveAnnotation).toHaveBeenCalledWith({
|
||||
project_id: 1,
|
||||
frame_id: 10,
|
||||
mask_data: { polygons: [] },
|
||||
}));
|
||||
expect(apiMock.buildAnnotationPayload).toHaveBeenCalledWith(
|
||||
'1',
|
||||
expect.objectContaining({ id: 'mask-1' }),
|
||||
expect.objectContaining({ id: '10' }),
|
||||
'2',
|
||||
);
|
||||
});
|
||||
|
||||
it('updates dirty saved masks through the archive button', async () => {
|
||||
apiMock.getProjectFrames.mockResolvedValueOnce([
|
||||
{ id: 10, project_id: 1, frame_index: 0, image_url: '/frame.jpg', width: 640, height: 360 },
|
||||
]);
|
||||
apiMock.buildAnnotationPayload.mockReturnValueOnce({
|
||||
project_id: 1,
|
||||
frame_id: 10,
|
||||
template_id: 2,
|
||||
mask_data: { polygons: [], label: '胆囊' },
|
||||
});
|
||||
apiMock.updateAnnotation.mockResolvedValueOnce({ id: 99 });
|
||||
|
||||
render(<VideoWorkspace />);
|
||||
await waitFor(() => expect(useStore.getState().frames).toHaveLength(1));
|
||||
act(() => {
|
||||
useStore.setState({
|
||||
activeTemplateId: '2',
|
||||
masks: [{
|
||||
id: 'annotation-99',
|
||||
annotationId: '99',
|
||||
frameId: '10',
|
||||
pathData: 'M 0 0 Z',
|
||||
label: '胆囊',
|
||||
color: '#ff0000',
|
||||
saveStatus: 'dirty',
|
||||
segmentation: [[0, 0, 10, 0, 10, 10]],
|
||||
bbox: [0, 0, 10, 10],
|
||||
}],
|
||||
});
|
||||
});
|
||||
|
||||
fireEvent.click(screen.getByRole('button', { name: '结构化归档保存' }));
|
||||
|
||||
await waitFor(() => expect(apiMock.updateAnnotation).toHaveBeenCalledWith('99', {
|
||||
template_id: 2,
|
||||
mask_data: { polygons: [], label: '胆囊' },
|
||||
points: undefined,
|
||||
bbox: undefined,
|
||||
}));
|
||||
expect(apiMock.saveAnnotation).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('deletes saved annotations when clearing current-frame masks', async () => {
|
||||
apiMock.getProjectFrames.mockResolvedValueOnce([
|
||||
{ id: 10, project_id: 1, frame_index: 0, image_url: '/frame.jpg', width: 640, height: 360 },
|
||||
]);
|
||||
apiMock.deleteAnnotation.mockResolvedValueOnce(undefined);
|
||||
|
||||
render(<VideoWorkspace />);
|
||||
await waitFor(() => expect(useStore.getState().frames).toHaveLength(1));
|
||||
act(() => {
|
||||
useStore.setState({
|
||||
masks: [
|
||||
{
|
||||
id: 'annotation-99',
|
||||
annotationId: '99',
|
||||
frameId: '10',
|
||||
pathData: 'M 0 0 Z',
|
||||
label: 'Saved',
|
||||
color: '#06b6d4',
|
||||
saved: true,
|
||||
saveStatus: 'saved',
|
||||
},
|
||||
{
|
||||
id: 'draft-1',
|
||||
frameId: '10',
|
||||
pathData: 'M 1 1 Z',
|
||||
label: 'Draft',
|
||||
color: '#ff0000',
|
||||
},
|
||||
],
|
||||
});
|
||||
});
|
||||
|
||||
fireEvent.click(screen.getByRole('button', { name: '清空遮罩' }));
|
||||
|
||||
await waitFor(() => expect(apiMock.deleteAnnotation).toHaveBeenCalledWith('99'));
|
||||
expect(useStore.getState().masks).toEqual([]);
|
||||
});
|
||||
|
||||
it('auto-saves pending masks before exporting COCO', async () => {
|
||||
apiMock.getProjectFrames.mockResolvedValueOnce([
|
||||
{ id: 10, project_id: 1, frame_index: 0, image_url: '/frame.jpg', width: 640, height: 360 },
|
||||
]);
|
||||
apiMock.buildAnnotationPayload.mockReturnValueOnce({ project_id: 1, frame_id: 10, mask_data: { polygons: [] } });
|
||||
apiMock.saveAnnotation.mockResolvedValueOnce({ id: 5 });
|
||||
apiMock.exportCoco.mockResolvedValueOnce(new Blob(['{}'], { type: 'application/json' }));
|
||||
|
||||
render(<VideoWorkspace />);
|
||||
await waitFor(() => expect(useStore.getState().frames).toHaveLength(1));
|
||||
act(() => {
|
||||
useStore.setState({
|
||||
masks: [{
|
||||
id: 'mask-1',
|
||||
frameId: '10',
|
||||
pathData: 'M 0 0 Z',
|
||||
label: 'AI Mask',
|
||||
color: '#06b6d4',
|
||||
segmentation: [[0, 0, 10, 0, 10, 10]],
|
||||
}],
|
||||
});
|
||||
});
|
||||
|
||||
fireEvent.click(screen.getByRole('button', { name: '导出 JSON 标注集' }));
|
||||
|
||||
await waitFor(() => expect(apiMock.saveAnnotation).toHaveBeenCalled());
|
||||
expect(apiMock.exportCoco).toHaveBeenCalledWith('1');
|
||||
});
|
||||
});
|
||||
@@ -1,10 +1,28 @@
|
||||
import React, { useEffect } from 'react';
|
||||
import React, { useCallback, useEffect, useMemo, useState } from 'react';
|
||||
import { useStore } from '../store/useStore';
|
||||
import { getProjectFrames, parseMedia, getTemplates } from '../lib/api';
|
||||
import {
|
||||
annotationToMask,
|
||||
buildAnnotationPayload,
|
||||
deleteAnnotation,
|
||||
exportCoco,
|
||||
getProjectAnnotations,
|
||||
getProjectFrames,
|
||||
getTask,
|
||||
getTemplates,
|
||||
parseMedia,
|
||||
saveAnnotation,
|
||||
updateAnnotation,
|
||||
} from '../lib/api';
|
||||
import { CanvasArea } from './CanvasArea';
|
||||
import { ToolsPalette } from './ToolsPalette';
|
||||
import { OntologyInspector } from './OntologyInspector';
|
||||
import { FrameTimeline } from './FrameTimeline';
|
||||
import { ModelStatusBadge } from './ModelStatusBadge';
|
||||
import type { Frame } from '../store/useStore';
|
||||
|
||||
function sleep(ms: number) {
|
||||
return new Promise((resolve) => setTimeout(resolve, ms));
|
||||
}
|
||||
|
||||
export function VideoWorkspace({ onNavigateToAI }: { onNavigateToAI?: () => void }) {
|
||||
const activeTool = useStore((state) => state.activeTool);
|
||||
@@ -12,8 +30,26 @@ export function VideoWorkspace({ onNavigateToAI }: { onNavigateToAI?: () => void
|
||||
const currentProject = useStore((state) => state.currentProject);
|
||||
const frames = useStore((state) => state.frames);
|
||||
const currentFrameIndex = useStore((state) => state.currentFrameIndex);
|
||||
const masks = useStore((state) => state.masks);
|
||||
const activeTemplateId = useStore((state) => state.activeTemplateId);
|
||||
const setFrames = useStore((state) => state.setFrames);
|
||||
const setCurrentFrame = useStore((state) => state.setCurrentFrame);
|
||||
const setMasks = useStore((state) => state.setMasks);
|
||||
const [isSaving, setIsSaving] = useState(false);
|
||||
const [isExporting, setIsExporting] = useState(false);
|
||||
const [statusMessage, setStatusMessage] = useState('');
|
||||
|
||||
const hydrateSavedAnnotations = useCallback(async (projectId: string, projectFrames: Frame[]) => {
|
||||
const frameById = new Map(projectFrames.map((frame) => [frame.id, frame]));
|
||||
const annotations = await getProjectAnnotations(projectId);
|
||||
const savedMasks = annotations
|
||||
.map((annotation) => {
|
||||
const frame = annotation.frame_id ? frameById.get(String(annotation.frame_id)) : null;
|
||||
return frame ? annotationToMask(annotation, frame) : null;
|
||||
})
|
||||
.filter((mask): mask is NonNullable<typeof mask> => Boolean(mask));
|
||||
setMasks(savedMasks);
|
||||
}, [setMasks]);
|
||||
|
||||
useEffect(() => {
|
||||
if (!currentProject?.id) return;
|
||||
@@ -25,34 +61,58 @@ export function VideoWorkspace({ onNavigateToAI }: { onNavigateToAI?: () => void
|
||||
if (cancelled) return;
|
||||
|
||||
if (data.length === 0 && currentProject.video_path) {
|
||||
// No frames yet but video exists → trigger parsing
|
||||
// No frames yet but video exists -> queue parsing and poll the task.
|
||||
try {
|
||||
await parseMedia(String(currentProject.id));
|
||||
const task = await parseMedia(String(currentProject.id));
|
||||
if (cancelled) return;
|
||||
setStatusMessage(`解析任务已入队 #${task.id}`);
|
||||
let completed = false;
|
||||
for (let attempt = 0; attempt < 60; attempt += 1) {
|
||||
const freshTask = await getTask(task.id);
|
||||
if (cancelled) return;
|
||||
setStatusMessage(freshTask.message || `解析进度 ${freshTask.progress}%`);
|
||||
if (freshTask.status === 'success') {
|
||||
completed = true;
|
||||
break;
|
||||
}
|
||||
if (freshTask.status === 'failed') {
|
||||
setStatusMessage(freshTask.error || '解析任务失败');
|
||||
return;
|
||||
}
|
||||
await sleep(2000);
|
||||
}
|
||||
if (!completed) {
|
||||
setStatusMessage('解析仍在后台运行,可稍后刷新工作区');
|
||||
return;
|
||||
}
|
||||
const fresh = await getProjectFrames(String(currentProject.id));
|
||||
if (cancelled) return;
|
||||
setFrames(fresh.map((f) => ({
|
||||
const mappedFrames = fresh.map((f) => ({
|
||||
id: String(f.id),
|
||||
projectId: String(f.project_id),
|
||||
index: f.frame_index,
|
||||
url: f.image_url,
|
||||
width: f.width ?? 0,
|
||||
height: f.height ?? 0,
|
||||
})));
|
||||
}));
|
||||
setFrames(mappedFrames);
|
||||
setCurrentFrame(0);
|
||||
await hydrateSavedAnnotations(String(currentProject.id), mappedFrames);
|
||||
} catch (err) {
|
||||
console.error('Parse failed:', err);
|
||||
}
|
||||
} else {
|
||||
setFrames(data.map((f) => ({
|
||||
const mappedFrames = data.map((f) => ({
|
||||
id: String(f.id),
|
||||
projectId: String(f.project_id),
|
||||
index: f.frame_index,
|
||||
url: f.image_url,
|
||||
width: f.width ?? 0,
|
||||
height: f.height ?? 0,
|
||||
})));
|
||||
}));
|
||||
setFrames(mappedFrames);
|
||||
setCurrentFrame(0);
|
||||
await hydrateSavedAnnotations(String(currentProject.id), mappedFrames);
|
||||
}
|
||||
} catch (err) {
|
||||
console.error('Failed to load frames:', err);
|
||||
@@ -61,7 +121,7 @@ export function VideoWorkspace({ onNavigateToAI }: { onNavigateToAI?: () => void
|
||||
|
||||
loadFrames();
|
||||
return () => { cancelled = true; };
|
||||
}, [currentProject?.id, setFrames, setCurrentFrame]);
|
||||
}, [currentProject?.id, currentProject?.video_path, hydrateSavedAnnotations, setFrames, setCurrentFrame]);
|
||||
|
||||
const templates = useStore((state) => state.templates);
|
||||
const setTemplates = useStore((state) => state.setTemplates);
|
||||
@@ -72,7 +132,121 @@ export function VideoWorkspace({ onNavigateToAI }: { onNavigateToAI?: () => void
|
||||
}
|
||||
}, [templates.length, setTemplates]);
|
||||
|
||||
const currentFrameUrl = frames[currentFrameIndex]?.url || '';
|
||||
const currentFrame = frames[currentFrameIndex] || null;
|
||||
const frameById = useMemo(() => new Map(frames.map((frame) => [frame.id, frame])), [frames]);
|
||||
const projectFrameIds = useMemo(() => new Set(frames.map((frame) => frame.id)), [frames]);
|
||||
|
||||
const savePendingAnnotations = useCallback(async ({ silent = false } = {}) => {
|
||||
if (!currentProject?.id) return 0;
|
||||
const projectMasks = masks.filter((mask) => projectFrameIds.has(mask.frameId));
|
||||
const pendingMasks = projectMasks.filter((mask) => !mask.annotationId);
|
||||
const dirtyMasks = projectMasks.filter((mask) => mask.annotationId && mask.saveStatus === 'dirty');
|
||||
if (pendingMasks.length === 0 && dirtyMasks.length === 0) {
|
||||
if (!silent) setStatusMessage('没有待保存标注');
|
||||
return 0;
|
||||
}
|
||||
|
||||
setIsSaving(true);
|
||||
setStatusMessage('正在保存标注...');
|
||||
try {
|
||||
const createPayloads = pendingMasks
|
||||
.map((mask) => {
|
||||
const frame = frameById.get(mask.frameId);
|
||||
return frame ? buildAnnotationPayload(currentProject.id, mask, frame, activeTemplateId) : null;
|
||||
})
|
||||
.filter((payload): payload is NonNullable<typeof payload> => Boolean(payload));
|
||||
|
||||
const updatePayloads = dirtyMasks
|
||||
.map((mask) => {
|
||||
const frame = frameById.get(mask.frameId);
|
||||
const payload = frame ? buildAnnotationPayload(currentProject.id, mask, frame, activeTemplateId) : null;
|
||||
if (!payload || !mask.annotationId) return null;
|
||||
const updatePayload = {
|
||||
template_id: payload.template_id,
|
||||
mask_data: payload.mask_data,
|
||||
points: payload.points,
|
||||
bbox: payload.bbox,
|
||||
};
|
||||
return { annotationId: mask.annotationId, payload: updatePayload };
|
||||
})
|
||||
.filter((item): item is NonNullable<typeof item> => Boolean(item));
|
||||
|
||||
if (createPayloads.length === 0 && updatePayloads.length === 0) {
|
||||
setStatusMessage('没有可保存的标注数据');
|
||||
return 0;
|
||||
}
|
||||
|
||||
await Promise.all([
|
||||
...createPayloads.map((payload) => saveAnnotation(payload)),
|
||||
...updatePayloads.map(({ annotationId, payload }) => updateAnnotation(annotationId, payload)),
|
||||
]);
|
||||
await hydrateSavedAnnotations(currentProject.id, frames);
|
||||
const savedCount = createPayloads.length + updatePayloads.length;
|
||||
setStatusMessage(`已保存 ${savedCount} 个标注`);
|
||||
return savedCount;
|
||||
} catch (err) {
|
||||
console.error('Save annotations failed:', err);
|
||||
setStatusMessage('保存失败,请检查后端服务');
|
||||
throw err;
|
||||
} finally {
|
||||
setIsSaving(false);
|
||||
}
|
||||
}, [activeTemplateId, currentProject?.id, frameById, frames, hydrateSavedAnnotations, masks, projectFrameIds]);
|
||||
|
||||
const handleClearCurrentFrameMasks = useCallback(async () => {
|
||||
if (!currentFrame) return;
|
||||
const frameMasks = masks.filter((mask) => mask.frameId === currentFrame.id);
|
||||
const annotationIds = frameMasks
|
||||
.map((mask) => mask.annotationId)
|
||||
.filter((annotationId): annotationId is string => Boolean(annotationId));
|
||||
|
||||
setIsSaving(true);
|
||||
setStatusMessage(annotationIds.length > 0 ? '正在删除已保存标注...' : '正在清空本帧遮罩...');
|
||||
try {
|
||||
await Promise.all(annotationIds.map((annotationId) => deleteAnnotation(annotationId)));
|
||||
setMasks(masks.filter((mask) => mask.frameId !== currentFrame.id));
|
||||
setStatusMessage(annotationIds.length > 0
|
||||
? `已删除 ${annotationIds.length} 个后端标注`
|
||||
: '已清空本帧未保存遮罩');
|
||||
} catch (err) {
|
||||
console.error('Delete annotations failed:', err);
|
||||
setStatusMessage('删除失败,请检查后端服务');
|
||||
} finally {
|
||||
setIsSaving(false);
|
||||
}
|
||||
}, [currentFrame, masks, setMasks]);
|
||||
|
||||
const handleSave = async () => {
|
||||
try {
|
||||
await savePendingAnnotations();
|
||||
} catch {
|
||||
// status message is set in savePendingAnnotations
|
||||
}
|
||||
};
|
||||
|
||||
const handleExport = async () => {
|
||||
if (!currentProject?.id) return;
|
||||
setIsExporting(true);
|
||||
setStatusMessage('正在准备导出...');
|
||||
try {
|
||||
await savePendingAnnotations({ silent: true });
|
||||
const blob = await exportCoco(currentProject.id);
|
||||
const url = URL.createObjectURL(blob);
|
||||
const link = document.createElement('a');
|
||||
link.href = url;
|
||||
link.download = `project_${currentProject.id}_coco.json`;
|
||||
document.body.appendChild(link);
|
||||
link.click();
|
||||
link.remove();
|
||||
URL.revokeObjectURL(url);
|
||||
setStatusMessage('COCO JSON 已导出');
|
||||
} catch (err) {
|
||||
console.error('Export failed:', err);
|
||||
setStatusMessage('导出失败,请检查后端服务');
|
||||
} finally {
|
||||
setIsExporting(false);
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="w-full h-full flex flex-col bg-[#0a0a0a]">
|
||||
@@ -84,14 +258,25 @@ export function VideoWorkspace({ onNavigateToAI }: { onNavigateToAI?: () => void
|
||||
<span className="text-sm text-white font-mono">{currentProject?.name || '未选择项目'}</span>
|
||||
</div>
|
||||
<div className="flex items-center gap-3">
|
||||
<div className="flex items-center gap-1.5 text-[10px] uppercase font-medium">
|
||||
<span className="px-2 py-0.5 rounded bg-green-500/10 text-green-400 border border-green-500/20">SAM 3 部署就绪</span>
|
||||
</div>
|
||||
<button className="px-4 py-1.5 bg-white/5 hover:bg-white/10 border border-white/10 rounded-md text-xs transition-colors text-white">
|
||||
导出 JSON 标注集
|
||||
{statusMessage && (
|
||||
<span className="text-[10px] text-gray-500 font-mono max-w-48 truncate" title={statusMessage}>
|
||||
{statusMessage}
|
||||
</span>
|
||||
)}
|
||||
<ModelStatusBadge />
|
||||
<button
|
||||
onClick={handleExport}
|
||||
disabled={!currentProject?.id || isExporting || isSaving}
|
||||
className="px-4 py-1.5 bg-white/5 hover:bg-white/10 border border-white/10 rounded-md text-xs transition-colors text-white disabled:opacity-40 disabled:cursor-not-allowed"
|
||||
>
|
||||
{isExporting ? '导出中...' : '导出 JSON 标注集'}
|
||||
</button>
|
||||
<button className="px-4 py-1.5 bg-cyan-600 hover:bg-cyan-500 text-white text-xs font-medium rounded-md transition-shadow shadow-lg shadow-cyan-900/20">
|
||||
结构化归档保存
|
||||
<button
|
||||
onClick={handleSave}
|
||||
disabled={!currentProject?.id || isSaving || isExporting}
|
||||
className="px-4 py-1.5 bg-cyan-600 hover:bg-cyan-500 text-white text-xs font-medium rounded-md transition-shadow shadow-lg shadow-cyan-900/20 disabled:opacity-40 disabled:cursor-not-allowed"
|
||||
>
|
||||
{isSaving ? '保存中...' : '结构化归档保存'}
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
@@ -102,7 +287,7 @@ export function VideoWorkspace({ onNavigateToAI }: { onNavigateToAI?: () => void
|
||||
|
||||
<div className="flex-1 relative flex items-center justify-center p-8 bg-[#151515] overflow-hidden">
|
||||
<div className="relative w-full h-full bg-[#1e1e1e] border border-white/5 shadow-2xl rounded-sm">
|
||||
<CanvasArea activeTool={activeTool} frameUrl={currentFrameUrl} />
|
||||
<CanvasArea activeTool={activeTool} frame={currentFrame} onClearMasks={handleClearCurrentFrameMasks} />
|
||||
</div>
|
||||
</div>
|
||||
|
||||
|
||||
361
src/lib/api.test.ts
Normal file
361
src/lib/api.test.ts
Normal file
@@ -0,0 +1,361 @@
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest';
|
||||
|
||||
const axiosMock = vi.hoisted(() => {
|
||||
const client = {
|
||||
get: vi.fn(),
|
||||
post: vi.fn(),
|
||||
patch: vi.fn(),
|
||||
delete: vi.fn(),
|
||||
interceptors: {
|
||||
request: { use: vi.fn() },
|
||||
response: { use: vi.fn() },
|
||||
},
|
||||
};
|
||||
return { client, create: vi.fn(() => client) };
|
||||
});
|
||||
|
||||
vi.mock('axios', () => ({
|
||||
default: {
|
||||
create: axiosMock.create,
|
||||
},
|
||||
}));
|
||||
|
||||
describe('api client contracts', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
vi.setSystemTime(new Date('2026-05-01T00:00:00Z'));
|
||||
});
|
||||
|
||||
it('maps backend project fields into frontend project fields', async () => {
|
||||
const { getProjects } = await import('./api');
|
||||
axiosMock.client.get.mockResolvedValueOnce({
|
||||
data: [
|
||||
{
|
||||
id: 7,
|
||||
name: 'Demo',
|
||||
description: 'desc',
|
||||
status: 'ready',
|
||||
frame_count: 12,
|
||||
original_fps: 29.97,
|
||||
parse_fps: 10,
|
||||
thumbnail_url: 'thumb',
|
||||
video_path: 'uploads/demo.mp4',
|
||||
source_type: 'video',
|
||||
created_at: 'created',
|
||||
updated_at: 'updated',
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
await expect(getProjects()).resolves.toEqual([
|
||||
expect.objectContaining({
|
||||
id: '7',
|
||||
name: 'Demo',
|
||||
status: 'ready',
|
||||
frames: 12,
|
||||
fps: '30FPS',
|
||||
thumbnail_url: 'thumb',
|
||||
video_path: 'uploads/demo.mp4',
|
||||
source_type: 'video',
|
||||
createdAt: 'created',
|
||||
updatedAt: 'updated',
|
||||
}),
|
||||
]);
|
||||
expect(axiosMock.client.get).toHaveBeenCalledWith('/api/projects');
|
||||
});
|
||||
|
||||
it('updates projects with PATCH instead of the old PUT contract', async () => {
|
||||
const { updateProject } = await import('./api');
|
||||
axiosMock.client.patch.mockResolvedValueOnce({ data: { id: 3, name: 'Renamed', status: 'ready' } });
|
||||
|
||||
await updateProject('3', { name: 'Renamed' } as any);
|
||||
|
||||
expect(axiosMock.client.patch).toHaveBeenCalledWith('/api/projects/3', { name: 'Renamed' });
|
||||
});
|
||||
|
||||
it('normalizes legacy project status values returned by existing databases', async () => {
|
||||
const { getProjects } = await import('./api');
|
||||
axiosMock.client.get.mockResolvedValueOnce({
|
||||
data: [
|
||||
{ id: 1, name: 'Old Ready', status: 'Ready' },
|
||||
{ id: 2, name: 'Old Parsing', status: 'Parsing' },
|
||||
{ id: 3, name: 'Old Error', status: 'Error' },
|
||||
],
|
||||
});
|
||||
|
||||
await expect(getProjects()).resolves.toEqual([
|
||||
expect.objectContaining({ status: 'ready' }),
|
||||
expect.objectContaining({ status: 'parsing' }),
|
||||
expect.objectContaining({ status: 'error' }),
|
||||
]);
|
||||
});
|
||||
|
||||
it('exports COCO from the backend route shape', async () => {
|
||||
const { exportCoco } = await import('./api');
|
||||
const blob = new Blob(['{}'], { type: 'application/json' });
|
||||
axiosMock.client.get.mockResolvedValueOnce({ data: blob });
|
||||
|
||||
await expect(exportCoco('9')).resolves.toBe(blob);
|
||||
expect(axiosMock.client.get).toHaveBeenCalledWith('/api/export/9/coco', {
|
||||
responseType: 'blob',
|
||||
});
|
||||
});
|
||||
|
||||
it('loads dashboard overview from the backend summary endpoint', async () => {
|
||||
const { getDashboardOverview } = await import('./api');
|
||||
const overview = {
|
||||
summary: {
|
||||
project_count: 2,
|
||||
parsing_task_count: 1,
|
||||
annotation_count: 5,
|
||||
frame_count: 100,
|
||||
template_count: 3,
|
||||
system_load_percent: 12,
|
||||
},
|
||||
tasks: [
|
||||
{ id: 'project-1', project_id: 1, name: 'Demo', progress: 60, status: 'pending', frame_count: 10, updated_at: 'now' },
|
||||
],
|
||||
activity: [
|
||||
{ id: 'project-1', kind: 'project', time: 'now', message: '项目状态: pending', project: 'Demo' },
|
||||
],
|
||||
};
|
||||
axiosMock.client.get.mockResolvedValueOnce({ data: overview });
|
||||
|
||||
await expect(getDashboardOverview()).resolves.toEqual(overview);
|
||||
expect(axiosMock.client.get).toHaveBeenCalledWith('/api/dashboard/overview');
|
||||
});
|
||||
|
||||
it('queues media parsing and reads processing task status', async () => {
|
||||
const { getTask, parseMedia } = await import('./api');
|
||||
const task = {
|
||||
id: 12,
|
||||
task_type: 'parse_video',
|
||||
status: 'queued',
|
||||
progress: 0,
|
||||
message: '解析任务已入队',
|
||||
project_id: 9,
|
||||
celery_task_id: 'celery-12',
|
||||
payload: { source_type: 'video' },
|
||||
result: null,
|
||||
error: null,
|
||||
created_at: 'created',
|
||||
started_at: null,
|
||||
finished_at: null,
|
||||
updated_at: 'updated',
|
||||
};
|
||||
axiosMock.client.post.mockResolvedValueOnce({ data: task });
|
||||
axiosMock.client.get.mockResolvedValueOnce({ data: { ...task, status: 'success', progress: 100 } });
|
||||
|
||||
await expect(parseMedia('9')).resolves.toEqual(task);
|
||||
expect(axiosMock.client.post).toHaveBeenCalledWith('/api/media/parse', null, {
|
||||
params: { project_id: '9' },
|
||||
});
|
||||
|
||||
await expect(getTask(12)).resolves.toEqual(expect.objectContaining({ status: 'success', progress: 100 }));
|
||||
expect(axiosMock.client.get).toHaveBeenCalledWith('/api/tasks/12');
|
||||
});
|
||||
|
||||
it('lists, saves, updates, and deletes annotations with the backend annotation contract', async () => {
|
||||
const { deleteAnnotation, getProjectAnnotations, saveAnnotation, updateAnnotation } = await import('./api');
|
||||
const saved = {
|
||||
id: 1,
|
||||
project_id: 9,
|
||||
frame_id: 5,
|
||||
template_id: 2,
|
||||
mask_data: { polygons: [[[0, 0], [1, 0], [1, 1]]] },
|
||||
points: null,
|
||||
bbox: null,
|
||||
created_at: 'created',
|
||||
updated_at: 'updated',
|
||||
};
|
||||
axiosMock.client.get.mockResolvedValueOnce({ data: [saved] });
|
||||
axiosMock.client.post.mockResolvedValueOnce({ data: saved });
|
||||
|
||||
await expect(getProjectAnnotations('9', '5')).resolves.toEqual([saved]);
|
||||
expect(axiosMock.client.get).toHaveBeenCalledWith('/api/ai/annotations', {
|
||||
params: { project_id: 9, frame_id: 5 },
|
||||
});
|
||||
|
||||
await expect(saveAnnotation({
|
||||
project_id: 9,
|
||||
frame_id: 5,
|
||||
template_id: 2,
|
||||
mask_data: { polygons: [[[0, 0], [1, 0], [1, 1]]], label: 'mask' },
|
||||
})).resolves.toEqual(saved);
|
||||
expect(axiosMock.client.post).toHaveBeenCalledWith('/api/ai/annotate', {
|
||||
project_id: 9,
|
||||
frame_id: 5,
|
||||
template_id: 2,
|
||||
mask_data: { polygons: [[[0, 0], [1, 0], [1, 1]]], label: 'mask' },
|
||||
});
|
||||
|
||||
axiosMock.client.patch.mockResolvedValueOnce({ data: { ...saved, mask_data: { ...saved.mask_data, label: 'updated' } } });
|
||||
await expect(updateAnnotation('1', {
|
||||
template_id: 2,
|
||||
mask_data: { polygons: [[[0, 0], [1, 0], [1, 1]]], label: 'updated' },
|
||||
})).resolves.toEqual(expect.objectContaining({ mask_data: expect.objectContaining({ label: 'updated' }) }));
|
||||
expect(axiosMock.client.patch).toHaveBeenCalledWith('/api/ai/annotations/1', {
|
||||
template_id: 2,
|
||||
mask_data: { polygons: [[[0, 0], [1, 0], [1, 1]]], label: 'updated' },
|
||||
});
|
||||
|
||||
axiosMock.client.delete.mockResolvedValueOnce({ data: null });
|
||||
await expect(deleteAnnotation('1')).resolves.toBeUndefined();
|
||||
expect(axiosMock.client.delete).toHaveBeenCalledWith('/api/ai/annotations/1');
|
||||
});
|
||||
|
||||
it('builds annotation payloads from frontend masks and restores saved annotations to masks', async () => {
|
||||
const { annotationToMask, buildAnnotationPayload } = await import('./api');
|
||||
const frame = { id: '5', projectId: '9', index: 0, url: '/frame.jpg', width: 100, height: 50 };
|
||||
const payload = buildAnnotationPayload('9', {
|
||||
id: 'm1',
|
||||
frameId: '5',
|
||||
pathData: 'M 10 10 L 90 10 L 90 40 Z',
|
||||
label: '胆囊',
|
||||
color: '#ff0000',
|
||||
classId: 'c1',
|
||||
className: '胆囊',
|
||||
classZIndex: 20,
|
||||
segmentation: [[10, 10, 90, 10, 90, 40]],
|
||||
bbox: [10, 10, 80, 30],
|
||||
}, frame, '2');
|
||||
|
||||
expect(payload).toEqual({
|
||||
project_id: 9,
|
||||
frame_id: 5,
|
||||
template_id: 2,
|
||||
mask_data: {
|
||||
polygons: [[[0.1, 0.2], [0.9, 0.2], [0.9, 0.8]]],
|
||||
label: '胆囊',
|
||||
color: '#ff0000',
|
||||
class: { id: 'c1', name: '胆囊', color: '#ff0000', zIndex: 20 },
|
||||
},
|
||||
bbox: [0.1, 0.2, 0.8, 0.6],
|
||||
});
|
||||
|
||||
expect(annotationToMask({
|
||||
id: 3,
|
||||
project_id: 9,
|
||||
frame_id: 5,
|
||||
template_id: 2,
|
||||
mask_data: {
|
||||
polygons: [[[0.1, 0.2], [0.9, 0.2], [0.9, 0.8]]],
|
||||
label: '旧标签',
|
||||
color: '#06b6d4',
|
||||
class: { id: 'c1', name: '胆囊', color: '#ff0000', zIndex: 20 },
|
||||
},
|
||||
points: null,
|
||||
bbox: null,
|
||||
created_at: 'created',
|
||||
updated_at: 'updated',
|
||||
}, frame)).toEqual(expect.objectContaining({
|
||||
id: 'annotation-3',
|
||||
annotationId: '3',
|
||||
frameId: '5',
|
||||
templateId: '2',
|
||||
classId: 'c1',
|
||||
className: '胆囊',
|
||||
classZIndex: 20,
|
||||
label: '胆囊',
|
||||
color: '#ff0000',
|
||||
saveStatus: 'saved',
|
||||
saved: true,
|
||||
pathData: 'M 10 10 L 90 10 L 90 40 Z',
|
||||
bbox: [10, 10, 80, 30],
|
||||
}));
|
||||
});
|
||||
|
||||
it('normalizes positive and negative point prompts for AI prediction', async () => {
|
||||
const { predictMask } = await import('./api');
|
||||
axiosMock.client.post.mockResolvedValueOnce({
|
||||
data: {
|
||||
polygons: [[[0.25, 0.25], [0.75, 0.25], [0.75, 0.75], [0.25, 0.75]]],
|
||||
scores: [0.9],
|
||||
},
|
||||
});
|
||||
|
||||
const result = await predictMask({
|
||||
imageId: '42',
|
||||
imageWidth: 400,
|
||||
imageHeight: 200,
|
||||
points: [
|
||||
{ x: 200, y: 100, type: 'pos' },
|
||||
{ x: 40, y: 20, type: 'neg' },
|
||||
],
|
||||
});
|
||||
|
||||
expect(axiosMock.client.post).toHaveBeenCalledWith('/api/ai/predict', {
|
||||
image_id: 42,
|
||||
prompt_type: 'point',
|
||||
prompt_data: {
|
||||
points: [[0.5, 0.5], [0.1, 0.1]],
|
||||
labels: [1, 0],
|
||||
},
|
||||
model: 'sam2',
|
||||
});
|
||||
expect(result.masks[0]).toEqual(expect.objectContaining({
|
||||
pathData: 'M 100 50 L 300 50 L 300 150 L 100 150 Z',
|
||||
segmentation: [[100, 50, 300, 50, 300, 150, 100, 150]],
|
||||
bbox: [100, 50, 200, 100],
|
||||
area: 20000,
|
||||
confidence: 0.9,
|
||||
}));
|
||||
});
|
||||
|
||||
it('normalizes box prompts for AI prediction', async () => {
|
||||
const { predictMask } = await import('./api');
|
||||
axiosMock.client.post.mockResolvedValueOnce({ data: { polygons: [], scores: [] } });
|
||||
|
||||
await predictMask({
|
||||
imageId: '5',
|
||||
imageWidth: 640,
|
||||
imageHeight: 320,
|
||||
box: { x1: 64, y1: 32, x2: 320, y2: 160 },
|
||||
});
|
||||
|
||||
expect(axiosMock.client.post).toHaveBeenCalledWith('/api/ai/predict', {
|
||||
image_id: 5,
|
||||
prompt_type: 'box',
|
||||
prompt_data: [0.1, 0.1, 0.5, 0.5],
|
||||
model: 'sam2',
|
||||
});
|
||||
});
|
||||
|
||||
it('uses semantic prompt type for text-only AI prediction', async () => {
|
||||
const { predictMask } = await import('./api');
|
||||
axiosMock.client.post.mockResolvedValueOnce({ data: { polygons: [], scores: [] } });
|
||||
|
||||
await predictMask({
|
||||
imageId: '6',
|
||||
imageWidth: 640,
|
||||
imageHeight: 360,
|
||||
model: 'sam3',
|
||||
text: '分割胆囊',
|
||||
});
|
||||
|
||||
expect(axiosMock.client.post).toHaveBeenCalledWith('/api/ai/predict', {
|
||||
image_id: 6,
|
||||
prompt_type: 'semantic',
|
||||
prompt_data: '分割胆囊',
|
||||
model: 'sam3',
|
||||
});
|
||||
});
|
||||
|
||||
it('loads AI model and GPU runtime status', async () => {
|
||||
const { getAiModelStatus } = await import('./api');
|
||||
const status = {
|
||||
selected_model: 'sam2',
|
||||
gpu: { available: false, device: 'cpu', name: null, torch_available: true, torch_version: '2.x', cuda_version: null },
|
||||
models: [
|
||||
{ id: 'sam2', label: 'SAM 2', available: true, loaded: false, device: 'cpu', supports: ['point'], message: 'ready', package_available: true, checkpoint_exists: true, checkpoint_path: 'model.pt', python_ok: true, torch_ok: true, cuda_required: false },
|
||||
{ id: 'sam3', label: 'SAM 3', available: false, loaded: false, device: 'unavailable', supports: ['semantic'], message: 'missing runtime', package_available: false, checkpoint_exists: false, checkpoint_path: null, python_ok: false, torch_ok: true, cuda_required: true },
|
||||
],
|
||||
};
|
||||
axiosMock.client.get.mockResolvedValueOnce({ data: status });
|
||||
|
||||
await expect(getAiModelStatus('sam3')).resolves.toEqual(status);
|
||||
expect(axiosMock.client.get).toHaveBeenCalledWith('/api/ai/models/status', {
|
||||
params: { selected_model: 'sam3' },
|
||||
});
|
||||
});
|
||||
});
|
||||
409
src/lib/api.ts
409
src/lib/api.ts
@@ -1,8 +1,9 @@
|
||||
import axios, { AxiosError } from 'axios';
|
||||
import type { Project, Template } from '../store/useStore';
|
||||
import type { AiModelId, Frame, Mask, Project, Template } from '../store/useStore';
|
||||
import { API_BASE_URL } from './config';
|
||||
|
||||
const apiClient = axios.create({
|
||||
baseURL: 'http://192.168.3.11:8000',
|
||||
baseURL: API_BASE_URL,
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
@@ -40,37 +41,20 @@ export async function login(username: string, password: string): Promise<{ token
|
||||
}
|
||||
|
||||
// Projects
|
||||
export async function getProjects(): Promise<Project[]> {
|
||||
const response = await apiClient.get('/api/projects');
|
||||
return response.data.map((p: any) => ({
|
||||
id: String(p.id),
|
||||
name: p.name,
|
||||
description: p.description,
|
||||
status: p.status,
|
||||
frames: p.frame_count ?? 0,
|
||||
fps: p.original_fps ? `${Math.round(p.original_fps)}FPS` : '30FPS',
|
||||
thumbnail_url: p.thumbnail_url,
|
||||
video_path: p.video_path,
|
||||
source_type: p.source_type,
|
||||
original_fps: p.original_fps,
|
||||
parse_fps: p.parse_fps,
|
||||
createdAt: p.created_at,
|
||||
updatedAt: p.updated_at,
|
||||
}));
|
||||
function normalizeProjectStatus(status?: string): Project['status'] {
|
||||
const value = (status || 'pending').toLowerCase();
|
||||
if (value === 'ready') return 'ready';
|
||||
if (value === 'parsing' || value === 'queued' || value === 'running') return 'parsing';
|
||||
if (value === 'error' || value === 'failed') return 'error';
|
||||
return 'pending';
|
||||
}
|
||||
|
||||
export async function createProject(payload: {
|
||||
name: string;
|
||||
description?: string;
|
||||
parse_fps?: number;
|
||||
}): Promise<Project> {
|
||||
const response = await apiClient.post('/api/projects', payload);
|
||||
const p = response.data;
|
||||
function mapProject(p: any): Project {
|
||||
return {
|
||||
id: String(p.id),
|
||||
name: p.name,
|
||||
description: p.description,
|
||||
status: p.status,
|
||||
status: normalizeProjectStatus(p.status),
|
||||
frames: p.frame_count ?? 0,
|
||||
fps: p.original_fps ? `${Math.round(p.original_fps)}FPS` : '30FPS',
|
||||
thumbnail_url: p.thumbnail_url,
|
||||
@@ -83,9 +67,23 @@ export async function createProject(payload: {
|
||||
};
|
||||
}
|
||||
|
||||
export async function getProjects(): Promise<Project[]> {
|
||||
const response = await apiClient.get('/api/projects');
|
||||
return response.data.map(mapProject);
|
||||
}
|
||||
|
||||
export async function createProject(payload: {
|
||||
name: string;
|
||||
description?: string;
|
||||
parse_fps?: number;
|
||||
}): Promise<Project> {
|
||||
const response = await apiClient.post('/api/projects', payload);
|
||||
return mapProject(response.data);
|
||||
}
|
||||
|
||||
export async function updateProject(id: string, payload: Partial<Project>): Promise<Project> {
|
||||
const response = await apiClient.put(`/api/projects/${id}`, payload);
|
||||
return response.data;
|
||||
const response = await apiClient.patch(`/api/projects/${id}`, payload);
|
||||
return mapProject(response.data);
|
||||
}
|
||||
|
||||
export async function deleteProject(id: string): Promise<void> {
|
||||
@@ -170,26 +168,46 @@ export async function uploadDicomBatch(files: File[], projectId?: string): Promi
|
||||
return response.data;
|
||||
}
|
||||
|
||||
export async function parseMedia(projectId: string): Promise<{
|
||||
project_id: number;
|
||||
frames_extracted: number;
|
||||
status: string;
|
||||
message: string;
|
||||
}> {
|
||||
export interface ProcessingTask {
|
||||
id: number;
|
||||
task_type: string;
|
||||
status: 'queued' | 'running' | 'success' | 'failed' | string;
|
||||
progress: number;
|
||||
message?: string | null;
|
||||
project_id?: number | null;
|
||||
celery_task_id?: string | null;
|
||||
payload?: Record<string, unknown> | null;
|
||||
result?: Record<string, unknown> | null;
|
||||
error?: string | null;
|
||||
created_at: string;
|
||||
started_at?: string | null;
|
||||
finished_at?: string | null;
|
||||
updated_at: string;
|
||||
}
|
||||
|
||||
export async function parseMedia(projectId: string): Promise<ProcessingTask> {
|
||||
const response = await apiClient.post('/api/media/parse', null, {
|
||||
params: { project_id: projectId },
|
||||
});
|
||||
return response.data;
|
||||
}
|
||||
|
||||
// AI Prediction
|
||||
export async function predictMask(payload: {
|
||||
imageUrl: string;
|
||||
export async function getTask(taskId: string | number): Promise<ProcessingTask> {
|
||||
const response = await apiClient.get(`/api/tasks/${taskId}`);
|
||||
return response.data;
|
||||
}
|
||||
|
||||
interface PredictMaskPayload {
|
||||
imageId: string;
|
||||
imageWidth: number;
|
||||
imageHeight: number;
|
||||
model?: AiModelId;
|
||||
points?: { x: number; y: number; type: 'pos' | 'neg' }[];
|
||||
box?: { x1: number; y1: number; x2: number; y2: number };
|
||||
text?: string;
|
||||
modelSize?: string;
|
||||
}): Promise<{
|
||||
}
|
||||
|
||||
interface PredictMaskResult {
|
||||
masks: Array<{
|
||||
id: string;
|
||||
pathData: string;
|
||||
@@ -200,14 +218,319 @@ export async function predictMask(payload: {
|
||||
area: number;
|
||||
confidence: number;
|
||||
}>;
|
||||
}> {
|
||||
const response = await apiClient.post('/api/ai/predict', payload);
|
||||
}
|
||||
|
||||
export interface AiModelStatus {
|
||||
id: AiModelId;
|
||||
label: string;
|
||||
available: boolean;
|
||||
loaded: boolean;
|
||||
device: string;
|
||||
supports: string[];
|
||||
message: string;
|
||||
package_available: boolean;
|
||||
checkpoint_exists: boolean;
|
||||
checkpoint_path?: string | null;
|
||||
python_ok: boolean;
|
||||
torch_ok: boolean;
|
||||
cuda_required: boolean;
|
||||
}
|
||||
|
||||
export interface AiRuntimeStatus {
|
||||
selected_model: AiModelId;
|
||||
gpu: {
|
||||
available: boolean;
|
||||
device: string;
|
||||
name?: string | null;
|
||||
torch_available: boolean;
|
||||
torch_version?: string | null;
|
||||
cuda_version?: string | null;
|
||||
};
|
||||
models: AiModelStatus[];
|
||||
}
|
||||
|
||||
export interface SavedAnnotation {
|
||||
id: number;
|
||||
project_id: number;
|
||||
frame_id: number | null;
|
||||
template_id: number | null;
|
||||
mask_data: {
|
||||
polygons?: number[][][];
|
||||
label?: string;
|
||||
color?: string;
|
||||
class?: {
|
||||
id?: string;
|
||||
name?: string;
|
||||
color?: string;
|
||||
zIndex?: number;
|
||||
category?: string;
|
||||
};
|
||||
} | null;
|
||||
points: number[][] | null;
|
||||
bbox: number[] | null;
|
||||
created_at: string;
|
||||
updated_at: string;
|
||||
}
|
||||
|
||||
export interface SaveAnnotationPayload {
|
||||
project_id: number;
|
||||
frame_id?: number;
|
||||
template_id?: number;
|
||||
mask_data?: {
|
||||
polygons: number[][][];
|
||||
label?: string;
|
||||
color?: string;
|
||||
class?: {
|
||||
id?: string;
|
||||
name?: string;
|
||||
color?: string;
|
||||
zIndex?: number;
|
||||
category?: string;
|
||||
};
|
||||
};
|
||||
points?: number[][];
|
||||
bbox?: number[];
|
||||
}
|
||||
|
||||
export type UpdateAnnotationPayload = Omit<SaveAnnotationPayload, 'project_id' | 'frame_id'>;
|
||||
|
||||
export interface DashboardTask {
|
||||
id: string;
|
||||
task_id?: number;
|
||||
project_id: number;
|
||||
name: string;
|
||||
progress: number;
|
||||
status: string;
|
||||
frame_count: number;
|
||||
updated_at: string | null;
|
||||
}
|
||||
|
||||
export interface DashboardActivity {
|
||||
id: string;
|
||||
kind: 'project' | 'annotation' | 'template' | string;
|
||||
time: string | null;
|
||||
message: string;
|
||||
project: string;
|
||||
}
|
||||
|
||||
export interface DashboardOverview {
|
||||
summary: {
|
||||
project_count: number;
|
||||
parsing_task_count: number;
|
||||
annotation_count: number;
|
||||
frame_count: number;
|
||||
template_count: number;
|
||||
system_load_percent: number;
|
||||
};
|
||||
tasks: DashboardTask[];
|
||||
activity: DashboardActivity[];
|
||||
}
|
||||
|
||||
function clamp01(value: number): number {
|
||||
return Math.min(Math.max(value, 0), 1);
|
||||
}
|
||||
|
||||
function normalizePoint(point: { x: number; y: number }, width: number, height: number): [number, number] {
|
||||
return [
|
||||
clamp01(point.x / Math.max(width, 1)),
|
||||
clamp01(point.y / Math.max(height, 1)),
|
||||
];
|
||||
}
|
||||
|
||||
function polygonToPath(points: number[][], width: number, height: number): string {
|
||||
if (points.length === 0) return '';
|
||||
return points
|
||||
.map(([x, y], index) => {
|
||||
const px = x * width;
|
||||
const py = y * height;
|
||||
return `${index === 0 ? 'M' : 'L'} ${px} ${py}`;
|
||||
})
|
||||
.join(' ')
|
||||
.concat(' Z');
|
||||
}
|
||||
|
||||
function polygonToBbox(points: number[][], width: number, height: number): [number, number, number, number] {
|
||||
const xs = points.map(([x]) => x * width);
|
||||
const ys = points.map(([, y]) => y * height);
|
||||
const minX = Math.min(...xs);
|
||||
const minY = Math.min(...ys);
|
||||
const maxX = Math.max(...xs);
|
||||
const maxY = Math.max(...ys);
|
||||
return [minX, minY, maxX - minX, maxY - minY];
|
||||
}
|
||||
|
||||
function pixelSegmentationToNormalizedPolygons(
|
||||
segmentation: number[][] | undefined,
|
||||
width: number,
|
||||
height: number,
|
||||
): number[][][] {
|
||||
if (!segmentation) return [];
|
||||
return segmentation
|
||||
.map((poly) => {
|
||||
const points: number[][] = [];
|
||||
for (let i = 0; i < poly.length - 1; i += 2) {
|
||||
points.push([
|
||||
clamp01(poly[i] / Math.max(width, 1)),
|
||||
clamp01(poly[i + 1] / Math.max(height, 1)),
|
||||
]);
|
||||
}
|
||||
return points;
|
||||
})
|
||||
.filter((points) => points.length > 0);
|
||||
}
|
||||
|
||||
export function buildAnnotationPayload(
|
||||
projectId: string,
|
||||
mask: Mask,
|
||||
frame: Frame,
|
||||
templateId?: string | null,
|
||||
): SaveAnnotationPayload | null {
|
||||
const polygons = pixelSegmentationToNormalizedPolygons(mask.segmentation, frame.width, frame.height);
|
||||
if (polygons.length === 0) return null;
|
||||
const effectiveTemplateId = mask.templateId || templateId || undefined;
|
||||
const classMetadata = mask.classId || mask.className || mask.classZIndex !== undefined
|
||||
? {
|
||||
id: mask.classId,
|
||||
name: mask.className || mask.label,
|
||||
color: mask.color,
|
||||
zIndex: mask.classZIndex,
|
||||
}
|
||||
: undefined;
|
||||
|
||||
return {
|
||||
project_id: Number(projectId),
|
||||
frame_id: Number(frame.id),
|
||||
template_id: effectiveTemplateId ? Number(effectiveTemplateId) : undefined,
|
||||
mask_data: {
|
||||
polygons,
|
||||
label: mask.label,
|
||||
color: mask.color,
|
||||
...(classMetadata ? { class: classMetadata } : {}),
|
||||
},
|
||||
bbox: mask.bbox
|
||||
? [
|
||||
clamp01(mask.bbox[0] / Math.max(frame.width, 1)),
|
||||
clamp01(mask.bbox[1] / Math.max(frame.height, 1)),
|
||||
clamp01(mask.bbox[2] / Math.max(frame.width, 1)),
|
||||
clamp01(mask.bbox[3] / Math.max(frame.height, 1)),
|
||||
]
|
||||
: undefined,
|
||||
};
|
||||
}
|
||||
|
||||
export function annotationToMask(annotation: SavedAnnotation, frame: Frame): Mask | null {
|
||||
const polygons = annotation.mask_data?.polygons || [];
|
||||
const firstPolygon = polygons[0];
|
||||
if (!firstPolygon || firstPolygon.length === 0) return null;
|
||||
const bbox = polygonToBbox(firstPolygon, frame.width, frame.height);
|
||||
const classMetadata = annotation.mask_data?.class;
|
||||
return {
|
||||
id: `annotation-${annotation.id}`,
|
||||
annotationId: String(annotation.id),
|
||||
frameId: String(annotation.frame_id),
|
||||
templateId: annotation.template_id ? String(annotation.template_id) : undefined,
|
||||
classId: classMetadata?.id,
|
||||
className: classMetadata?.name,
|
||||
classZIndex: classMetadata?.zIndex,
|
||||
saveStatus: 'saved',
|
||||
saved: true,
|
||||
pathData: polygonToPath(firstPolygon, frame.width, frame.height),
|
||||
label: classMetadata?.name || annotation.mask_data?.label || `Annotation ${annotation.id}`,
|
||||
color: classMetadata?.color || annotation.mask_data?.color || '#06b6d4',
|
||||
segmentation: polygons.map((polygon) => polygon.flatMap(([x, y]) => [x * frame.width, y * frame.height])),
|
||||
bbox,
|
||||
area: bbox[2] * bbox[3],
|
||||
};
|
||||
}
|
||||
|
||||
export async function predictMask(payload: PredictMaskPayload): Promise<PredictMaskResult> {
|
||||
let prompt_type: 'point' | 'box' | 'semantic';
|
||||
let prompt_data: unknown;
|
||||
|
||||
if (payload.box) {
|
||||
prompt_type = 'box';
|
||||
prompt_data = [
|
||||
clamp01(payload.box.x1 / Math.max(payload.imageWidth, 1)),
|
||||
clamp01(payload.box.y1 / Math.max(payload.imageHeight, 1)),
|
||||
clamp01(payload.box.x2 / Math.max(payload.imageWidth, 1)),
|
||||
clamp01(payload.box.y2 / Math.max(payload.imageHeight, 1)),
|
||||
];
|
||||
} else if (payload.points && payload.points.length > 0) {
|
||||
prompt_type = 'point';
|
||||
prompt_data = {
|
||||
points: payload.points.map((point) => normalizePoint(point, payload.imageWidth, payload.imageHeight)),
|
||||
labels: payload.points.map((point) => (point.type === 'neg' ? 0 : 1)),
|
||||
};
|
||||
} else {
|
||||
prompt_type = 'semantic';
|
||||
prompt_data = payload.text?.trim() || '';
|
||||
}
|
||||
|
||||
const response = await apiClient.post('/api/ai/predict', {
|
||||
image_id: Number(payload.imageId),
|
||||
prompt_type,
|
||||
prompt_data,
|
||||
model: payload.model || 'sam2',
|
||||
});
|
||||
|
||||
const polygons: number[][][] = response.data.polygons || [];
|
||||
const scores: number[] = response.data.scores || [];
|
||||
return {
|
||||
masks: polygons.map((polygon, index) => {
|
||||
const bbox = polygonToBbox(polygon, payload.imageWidth, payload.imageHeight);
|
||||
return {
|
||||
id: `mask-${payload.imageId}-${Date.now()}-${index}`,
|
||||
pathData: polygonToPath(polygon, payload.imageWidth, payload.imageHeight),
|
||||
label: prompt_type === 'semantic' ? (payload.text?.trim() || 'AI Mask') : 'AI Mask',
|
||||
color: '#06b6d4',
|
||||
segmentation: [polygon.flatMap(([x, y]) => [x * payload.imageWidth, y * payload.imageHeight])],
|
||||
bbox,
|
||||
area: bbox[2] * bbox[3],
|
||||
confidence: scores[index] ?? 0,
|
||||
};
|
||||
}),
|
||||
};
|
||||
}
|
||||
|
||||
export async function getAiModelStatus(selectedModel?: AiModelId): Promise<AiRuntimeStatus> {
|
||||
const response = await apiClient.get('/api/ai/models/status', {
|
||||
params: selectedModel ? { selected_model: selectedModel } : undefined,
|
||||
});
|
||||
return response.data;
|
||||
}
|
||||
|
||||
export async function getProjectAnnotations(projectId: string, frameId?: string): Promise<SavedAnnotation[]> {
|
||||
const response = await apiClient.get('/api/ai/annotations', {
|
||||
params: {
|
||||
project_id: Number(projectId),
|
||||
...(frameId ? { frame_id: Number(frameId) } : {}),
|
||||
},
|
||||
});
|
||||
return response.data;
|
||||
}
|
||||
|
||||
export async function saveAnnotation(payload: SaveAnnotationPayload): Promise<SavedAnnotation> {
|
||||
const response = await apiClient.post('/api/ai/annotate', payload);
|
||||
return response.data;
|
||||
}
|
||||
|
||||
export async function updateAnnotation(annotationId: string, payload: UpdateAnnotationPayload): Promise<SavedAnnotation> {
|
||||
const response = await apiClient.patch(`/api/ai/annotations/${annotationId}`, payload);
|
||||
return response.data;
|
||||
}
|
||||
|
||||
export async function deleteAnnotation(annotationId: string): Promise<void> {
|
||||
await apiClient.delete(`/api/ai/annotations/${annotationId}`);
|
||||
}
|
||||
|
||||
export async function getDashboardOverview(): Promise<DashboardOverview> {
|
||||
const response = await apiClient.get('/api/dashboard/overview');
|
||||
return response.data;
|
||||
}
|
||||
|
||||
// Export
|
||||
export async function exportCoco(projectId: string): Promise<Blob> {
|
||||
const response = await apiClient.get(`/api/export/coco/${projectId}`, {
|
||||
const response = await apiClient.get(`/api/export/${projectId}/coco`, {
|
||||
responseType: 'blob',
|
||||
});
|
||||
return response.data;
|
||||
|
||||
38
src/lib/config.test.ts
Normal file
38
src/lib/config.test.ts
Normal file
@@ -0,0 +1,38 @@
|
||||
import { afterEach, describe, expect, it, vi } from 'vitest';
|
||||
|
||||
describe('frontend runtime config', () => {
|
||||
afterEach(() => {
|
||||
vi.unstubAllEnvs();
|
||||
vi.resetModules();
|
||||
});
|
||||
|
||||
it('prefers explicit VITE_API_BASE_URL and trims trailing slashes', async () => {
|
||||
vi.stubEnv('VITE_API_BASE_URL', 'http://api.example.test:8000///');
|
||||
|
||||
const config = await import('./config');
|
||||
|
||||
expect(config.API_BASE_URL).toBe('http://api.example.test:8000');
|
||||
});
|
||||
|
||||
it('infers the API host from the current browser hostname', async () => {
|
||||
const config = await import('./config');
|
||||
|
||||
expect(config.API_BASE_URL).toBe('http://seg.local:8000');
|
||||
});
|
||||
|
||||
it('derives websocket URL from API URL unless explicitly configured', async () => {
|
||||
vi.stubEnv('VITE_API_BASE_URL', 'https://seg.example.test');
|
||||
|
||||
const config = await import('./config');
|
||||
|
||||
expect(config.WS_PROGRESS_URL).toBe('wss://seg.example.test/ws/progress');
|
||||
});
|
||||
|
||||
it('prefers explicit VITE_WS_PROGRESS_URL', async () => {
|
||||
vi.stubEnv('VITE_WS_PROGRESS_URL', 'ws://custom/ws/progress');
|
||||
|
||||
const config = await import('./config');
|
||||
|
||||
expect(config.WS_PROGRESS_URL).toBe('ws://custom/ws/progress');
|
||||
});
|
||||
});
|
||||
29
src/lib/config.ts
Normal file
29
src/lib/config.ts
Normal file
@@ -0,0 +1,29 @@
|
||||
const DEFAULT_API_BASE_URL = 'http://192.168.3.11:8000';
|
||||
|
||||
function trimTrailingSlash(value: string): string {
|
||||
return value.replace(/\/+$/, '');
|
||||
}
|
||||
|
||||
function inferApiBaseUrl(): string {
|
||||
const envUrl = import.meta.env.VITE_API_BASE_URL;
|
||||
if (envUrl) return trimTrailingSlash(envUrl);
|
||||
|
||||
if (typeof window !== 'undefined' && window.location.hostname) {
|
||||
return `${window.location.protocol}//${window.location.hostname}:8000`;
|
||||
}
|
||||
|
||||
return DEFAULT_API_BASE_URL;
|
||||
}
|
||||
|
||||
export const API_BASE_URL = inferApiBaseUrl();
|
||||
|
||||
function inferWsProgressUrl(): string {
|
||||
const envUrl = import.meta.env.VITE_WS_PROGRESS_URL;
|
||||
if (envUrl) return envUrl;
|
||||
|
||||
const url = new URL('/ws/progress', API_BASE_URL);
|
||||
url.protocol = url.protocol === 'https:' ? 'wss:' : 'ws:';
|
||||
return url.toString();
|
||||
}
|
||||
|
||||
export const WS_PROGRESS_URL = inferWsProgressUrl();
|
||||
15
src/lib/templateSelection.ts
Normal file
15
src/lib/templateSelection.ts
Normal file
@@ -0,0 +1,15 @@
|
||||
import type { Template, TemplateClass } from '../store/useStore';
|
||||
|
||||
export function getActiveTemplate(templates: Template[], activeTemplateId: string | null): Template | null {
|
||||
return templates.find((template) => template.id === activeTemplateId) || templates[0] || null;
|
||||
}
|
||||
|
||||
export function getActiveClass(
|
||||
templates: Template[],
|
||||
activeTemplateId: string | null,
|
||||
activeClassId: string | null,
|
||||
): TemplateClass | null {
|
||||
const template = getActiveTemplate(templates, activeTemplateId);
|
||||
if (!template) return null;
|
||||
return template.classes.find((templateClass) => templateClass.id === activeClassId) || null;
|
||||
}
|
||||
46
src/lib/websocket.test.ts
Normal file
46
src/lib/websocket.test.ts
Normal file
@@ -0,0 +1,46 @@
|
||||
import { afterEach, describe, expect, it, vi } from 'vitest';
|
||||
|
||||
describe('progress websocket client', () => {
|
||||
afterEach(() => {
|
||||
vi.restoreAllMocks();
|
||||
vi.resetModules();
|
||||
vi.unstubAllGlobals();
|
||||
});
|
||||
|
||||
it('connects using the configured URL and reports open state', async () => {
|
||||
const instances: any[] = [];
|
||||
class FakeWebSocket {
|
||||
static CONNECTING = 0;
|
||||
static OPEN = 1;
|
||||
readyState = FakeWebSocket.OPEN;
|
||||
onopen?: () => void;
|
||||
onmessage?: (event: MessageEvent) => void;
|
||||
onclose?: () => void;
|
||||
onerror?: () => void;
|
||||
constructor(public url: string) {
|
||||
instances.push(this);
|
||||
}
|
||||
close = vi.fn();
|
||||
}
|
||||
vi.stubGlobal('WebSocket', FakeWebSocket);
|
||||
|
||||
const { progressWS } = await import('./websocket');
|
||||
progressWS.connect();
|
||||
|
||||
expect(instances[0].url).toContain('/ws/progress');
|
||||
expect(progressWS.isConnected()).toBe(true);
|
||||
});
|
||||
|
||||
it('subscribes and unsubscribes progress callbacks', async () => {
|
||||
const { progressWS } = await import('./websocket');
|
||||
const callback = vi.fn();
|
||||
|
||||
const unsubscribe = progressWS.onProgress(callback);
|
||||
(progressWS as any).callbacks.forEach((cb: any) => cb({ type: 'status', message: 'ok' }));
|
||||
unsubscribe();
|
||||
(progressWS as any).callbacks.forEach((cb: any) => cb({ type: 'status', message: 'again' }));
|
||||
|
||||
expect(callback).toHaveBeenCalledTimes(1);
|
||||
expect(callback).toHaveBeenCalledWith({ type: 'status', message: 'ok' });
|
||||
});
|
||||
});
|
||||
@@ -1,12 +1,18 @@
|
||||
import { WS_PROGRESS_URL } from './config';
|
||||
|
||||
type ProgressCallback = (data: ProgressMessage) => void;
|
||||
|
||||
interface ProgressMessage {
|
||||
type: 'progress' | 'status' | 'error' | 'complete';
|
||||
taskId?: string;
|
||||
task_id?: number;
|
||||
project_id?: number;
|
||||
projectName?: string;
|
||||
filename?: string;
|
||||
progress?: number;
|
||||
status?: string;
|
||||
message?: string;
|
||||
error?: string;
|
||||
timestamp?: string;
|
||||
}
|
||||
|
||||
@@ -21,7 +27,7 @@ class ProgressWebSocket {
|
||||
private shouldCloseAfterOpen = false;
|
||||
private currentInterval = 3000;
|
||||
|
||||
constructor(url = 'ws://192.168.3.11:8000/ws/progress') {
|
||||
constructor(url = WS_PROGRESS_URL) {
|
||||
this.url = url;
|
||||
}
|
||||
|
||||
|
||||
56
src/store/useStore.test.ts
Normal file
56
src/store/useStore.test.ts
Normal file
@@ -0,0 +1,56 @@
|
||||
import { beforeEach, describe, expect, it } from 'vitest';
|
||||
import { resetStore } from '../test/storeTestUtils';
|
||||
import { useStore } from './useStore';
|
||||
|
||||
describe('useStore', () => {
|
||||
beforeEach(() => {
|
||||
resetStore();
|
||||
});
|
||||
|
||||
it('stores and clears auth state with localStorage', () => {
|
||||
useStore.getState().login('token-1');
|
||||
|
||||
expect(useStore.getState().isAuthenticated).toBe(true);
|
||||
expect(useStore.getState().token).toBe('token-1');
|
||||
expect(localStorage.getItem('token')).toBe('token-1');
|
||||
|
||||
useStore.getState().logout();
|
||||
|
||||
expect(useStore.getState().isAuthenticated).toBe(false);
|
||||
expect(useStore.getState().projects).toEqual([]);
|
||||
expect(useStore.getState().frames).toEqual([]);
|
||||
expect(localStorage.getItem('token')).toBeNull();
|
||||
});
|
||||
|
||||
it('manages projects, frames, masks, annotations and templates', () => {
|
||||
const project = { id: '1', name: 'Project', status: 'ready' as const };
|
||||
useStore.getState().addProject(project);
|
||||
useStore.getState().updateProject({ ...project, name: 'Updated' });
|
||||
useStore.getState().setCurrentProject(project);
|
||||
useStore.getState().setFrames([{ id: 'f1', projectId: '1', index: 0, url: '/f1.jpg', width: 640, height: 360 }]);
|
||||
useStore.getState().setCurrentFrame(0);
|
||||
useStore.getState().addMask({ id: 'm1', frameId: 'f1', pathData: 'M 0 0 Z', label: 'mask', color: '#fff' });
|
||||
useStore.getState().updateMask('m1', { label: 'updated mask', saveStatus: 'dirty' });
|
||||
useStore.getState().addAnnotation({ id: 'a1', frameId: 'f1', type: 'mask', points: [], label: 'ann', color: '#fff' });
|
||||
useStore.getState().addTemplate({ id: 't1', name: 'Template', classes: [], rules: [] });
|
||||
useStore.getState().updateTemplate({ id: 't1', name: 'Template 2', classes: [], rules: [] });
|
||||
useStore.getState().setActiveClass({ id: 'c1', name: '胆囊', color: '#ff0000', zIndex: 10 });
|
||||
|
||||
expect(useStore.getState().projects[0].name).toBe('Updated');
|
||||
expect(useStore.getState().currentProject?.id).toBe('1');
|
||||
expect(useStore.getState().frames).toHaveLength(1);
|
||||
expect(useStore.getState().currentFrameIndex).toBe(0);
|
||||
expect(useStore.getState().masks[0]).toEqual(expect.objectContaining({ label: 'updated mask', saveStatus: 'dirty' }));
|
||||
expect(useStore.getState().annotations).toHaveLength(1);
|
||||
expect(useStore.getState().templates[0].name).toBe('Template 2');
|
||||
expect(useStore.getState().activeClassId).toBe('c1');
|
||||
|
||||
useStore.getState().removeAnnotation('a1');
|
||||
useStore.getState().clearMasks();
|
||||
useStore.getState().removeTemplate('t1');
|
||||
|
||||
expect(useStore.getState().annotations).toEqual([]);
|
||||
expect(useStore.getState().masks).toEqual([]);
|
||||
expect(useStore.getState().templates).toEqual([]);
|
||||
});
|
||||
});
|
||||
@@ -4,7 +4,7 @@ export interface Project {
|
||||
id: string;
|
||||
name: string;
|
||||
description?: string;
|
||||
status: 'Ready' | 'Parsing' | 'Error';
|
||||
status: 'pending' | 'parsing' | 'ready' | 'error';
|
||||
fps?: string;
|
||||
frames?: number;
|
||||
thumbnail?: string;
|
||||
@@ -17,6 +17,8 @@ export interface Project {
|
||||
updatedAt?: string;
|
||||
}
|
||||
|
||||
export type AiModelId = 'sam2' | 'sam3';
|
||||
|
||||
export interface Frame {
|
||||
id: string;
|
||||
projectId: string;
|
||||
@@ -42,6 +44,13 @@ export interface Annotation {
|
||||
export interface Mask {
|
||||
id: string;
|
||||
frameId: string;
|
||||
annotationId?: string;
|
||||
templateId?: string;
|
||||
classId?: string;
|
||||
className?: string;
|
||||
classZIndex?: number;
|
||||
saveStatus?: 'draft' | 'saved' | 'dirty' | 'saving' | 'error';
|
||||
saved?: boolean;
|
||||
pathData: string;
|
||||
label: string;
|
||||
color: string;
|
||||
@@ -96,24 +105,32 @@ export interface AppState {
|
||||
// Workspace
|
||||
activeModule: string;
|
||||
activeTool: string;
|
||||
aiModel: AiModelId;
|
||||
frames: Frame[];
|
||||
currentFrameIndex: number;
|
||||
annotations: Annotation[];
|
||||
masks: Mask[];
|
||||
setActiveModule: (module: string) => void;
|
||||
setActiveTool: (tool: string) => void;
|
||||
setAiModel: (model: AiModelId) => void;
|
||||
setFrames: (frames: Frame[]) => void;
|
||||
setCurrentFrame: (index: number) => void;
|
||||
addAnnotation: (annotation: Annotation) => void;
|
||||
addMask: (mask: Mask) => void;
|
||||
updateMask: (id: string, updates: Partial<Mask>) => void;
|
||||
setMasks: (masks: Mask[]) => void;
|
||||
clearMasks: () => void;
|
||||
removeAnnotation: (id: string) => void;
|
||||
|
||||
// Templates
|
||||
templates: Template[];
|
||||
activeTemplateId: string | null;
|
||||
activeClassId: string | null;
|
||||
activeClass: TemplateClass | null;
|
||||
setTemplates: (templates: Template[]) => void;
|
||||
setActiveTemplateId: (id: string | null) => void;
|
||||
setActiveClassId: (id: string | null) => void;
|
||||
setActiveClass: (templateClass: TemplateClass | null) => void;
|
||||
addTemplate: (template: Template) => void;
|
||||
updateTemplate: (template: Template) => void;
|
||||
removeTemplate: (id: string) => void;
|
||||
@@ -144,6 +161,9 @@ export const useStore = create<AppState>((set) => ({
|
||||
frames: [],
|
||||
annotations: [],
|
||||
masks: [],
|
||||
activeTemplateId: null,
|
||||
activeClassId: null,
|
||||
activeClass: null,
|
||||
});
|
||||
},
|
||||
|
||||
@@ -162,18 +182,25 @@ export const useStore = create<AppState>((set) => ({
|
||||
// Workspace
|
||||
activeModule: 'workspace',
|
||||
activeTool: 'move',
|
||||
aiModel: 'sam2',
|
||||
frames: [],
|
||||
currentFrameIndex: 0,
|
||||
annotations: [],
|
||||
masks: [],
|
||||
setActiveModule: (activeModule: string) => set({ activeModule }),
|
||||
setActiveTool: (activeTool: string) => set({ activeTool }),
|
||||
setAiModel: (aiModel: AiModelId) => set({ aiModel }),
|
||||
setFrames: (frames: Frame[]) => set({ frames }),
|
||||
setCurrentFrame: (currentFrameIndex: number) => set({ currentFrameIndex }),
|
||||
addAnnotation: (annotation: Annotation) =>
|
||||
set((state) => ({ annotations: [...state.annotations, annotation] })),
|
||||
addMask: (mask: Mask) =>
|
||||
set((state) => ({ masks: [...state.masks, mask] })),
|
||||
updateMask: (id: string, updates: Partial<Mask>) =>
|
||||
set((state) => ({
|
||||
masks: state.masks.map((mask) => (mask.id === id ? { ...mask, ...updates } : mask)),
|
||||
})),
|
||||
setMasks: (masks: Mask[]) => set({ masks }),
|
||||
clearMasks: () => set({ masks: [] }),
|
||||
removeAnnotation: (id: string) =>
|
||||
set((state) => ({
|
||||
@@ -183,8 +210,15 @@ export const useStore = create<AppState>((set) => ({
|
||||
// Templates
|
||||
templates: [],
|
||||
activeTemplateId: null,
|
||||
activeClassId: null,
|
||||
activeClass: null,
|
||||
setTemplates: (templates: Template[]) => set({ templates }),
|
||||
setActiveTemplateId: (activeTemplateId: string | null) => set({ activeTemplateId }),
|
||||
setActiveClassId: (activeClassId: string | null) => set({ activeClassId }),
|
||||
setActiveClass: (activeClass: TemplateClass | null) => set({
|
||||
activeClass,
|
||||
activeClassId: activeClass?.id || null,
|
||||
}),
|
||||
addTemplate: (template: Template) =>
|
||||
set((state) => ({ templates: [...state.templates, template] })),
|
||||
updateTemplate: (template: Template) =>
|
||||
|
||||
66
src/test/setup.tsx
Normal file
66
src/test/setup.tsx
Normal file
@@ -0,0 +1,66 @@
|
||||
import React from 'react';
|
||||
import { afterEach, vi } from 'vitest';
|
||||
import { cleanup } from '@testing-library/react';
|
||||
import '@testing-library/jest-dom/vitest';
|
||||
|
||||
afterEach(() => {
|
||||
cleanup();
|
||||
localStorage.clear();
|
||||
});
|
||||
|
||||
vi.stubGlobal('alert', vi.fn());
|
||||
vi.stubGlobal('confirm', vi.fn(() => true));
|
||||
URL.createObjectURL = vi.fn(() => 'blob:mock-url');
|
||||
URL.revokeObjectURL = vi.fn();
|
||||
HTMLAnchorElement.prototype.click = vi.fn();
|
||||
|
||||
function makeStageEvent(x = 120, y = 80) {
|
||||
const stage = {
|
||||
getPointerPosition: () => ({ x, y }),
|
||||
getRelativePointerPosition: () => ({ x, y }),
|
||||
scaleX: () => 1,
|
||||
x: () => 0,
|
||||
y: () => 0,
|
||||
};
|
||||
|
||||
return {
|
||||
evt: { preventDefault: vi.fn(), deltaY: -1 },
|
||||
target: {
|
||||
getStage: () => stage,
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
vi.mock('react-konva', () => ({
|
||||
Stage: ({ children, onClick, onMouseDown, onMouseUp, onMouseMove, onWheel }: any) => (
|
||||
<div
|
||||
data-testid="konva-stage"
|
||||
onClick={() => onClick?.(makeStageEvent())}
|
||||
onMouseDown={() => onMouseDown?.(makeStageEvent())}
|
||||
onMouseUp={() => onMouseUp?.(makeStageEvent(260, 200))}
|
||||
onMouseMove={() => onMouseMove?.(makeStageEvent(180, 120))}
|
||||
onWheel={() => onWheel?.(makeStageEvent())}
|
||||
>
|
||||
{children}
|
||||
</div>
|
||||
),
|
||||
Layer: ({ children }: any) => <div data-testid="konva-layer">{children}</div>,
|
||||
Group: ({ children }: any) => <div data-testid="konva-group">{children}</div>,
|
||||
Image: ({ image }: any) => <img data-testid="konva-image" alt="" src={image?.src || ''} />,
|
||||
Circle: (props: any) => <span data-testid="konva-circle" data-fill={props.fill} />,
|
||||
Rect: (props: any) => <span data-testid="konva-rect" data-width={props.width} />,
|
||||
Path: (props: any) => <span data-testid="konva-path" data-path={props.data} data-fill={props.fill} />,
|
||||
}));
|
||||
|
||||
vi.mock('use-image', () => ({
|
||||
default: (src: string) => [
|
||||
{
|
||||
src,
|
||||
width: 640,
|
||||
height: 360,
|
||||
naturalWidth: 640,
|
||||
naturalHeight: 360,
|
||||
},
|
||||
'loaded',
|
||||
],
|
||||
}));
|
||||
23
src/test/storeTestUtils.ts
Normal file
23
src/test/storeTestUtils.ts
Normal file
@@ -0,0 +1,23 @@
|
||||
import { useStore } from '../store/useStore';
|
||||
|
||||
export function resetStore() {
|
||||
useStore.setState({
|
||||
isAuthenticated: false,
|
||||
token: null,
|
||||
projects: [],
|
||||
currentProject: null,
|
||||
activeModule: 'workspace',
|
||||
activeTool: 'move',
|
||||
aiModel: 'sam2',
|
||||
frames: [],
|
||||
currentFrameIndex: 0,
|
||||
annotations: [],
|
||||
masks: [],
|
||||
templates: [],
|
||||
activeTemplateId: null,
|
||||
activeClassId: null,
|
||||
activeClass: null,
|
||||
isLoading: false,
|
||||
error: null,
|
||||
});
|
||||
}
|
||||
6
src/vite-env.d.ts
vendored
Normal file
6
src/vite-env.d.ts
vendored
Normal file
@@ -0,0 +1,6 @@
|
||||
/// <reference types="vite/client" />
|
||||
|
||||
interface ImportMetaEnv {
|
||||
readonly VITE_API_BASE_URL?: string;
|
||||
readonly VITE_WS_PROGRESS_URL?: string;
|
||||
}
|
||||
@@ -12,7 +12,7 @@ echo " 语义分割系统全栈启动"
|
||||
echo "========================================"
|
||||
|
||||
# 1. 检查 PostgreSQL
|
||||
echo "[1/5] 检查 PostgreSQL..."
|
||||
echo "[1/6] 检查 PostgreSQL..."
|
||||
if ! pg_isready -q; then
|
||||
echo "Wkmgc" | sudo -S systemctl start postgresql
|
||||
sleep 1
|
||||
@@ -20,7 +20,7 @@ fi
|
||||
pg_isready && echo " ✓ PostgreSQL 就绪"
|
||||
|
||||
# 2. 检查 Redis
|
||||
echo "[2/5] 检查 Redis..."
|
||||
echo "[2/6] 检查 Redis..."
|
||||
if ! redis-cli ping > /dev/null 2>&1; then
|
||||
echo "Wkmgc" | sudo -S systemctl start redis-server
|
||||
sleep 1
|
||||
@@ -28,7 +28,7 @@ fi
|
||||
redis-cli ping && echo " ✓ Redis 就绪"
|
||||
|
||||
# 3. 检查 MinIO
|
||||
echo "[3/5] 检查 MinIO..."
|
||||
echo "[3/6] 检查 MinIO..."
|
||||
if ! curl -s http://localhost:9000/minio/health/live > /dev/null; then
|
||||
nohup minio server /home/wkmgc/minio_data --console-address :9001 > /tmp/minio.log 2>&1 &
|
||||
sleep 3
|
||||
@@ -36,7 +36,7 @@ fi
|
||||
curl -s http://localhost:9000/minio/health/live > /dev/null && echo " ✓ MinIO 就绪 (http://localhost:9001)"
|
||||
|
||||
# 4. 启动 FastAPI 后端
|
||||
echo "[4/5] 启动 FastAPI 后端..."
|
||||
echo "[4/6] 启动 FastAPI 后端..."
|
||||
source /home/wkmgc/miniconda3/etc/profile.d/conda.sh
|
||||
conda activate "$CONDA_ENV"
|
||||
cd "$PROJECT_DIR/backend"
|
||||
@@ -44,8 +44,15 @@ nohup uvicorn main:app --host 0.0.0.0 --port 8000 --reload > /tmp/fastapi.log 2>
|
||||
sleep 2
|
||||
echo " ✓ FastAPI 已启动 (http://localhost:8000/docs)"
|
||||
|
||||
# 5. 启动前端
|
||||
echo "[5/5] 启动前端..."
|
||||
# 5. 启动 Celery Worker
|
||||
echo "[5/6] 启动 Celery Worker..."
|
||||
cd "$PROJECT_DIR/backend"
|
||||
nohup celery -A celery_app:celery_app worker --loglevel=info --concurrency=1 > /tmp/celery.log 2>&1 &
|
||||
sleep 2
|
||||
echo " ✓ Celery Worker 已启动"
|
||||
|
||||
# 6. 启动前端
|
||||
echo "[6/6] 启动前端..."
|
||||
cd "$PROJECT_DIR"
|
||||
nohup npm start > /tmp/frontend.log 2>&1 &
|
||||
sleep 2
|
||||
@@ -61,6 +68,7 @@ echo "MinIO: http://localhost:9001"
|
||||
echo ""
|
||||
echo "日志文件:"
|
||||
echo " FastAPI: /tmp/fastapi.log"
|
||||
echo " Celery: /tmp/celery.log"
|
||||
echo " 前端: /tmp/frontend.log"
|
||||
echo " MinIO: /tmp/minio.log"
|
||||
echo "========================================"
|
||||
|
||||
24
vitest.config.ts
Normal file
24
vitest.config.ts
Normal file
@@ -0,0 +1,24 @@
|
||||
import react from '@vitejs/plugin-react';
|
||||
import path from 'path';
|
||||
import { defineConfig } from 'vitest/config';
|
||||
|
||||
export default defineConfig({
|
||||
plugins: [react()],
|
||||
resolve: {
|
||||
alias: {
|
||||
'@': path.resolve(__dirname, '.'),
|
||||
},
|
||||
},
|
||||
test: {
|
||||
environment: 'jsdom',
|
||||
environmentOptions: {
|
||||
jsdom: {
|
||||
url: 'http://seg.local:3000',
|
||||
},
|
||||
},
|
||||
globals: true,
|
||||
setupFiles: './src/test/setup.tsx',
|
||||
include: ['src/**/*.{test,spec}.{ts,tsx}'],
|
||||
css: false,
|
||||
},
|
||||
});
|
||||
Reference in New Issue
Block a user