feat: 打通全栈标注闭环、异步拆帧与模型状态
后端能力: - 新增 Celery app、worker task、ProcessingTask 模型、/api/tasks 查询接口和 media_task_runner,将 /api/media/parse 改为创建后台任务并由 worker 执行 FFmpeg/OpenCV/pydicom 拆帧。 - 新增 Redis 进度事件模块和 FastAPI Redis pub/sub 订阅,将 worker 任务进度广播到 /ws/progress;Dashboard 后端概览接口改为聚合 projects/frames/annotations/templates/processing_tasks。 - 统一项目状态为 pending/parsing/ready/error,新增共享 status 常量,并让前端兼容归一化旧状态值。 - 扩展 AI 后端:新增 SAM registry、SAM2 真实运行状态、SAM3 状态检测与文本语义推理适配入口,以及 /api/ai/models/status GPU/模型状态接口。 - 补齐标注保存/更新/删除、COCO/PNG mask 导出相关后端契约和模板 mapping_rules 打包/解包行为。 前端能力: - 新增运行时 API/WS 地址推导配置,前端 API 封装对齐 FastAPI 路由、字段映射、任务轮询、标注归档、导出下载和 AI 预测响应转换。 - Dashboard 改为读取 /api/dashboard/overview,并订阅 WebSocket progress/complete/error/status 更新解析队列和实时流转记录。 - 项目库导入视频/DICOM 后创建项目、上传媒体、触发异步解析并刷新真实项目列表。 - 工作区加载真实帧、无帧时触发解析任务、回显已保存标注、保存未归档 mask、更新 dirty mask、清空当前帧后端标注、导出 COCO JSON。 - Canvas 支持当前帧点/框提示调用后端 AI、渲染推理/已保存 mask、应用模板分类并维护保存状态计数;时间轴按项目 fps 播放。 - AI 页面新增 SAM2/SAM3 模型选择,预测请求携带 model;侧边栏和工作区新增真实 GPU/SAM 状态徽标。 - 模板库和本体面板接入真实模板 CRUD、分类编辑、拖拽排序、JSON 导入、默认腹腔镜分类和本地自定义分类选择。 测试与文档: - 新增 Vitest 配置、前端测试 setup、API/config/websocket/store/组件测试,覆盖登录、项目库、Dashboard、Canvas、工作区、模型状态、时间轴、本体和模板库。 - 新增 pytest 后端测试夹具和 auth/projects/templates/media/AI/export/dashboard/tasks/progress 测试,使用 SQLite、fake MinIO、fake SAM registry 和 Redis monkeypatch 隔离外部服务。 - 新增 doc/ 文档结构,冻结当前需求、设计、接口契约、测试计划、前端逐元素审计、实现地图和后续实施计划,并同步更新 README 与 AGENTS。 验证: - conda run -n seg_server pytest backend/tests:27 passed。 - npm run test:run:54 passed。 - npm run lint、npm run build、compileall、git diff --check 均通过;Vite 仅提示大 chunk 警告。
This commit is contained in:
43
src/components/AISegmentation.test.tsx
Normal file
43
src/components/AISegmentation.test.tsx
Normal file
@@ -0,0 +1,43 @@
|
||||
import { fireEvent, render, screen } from '@testing-library/react';
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest';
|
||||
import { resetStore } from '../test/storeTestUtils';
|
||||
import { useStore } from '../store/useStore';
|
||||
import { AISegmentation } from './AISegmentation';
|
||||
|
||||
const apiMock = vi.hoisted(() => ({
|
||||
getAiModelStatus: vi.fn(),
|
||||
predictMask: vi.fn(),
|
||||
}));
|
||||
|
||||
vi.mock('../lib/api', () => ({
|
||||
getAiModelStatus: apiMock.getAiModelStatus,
|
||||
predictMask: apiMock.predictMask,
|
||||
}));
|
||||
|
||||
describe('AISegmentation', () => {
|
||||
beforeEach(() => {
|
||||
resetStore();
|
||||
vi.clearAllMocks();
|
||||
useStore.setState({
|
||||
frames: [{ id: 'frame-1', projectId: 'project-1', index: 0, url: '/frame.jpg', width: 640, height: 360 }],
|
||||
});
|
||||
apiMock.getAiModelStatus.mockResolvedValue({
|
||||
selected_model: 'sam2',
|
||||
gpu: { available: true, device: 'cuda', name: 'RTX 4090', torch_available: true },
|
||||
models: [
|
||||
{ id: 'sam2', label: 'SAM 2', available: true, loaded: false, device: 'cuda', supports: ['point', 'box'], message: 'SAM 2 ready', package_available: true, checkpoint_exists: true, python_ok: true, torch_ok: true, cuda_required: false },
|
||||
{ id: 'sam3', label: 'SAM 3', available: false, loaded: false, device: 'cuda', supports: ['semantic'], message: 'SAM 3 missing runtime', package_available: false, checkpoint_exists: false, python_ok: false, torch_ok: true, cuda_required: true },
|
||||
],
|
||||
});
|
||||
});
|
||||
|
||||
it('lets the user choose SAM3 for subsequent predictions', async () => {
|
||||
render(<AISegmentation onSendToWorkspace={vi.fn()} />);
|
||||
|
||||
const sam3Button = (await screen.findByText('SAM3')).closest('button')!;
|
||||
fireEvent.click(sam3Button);
|
||||
|
||||
expect(useStore.getState().aiModel).toBe('sam3');
|
||||
expect(await screen.findByText('SAM 3 missing runtime')).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
@@ -1,11 +1,11 @@
|
||||
import React, { useState, useCallback } from 'react';
|
||||
import React, { useState, useCallback, useEffect } from 'react';
|
||||
import { Target, PlusCircle, MinusCircle, SquareDashed, Sparkles, SendToBack, Image as ImageIcon, Undo, Redo, Loader2 } from 'lucide-react';
|
||||
import { cn } from '../lib/utils';
|
||||
import { Stage, Layer, Image as KonvaImage, Circle, Path, Group } from 'react-konva';
|
||||
import useImage from 'use-image';
|
||||
import { OntologyInspector } from './OntologyInspector';
|
||||
import { useStore } from '../store/useStore';
|
||||
import { predictMask } from '../lib/api';
|
||||
import { getAiModelStatus, predictMask, type AiRuntimeStatus } from '../lib/api';
|
||||
|
||||
interface AISegmentationProps {
|
||||
onSendToWorkspace: () => void;
|
||||
@@ -17,9 +17,15 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
|
||||
const masks = useStore((state) => state.masks);
|
||||
const addMask = useStore((state) => state.addMask);
|
||||
const clearMasks = useStore((state) => state.clearMasks);
|
||||
const frames = useStore((state) => state.frames);
|
||||
const currentFrameIndex = useStore((state) => state.currentFrameIndex);
|
||||
const activeTemplateId = useStore((state) => state.activeTemplateId);
|
||||
const activeClass = useStore((state) => state.activeClass);
|
||||
const aiModel = useStore((state) => state.aiModel);
|
||||
const setAiModel = useStore((state) => state.setAiModel);
|
||||
|
||||
const [modelSize, setModelSize] = useState('vit_l');
|
||||
const [semanticText, setSemanticText] = useState('');
|
||||
const [modelStatus, setModelStatus] = useState<AiRuntimeStatus | null>(null);
|
||||
const [autoDeleteBg, setAutoDeleteBg] = useState(true);
|
||||
const [cropMode, setCropMode] = useState(false);
|
||||
const [isInferencing, setIsInferencing] = useState(false);
|
||||
@@ -29,10 +35,29 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
|
||||
const [position, setPosition] = useState({ x: 0, y: 0 });
|
||||
const [points, setPoints] = useState<{ x: number, y: number, type: 'pos'|'neg' }[]>([]);
|
||||
const [cursorPos, setCursorPos] = useState({ x: 0, y: 0 });
|
||||
const [image] = useImage('https://images.unsplash.com/photo-1549317661-bd32c8ce0be2?q=80&w=2070&auto=format&fit=crop');
|
||||
const currentFrame = frames[currentFrameIndex] || null;
|
||||
const previewUrl = currentFrame?.url || 'https://images.unsplash.com/photo-1549317661-bd32c8ce0be2?q=80&w=2070&auto=format&fit=crop';
|
||||
const [image] = useImage(previewUrl);
|
||||
const frameMasks = currentFrame ? masks.filter((mask) => mask.frameId === currentFrame.id) : masks;
|
||||
const selectedModelStatus = modelStatus?.models.find((model) => model.id === aiModel);
|
||||
const modelCanInfer = selectedModelStatus?.available ?? true;
|
||||
|
||||
const effectiveTool = storeActiveTool;
|
||||
|
||||
useEffect(() => {
|
||||
let cancelled = false;
|
||||
getAiModelStatus(aiModel)
|
||||
.then((status) => {
|
||||
if (!cancelled) setModelStatus(status);
|
||||
})
|
||||
.catch(() => {
|
||||
if (!cancelled) setModelStatus(null);
|
||||
});
|
||||
return () => {
|
||||
cancelled = true;
|
||||
};
|
||||
}, [aiModel]);
|
||||
|
||||
const handleWheel = (e: any) => {
|
||||
e.evt.preventDefault();
|
||||
const scaleBy = 1.1;
|
||||
@@ -63,22 +88,44 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
|
||||
|
||||
const runInference = useCallback(async () => {
|
||||
if (points.length === 0 && !semanticText.trim()) return;
|
||||
if (!currentFrame?.id) {
|
||||
console.warn('AI inference skipped: no project frame is selected');
|
||||
return;
|
||||
}
|
||||
|
||||
const imageWidth = currentFrame.width || image?.naturalWidth || image?.width || 0;
|
||||
const imageHeight = currentFrame.height || image?.naturalHeight || image?.height || 0;
|
||||
if (imageWidth <= 0 || imageHeight <= 0) {
|
||||
console.warn('AI inference skipped: active frame dimensions are unavailable');
|
||||
return;
|
||||
}
|
||||
|
||||
setIsInferencing(true);
|
||||
try {
|
||||
const result = await predictMask({
|
||||
imageUrl: 'https://images.unsplash.com/photo-1549317661-bd32c8ce0be2?q=80&w=2070&auto=format&fit=crop',
|
||||
imageId: currentFrame.id,
|
||||
imageWidth,
|
||||
imageHeight,
|
||||
model: aiModel,
|
||||
points: points.map((p) => ({ x: p.x, y: p.y, type: p.type })),
|
||||
text: semanticText.trim() || undefined,
|
||||
modelSize,
|
||||
});
|
||||
|
||||
result.masks.forEach((m) => {
|
||||
const label = activeClass?.name || m.label;
|
||||
const color = activeClass?.color || m.color;
|
||||
addMask({
|
||||
id: m.id,
|
||||
frameId: 'frame-ai-1',
|
||||
frameId: currentFrame.id,
|
||||
templateId: activeTemplateId || undefined,
|
||||
classId: activeClass?.id,
|
||||
className: activeClass?.name,
|
||||
classZIndex: activeClass?.zIndex,
|
||||
saveStatus: 'draft',
|
||||
saved: false,
|
||||
pathData: m.pathData,
|
||||
label: m.label,
|
||||
color: m.color,
|
||||
label,
|
||||
color,
|
||||
segmentation: m.segmentation,
|
||||
bbox: m.bbox,
|
||||
area: m.area,
|
||||
@@ -89,7 +136,7 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
|
||||
} finally {
|
||||
setIsInferencing(false);
|
||||
}
|
||||
}, [points, semanticText, modelSize, addMask]);
|
||||
}, [activeClass, activeTemplateId, addMask, aiModel, currentFrame?.height, currentFrame?.id, currentFrame?.width, image?.height, image?.naturalHeight, image?.naturalWidth, image?.width, points, semanticText]);
|
||||
|
||||
const handleStageClick = (e: any) => {
|
||||
if (effectiveTool === 'move') return;
|
||||
@@ -117,17 +164,26 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
|
||||
{/* Model Select */}
|
||||
<div>
|
||||
<h3 className="text-[10px] font-bold text-gray-500 uppercase tracking-widest mb-3">视觉基础模型选型</h3>
|
||||
<div className="bg-[#111] border border-white/5 flex p-1 rounded-lg">
|
||||
{['vit_b', 'vit_l', 'vit_h'].map(m => (
|
||||
<div className="bg-[#111] border border-white/5 grid grid-cols-2 gap-1 p-1 rounded-lg">
|
||||
{(modelStatus?.models || [
|
||||
{ id: 'sam2' as const, label: 'SAM 2', available: true, message: '正在读取 SAM 2 状态' },
|
||||
{ id: 'sam3' as const, label: 'SAM 3', available: false, message: '正在读取 SAM 3 状态' },
|
||||
]).map((m) => (
|
||||
<button
|
||||
key={m}
|
||||
className={cn("flex-1 text-xs py-2 rounded-md transition-colors text-center uppercase tracking-wider font-mono", modelSize === m ? "bg-white/10 text-white font-medium shadow-sm" : "text-gray-500 hover:text-gray-300 hover:bg-white/5")}
|
||||
onClick={() => setModelSize(m)}
|
||||
key={m.id}
|
||||
className={cn("text-xs py-2 rounded-md transition-colors text-center uppercase tracking-wider font-mono", aiModel === m.id ? "bg-white/10 text-white font-medium shadow-sm" : "text-gray-500 hover:text-gray-300 hover:bg-white/5")}
|
||||
onClick={() => setAiModel(m.id)}
|
||||
title={m.message}
|
||||
>
|
||||
{m.split('_')[1]}
|
||||
{m.label.replace(' ', '')}
|
||||
<span className={cn("ml-1", m.available ? "text-emerald-400" : "text-amber-400")}>●</span>
|
||||
</button>
|
||||
))}
|
||||
</div>
|
||||
<div className="mt-2 text-[10px] text-gray-500 leading-relaxed">
|
||||
<div>{selectedModelStatus?.message || '正在读取模型状态...'}</div>
|
||||
<div>GPU: {modelStatus?.gpu.available ? `${modelStatus.gpu.name || 'CUDA'} 可用` : '不可用或未检测到 CUDA'}</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Prompt Tools */}
|
||||
@@ -206,16 +262,16 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
|
||||
<div className="p-6 bg-[#0a0a0a] border-t border-white/5 shrink-0 flex flex-col gap-3">
|
||||
<button
|
||||
onClick={runInference}
|
||||
disabled={isInferencing}
|
||||
disabled={isInferencing || !currentFrame || !modelCanInfer}
|
||||
className={cn(
|
||||
"w-full py-3.5 rounded-lg flex items-center justify-center gap-2 transition-all shadow-lg font-medium tracking-wide text-xs uppercase",
|
||||
isInferencing
|
||||
isInferencing || !currentFrame || !modelCanInfer
|
||||
? "bg-cyan-500/50 text-black/70 cursor-not-allowed"
|
||||
: "bg-cyan-500 hover:bg-cyan-400 text-black shadow-cyan-500/20 hover:shadow-cyan-500/40"
|
||||
)}
|
||||
>
|
||||
{isInferencing ? <Loader2 size={16} className="animate-spin" /> : <Sparkles size={16} />}
|
||||
{isInferencing ? '推理中...' : '执行高精度语义分割'}
|
||||
{isInferencing ? '推理中...' : modelCanInfer ? '执行高精度语义分割' : '当前模型不可用'}
|
||||
</button>
|
||||
<button
|
||||
onClick={onSendToWorkspace}
|
||||
@@ -231,7 +287,7 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
|
||||
<header className="h-16 border-b border-white/5 bg-[#111] flex items-center justify-between px-6 shrink-0">
|
||||
<div className="flex flex-col">
|
||||
<h2 className="text-sm font-semibold tracking-wide text-white">模型端推理侧可视化 (Visualizer)</h2>
|
||||
<span className="text-[10px] text-gray-500 uppercase tracking-widest font-mono">SAM 3 内核级动态即时渲染</span>
|
||||
<span className="text-[10px] text-gray-500 uppercase tracking-widest font-mono">{aiModel.toUpperCase()} 动态推理渲染</span>
|
||||
</div>
|
||||
<div className="flex items-center gap-4">
|
||||
<button className="w-8 h-8 rounded text-gray-400 hover:bg-white/5 hover:text-white flex items-center justify-center transition-colors" title="撤销操作 (Ctrl+Z)">
|
||||
@@ -276,7 +332,7 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
|
||||
)}
|
||||
|
||||
{/* AI Returned Masks */}
|
||||
{masks.map((mask) => (
|
||||
{frameMasks.map((mask) => (
|
||||
<Group key={mask.id} opacity={0.45}>
|
||||
<Path
|
||||
data={mask.pathData}
|
||||
@@ -309,7 +365,7 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
|
||||
<div className="absolute bottom-4 left-4 flex gap-4 text-[10px] font-mono text-gray-500 pointer-events-none">
|
||||
<span>光标坐标: {cursorPos.x.toFixed(2)}, {cursorPos.y.toFixed(2)}</span>
|
||||
<span>缩放比率: {(scale * 100).toFixed(0)}%</span>
|
||||
<span>遮罩数: {masks.length}</span>
|
||||
<span>遮罩数: {frameMasks.length}</span>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
130
src/components/CanvasArea.test.tsx
Normal file
130
src/components/CanvasArea.test.tsx
Normal file
@@ -0,0 +1,130 @@
|
||||
import { fireEvent, render, screen, waitFor } from '@testing-library/react';
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest';
|
||||
import { resetStore } from '../test/storeTestUtils';
|
||||
import { useStore } from '../store/useStore';
|
||||
import { CanvasArea } from './CanvasArea';
|
||||
|
||||
const apiMock = vi.hoisted(() => ({
|
||||
predictMask: vi.fn(),
|
||||
}));
|
||||
|
||||
vi.mock('../lib/api', () => ({
|
||||
predictMask: apiMock.predictMask,
|
||||
}));
|
||||
|
||||
describe('CanvasArea', () => {
|
||||
const frame = { id: 'frame-1', projectId: 'project-1', index: 0, url: '/frame.jpg', width: 640, height: 360 };
|
||||
|
||||
beforeEach(() => {
|
||||
resetStore();
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
it('calls AI prediction with the active frame when a point prompt is placed', async () => {
|
||||
useStore.setState({
|
||||
activeTemplateId: '2',
|
||||
activeClass: { id: 'c1', name: '胆囊', color: '#ff0000', zIndex: 20 },
|
||||
activeClassId: 'c1',
|
||||
});
|
||||
apiMock.predictMask.mockResolvedValueOnce({
|
||||
masks: [
|
||||
{
|
||||
id: 'mask-1',
|
||||
pathData: 'M 0 0 L 10 0 L 10 10 Z',
|
||||
label: 'AI Mask',
|
||||
color: '#06b6d4',
|
||||
segmentation: [[0, 0, 10, 0, 10, 10]],
|
||||
bbox: [0, 0, 10, 10],
|
||||
area: 100,
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
render(<CanvasArea activeTool="point_pos" frame={frame} />);
|
||||
fireEvent.click(screen.getByTestId('konva-stage'));
|
||||
|
||||
await waitFor(() => expect(apiMock.predictMask).toHaveBeenCalledWith({
|
||||
imageId: 'frame-1',
|
||||
imageWidth: 640,
|
||||
imageHeight: 360,
|
||||
model: 'sam2',
|
||||
points: [{ x: 120, y: 80, type: 'pos' }],
|
||||
box: undefined,
|
||||
}));
|
||||
expect(useStore.getState().masks[0]).toEqual(expect.objectContaining({
|
||||
id: 'mask-1',
|
||||
frameId: 'frame-1',
|
||||
pathData: 'M 0 0 L 10 0 L 10 10 Z',
|
||||
templateId: '2',
|
||||
classId: 'c1',
|
||||
className: '胆囊',
|
||||
classZIndex: 20,
|
||||
label: '胆囊',
|
||||
color: '#ff0000',
|
||||
saveStatus: 'draft',
|
||||
}));
|
||||
});
|
||||
|
||||
it('renders only masks that belong to the current frame', () => {
|
||||
useStore.setState({
|
||||
masks: [
|
||||
{ id: 'm1', frameId: 'frame-1', pathData: 'M 0 0 Z', label: 'A', color: '#fff' },
|
||||
{ id: 'm2', frameId: 'frame-2', pathData: 'M 1 1 Z', label: 'B', color: '#000' },
|
||||
],
|
||||
});
|
||||
|
||||
render(<CanvasArea activeTool="move" frame={frame} />);
|
||||
|
||||
expect(screen.getAllByTestId('konva-path')).toHaveLength(1);
|
||||
expect(screen.getByText('遮罩数: 1')).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('applies the selected class to current-frame masks and marks saved masks dirty', () => {
|
||||
useStore.setState({
|
||||
activeTemplateId: '2',
|
||||
activeClass: { id: 'c1', name: '胆囊', color: '#ff0000', zIndex: 20 },
|
||||
activeClassId: 'c1',
|
||||
masks: [
|
||||
{
|
||||
id: 'm1',
|
||||
frameId: 'frame-1',
|
||||
annotationId: '99',
|
||||
pathData: 'M 0 0 Z',
|
||||
label: '旧标签',
|
||||
color: '#06b6d4',
|
||||
saved: true,
|
||||
saveStatus: 'saved',
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
render(<CanvasArea activeTool="move" frame={frame} />);
|
||||
fireEvent.click(screen.getByRole('button', { name: '应用分类' }));
|
||||
|
||||
expect(useStore.getState().masks[0]).toEqual(expect.objectContaining({
|
||||
templateId: '2',
|
||||
classId: 'c1',
|
||||
className: '胆囊',
|
||||
classZIndex: 20,
|
||||
label: '胆囊',
|
||||
color: '#ff0000',
|
||||
saveStatus: 'dirty',
|
||||
saved: false,
|
||||
}));
|
||||
});
|
||||
|
||||
it('delegates clear to the workspace handler so saved annotations can be deleted', () => {
|
||||
const onClearMasks = vi.fn();
|
||||
useStore.setState({
|
||||
masks: [
|
||||
{ id: 'm1', frameId: 'frame-1', pathData: 'M 0 0 Z', label: 'A', color: '#fff' },
|
||||
],
|
||||
});
|
||||
|
||||
render(<CanvasArea activeTool="move" frame={frame} onClearMasks={onClearMasks} />);
|
||||
fireEvent.click(screen.getByRole('button', { name: '清空遮罩' }));
|
||||
|
||||
expect(onClearMasks).toHaveBeenCalled();
|
||||
expect(useStore.getState().masks).toHaveLength(1);
|
||||
});
|
||||
});
|
||||
@@ -3,14 +3,15 @@ import { Stage, Layer, Image as KonvaImage, Circle, Rect, Path, Group } from 're
|
||||
import useImage from 'use-image';
|
||||
import { useStore } from '../store/useStore';
|
||||
import { predictMask } from '../lib/api';
|
||||
import { cn } from '../lib/utils';
|
||||
import type { Frame } from '../store/useStore';
|
||||
|
||||
interface CanvasAreaProps {
|
||||
activeTool: string;
|
||||
frameUrl: string;
|
||||
frame: Frame | null;
|
||||
onClearMasks?: () => void;
|
||||
}
|
||||
|
||||
export function CanvasArea({ activeTool, frameUrl }: CanvasAreaProps) {
|
||||
export function CanvasArea({ activeTool, frame, onClearMasks }: CanvasAreaProps) {
|
||||
const containerRef = useRef<HTMLDivElement>(null);
|
||||
const [stageSize, setStageSize] = useState({ width: 800, height: 600 });
|
||||
const [scale, setScale] = useState(1);
|
||||
@@ -24,13 +25,20 @@ export function CanvasArea({ activeTool, frameUrl }: CanvasAreaProps) {
|
||||
const masks = useStore((state) => state.masks);
|
||||
const addMask = useStore((state) => state.addMask);
|
||||
const clearMasks = useStore((state) => state.clearMasks);
|
||||
const setMasks = useStore((state) => state.setMasks);
|
||||
const storeActiveTool = useStore((state) => state.activeTool);
|
||||
const setActiveTool = useStore((state) => state.setActiveTool);
|
||||
const aiModel = useStore((state) => state.aiModel);
|
||||
const activeTemplateId = useStore((state) => state.activeTemplateId);
|
||||
const activeClass = useStore((state) => state.activeClass);
|
||||
|
||||
const effectiveTool = activeTool || storeActiveTool;
|
||||
|
||||
// Load the actual frame image
|
||||
const [image] = useImage(frameUrl || '');
|
||||
const [image] = useImage(frame?.url || '');
|
||||
const frameMasks = masks.filter((mask) => mask.frameId === frame?.id);
|
||||
const savedMaskCount = frameMasks.filter((mask) => mask.saveStatus === 'saved' || mask.saved).length;
|
||||
const draftMaskCount = frameMasks.filter((mask) => !mask.annotationId).length;
|
||||
const dirtyMaskCount = frameMasks.filter((mask) => mask.saveStatus === 'dirty').length;
|
||||
|
||||
useEffect(() => {
|
||||
const handleResize = () => {
|
||||
@@ -85,21 +93,44 @@ export function CanvasArea({ activeTool, frameUrl }: CanvasAreaProps) {
|
||||
};
|
||||
|
||||
const runInference = useCallback(async (promptPoints?: typeof points, promptBox?: { x1: number, y1: number, x2: number, y2: number }) => {
|
||||
if (!frame?.id) {
|
||||
console.warn('Inference skipped: no active frame');
|
||||
return;
|
||||
}
|
||||
|
||||
const imageWidth = frame.width || image?.naturalWidth || image?.width || 0;
|
||||
const imageHeight = frame.height || image?.naturalHeight || image?.height || 0;
|
||||
if (imageWidth <= 0 || imageHeight <= 0) {
|
||||
console.warn('Inference skipped: active frame dimensions are unavailable');
|
||||
return;
|
||||
}
|
||||
|
||||
setIsInferencing(true);
|
||||
try {
|
||||
const result = await predictMask({
|
||||
imageUrl: frameUrl || '',
|
||||
imageId: frame.id,
|
||||
imageWidth,
|
||||
imageHeight,
|
||||
model: aiModel,
|
||||
points: promptPoints?.map((p) => ({ x: p.x, y: p.y, type: p.type })),
|
||||
box: promptBox,
|
||||
});
|
||||
|
||||
result.masks.forEach((m) => {
|
||||
const label = activeClass?.name || m.label;
|
||||
const color = activeClass?.color || m.color;
|
||||
addMask({
|
||||
id: m.id,
|
||||
frameId: 'frame-1',
|
||||
frameId: frame.id,
|
||||
templateId: activeTemplateId || undefined,
|
||||
classId: activeClass?.id,
|
||||
className: activeClass?.name,
|
||||
classZIndex: activeClass?.zIndex,
|
||||
saveStatus: 'draft',
|
||||
saved: false,
|
||||
pathData: m.pathData,
|
||||
label: m.label,
|
||||
color: m.color,
|
||||
label,
|
||||
color,
|
||||
segmentation: m.segmentation,
|
||||
bbox: m.bbox,
|
||||
area: m.area,
|
||||
@@ -110,7 +141,33 @@ export function CanvasArea({ activeTool, frameUrl }: CanvasAreaProps) {
|
||||
} finally {
|
||||
setIsInferencing(false);
|
||||
}
|
||||
}, [addMask]);
|
||||
}, [activeClass, activeTemplateId, addMask, aiModel, frame?.height, frame?.id, frame?.width, image?.height, image?.naturalHeight, image?.naturalWidth, image?.width]);
|
||||
|
||||
const handleApplyActiveClass = () => {
|
||||
if (!frame?.id || !activeClass) return;
|
||||
setMasks(masks.map((mask) => {
|
||||
if (mask.frameId !== frame.id) return mask;
|
||||
return {
|
||||
...mask,
|
||||
templateId: activeTemplateId || mask.templateId,
|
||||
classId: activeClass.id,
|
||||
className: activeClass.name,
|
||||
classZIndex: activeClass.zIndex,
|
||||
label: activeClass.name,
|
||||
color: activeClass.color,
|
||||
saveStatus: mask.annotationId ? 'dirty' : 'draft',
|
||||
saved: Boolean(mask.annotationId) ? false : mask.saved,
|
||||
};
|
||||
}));
|
||||
};
|
||||
|
||||
const handleClearMasks = () => {
|
||||
if (onClearMasks) {
|
||||
onClearMasks();
|
||||
return;
|
||||
}
|
||||
clearMasks();
|
||||
};
|
||||
|
||||
const handleStageMouseDown = (e: any) => {
|
||||
if (effectiveTool === 'box_select') {
|
||||
@@ -199,7 +256,7 @@ export function CanvasArea({ activeTool, frameUrl }: CanvasAreaProps) {
|
||||
)}
|
||||
|
||||
{/* AI Returned Masks */}
|
||||
{masks.map((mask) => (
|
||||
{frameMasks.map((mask) => (
|
||||
<Group key={mask.id} opacity={0.5}>
|
||||
<Path
|
||||
data={mask.pathData}
|
||||
@@ -248,16 +305,29 @@ export function CanvasArea({ activeTool, frameUrl }: CanvasAreaProps) {
|
||||
<span>光标: {cursorPos.x.toFixed(2)}, {cursorPos.y.toFixed(2)}</span>
|
||||
<span>当前图层树: OBJECT_VEHICLE_01</span>
|
||||
<span>缩放比: {(scale * 100).toFixed(0)}%</span>
|
||||
<span>遮罩数: {masks.length}</span>
|
||||
<span>遮罩数: {frameMasks.length}</span>
|
||||
<span>已保存: {savedMaskCount}</span>
|
||||
<span>未保存: {draftMaskCount}</span>
|
||||
<span>待更新: {dirtyMaskCount}</span>
|
||||
</div>
|
||||
|
||||
{masks.length > 0 && (
|
||||
<button
|
||||
onClick={clearMasks}
|
||||
className="absolute bottom-4 right-4 text-xs bg-red-500/10 hover:bg-red-500/20 text-red-400 border border-red-500/20 px-3 py-1.5 rounded transition-colors"
|
||||
>
|
||||
清空遮罩
|
||||
</button>
|
||||
{frameMasks.length > 0 && (
|
||||
<div className="absolute bottom-4 right-4 flex gap-2">
|
||||
{activeClass && (
|
||||
<button
|
||||
onClick={handleApplyActiveClass}
|
||||
className="text-xs bg-cyan-500/10 hover:bg-cyan-500/20 text-cyan-300 border border-cyan-500/20 px-3 py-1.5 rounded transition-colors"
|
||||
>
|
||||
应用分类
|
||||
</button>
|
||||
)}
|
||||
<button
|
||||
onClick={handleClearMasks}
|
||||
className="text-xs bg-red-500/10 hover:bg-red-500/20 text-red-400 border border-red-500/20 px-3 py-1.5 rounded transition-colors"
|
||||
>
|
||||
清空遮罩
|
||||
</button>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
|
||||
115
src/components/Dashboard.test.tsx
Normal file
115
src/components/Dashboard.test.tsx
Normal file
@@ -0,0 +1,115 @@
|
||||
import { act, render, screen, waitFor } from '@testing-library/react';
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest';
|
||||
import { Dashboard } from './Dashboard';
|
||||
|
||||
const apiMock = vi.hoisted(() => ({
|
||||
getDashboardOverview: vi.fn(),
|
||||
}));
|
||||
|
||||
const wsMock = vi.hoisted(() => {
|
||||
const state = {
|
||||
callback: undefined as undefined | ((data: any) => void),
|
||||
connected: false,
|
||||
};
|
||||
return {
|
||||
state,
|
||||
progressWS: {
|
||||
connect: vi.fn(() => { state.connected = true; }),
|
||||
disconnect: vi.fn(() => { state.connected = false; }),
|
||||
isConnected: vi.fn(() => state.connected),
|
||||
onProgress: vi.fn((cb: (data: any) => void) => {
|
||||
state.callback = cb;
|
||||
return vi.fn();
|
||||
}),
|
||||
},
|
||||
};
|
||||
});
|
||||
|
||||
vi.mock('../lib/websocket', () => ({
|
||||
progressWS: wsMock.progressWS,
|
||||
}));
|
||||
|
||||
vi.mock('../lib/api', () => ({
|
||||
getDashboardOverview: apiMock.getDashboardOverview,
|
||||
}));
|
||||
|
||||
describe('Dashboard', () => {
|
||||
beforeEach(() => {
|
||||
vi.useRealTimers();
|
||||
vi.clearAllMocks();
|
||||
wsMock.state.connected = false;
|
||||
wsMock.state.callback = undefined;
|
||||
apiMock.getDashboardOverview.mockResolvedValue({
|
||||
summary: {
|
||||
project_count: 2,
|
||||
parsing_task_count: 1,
|
||||
annotation_count: 5,
|
||||
frame_count: 100,
|
||||
template_count: 3,
|
||||
system_load_percent: 12,
|
||||
},
|
||||
tasks: [
|
||||
{
|
||||
id: 'project-1',
|
||||
project_id: 1,
|
||||
name: '真实项目.mp4',
|
||||
progress: 60,
|
||||
status: 'pending',
|
||||
frame_count: 10,
|
||||
updated_at: '2026-05-01T00:00:00Z',
|
||||
},
|
||||
],
|
||||
activity: [
|
||||
{
|
||||
id: 'activity-1',
|
||||
kind: 'project',
|
||||
time: '2026-05-01T00:00:00Z',
|
||||
message: '项目状态: pending',
|
||||
project: '真实项目.mp4',
|
||||
},
|
||||
],
|
||||
});
|
||||
});
|
||||
|
||||
it('loads dashboard stats, tasks, and activity from the backend overview endpoint', async () => {
|
||||
render(<Dashboard />);
|
||||
|
||||
await waitFor(() => expect(apiMock.getDashboardOverview).toHaveBeenCalled());
|
||||
expect(screen.getByText('项目总数')).toBeInTheDocument();
|
||||
expect(screen.getByText('已存标注')).toBeInTheDocument();
|
||||
expect(screen.getByText('真实项目.mp4')).toBeInTheDocument();
|
||||
expect(screen.getByText('项目状态: pending')).toBeInTheDocument();
|
||||
expect(screen.queryByText('City_Driving_Dataset_004.mp4')).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('connects to the progress stream and updates progress tasks', async () => {
|
||||
render(<Dashboard />);
|
||||
|
||||
await waitFor(() => expect(wsMock.progressWS.connect).toHaveBeenCalled());
|
||||
|
||||
act(() => {
|
||||
wsMock.state.callback?.({
|
||||
type: 'progress',
|
||||
taskId: 'task-1',
|
||||
projectName: 'demo.mp4',
|
||||
progress: 44,
|
||||
status: '正在截取帧',
|
||||
});
|
||||
});
|
||||
|
||||
expect(await screen.findByText('demo.mp4')).toBeInTheDocument();
|
||||
expect(screen.getByText('44%')).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('adds activity logs for complete and status messages', async () => {
|
||||
render(<Dashboard />);
|
||||
|
||||
act(() => {
|
||||
wsMock.state.callback?.({ type: 'status', message: 'Progress stream active' });
|
||||
wsMock.state.callback?.({ type: 'complete', taskId: '1', filename: 'done.mp4' });
|
||||
});
|
||||
|
||||
await waitFor(() => expect(screen.getByText('Progress stream active')).toBeInTheDocument());
|
||||
expect(screen.getByText('解析完成: done.mp4')).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
@@ -2,30 +2,68 @@ import React, { useState, useEffect } from 'react';
|
||||
import { Activity, Clock, Folders, CheckCircle2, Loader2 } from 'lucide-react';
|
||||
import { progressWS, type ProgressMessage } from '../lib/websocket';
|
||||
import { cn } from '../lib/utils';
|
||||
import { getDashboardOverview, type DashboardActivity, type DashboardOverview, type DashboardTask } from '../lib/api';
|
||||
|
||||
interface QueueTask {
|
||||
id: string;
|
||||
name: string;
|
||||
progress: number;
|
||||
status: string;
|
||||
}
|
||||
const emptySummary: DashboardOverview['summary'] = {
|
||||
project_count: 0,
|
||||
parsing_task_count: 0,
|
||||
annotation_count: 0,
|
||||
frame_count: 0,
|
||||
template_count: 0,
|
||||
system_load_percent: 0,
|
||||
};
|
||||
|
||||
export function Dashboard() {
|
||||
const [tasks, setTasks] = useState<QueueTask[]>([
|
||||
{ id: '1', name: 'City_Driving_Dataset_004.mp4', progress: 85, status: '正在截取帧 (30fps)' },
|
||||
{ id: '2', name: 'Pedestrian_Night_Vision_02.mkv', progress: 32, status: '正在截取帧 (60fps)' },
|
||||
{ id: '3', name: 'Drone_Mapping_Sector_7.avi', progress: 0, status: '队列排队等待中' },
|
||||
]);
|
||||
const [summary, setSummary] = useState<DashboardOverview['summary']>(emptySummary);
|
||||
const [tasks, setTasks] = useState<DashboardTask[]>([]);
|
||||
const [isConnected, setIsConnected] = useState(false);
|
||||
const [activityLog, setActivityLog] = useState<Array<{ time: string; message: string; project?: string }>>([
|
||||
{ time: '10 分钟前', message: '语义归档完成 54 帧', project: 'Highway_Data' },
|
||||
{ time: '25 分钟前', message: '项目解析开始', project: 'City_Driving_Dataset_004' },
|
||||
{ time: '1 小时前', message: '模板库更新: Cityscapes_v2', project: '系统' },
|
||||
{ time: '2 小时前', message: 'AI 推理完成 12 个实例', project: 'Nav_Cam_Left' },
|
||||
]);
|
||||
const [activityLog, setActivityLog] = useState<DashboardActivity[]>([]);
|
||||
const [isLoading, setIsLoading] = useState(true);
|
||||
const [loadError, setLoadError] = useState('');
|
||||
|
||||
useEffect(() => {
|
||||
let cancelled = false;
|
||||
|
||||
const loadOverview = () => {
|
||||
getDashboardOverview()
|
||||
.then((overview) => {
|
||||
if (cancelled) return;
|
||||
setSummary(overview.summary);
|
||||
setTasks((prev) => {
|
||||
if (prev.length === 0) return overview.tasks;
|
||||
const overviewIds = new Set(overview.tasks.map((task) => task.id));
|
||||
const wsOnly = prev.filter((task) => !task.id.startsWith('task-') && !overviewIds.has(task.id) && task.progress < 100);
|
||||
return [...overview.tasks, ...wsOnly];
|
||||
});
|
||||
setActivityLog((prev) => {
|
||||
if (prev.length === 0) return overview.activity;
|
||||
const byId = new Map(prev.map((item) => [item.id, item]));
|
||||
overview.activity.forEach((item) => byId.set(item.id, item));
|
||||
return Array.from(byId.values()).slice(0, 10);
|
||||
});
|
||||
setLoadError('');
|
||||
})
|
||||
.catch((err) => {
|
||||
console.error('Failed to load dashboard overview:', err);
|
||||
if (!cancelled) setLoadError('Dashboard 数据加载失败');
|
||||
})
|
||||
.finally(() => {
|
||||
if (!cancelled) setIsLoading(false);
|
||||
});
|
||||
};
|
||||
|
||||
loadOverview();
|
||||
const overviewInterval = setInterval(loadOverview, 5000);
|
||||
|
||||
return () => {
|
||||
cancelled = true;
|
||||
clearInterval(overviewInterval);
|
||||
};
|
||||
}, []);
|
||||
|
||||
useEffect(() => {
|
||||
let mounted = true;
|
||||
const taskTitle = (data: ProgressMessage) => data.filename || data.projectName || data.taskId || '后台任务';
|
||||
const timer = setTimeout(() => {
|
||||
if (mounted) progressWS.connect();
|
||||
}, 500);
|
||||
@@ -34,7 +72,7 @@ export function Dashboard() {
|
||||
if (!mounted) return;
|
||||
setIsConnected(progressWS.isConnected());
|
||||
|
||||
if (data.type === 'progress' && data.taskId && data.filename) {
|
||||
if (data.type === 'progress' && data.taskId) {
|
||||
setTasks((prev) => {
|
||||
const exists = prev.find((t) => t.id === data.taskId);
|
||||
if (exists) {
|
||||
@@ -48,9 +86,12 @@ export function Dashboard() {
|
||||
...prev,
|
||||
{
|
||||
id: data.taskId!,
|
||||
name: data.filename!,
|
||||
project_id: data.project_id ?? Number(data.task_id || 0),
|
||||
name: taskTitle(data),
|
||||
progress: data.progress ?? 0,
|
||||
status: data.status ?? '处理中',
|
||||
frame_count: 0,
|
||||
updated_at: new Date().toISOString(),
|
||||
},
|
||||
];
|
||||
});
|
||||
@@ -63,7 +104,7 @@ export function Dashboard() {
|
||||
)
|
||||
);
|
||||
setActivityLog((prev) => [
|
||||
{ time: '刚刚', message: `解析完成: ${data.filename || data.taskId}`, project: '系统' },
|
||||
{ id: `ws-complete-${Date.now()}`, kind: 'websocket', time: new Date().toISOString(), message: data.message || `解析完成: ${taskTitle(data)}`, project: data.projectName || '系统' },
|
||||
...prev.slice(0, 9),
|
||||
]);
|
||||
}
|
||||
@@ -71,14 +112,18 @@ export function Dashboard() {
|
||||
if (data.type === 'error' && data.taskId) {
|
||||
setTasks((prev) =>
|
||||
prev.map((t) =>
|
||||
t.id === data.taskId ? { ...t, status: `错误: ${data.message || '未知错误'}` } : t
|
||||
t.id === data.taskId ? { ...t, progress: data.progress ?? t.progress, status: `错误: ${data.error || data.message || '未知错误'}` } : t
|
||||
)
|
||||
);
|
||||
setActivityLog((prev) => [
|
||||
{ id: `ws-error-${Date.now()}`, kind: 'websocket', time: new Date().toISOString(), message: data.message || `解析失败: ${taskTitle(data)}`, project: data.projectName || '系统' },
|
||||
...prev.slice(0, 9),
|
||||
]);
|
||||
}
|
||||
|
||||
if (data.type === 'status') {
|
||||
setActivityLog((prev) => [
|
||||
{ time: '刚刚', message: data.message || '状态更新', project: '系统' },
|
||||
{ id: `ws-status-${Date.now()}`, kind: 'websocket', time: new Date().toISOString(), message: data.message || '状态更新', project: '系统' },
|
||||
...prev.slice(0, 9),
|
||||
]);
|
||||
}
|
||||
@@ -97,12 +142,24 @@ export function Dashboard() {
|
||||
}, []);
|
||||
|
||||
const stats = [
|
||||
{ label: '运行中项目', value: '14', icon: Folders, color: 'text-blue-400', bg: 'bg-blue-400/10' },
|
||||
{ label: '排队处理任务', value: tasks.length.toString(), icon: Clock, color: 'text-orange-400', bg: 'bg-orange-400/10' },
|
||||
{ label: '已归档批次', value: '128', icon: CheckCircle2, color: 'text-emerald-400', bg: 'bg-emerald-400/10' },
|
||||
{ label: '系统负载', value: '78%', icon: Activity, color: 'text-cyan-400', bg: 'bg-cyan-400/10' },
|
||||
{ label: '项目总数', value: summary.project_count.toString(), icon: Folders, color: 'text-blue-400', bg: 'bg-blue-400/10' },
|
||||
{ label: '处理任务', value: summary.parsing_task_count.toString(), icon: Clock, color: 'text-orange-400', bg: 'bg-orange-400/10' },
|
||||
{ label: '已存标注', value: summary.annotation_count.toString(), icon: CheckCircle2, color: 'text-emerald-400', bg: 'bg-emerald-400/10' },
|
||||
{ label: '系统负载', value: `${summary.system_load_percent}%`, icon: Activity, color: 'text-cyan-400', bg: 'bg-cyan-400/10' },
|
||||
];
|
||||
|
||||
function formatActivityTime(value: string | null): string {
|
||||
if (!value) return '未知时间';
|
||||
const date = new Date(value);
|
||||
if (Number.isNaN(date.getTime())) return value;
|
||||
return date.toLocaleString('zh-CN', {
|
||||
month: '2-digit',
|
||||
day: '2-digit',
|
||||
hour: '2-digit',
|
||||
minute: '2-digit',
|
||||
});
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="p-8 w-full h-full overflow-y-auto bg-[#0a0a0a]">
|
||||
<header className="mb-8">
|
||||
@@ -119,6 +176,7 @@ export function Dashboard() {
|
||||
</div>
|
||||
</div>
|
||||
<p className="text-gray-400 text-sm mt-1">系统全局数据吞吐状态与所有接入项目进度实时洞察驾驶舱。</p>
|
||||
{loadError && <p className="text-red-400 text-xs mt-2">{loadError}</p>}
|
||||
</header>
|
||||
|
||||
<div className="grid grid-cols-1 md:grid-cols-2 lg:grid-cols-4 gap-4 mb-8">
|
||||
@@ -140,8 +198,11 @@ export function Dashboard() {
|
||||
|
||||
<div className="grid grid-cols-1 lg:grid-cols-3 gap-6">
|
||||
<div className="lg:col-span-2 bg-[#111] border border-white/5 rounded-xl p-6 min-h-[400px]">
|
||||
<h2 className="text-sm font-medium text-gray-400 uppercase tracking-widest mb-6">解析队列 (FFmpeg 挂起任务)</h2>
|
||||
<h2 className="text-sm font-medium text-gray-400 uppercase tracking-widest mb-6">解析队列 (后台任务)</h2>
|
||||
<div className="space-y-4">
|
||||
{isLoading && (
|
||||
<div className="text-sm text-gray-500 text-center py-12">正在读取后端 Dashboard 数据...</div>
|
||||
)}
|
||||
{tasks.map((task) => (
|
||||
<div key={task.id} className="bg-[#0d0d0d] border border-white/5 p-4 rounded-lg">
|
||||
<div className="flex justify-between items-center mb-2">
|
||||
@@ -152,7 +213,7 @@ export function Dashboard() {
|
||||
<div className="h-full bg-gradient-to-r from-cyan-600 to-cyan-400 rounded-full transition-all duration-500" style={{ width: `${task.progress}%` }} />
|
||||
</div>
|
||||
<div className="text-xs text-gray-500 flex items-center gap-2">
|
||||
{task.status === '已完成' ? (
|
||||
{task.status === '已完成' || task.progress >= 100 ? (
|
||||
<CheckCircle2 size={12} className="text-emerald-400" />
|
||||
) : task.status.includes('错误') ? (
|
||||
<span className="text-red-400">●</span>
|
||||
@@ -160,10 +221,11 @@ export function Dashboard() {
|
||||
<Loader2 size={12} className="text-cyan-400 animate-spin" />
|
||||
)}
|
||||
{task.status}
|
||||
<span className="text-gray-600">帧: {task.frame_count}</span>
|
||||
</div>
|
||||
</div>
|
||||
))}
|
||||
{tasks.length === 0 && (
|
||||
{!isLoading && tasks.length === 0 && (
|
||||
<div className="text-sm text-gray-500 text-center py-12">当前无处理任务</div>
|
||||
)}
|
||||
</div>
|
||||
@@ -172,16 +234,22 @@ export function Dashboard() {
|
||||
<div className="bg-[#111] border border-white/5 rounded-xl p-6 min-h-[400px]">
|
||||
<h2 className="text-sm font-medium text-gray-400 uppercase tracking-widest mb-6">近期实时流转记录</h2>
|
||||
<div className="space-y-6 relative before:absolute before:inset-0 before:ml-[11px] before:-translate-x-px md:before:mx-auto md:before:translate-x-0 before:h-full before:w-0.5 before:bg-gradient-to-b before:from-transparent before:via-white/10 before:to-transparent">
|
||||
{activityLog.map((log, i) => (
|
||||
<div key={i} className="relative flex items-center justify-between md:justify-normal md:odd:flex-row-reverse group is-active">
|
||||
{isLoading && (
|
||||
<div className="text-sm text-gray-500 text-center py-12">正在读取近期流转记录...</div>
|
||||
)}
|
||||
{activityLog.map((log) => (
|
||||
<div key={log.id} className="relative flex items-center justify-between md:justify-normal md:odd:flex-row-reverse group is-active">
|
||||
<div className="flex items-center justify-center w-6 h-6 rounded-full border border-white/10 bg-[#111] group-[.is-active]:bg-cyan-500 group-[.is-active]:border-cyan-400 text-slate-500 group-[.is-active]:text-black shadow shrink-0 md:order-1 md:group-odd:-translate-x-1/2 md:group-even:translate-x-1/2 z-10" />
|
||||
<div className="w-[calc(100%-4rem)] md:w-[calc(50%-2.5rem)] bg-[#0d0d0d] p-3 rounded border border-white/5">
|
||||
<div className="text-xs text-gray-400 mb-1">{log.time}</div>
|
||||
<div className="text-xs text-gray-400 mb-1">{formatActivityTime(log.time)}</div>
|
||||
<div className="text-sm font-medium text-gray-200">{log.message}</div>
|
||||
<div className="text-xs text-gray-500">归属项目: {log.project}</div>
|
||||
</div>
|
||||
</div>
|
||||
))}
|
||||
{!isLoading && activityLog.length === 0 && (
|
||||
<div className="text-sm text-gray-500 text-center py-12">暂无近期流转记录</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
62
src/components/FrameTimeline.test.tsx
Normal file
62
src/components/FrameTimeline.test.tsx
Normal file
@@ -0,0 +1,62 @@
|
||||
import { act, fireEvent, render, screen } from '@testing-library/react';
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest';
|
||||
import { resetStore } from '../test/storeTestUtils';
|
||||
import { useStore } from '../store/useStore';
|
||||
import { FrameTimeline } from './FrameTimeline';
|
||||
|
||||
describe('FrameTimeline', () => {
|
||||
beforeEach(() => {
|
||||
resetStore();
|
||||
vi.useRealTimers();
|
||||
});
|
||||
|
||||
it('renders empty state when no frames are loaded', () => {
|
||||
render(<FrameTimeline />);
|
||||
|
||||
expect(screen.getByText('暂无帧数据')).toBeInTheDocument();
|
||||
expect(screen.getByText('0')).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('changes the current frame through thumbnails and range input', () => {
|
||||
useStore.setState({
|
||||
frames: [
|
||||
{ id: 'f1', projectId: 'p1', index: 0, url: '/1.jpg', width: 640, height: 360 },
|
||||
{ id: 'f2', projectId: 'p1', index: 1, url: '/2.jpg', width: 640, height: 360 },
|
||||
{ id: 'f3', projectId: 'p1', index: 2, url: '/3.jpg', width: 640, height: 360 },
|
||||
],
|
||||
});
|
||||
|
||||
render(<FrameTimeline />);
|
||||
fireEvent.click(screen.getByAltText('frame-1'));
|
||||
expect(useStore.getState().currentFrameIndex).toBe(1);
|
||||
|
||||
fireEvent.change(screen.getByRole('slider'), { target: { value: '3' } });
|
||||
expect(useStore.getState().currentFrameIndex).toBe(2);
|
||||
});
|
||||
|
||||
it('plays forward using the project parse fps and stops at the end', () => {
|
||||
vi.useFakeTimers();
|
||||
useStore.setState({
|
||||
currentProject: { id: 'p1', name: 'P', status: 'ready', parse_fps: 10 },
|
||||
frames: [
|
||||
{ id: 'f1', projectId: 'p1', index: 0, url: '/1.jpg', width: 640, height: 360 },
|
||||
{ id: 'f2', projectId: 'p1', index: 1, url: '/2.jpg', width: 640, height: 360 },
|
||||
],
|
||||
});
|
||||
|
||||
const { container } = render(<FrameTimeline />);
|
||||
fireEvent.click(container.querySelector('button')!);
|
||||
|
||||
act(() => {
|
||||
vi.advanceTimersByTime(100);
|
||||
});
|
||||
|
||||
expect(useStore.getState().currentFrameIndex).toBe(1);
|
||||
|
||||
act(() => {
|
||||
vi.advanceTimersByTime(100);
|
||||
});
|
||||
|
||||
expect(screen.getByText('播放序列 (F5)')).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
@@ -1,16 +1,42 @@
|
||||
import React, { useState } from 'react';
|
||||
import React, { useEffect, useMemo, useState } from 'react';
|
||||
import { Play, Pause } from 'lucide-react';
|
||||
import { cn } from '../lib/utils';
|
||||
import { useStore } from '../store/useStore';
|
||||
|
||||
export function FrameTimeline() {
|
||||
const frames = useStore((state) => state.frames);
|
||||
const currentProject = useStore((state) => state.currentProject);
|
||||
const currentFrameIndex = useStore((state) => state.currentFrameIndex);
|
||||
const setCurrentFrame = useStore((state) => state.setCurrentFrame);
|
||||
const [isPlaying, setIsPlaying] = useState(false);
|
||||
|
||||
const totalFrames = frames.length;
|
||||
const currentFrame = totalFrames > 0 ? currentFrameIndex + 1 : 0;
|
||||
const playbackFps = useMemo(() => {
|
||||
const fps = currentProject?.parse_fps || currentProject?.original_fps || 12;
|
||||
return Math.min(Math.max(fps, 1), 30);
|
||||
}, [currentProject?.original_fps, currentProject?.parse_fps]);
|
||||
|
||||
useEffect(() => {
|
||||
if (!isPlaying || totalFrames <= 1) return;
|
||||
|
||||
const timer = window.setTimeout(() => {
|
||||
if (currentFrameIndex >= totalFrames - 1) {
|
||||
setIsPlaying(false);
|
||||
return;
|
||||
}
|
||||
|
||||
setCurrentFrame(currentFrameIndex + 1);
|
||||
}, 1000 / playbackFps);
|
||||
|
||||
return () => window.clearTimeout(timer);
|
||||
}, [currentFrameIndex, isPlaying, playbackFps, setCurrentFrame, totalFrames]);
|
||||
|
||||
useEffect(() => {
|
||||
if (totalFrames === 0) {
|
||||
setIsPlaying(false);
|
||||
}
|
||||
}, [totalFrames]);
|
||||
|
||||
// show frames around current frame
|
||||
const frameWindow = 20;
|
||||
@@ -45,8 +71,14 @@ export function FrameTimeline() {
|
||||
<div className="flex-1 flex items-center px-4 gap-6">
|
||||
<div className="flex flex-col items-center gap-2 px-4 border-r border-white/10 shrink-0">
|
||||
<button
|
||||
className="p-2 rounded-full bg-white/5 text-white hover:bg-white/10"
|
||||
onClick={() => setIsPlaying(!isPlaying)}
|
||||
className="p-2 rounded-full bg-white/5 text-white hover:bg-white/10 disabled:opacity-40 disabled:cursor-not-allowed"
|
||||
disabled={totalFrames <= 1}
|
||||
onClick={() => {
|
||||
if (currentFrameIndex >= totalFrames - 1) {
|
||||
setCurrentFrame(0);
|
||||
}
|
||||
setIsPlaying(!isPlaying);
|
||||
}}
|
||||
>
|
||||
{isPlaying ? <Pause size={20} fill="currentColor" /> : <Play size={20} fill="currentColor" />}
|
||||
</button>
|
||||
|
||||
42
src/components/Login.test.tsx
Normal file
42
src/components/Login.test.tsx
Normal file
@@ -0,0 +1,42 @@
|
||||
import { fireEvent, render, screen, waitFor } from '@testing-library/react';
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest';
|
||||
import { resetStore } from '../test/storeTestUtils';
|
||||
import { useStore } from '../store/useStore';
|
||||
import { Login } from './Login';
|
||||
|
||||
const apiMock = vi.hoisted(() => ({
|
||||
login: vi.fn(),
|
||||
}));
|
||||
|
||||
vi.mock('../lib/api', () => ({
|
||||
login: apiMock.login,
|
||||
}));
|
||||
|
||||
describe('Login', () => {
|
||||
beforeEach(() => {
|
||||
resetStore();
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
it('logs in with the development credentials and stores the token', async () => {
|
||||
apiMock.login.mockResolvedValueOnce({ token: 'fake-jwt-token-for-admin' });
|
||||
|
||||
render(<Login />);
|
||||
fireEvent.click(screen.getByRole('button', { name: '安全登录' }));
|
||||
|
||||
await waitFor(() => expect(apiMock.login).toHaveBeenCalledWith('admin', '123456'));
|
||||
expect(useStore.getState().isAuthenticated).toBe(true);
|
||||
expect(localStorage.getItem('token')).toBe('fake-jwt-token-for-admin');
|
||||
});
|
||||
|
||||
it('shows backend login errors', async () => {
|
||||
apiMock.login.mockRejectedValueOnce({ response: { data: { detail: 'Invalid credentials' } } });
|
||||
|
||||
render(<Login />);
|
||||
fireEvent.change(screen.getByDisplayValue('admin'), { target: { value: 'bad' } });
|
||||
fireEvent.click(screen.getByRole('button', { name: '安全登录' }));
|
||||
|
||||
expect(await screen.findByText('Invalid credentials')).toBeInTheDocument();
|
||||
expect(useStore.getState().isAuthenticated).toBe(false);
|
||||
});
|
||||
});
|
||||
45
src/components/ModelStatusBadge.test.tsx
Normal file
45
src/components/ModelStatusBadge.test.tsx
Normal file
@@ -0,0 +1,45 @@
|
||||
import { render, screen, waitFor } from '@testing-library/react';
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest';
|
||||
import { resetStore } from '../test/storeTestUtils';
|
||||
import { useStore } from '../store/useStore';
|
||||
import { ModelStatusBadge } from './ModelStatusBadge';
|
||||
|
||||
const apiMock = vi.hoisted(() => ({
|
||||
getAiModelStatus: vi.fn(),
|
||||
}));
|
||||
|
||||
vi.mock('../lib/api', () => ({
|
||||
getAiModelStatus: apiMock.getAiModelStatus,
|
||||
}));
|
||||
|
||||
describe('ModelStatusBadge', () => {
|
||||
beforeEach(() => {
|
||||
resetStore();
|
||||
vi.clearAllMocks();
|
||||
apiMock.getAiModelStatus.mockResolvedValue({
|
||||
selected_model: 'sam2',
|
||||
gpu: { available: true, device: 'cuda', name: 'RTX 4090', torch_available: true },
|
||||
models: [
|
||||
{ id: 'sam2', label: 'SAM 2', available: true, loaded: false, device: 'cuda', supports: ['point', 'box'], message: 'SAM 2 ready', package_available: true, checkpoint_exists: true, python_ok: true, torch_ok: true, cuda_required: false },
|
||||
{ id: 'sam3', label: 'SAM 3', available: false, loaded: false, device: 'cuda', supports: ['semantic'], message: 'SAM 3 missing runtime', package_available: false, checkpoint_exists: false, python_ok: false, torch_ok: true, cuda_required: true },
|
||||
],
|
||||
});
|
||||
});
|
||||
|
||||
it('loads real model status for the selected model', async () => {
|
||||
render(<ModelStatusBadge />);
|
||||
|
||||
expect(await screen.findByText('SAM 2 可用')).toBeInTheDocument();
|
||||
expect(apiMock.getAiModelStatus).toHaveBeenCalledWith('sam2');
|
||||
});
|
||||
|
||||
it('shows unavailable state when SAM3 is selected but not runnable', async () => {
|
||||
useStore.getState().setAiModel('sam3');
|
||||
|
||||
render(<ModelStatusBadge />);
|
||||
|
||||
await waitFor(() => expect(apiMock.getAiModelStatus).toHaveBeenCalledWith('sam3'));
|
||||
expect(await screen.findByText('SAM 3 不可用')).toBeInTheDocument();
|
||||
expect(screen.getByTitle('SAM 3 missing runtime')).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
56
src/components/ModelStatusBadge.tsx
Normal file
56
src/components/ModelStatusBadge.tsx
Normal file
@@ -0,0 +1,56 @@
|
||||
import React, { useEffect, useState } from 'react';
|
||||
import { Cpu, Loader2 } from 'lucide-react';
|
||||
import { getAiModelStatus, type AiRuntimeStatus } from '../lib/api';
|
||||
import { cn } from '../lib/utils';
|
||||
import { useStore } from '../store/useStore';
|
||||
|
||||
interface ModelStatusBadgeProps {
|
||||
compact?: boolean;
|
||||
}
|
||||
|
||||
export function ModelStatusBadge({ compact = false }: ModelStatusBadgeProps) {
|
||||
const aiModel = useStore((state) => state.aiModel);
|
||||
const [status, setStatus] = useState<AiRuntimeStatus | null>(null);
|
||||
const [isLoading, setIsLoading] = useState(true);
|
||||
|
||||
useEffect(() => {
|
||||
let cancelled = false;
|
||||
setIsLoading(true);
|
||||
getAiModelStatus(aiModel)
|
||||
.then((data) => {
|
||||
if (!cancelled) setStatus(data);
|
||||
})
|
||||
.catch(() => {
|
||||
if (!cancelled) setStatus(null);
|
||||
})
|
||||
.finally(() => {
|
||||
if (!cancelled) setIsLoading(false);
|
||||
});
|
||||
return () => {
|
||||
cancelled = true;
|
||||
};
|
||||
}, [aiModel]);
|
||||
|
||||
const model = status?.models.find((item) => item.id === aiModel);
|
||||
const ready = Boolean(model?.available);
|
||||
const gpuReady = Boolean(status?.gpu.available);
|
||||
const label = compact
|
||||
? (gpuReady ? 'GPU' : 'CPU')
|
||||
: `${model?.label || aiModel.toUpperCase()} ${ready ? '可用' : '不可用'}`;
|
||||
|
||||
return (
|
||||
<div
|
||||
className={cn(
|
||||
"inline-flex items-center gap-1.5 rounded border font-mono uppercase",
|
||||
compact ? "w-8 h-8 justify-center text-[9px]" : "px-2 py-0.5 text-[10px]",
|
||||
ready
|
||||
? "bg-emerald-500/10 text-emerald-400 border-emerald-500/20"
|
||||
: "bg-amber-500/10 text-amber-400 border-amber-500/20"
|
||||
)}
|
||||
title={model?.message || 'AI 模型状态读取中'}
|
||||
>
|
||||
{isLoading ? <Loader2 size={compact ? 12 : 10} className="animate-spin" /> : <Cpu size={compact ? 12 : 10} />}
|
||||
<span>{label}</span>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
60
src/components/OntologyInspector.test.tsx
Normal file
60
src/components/OntologyInspector.test.tsx
Normal file
@@ -0,0 +1,60 @@
|
||||
import { fireEvent, render, screen, within } from '@testing-library/react';
|
||||
import { beforeEach, describe, expect, it } from 'vitest';
|
||||
import { resetStore } from '../test/storeTestUtils';
|
||||
import { useStore } from '../store/useStore';
|
||||
import { OntologyInspector } from './OntologyInspector';
|
||||
|
||||
describe('OntologyInspector', () => {
|
||||
beforeEach(() => {
|
||||
resetStore();
|
||||
useStore.setState({
|
||||
templates: [
|
||||
{
|
||||
id: 't1',
|
||||
name: '腹腔镜模板',
|
||||
classes: [
|
||||
{ id: 'c1', name: '胆囊', color: '#ff0000', zIndex: 20, category: '器官' },
|
||||
{ id: 'c2', name: '肝脏', color: '#00ff00', zIndex: 10, category: '器官' },
|
||||
],
|
||||
rules: [],
|
||||
},
|
||||
],
|
||||
});
|
||||
});
|
||||
|
||||
it('shows template classes and changes the active template', () => {
|
||||
render(<OntologyInspector />);
|
||||
|
||||
fireEvent.change(screen.getByRole('combobox'), { target: { value: 't1' } });
|
||||
|
||||
expect(useStore.getState().activeTemplateId).toBe('t1');
|
||||
expect(screen.getByText('胆囊')).toBeInTheDocument();
|
||||
expect(screen.getByText('肝脏')).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('selects a concrete class for subsequent masks', () => {
|
||||
render(<OntologyInspector />);
|
||||
|
||||
fireEvent.click(screen.getByText('胆囊'));
|
||||
|
||||
expect(useStore.getState().activeClassId).toBe('c1');
|
||||
expect(useStore.getState().activeClass).toEqual(expect.objectContaining({
|
||||
id: 'c1',
|
||||
name: '胆囊',
|
||||
zIndex: 20,
|
||||
}));
|
||||
});
|
||||
|
||||
it('adds custom classes locally without backend persistence', () => {
|
||||
const { container } = render(<OntologyInspector />);
|
||||
const customSection = screen.getByText('自定义分类').parentElement!;
|
||||
fireEvent.click(within(customSection).getByRole('button'));
|
||||
fireEvent.change(screen.getByPlaceholderText('分类名称'), { target: { value: '新局部分类' } });
|
||||
fireEvent.keyDown(screen.getByPlaceholderText('分类名称'), { key: 'Enter' });
|
||||
|
||||
expect(screen.getAllByText('新局部分类')).toHaveLength(2);
|
||||
expect(useStore.getState().activeClass).toEqual(expect.objectContaining({ name: '新局部分类' }));
|
||||
expect(useStore.getState().templates[0].classes).toHaveLength(2);
|
||||
expect(container).toHaveTextContent('2 个分类来自模板 + 1 个自定义');
|
||||
});
|
||||
});
|
||||
@@ -2,11 +2,16 @@ import React, { useState } from 'react';
|
||||
import { Layers, ChevronDown, Tag, Eye, Plus, X } from 'lucide-react';
|
||||
import { useStore } from '../store/useStore';
|
||||
import type { TemplateClass } from '../store/useStore';
|
||||
import { cn } from '../lib/utils';
|
||||
import { getActiveTemplate } from '../lib/templateSelection';
|
||||
|
||||
export function OntologyInspector() {
|
||||
const templates = useStore((state) => state.templates);
|
||||
const activeTemplateId = useStore((state) => state.activeTemplateId);
|
||||
const activeClassId = useStore((state) => state.activeClassId);
|
||||
const activeClass = useStore((state) => state.activeClass);
|
||||
const setActiveTemplateId = useStore((state) => state.setActiveTemplateId);
|
||||
const setActiveClass = useStore((state) => state.setActiveClass);
|
||||
|
||||
// Project-level custom classes (in addition to template classes)
|
||||
const [customClasses, setCustomClasses] = useState<TemplateClass[]>([]);
|
||||
@@ -14,10 +19,17 @@ export function OntologyInspector() {
|
||||
const [newClassName, setNewClassName] = useState('');
|
||||
const [newClassColor, setNewClassColor] = useState('#06b6d4');
|
||||
|
||||
const activeTemplate = templates.find((t) => t.id === activeTemplateId) || templates[0] || null;
|
||||
const activeTemplate = getActiveTemplate(templates, activeTemplateId);
|
||||
const templateClasses = activeTemplate?.classes || [];
|
||||
const allClasses = [...templateClasses, ...customClasses].sort((a, b) => b.zIndex - a.zIndex);
|
||||
|
||||
const handleSelectClass = (templateClass: TemplateClass) => {
|
||||
if (activeTemplate && !activeTemplateId) {
|
||||
setActiveTemplateId(activeTemplate.id);
|
||||
}
|
||||
setActiveClass(templateClass);
|
||||
};
|
||||
|
||||
const handleAddCustom = () => {
|
||||
if (!newClassName.trim()) return;
|
||||
const maxZ = allClasses.length > 0 ? Math.max(...allClasses.map((c) => c.zIndex)) : 0;
|
||||
@@ -29,6 +41,7 @@ export function OntologyInspector() {
|
||||
category: '自定义',
|
||||
};
|
||||
setCustomClasses([...customClasses, newClass]);
|
||||
handleSelectClass(newClass);
|
||||
setNewClassName('');
|
||||
setShowAddForm(false);
|
||||
};
|
||||
@@ -47,7 +60,10 @@ export function OntologyInspector() {
|
||||
<div className="relative">
|
||||
<select
|
||||
value={activeTemplate?.id || ''}
|
||||
onChange={(e) => setActiveTemplateId(e.target.value || null)}
|
||||
onChange={(e) => {
|
||||
setActiveTemplateId(e.target.value || null);
|
||||
setActiveClass(null);
|
||||
}}
|
||||
className="w-full bg-[#1a1a1a] border border-white/10 rounded-lg px-3 py-2 text-xs text-gray-300 appearance-none cursor-pointer focus:outline-none focus:border-cyan-500/50"
|
||||
>
|
||||
<option value="">-- 选择模板 --</option>
|
||||
@@ -73,7 +89,14 @@ export function OntologyInspector() {
|
||||
<div className="space-y-2">
|
||||
{allClasses.map(cls => (
|
||||
<div key={cls.id} className="flex flex-col gap-1">
|
||||
<div className="flex items-center justify-between p-2 rounded bg-white/5 hover:bg-white/10 cursor-pointer group transition-colors">
|
||||
<button
|
||||
type="button"
|
||||
onClick={() => handleSelectClass(cls)}
|
||||
className={cn(
|
||||
'flex items-center justify-between p-2 rounded bg-white/5 hover:bg-white/10 cursor-pointer group transition-colors text-left border',
|
||||
activeClassId === cls.id ? 'border-cyan-500/50 bg-cyan-500/10' : 'border-transparent',
|
||||
)}
|
||||
>
|
||||
<div className="flex items-center gap-2">
|
||||
<span className="w-2.5 h-2.5 rounded-sm" style={{ backgroundColor: cls.color }} />
|
||||
<span className="text-xs font-medium text-gray-200">{cls.name}</span>
|
||||
@@ -82,7 +105,7 @@ export function OntologyInspector() {
|
||||
<span className="text-[10px] text-gray-500 font-mono">z:{cls.zIndex}</span>
|
||||
<Eye size={14} className="text-gray-500 group-hover:text-gray-300" />
|
||||
</div>
|
||||
</div>
|
||||
</button>
|
||||
</div>
|
||||
))}
|
||||
{allClasses.length === 0 && (
|
||||
@@ -136,7 +159,9 @@ export function OntologyInspector() {
|
||||
<div className="bg-white/5 rounded-lg p-3">
|
||||
<div className="flex items-center gap-2 mb-3">
|
||||
<Tag size={12} className="text-cyan-400" />
|
||||
<span className="text-xs font-semibold text-gray-200">{activeTemplate?.name || '未选择'}</span>
|
||||
<span className="text-xs font-semibold text-gray-200">
|
||||
{activeClass?.name || activeTemplate?.name || '未选择'}
|
||||
</span>
|
||||
</div>
|
||||
<div className="space-y-3">
|
||||
<div className="space-y-1">
|
||||
|
||||
92
src/components/ProjectLibrary.test.tsx
Normal file
92
src/components/ProjectLibrary.test.tsx
Normal file
@@ -0,0 +1,92 @@
|
||||
import { fireEvent, render, screen, waitFor } from '@testing-library/react';
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest';
|
||||
import { resetStore } from '../test/storeTestUtils';
|
||||
import { useStore } from '../store/useStore';
|
||||
import { ProjectLibrary } from './ProjectLibrary';
|
||||
|
||||
const apiMock = vi.hoisted(() => ({
|
||||
getProjects: vi.fn(),
|
||||
createProject: vi.fn(),
|
||||
uploadMedia: vi.fn(),
|
||||
parseMedia: vi.fn(),
|
||||
uploadDicomBatch: vi.fn(),
|
||||
}));
|
||||
|
||||
vi.mock('../lib/api', () => ({
|
||||
getProjects: apiMock.getProjects,
|
||||
createProject: apiMock.createProject,
|
||||
uploadMedia: apiMock.uploadMedia,
|
||||
parseMedia: apiMock.parseMedia,
|
||||
uploadDicomBatch: apiMock.uploadDicomBatch,
|
||||
}));
|
||||
|
||||
describe('ProjectLibrary', () => {
|
||||
beforeEach(() => {
|
||||
resetStore();
|
||||
vi.clearAllMocks();
|
||||
apiMock.getProjects.mockResolvedValue([]);
|
||||
});
|
||||
|
||||
it('loads projects and selects one into the workspace', async () => {
|
||||
const onProjectSelect = vi.fn();
|
||||
apiMock.getProjects.mockResolvedValueOnce([
|
||||
{ id: 'p1', name: 'Demo Project', status: 'ready', frames: 3, fps: '30FPS' },
|
||||
]);
|
||||
|
||||
render(<ProjectLibrary onProjectSelect={onProjectSelect} />);
|
||||
|
||||
fireEvent.click(await screen.findByText('Demo Project'));
|
||||
expect(useStore.getState().currentProject?.id).toBe('p1');
|
||||
expect(onProjectSelect).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('creates a new project from the modal', async () => {
|
||||
apiMock.createProject.mockResolvedValueOnce({ id: 'p2', name: 'New Project', status: 'pending' });
|
||||
|
||||
render(<ProjectLibrary onProjectSelect={vi.fn()} />);
|
||||
fireEvent.click(screen.getByText('新建项目'));
|
||||
fireEvent.change(screen.getByPlaceholderText('输入项目名称'), { target: { value: 'New Project' } });
|
||||
fireEvent.change(screen.getByPlaceholderText('输入项目描述'), { target: { value: 'desc' } });
|
||||
fireEvent.click(screen.getByRole('button', { name: '创建' }));
|
||||
|
||||
await waitFor(() => expect(apiMock.createProject).toHaveBeenCalledWith({
|
||||
name: 'New Project',
|
||||
description: 'desc',
|
||||
}));
|
||||
expect(useStore.getState().projects[0]).toEqual(expect.objectContaining({ id: 'p2' }));
|
||||
});
|
||||
|
||||
it('imports video by creating a project, uploading media, parsing frames and refreshing projects', async () => {
|
||||
apiMock.createProject.mockResolvedValueOnce({ id: 'p3', name: 'clip.mp4', status: 'pending' });
|
||||
apiMock.uploadMedia.mockResolvedValueOnce({ url: 'http://file', id: 'object' });
|
||||
apiMock.parseMedia.mockResolvedValueOnce({ frames_extracted: 1 });
|
||||
apiMock.getProjects.mockResolvedValue([]);
|
||||
|
||||
const { container } = render(<ProjectLibrary onProjectSelect={vi.fn()} />);
|
||||
const input = container.querySelector('input[accept="video/*"]') as HTMLInputElement;
|
||||
const file = new File(['video'], 'clip.mp4', { type: 'video/mp4' });
|
||||
fireEvent.change(input, { target: { files: [file] } });
|
||||
fireEvent.click(await screen.findByRole('button', { name: '开始导入' }));
|
||||
|
||||
await waitFor(() => expect(apiMock.createProject).toHaveBeenCalledWith(expect.objectContaining({
|
||||
name: 'clip.mp4',
|
||||
parse_fps: 30,
|
||||
})));
|
||||
expect(apiMock.uploadMedia).toHaveBeenCalledWith(file, 'p3');
|
||||
expect(apiMock.parseMedia).toHaveBeenCalledWith('p3');
|
||||
});
|
||||
|
||||
it('imports only valid DICOM files and parses the returned project', async () => {
|
||||
apiMock.uploadDicomBatch.mockResolvedValueOnce({ project_id: 77, uploaded_count: 1, message: 'ok' });
|
||||
apiMock.parseMedia.mockResolvedValueOnce({ frames_extracted: 1 });
|
||||
|
||||
const { container } = render(<ProjectLibrary onProjectSelect={vi.fn()} />);
|
||||
const input = container.querySelector('input[accept=".dcm"]') as HTMLInputElement;
|
||||
const dcm = new File(['dcm'], 'scan.dcm', { type: 'application/dicom' });
|
||||
const ignored = new File(['txt'], 'notes.txt', { type: 'text/plain' });
|
||||
fireEvent.change(input, { target: { files: [dcm, ignored] } });
|
||||
|
||||
await waitFor(() => expect(apiMock.uploadDicomBatch).toHaveBeenCalledWith([dcm]));
|
||||
expect(apiMock.parseMedia).toHaveBeenCalledWith('77');
|
||||
});
|
||||
});
|
||||
@@ -212,11 +212,11 @@ export function ProjectLibrary({ onProjectSelect }: ProjectLibraryProps) {
|
||||
{proj.source_type === 'dicom' ? 'DICOM' : (proj.fps || '30FPS')}
|
||||
</span>
|
||||
<span className="backdrop-blur-md bg-black/40 text-gray-200 text-[10px] px-2 py-1 rounded border border-white/10 flex items-center gap-1 uppercase tracking-widest">
|
||||
{proj.status === 'Ready' ? (
|
||||
{proj.status === 'ready' ? (
|
||||
<><div className="w-1.5 h-1.5 bg-emerald-500 rounded-full" /> 已就绪</>
|
||||
) : proj.status === 'Parsing' ? (
|
||||
) : proj.status === 'parsing' ? (
|
||||
<><div className="w-1.5 h-1.5 bg-amber-500 rounded-full animate-pulse" /> 解析拆帧中</>
|
||||
) : proj.status === 'Error' ? (
|
||||
) : proj.status === 'error' ? (
|
||||
<><div className="w-1.5 h-1.5 bg-red-500 rounded-full" /> 异常</>
|
||||
) : (
|
||||
<><div className="w-1.5 h-1.5 bg-blue-500 rounded-full" /> 待处理</>
|
||||
|
||||
@@ -2,6 +2,7 @@ import React from 'react';
|
||||
import { Home, FolderOpen, Edit3, LayoutTemplate, BrainCircuit } from 'lucide-react';
|
||||
import { cn } from '../lib/utils';
|
||||
import type { ActiveModule } from '../App';
|
||||
import { ModelStatusBadge } from './ModelStatusBadge';
|
||||
|
||||
interface SidebarProps {
|
||||
activeModule: ActiveModule;
|
||||
@@ -47,9 +48,7 @@ export function Sidebar({ activeModule, setActiveModule }: SidebarProps) {
|
||||
})}
|
||||
</nav>
|
||||
<div className="mt-auto mb-4 flex flex-col gap-4">
|
||||
<div className="w-8 h-8 rounded-full border border-cyan-500/50 flex items-center justify-center text-[10px] text-cyan-400 font-bold cursor-pointer transition-all hover:bg-cyan-500/10">
|
||||
GPU
|
||||
</div>
|
||||
<ModelStatusBadge compact />
|
||||
</div>
|
||||
</aside>
|
||||
);
|
||||
|
||||
85
src/components/TemplateRegistry.test.tsx
Normal file
85
src/components/TemplateRegistry.test.tsx
Normal file
@@ -0,0 +1,85 @@
|
||||
import { fireEvent, render, screen, waitFor } from '@testing-library/react';
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest';
|
||||
import { resetStore } from '../test/storeTestUtils';
|
||||
import { useStore } from '../store/useStore';
|
||||
import { TemplateRegistry } from './TemplateRegistry';
|
||||
|
||||
const apiMock = vi.hoisted(() => ({
|
||||
getTemplates: vi.fn(),
|
||||
createTemplate: vi.fn(),
|
||||
updateTemplate: vi.fn(),
|
||||
deleteTemplate: vi.fn(),
|
||||
}));
|
||||
|
||||
vi.mock('../lib/api', () => ({
|
||||
getTemplates: apiMock.getTemplates,
|
||||
createTemplate: apiMock.createTemplate,
|
||||
updateTemplate: apiMock.updateTemplate,
|
||||
deleteTemplate: apiMock.deleteTemplate,
|
||||
}));
|
||||
|
||||
describe('TemplateRegistry', () => {
|
||||
beforeEach(() => {
|
||||
resetStore();
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
it('loads and displays templates with unpacked classes', async () => {
|
||||
apiMock.getTemplates.mockResolvedValueOnce([
|
||||
{
|
||||
id: 't1',
|
||||
name: '腹腔镜胆囊切除术',
|
||||
description: 'desc',
|
||||
classes: [{ id: 'c1', name: '胆囊', color: '#ff0000', zIndex: 10, category: '器官' }],
|
||||
rules: [],
|
||||
},
|
||||
]);
|
||||
|
||||
render(<TemplateRegistry />);
|
||||
|
||||
expect(await screen.findAllByText('腹腔镜胆囊切除术')).toHaveLength(2);
|
||||
expect(screen.getByText('胆囊')).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('creates a template and stores it globally', async () => {
|
||||
apiMock.getTemplates.mockResolvedValueOnce([]);
|
||||
apiMock.createTemplate.mockResolvedValueOnce({
|
||||
id: 't2',
|
||||
name: 'New Template',
|
||||
description: 'desc',
|
||||
classes: [],
|
||||
rules: [],
|
||||
});
|
||||
|
||||
render(<TemplateRegistry />);
|
||||
fireEvent.click(screen.getByText('新建方案'));
|
||||
fireEvent.change(screen.getAllByRole('textbox')[0], { target: { value: 'New Template' } });
|
||||
fireEvent.change(screen.getAllByRole('textbox')[1], { target: { value: 'desc' } });
|
||||
fireEvent.click(screen.getByRole('button', { name: '保存' }));
|
||||
|
||||
await waitFor(() => expect(apiMock.createTemplate).toHaveBeenCalledWith(expect.objectContaining({
|
||||
name: 'New Template',
|
||||
description: 'desc',
|
||||
classes: [],
|
||||
rules: [],
|
||||
color: '#06b6d4',
|
||||
z_index: 0,
|
||||
})));
|
||||
expect(useStore.getState().templates[0]).toEqual(expect.objectContaining({ id: 't2' }));
|
||||
});
|
||||
|
||||
it('imports JSON classes into the edit modal before saving', async () => {
|
||||
apiMock.getTemplates.mockResolvedValueOnce([]);
|
||||
|
||||
render(<TemplateRegistry />);
|
||||
fireEvent.click(screen.getByText('新建方案'));
|
||||
fireEvent.change(screen.getAllByRole('textbox')[0], { target: { value: 'With Classes' } });
|
||||
fireEvent.click(screen.getByText('批量导入'));
|
||||
fireEvent.change(screen.getByPlaceholderText('[[[255,0,0], [0,255,0]], ["分类A", "分类B"]]'), {
|
||||
target: { value: '{"colors":[[255,0,0]],"names":["分类A"]}' },
|
||||
});
|
||||
fireEvent.click(screen.getByRole('button', { name: '导入' }));
|
||||
|
||||
expect(screen.getByText('分类A')).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
30
src/components/ToolsPalette.test.tsx
Normal file
30
src/components/ToolsPalette.test.tsx
Normal file
@@ -0,0 +1,30 @@
|
||||
import { fireEvent, render, screen } from '@testing-library/react';
|
||||
import { describe, expect, it, vi } from 'vitest';
|
||||
import { ToolsPalette } from './ToolsPalette';
|
||||
|
||||
describe('ToolsPalette', () => {
|
||||
it('switches tools and exposes UI-only placeholder buttons', () => {
|
||||
const setActiveTool = vi.fn();
|
||||
|
||||
render(<ToolsPalette activeTool="move" setActiveTool={setActiveTool} />);
|
||||
|
||||
fireEvent.click(screen.getByTitle('创建多边形 (P)'));
|
||||
fireEvent.click(screen.getByTitle('正向选点 (SAM)'));
|
||||
|
||||
expect(setActiveTool).toHaveBeenNthCalledWith(1, 'create_polygon');
|
||||
expect(setActiveTool).toHaveBeenNthCalledWith(2, 'point_pos');
|
||||
expect(screen.getByTitle('撤销操作 (Ctrl+Z)')).toBeInTheDocument();
|
||||
expect(screen.getByTitle('重做操作 (Ctrl+Shift+Z)')).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('switches to SAM trigger and calls the AI navigation hook', () => {
|
||||
const setActiveTool = vi.fn();
|
||||
const onTriggerAI = vi.fn();
|
||||
|
||||
render(<ToolsPalette activeTool="move" setActiveTool={setActiveTool} onTriggerAI={onTriggerAI} />);
|
||||
fireEvent.click(screen.getByTitle('触发 SAM 推理 (Enter)'));
|
||||
|
||||
expect(setActiveTool).toHaveBeenCalledWith('sam_trigger');
|
||||
expect(onTriggerAI).toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
@@ -78,7 +78,7 @@ export function ToolsPalette({ activeTool, setActiveTool, onTriggerAI }: ToolsPa
|
||||
setActiveTool('sam_trigger');
|
||||
if (onTriggerAI) onTriggerAI();
|
||||
}}
|
||||
title="触发 SAM 3 推理 (Enter)"
|
||||
title="触发 SAM 推理 (Enter)"
|
||||
className={cn(
|
||||
"w-10 h-10 rounded-lg flex items-center justify-center transition-all",
|
||||
activeTool === 'sam_trigger'
|
||||
|
||||
259
src/components/VideoWorkspace.test.tsx
Normal file
259
src/components/VideoWorkspace.test.tsx
Normal file
@@ -0,0 +1,259 @@
|
||||
import { act, fireEvent, render, screen, waitFor } from '@testing-library/react';
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest';
|
||||
import { resetStore } from '../test/storeTestUtils';
|
||||
import { useStore } from '../store/useStore';
|
||||
import { VideoWorkspace } from './VideoWorkspace';
|
||||
|
||||
const apiMock = vi.hoisted(() => ({
|
||||
getProjectFrames: vi.fn(),
|
||||
parseMedia: vi.fn(),
|
||||
getTask: vi.fn(),
|
||||
getTemplates: vi.fn(),
|
||||
getProjectAnnotations: vi.fn(),
|
||||
saveAnnotation: vi.fn(),
|
||||
updateAnnotation: vi.fn(),
|
||||
deleteAnnotation: vi.fn(),
|
||||
exportCoco: vi.fn(),
|
||||
annotationToMask: vi.fn(),
|
||||
buildAnnotationPayload: vi.fn(),
|
||||
getAiModelStatus: vi.fn(),
|
||||
}));
|
||||
|
||||
vi.mock('../lib/api', () => ({
|
||||
getProjectFrames: apiMock.getProjectFrames,
|
||||
parseMedia: apiMock.parseMedia,
|
||||
getTask: apiMock.getTask,
|
||||
getTemplates: apiMock.getTemplates,
|
||||
getProjectAnnotations: apiMock.getProjectAnnotations,
|
||||
saveAnnotation: apiMock.saveAnnotation,
|
||||
updateAnnotation: apiMock.updateAnnotation,
|
||||
deleteAnnotation: apiMock.deleteAnnotation,
|
||||
exportCoco: apiMock.exportCoco,
|
||||
annotationToMask: apiMock.annotationToMask,
|
||||
buildAnnotationPayload: apiMock.buildAnnotationPayload,
|
||||
getAiModelStatus: apiMock.getAiModelStatus,
|
||||
}));
|
||||
|
||||
describe('VideoWorkspace', () => {
|
||||
beforeEach(() => {
|
||||
resetStore();
|
||||
vi.clearAllMocks();
|
||||
useStore.setState({ currentProject: { id: '1', name: 'Demo', status: 'ready', video_path: 'uploads/demo.mp4' } });
|
||||
apiMock.getTemplates.mockResolvedValue([]);
|
||||
apiMock.getProjectAnnotations.mockResolvedValue([]);
|
||||
apiMock.annotationToMask.mockReturnValue(null);
|
||||
apiMock.getTask.mockResolvedValue({ id: 1, status: 'success', progress: 100, message: '解析完成' });
|
||||
apiMock.getAiModelStatus.mockResolvedValue({
|
||||
selected_model: 'sam2',
|
||||
gpu: { available: false, device: 'cpu', name: null, torch_available: true },
|
||||
models: [
|
||||
{ id: 'sam2', label: 'SAM 2', available: true, loaded: false, device: 'cpu', supports: [], message: 'ready', package_available: true, checkpoint_exists: true, python_ok: true, torch_ok: true, cuda_required: false },
|
||||
{ id: 'sam3', label: 'SAM 3', available: false, loaded: false, device: 'unavailable', supports: [], message: 'missing', package_available: false, checkpoint_exists: false, python_ok: false, torch_ok: true, cuda_required: true },
|
||||
],
|
||||
});
|
||||
});
|
||||
|
||||
it('loads project frames into the workspace store', async () => {
|
||||
apiMock.getProjectFrames.mockResolvedValueOnce([
|
||||
{ id: 10, project_id: 1, frame_index: 0, image_url: '/frame.jpg', width: 640, height: 360 },
|
||||
]);
|
||||
|
||||
render(<VideoWorkspace />);
|
||||
|
||||
await waitFor(() => expect(useStore.getState().frames).toEqual([
|
||||
{ id: '10', projectId: '1', index: 0, url: '/frame.jpg', width: 640, height: 360 },
|
||||
]));
|
||||
expect(screen.getByText('Demo')).toBeInTheDocument();
|
||||
expect(apiMock.parseMedia).not.toHaveBeenCalled();
|
||||
expect(apiMock.getProjectAnnotations).toHaveBeenCalledWith('1');
|
||||
});
|
||||
|
||||
it('triggers parsing when a media project has no frames yet', async () => {
|
||||
apiMock.getProjectFrames
|
||||
.mockResolvedValueOnce([])
|
||||
.mockResolvedValueOnce([
|
||||
{ id: 11, project_id: 1, frame_index: 0, image_url: '/parsed.jpg', width: 320, height: 240 },
|
||||
]);
|
||||
apiMock.parseMedia.mockResolvedValueOnce({ id: 7, status: 'queued', progress: 0 });
|
||||
apiMock.getTask.mockResolvedValueOnce({ id: 7, status: 'success', progress: 100, message: '解析完成' });
|
||||
|
||||
render(<VideoWorkspace />);
|
||||
|
||||
await waitFor(() => expect(apiMock.parseMedia).toHaveBeenCalledWith('1'));
|
||||
expect(apiMock.getTask).toHaveBeenCalledWith(7);
|
||||
await waitFor(() => expect(useStore.getState().frames[0]).toEqual(expect.objectContaining({
|
||||
id: '11',
|
||||
url: '/parsed.jpg',
|
||||
})));
|
||||
});
|
||||
|
||||
it('hydrates saved annotations after loading frames', async () => {
|
||||
apiMock.getProjectFrames.mockResolvedValueOnce([
|
||||
{ id: 10, project_id: 1, frame_index: 0, image_url: '/frame.jpg', width: 640, height: 360 },
|
||||
]);
|
||||
apiMock.getProjectAnnotations.mockResolvedValueOnce([{ id: 99, frame_id: 10 }]);
|
||||
apiMock.annotationToMask.mockReturnValueOnce({
|
||||
id: 'annotation-99',
|
||||
annotationId: '99',
|
||||
frameId: '10',
|
||||
saved: true,
|
||||
pathData: 'M 0 0 Z',
|
||||
label: 'Saved',
|
||||
color: '#06b6d4',
|
||||
});
|
||||
|
||||
render(<VideoWorkspace />);
|
||||
|
||||
await waitFor(() => expect(useStore.getState().masks).toEqual([
|
||||
expect.objectContaining({ id: 'annotation-99', saved: true }),
|
||||
]));
|
||||
});
|
||||
|
||||
it('saves pending masks through the archive button', async () => {
|
||||
apiMock.getProjectFrames.mockResolvedValueOnce([
|
||||
{ id: 10, project_id: 1, frame_index: 0, image_url: '/frame.jpg', width: 640, height: 360 },
|
||||
]);
|
||||
apiMock.buildAnnotationPayload.mockReturnValueOnce({ project_id: 1, frame_id: 10, mask_data: { polygons: [] } });
|
||||
apiMock.saveAnnotation.mockResolvedValueOnce({ id: 5 });
|
||||
|
||||
render(<VideoWorkspace />);
|
||||
await waitFor(() => expect(useStore.getState().frames).toHaveLength(1));
|
||||
act(() => {
|
||||
useStore.setState({
|
||||
activeTemplateId: '2',
|
||||
masks: [{
|
||||
id: 'mask-1',
|
||||
frameId: '10',
|
||||
pathData: 'M 0 0 Z',
|
||||
label: 'AI Mask',
|
||||
color: '#06b6d4',
|
||||
segmentation: [[0, 0, 10, 0, 10, 10]],
|
||||
bbox: [0, 0, 10, 10],
|
||||
}],
|
||||
});
|
||||
});
|
||||
|
||||
fireEvent.click(screen.getByRole('button', { name: '结构化归档保存' }));
|
||||
|
||||
await waitFor(() => expect(apiMock.saveAnnotation).toHaveBeenCalledWith({
|
||||
project_id: 1,
|
||||
frame_id: 10,
|
||||
mask_data: { polygons: [] },
|
||||
}));
|
||||
expect(apiMock.buildAnnotationPayload).toHaveBeenCalledWith(
|
||||
'1',
|
||||
expect.objectContaining({ id: 'mask-1' }),
|
||||
expect.objectContaining({ id: '10' }),
|
||||
'2',
|
||||
);
|
||||
});
|
||||
|
||||
it('updates dirty saved masks through the archive button', async () => {
|
||||
apiMock.getProjectFrames.mockResolvedValueOnce([
|
||||
{ id: 10, project_id: 1, frame_index: 0, image_url: '/frame.jpg', width: 640, height: 360 },
|
||||
]);
|
||||
apiMock.buildAnnotationPayload.mockReturnValueOnce({
|
||||
project_id: 1,
|
||||
frame_id: 10,
|
||||
template_id: 2,
|
||||
mask_data: { polygons: [], label: '胆囊' },
|
||||
});
|
||||
apiMock.updateAnnotation.mockResolvedValueOnce({ id: 99 });
|
||||
|
||||
render(<VideoWorkspace />);
|
||||
await waitFor(() => expect(useStore.getState().frames).toHaveLength(1));
|
||||
act(() => {
|
||||
useStore.setState({
|
||||
activeTemplateId: '2',
|
||||
masks: [{
|
||||
id: 'annotation-99',
|
||||
annotationId: '99',
|
||||
frameId: '10',
|
||||
pathData: 'M 0 0 Z',
|
||||
label: '胆囊',
|
||||
color: '#ff0000',
|
||||
saveStatus: 'dirty',
|
||||
segmentation: [[0, 0, 10, 0, 10, 10]],
|
||||
bbox: [0, 0, 10, 10],
|
||||
}],
|
||||
});
|
||||
});
|
||||
|
||||
fireEvent.click(screen.getByRole('button', { name: '结构化归档保存' }));
|
||||
|
||||
await waitFor(() => expect(apiMock.updateAnnotation).toHaveBeenCalledWith('99', {
|
||||
template_id: 2,
|
||||
mask_data: { polygons: [], label: '胆囊' },
|
||||
points: undefined,
|
||||
bbox: undefined,
|
||||
}));
|
||||
expect(apiMock.saveAnnotation).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('deletes saved annotations when clearing current-frame masks', async () => {
|
||||
apiMock.getProjectFrames.mockResolvedValueOnce([
|
||||
{ id: 10, project_id: 1, frame_index: 0, image_url: '/frame.jpg', width: 640, height: 360 },
|
||||
]);
|
||||
apiMock.deleteAnnotation.mockResolvedValueOnce(undefined);
|
||||
|
||||
render(<VideoWorkspace />);
|
||||
await waitFor(() => expect(useStore.getState().frames).toHaveLength(1));
|
||||
act(() => {
|
||||
useStore.setState({
|
||||
masks: [
|
||||
{
|
||||
id: 'annotation-99',
|
||||
annotationId: '99',
|
||||
frameId: '10',
|
||||
pathData: 'M 0 0 Z',
|
||||
label: 'Saved',
|
||||
color: '#06b6d4',
|
||||
saved: true,
|
||||
saveStatus: 'saved',
|
||||
},
|
||||
{
|
||||
id: 'draft-1',
|
||||
frameId: '10',
|
||||
pathData: 'M 1 1 Z',
|
||||
label: 'Draft',
|
||||
color: '#ff0000',
|
||||
},
|
||||
],
|
||||
});
|
||||
});
|
||||
|
||||
fireEvent.click(screen.getByRole('button', { name: '清空遮罩' }));
|
||||
|
||||
await waitFor(() => expect(apiMock.deleteAnnotation).toHaveBeenCalledWith('99'));
|
||||
expect(useStore.getState().masks).toEqual([]);
|
||||
});
|
||||
|
||||
it('auto-saves pending masks before exporting COCO', async () => {
|
||||
apiMock.getProjectFrames.mockResolvedValueOnce([
|
||||
{ id: 10, project_id: 1, frame_index: 0, image_url: '/frame.jpg', width: 640, height: 360 },
|
||||
]);
|
||||
apiMock.buildAnnotationPayload.mockReturnValueOnce({ project_id: 1, frame_id: 10, mask_data: { polygons: [] } });
|
||||
apiMock.saveAnnotation.mockResolvedValueOnce({ id: 5 });
|
||||
apiMock.exportCoco.mockResolvedValueOnce(new Blob(['{}'], { type: 'application/json' }));
|
||||
|
||||
render(<VideoWorkspace />);
|
||||
await waitFor(() => expect(useStore.getState().frames).toHaveLength(1));
|
||||
act(() => {
|
||||
useStore.setState({
|
||||
masks: [{
|
||||
id: 'mask-1',
|
||||
frameId: '10',
|
||||
pathData: 'M 0 0 Z',
|
||||
label: 'AI Mask',
|
||||
color: '#06b6d4',
|
||||
segmentation: [[0, 0, 10, 0, 10, 10]],
|
||||
}],
|
||||
});
|
||||
});
|
||||
|
||||
fireEvent.click(screen.getByRole('button', { name: '导出 JSON 标注集' }));
|
||||
|
||||
await waitFor(() => expect(apiMock.saveAnnotation).toHaveBeenCalled());
|
||||
expect(apiMock.exportCoco).toHaveBeenCalledWith('1');
|
||||
});
|
||||
});
|
||||
@@ -1,10 +1,28 @@
|
||||
import React, { useEffect } from 'react';
|
||||
import React, { useCallback, useEffect, useMemo, useState } from 'react';
|
||||
import { useStore } from '../store/useStore';
|
||||
import { getProjectFrames, parseMedia, getTemplates } from '../lib/api';
|
||||
import {
|
||||
annotationToMask,
|
||||
buildAnnotationPayload,
|
||||
deleteAnnotation,
|
||||
exportCoco,
|
||||
getProjectAnnotations,
|
||||
getProjectFrames,
|
||||
getTask,
|
||||
getTemplates,
|
||||
parseMedia,
|
||||
saveAnnotation,
|
||||
updateAnnotation,
|
||||
} from '../lib/api';
|
||||
import { CanvasArea } from './CanvasArea';
|
||||
import { ToolsPalette } from './ToolsPalette';
|
||||
import { OntologyInspector } from './OntologyInspector';
|
||||
import { FrameTimeline } from './FrameTimeline';
|
||||
import { ModelStatusBadge } from './ModelStatusBadge';
|
||||
import type { Frame } from '../store/useStore';
|
||||
|
||||
function sleep(ms: number) {
|
||||
return new Promise((resolve) => setTimeout(resolve, ms));
|
||||
}
|
||||
|
||||
export function VideoWorkspace({ onNavigateToAI }: { onNavigateToAI?: () => void }) {
|
||||
const activeTool = useStore((state) => state.activeTool);
|
||||
@@ -12,8 +30,26 @@ export function VideoWorkspace({ onNavigateToAI }: { onNavigateToAI?: () => void
|
||||
const currentProject = useStore((state) => state.currentProject);
|
||||
const frames = useStore((state) => state.frames);
|
||||
const currentFrameIndex = useStore((state) => state.currentFrameIndex);
|
||||
const masks = useStore((state) => state.masks);
|
||||
const activeTemplateId = useStore((state) => state.activeTemplateId);
|
||||
const setFrames = useStore((state) => state.setFrames);
|
||||
const setCurrentFrame = useStore((state) => state.setCurrentFrame);
|
||||
const setMasks = useStore((state) => state.setMasks);
|
||||
const [isSaving, setIsSaving] = useState(false);
|
||||
const [isExporting, setIsExporting] = useState(false);
|
||||
const [statusMessage, setStatusMessage] = useState('');
|
||||
|
||||
const hydrateSavedAnnotations = useCallback(async (projectId: string, projectFrames: Frame[]) => {
|
||||
const frameById = new Map(projectFrames.map((frame) => [frame.id, frame]));
|
||||
const annotations = await getProjectAnnotations(projectId);
|
||||
const savedMasks = annotations
|
||||
.map((annotation) => {
|
||||
const frame = annotation.frame_id ? frameById.get(String(annotation.frame_id)) : null;
|
||||
return frame ? annotationToMask(annotation, frame) : null;
|
||||
})
|
||||
.filter((mask): mask is NonNullable<typeof mask> => Boolean(mask));
|
||||
setMasks(savedMasks);
|
||||
}, [setMasks]);
|
||||
|
||||
useEffect(() => {
|
||||
if (!currentProject?.id) return;
|
||||
@@ -25,34 +61,58 @@ export function VideoWorkspace({ onNavigateToAI }: { onNavigateToAI?: () => void
|
||||
if (cancelled) return;
|
||||
|
||||
if (data.length === 0 && currentProject.video_path) {
|
||||
// No frames yet but video exists → trigger parsing
|
||||
// No frames yet but video exists -> queue parsing and poll the task.
|
||||
try {
|
||||
await parseMedia(String(currentProject.id));
|
||||
const task = await parseMedia(String(currentProject.id));
|
||||
if (cancelled) return;
|
||||
setStatusMessage(`解析任务已入队 #${task.id}`);
|
||||
let completed = false;
|
||||
for (let attempt = 0; attempt < 60; attempt += 1) {
|
||||
const freshTask = await getTask(task.id);
|
||||
if (cancelled) return;
|
||||
setStatusMessage(freshTask.message || `解析进度 ${freshTask.progress}%`);
|
||||
if (freshTask.status === 'success') {
|
||||
completed = true;
|
||||
break;
|
||||
}
|
||||
if (freshTask.status === 'failed') {
|
||||
setStatusMessage(freshTask.error || '解析任务失败');
|
||||
return;
|
||||
}
|
||||
await sleep(2000);
|
||||
}
|
||||
if (!completed) {
|
||||
setStatusMessage('解析仍在后台运行,可稍后刷新工作区');
|
||||
return;
|
||||
}
|
||||
const fresh = await getProjectFrames(String(currentProject.id));
|
||||
if (cancelled) return;
|
||||
setFrames(fresh.map((f) => ({
|
||||
const mappedFrames = fresh.map((f) => ({
|
||||
id: String(f.id),
|
||||
projectId: String(f.project_id),
|
||||
index: f.frame_index,
|
||||
url: f.image_url,
|
||||
width: f.width ?? 0,
|
||||
height: f.height ?? 0,
|
||||
})));
|
||||
}));
|
||||
setFrames(mappedFrames);
|
||||
setCurrentFrame(0);
|
||||
await hydrateSavedAnnotations(String(currentProject.id), mappedFrames);
|
||||
} catch (err) {
|
||||
console.error('Parse failed:', err);
|
||||
}
|
||||
} else {
|
||||
setFrames(data.map((f) => ({
|
||||
const mappedFrames = data.map((f) => ({
|
||||
id: String(f.id),
|
||||
projectId: String(f.project_id),
|
||||
index: f.frame_index,
|
||||
url: f.image_url,
|
||||
width: f.width ?? 0,
|
||||
height: f.height ?? 0,
|
||||
})));
|
||||
}));
|
||||
setFrames(mappedFrames);
|
||||
setCurrentFrame(0);
|
||||
await hydrateSavedAnnotations(String(currentProject.id), mappedFrames);
|
||||
}
|
||||
} catch (err) {
|
||||
console.error('Failed to load frames:', err);
|
||||
@@ -61,7 +121,7 @@ export function VideoWorkspace({ onNavigateToAI }: { onNavigateToAI?: () => void
|
||||
|
||||
loadFrames();
|
||||
return () => { cancelled = true; };
|
||||
}, [currentProject?.id, setFrames, setCurrentFrame]);
|
||||
}, [currentProject?.id, currentProject?.video_path, hydrateSavedAnnotations, setFrames, setCurrentFrame]);
|
||||
|
||||
const templates = useStore((state) => state.templates);
|
||||
const setTemplates = useStore((state) => state.setTemplates);
|
||||
@@ -72,7 +132,121 @@ export function VideoWorkspace({ onNavigateToAI }: { onNavigateToAI?: () => void
|
||||
}
|
||||
}, [templates.length, setTemplates]);
|
||||
|
||||
const currentFrameUrl = frames[currentFrameIndex]?.url || '';
|
||||
const currentFrame = frames[currentFrameIndex] || null;
|
||||
const frameById = useMemo(() => new Map(frames.map((frame) => [frame.id, frame])), [frames]);
|
||||
const projectFrameIds = useMemo(() => new Set(frames.map((frame) => frame.id)), [frames]);
|
||||
|
||||
const savePendingAnnotations = useCallback(async ({ silent = false } = {}) => {
|
||||
if (!currentProject?.id) return 0;
|
||||
const projectMasks = masks.filter((mask) => projectFrameIds.has(mask.frameId));
|
||||
const pendingMasks = projectMasks.filter((mask) => !mask.annotationId);
|
||||
const dirtyMasks = projectMasks.filter((mask) => mask.annotationId && mask.saveStatus === 'dirty');
|
||||
if (pendingMasks.length === 0 && dirtyMasks.length === 0) {
|
||||
if (!silent) setStatusMessage('没有待保存标注');
|
||||
return 0;
|
||||
}
|
||||
|
||||
setIsSaving(true);
|
||||
setStatusMessage('正在保存标注...');
|
||||
try {
|
||||
const createPayloads = pendingMasks
|
||||
.map((mask) => {
|
||||
const frame = frameById.get(mask.frameId);
|
||||
return frame ? buildAnnotationPayload(currentProject.id, mask, frame, activeTemplateId) : null;
|
||||
})
|
||||
.filter((payload): payload is NonNullable<typeof payload> => Boolean(payload));
|
||||
|
||||
const updatePayloads = dirtyMasks
|
||||
.map((mask) => {
|
||||
const frame = frameById.get(mask.frameId);
|
||||
const payload = frame ? buildAnnotationPayload(currentProject.id, mask, frame, activeTemplateId) : null;
|
||||
if (!payload || !mask.annotationId) return null;
|
||||
const updatePayload = {
|
||||
template_id: payload.template_id,
|
||||
mask_data: payload.mask_data,
|
||||
points: payload.points,
|
||||
bbox: payload.bbox,
|
||||
};
|
||||
return { annotationId: mask.annotationId, payload: updatePayload };
|
||||
})
|
||||
.filter((item): item is NonNullable<typeof item> => Boolean(item));
|
||||
|
||||
if (createPayloads.length === 0 && updatePayloads.length === 0) {
|
||||
setStatusMessage('没有可保存的标注数据');
|
||||
return 0;
|
||||
}
|
||||
|
||||
await Promise.all([
|
||||
...createPayloads.map((payload) => saveAnnotation(payload)),
|
||||
...updatePayloads.map(({ annotationId, payload }) => updateAnnotation(annotationId, payload)),
|
||||
]);
|
||||
await hydrateSavedAnnotations(currentProject.id, frames);
|
||||
const savedCount = createPayloads.length + updatePayloads.length;
|
||||
setStatusMessage(`已保存 ${savedCount} 个标注`);
|
||||
return savedCount;
|
||||
} catch (err) {
|
||||
console.error('Save annotations failed:', err);
|
||||
setStatusMessage('保存失败,请检查后端服务');
|
||||
throw err;
|
||||
} finally {
|
||||
setIsSaving(false);
|
||||
}
|
||||
}, [activeTemplateId, currentProject?.id, frameById, frames, hydrateSavedAnnotations, masks, projectFrameIds]);
|
||||
|
||||
const handleClearCurrentFrameMasks = useCallback(async () => {
|
||||
if (!currentFrame) return;
|
||||
const frameMasks = masks.filter((mask) => mask.frameId === currentFrame.id);
|
||||
const annotationIds = frameMasks
|
||||
.map((mask) => mask.annotationId)
|
||||
.filter((annotationId): annotationId is string => Boolean(annotationId));
|
||||
|
||||
setIsSaving(true);
|
||||
setStatusMessage(annotationIds.length > 0 ? '正在删除已保存标注...' : '正在清空本帧遮罩...');
|
||||
try {
|
||||
await Promise.all(annotationIds.map((annotationId) => deleteAnnotation(annotationId)));
|
||||
setMasks(masks.filter((mask) => mask.frameId !== currentFrame.id));
|
||||
setStatusMessage(annotationIds.length > 0
|
||||
? `已删除 ${annotationIds.length} 个后端标注`
|
||||
: '已清空本帧未保存遮罩');
|
||||
} catch (err) {
|
||||
console.error('Delete annotations failed:', err);
|
||||
setStatusMessage('删除失败,请检查后端服务');
|
||||
} finally {
|
||||
setIsSaving(false);
|
||||
}
|
||||
}, [currentFrame, masks, setMasks]);
|
||||
|
||||
const handleSave = async () => {
|
||||
try {
|
||||
await savePendingAnnotations();
|
||||
} catch {
|
||||
// status message is set in savePendingAnnotations
|
||||
}
|
||||
};
|
||||
|
||||
const handleExport = async () => {
|
||||
if (!currentProject?.id) return;
|
||||
setIsExporting(true);
|
||||
setStatusMessage('正在准备导出...');
|
||||
try {
|
||||
await savePendingAnnotations({ silent: true });
|
||||
const blob = await exportCoco(currentProject.id);
|
||||
const url = URL.createObjectURL(blob);
|
||||
const link = document.createElement('a');
|
||||
link.href = url;
|
||||
link.download = `project_${currentProject.id}_coco.json`;
|
||||
document.body.appendChild(link);
|
||||
link.click();
|
||||
link.remove();
|
||||
URL.revokeObjectURL(url);
|
||||
setStatusMessage('COCO JSON 已导出');
|
||||
} catch (err) {
|
||||
console.error('Export failed:', err);
|
||||
setStatusMessage('导出失败,请检查后端服务');
|
||||
} finally {
|
||||
setIsExporting(false);
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="w-full h-full flex flex-col bg-[#0a0a0a]">
|
||||
@@ -84,14 +258,25 @@ export function VideoWorkspace({ onNavigateToAI }: { onNavigateToAI?: () => void
|
||||
<span className="text-sm text-white font-mono">{currentProject?.name || '未选择项目'}</span>
|
||||
</div>
|
||||
<div className="flex items-center gap-3">
|
||||
<div className="flex items-center gap-1.5 text-[10px] uppercase font-medium">
|
||||
<span className="px-2 py-0.5 rounded bg-green-500/10 text-green-400 border border-green-500/20">SAM 3 部署就绪</span>
|
||||
</div>
|
||||
<button className="px-4 py-1.5 bg-white/5 hover:bg-white/10 border border-white/10 rounded-md text-xs transition-colors text-white">
|
||||
导出 JSON 标注集
|
||||
{statusMessage && (
|
||||
<span className="text-[10px] text-gray-500 font-mono max-w-48 truncate" title={statusMessage}>
|
||||
{statusMessage}
|
||||
</span>
|
||||
)}
|
||||
<ModelStatusBadge />
|
||||
<button
|
||||
onClick={handleExport}
|
||||
disabled={!currentProject?.id || isExporting || isSaving}
|
||||
className="px-4 py-1.5 bg-white/5 hover:bg-white/10 border border-white/10 rounded-md text-xs transition-colors text-white disabled:opacity-40 disabled:cursor-not-allowed"
|
||||
>
|
||||
{isExporting ? '导出中...' : '导出 JSON 标注集'}
|
||||
</button>
|
||||
<button className="px-4 py-1.5 bg-cyan-600 hover:bg-cyan-500 text-white text-xs font-medium rounded-md transition-shadow shadow-lg shadow-cyan-900/20">
|
||||
结构化归档保存
|
||||
<button
|
||||
onClick={handleSave}
|
||||
disabled={!currentProject?.id || isSaving || isExporting}
|
||||
className="px-4 py-1.5 bg-cyan-600 hover:bg-cyan-500 text-white text-xs font-medium rounded-md transition-shadow shadow-lg shadow-cyan-900/20 disabled:opacity-40 disabled:cursor-not-allowed"
|
||||
>
|
||||
{isSaving ? '保存中...' : '结构化归档保存'}
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
@@ -102,7 +287,7 @@ export function VideoWorkspace({ onNavigateToAI }: { onNavigateToAI?: () => void
|
||||
|
||||
<div className="flex-1 relative flex items-center justify-center p-8 bg-[#151515] overflow-hidden">
|
||||
<div className="relative w-full h-full bg-[#1e1e1e] border border-white/5 shadow-2xl rounded-sm">
|
||||
<CanvasArea activeTool={activeTool} frameUrl={currentFrameUrl} />
|
||||
<CanvasArea activeTool={activeTool} frame={currentFrame} onClearMasks={handleClearCurrentFrameMasks} />
|
||||
</div>
|
||||
</div>
|
||||
|
||||
|
||||
361
src/lib/api.test.ts
Normal file
361
src/lib/api.test.ts
Normal file
@@ -0,0 +1,361 @@
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest';
|
||||
|
||||
const axiosMock = vi.hoisted(() => {
|
||||
const client = {
|
||||
get: vi.fn(),
|
||||
post: vi.fn(),
|
||||
patch: vi.fn(),
|
||||
delete: vi.fn(),
|
||||
interceptors: {
|
||||
request: { use: vi.fn() },
|
||||
response: { use: vi.fn() },
|
||||
},
|
||||
};
|
||||
return { client, create: vi.fn(() => client) };
|
||||
});
|
||||
|
||||
vi.mock('axios', () => ({
|
||||
default: {
|
||||
create: axiosMock.create,
|
||||
},
|
||||
}));
|
||||
|
||||
describe('api client contracts', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
vi.setSystemTime(new Date('2026-05-01T00:00:00Z'));
|
||||
});
|
||||
|
||||
it('maps backend project fields into frontend project fields', async () => {
|
||||
const { getProjects } = await import('./api');
|
||||
axiosMock.client.get.mockResolvedValueOnce({
|
||||
data: [
|
||||
{
|
||||
id: 7,
|
||||
name: 'Demo',
|
||||
description: 'desc',
|
||||
status: 'ready',
|
||||
frame_count: 12,
|
||||
original_fps: 29.97,
|
||||
parse_fps: 10,
|
||||
thumbnail_url: 'thumb',
|
||||
video_path: 'uploads/demo.mp4',
|
||||
source_type: 'video',
|
||||
created_at: 'created',
|
||||
updated_at: 'updated',
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
await expect(getProjects()).resolves.toEqual([
|
||||
expect.objectContaining({
|
||||
id: '7',
|
||||
name: 'Demo',
|
||||
status: 'ready',
|
||||
frames: 12,
|
||||
fps: '30FPS',
|
||||
thumbnail_url: 'thumb',
|
||||
video_path: 'uploads/demo.mp4',
|
||||
source_type: 'video',
|
||||
createdAt: 'created',
|
||||
updatedAt: 'updated',
|
||||
}),
|
||||
]);
|
||||
expect(axiosMock.client.get).toHaveBeenCalledWith('/api/projects');
|
||||
});
|
||||
|
||||
it('updates projects with PATCH instead of the old PUT contract', async () => {
|
||||
const { updateProject } = await import('./api');
|
||||
axiosMock.client.patch.mockResolvedValueOnce({ data: { id: 3, name: 'Renamed', status: 'ready' } });
|
||||
|
||||
await updateProject('3', { name: 'Renamed' } as any);
|
||||
|
||||
expect(axiosMock.client.patch).toHaveBeenCalledWith('/api/projects/3', { name: 'Renamed' });
|
||||
});
|
||||
|
||||
it('normalizes legacy project status values returned by existing databases', async () => {
|
||||
const { getProjects } = await import('./api');
|
||||
axiosMock.client.get.mockResolvedValueOnce({
|
||||
data: [
|
||||
{ id: 1, name: 'Old Ready', status: 'Ready' },
|
||||
{ id: 2, name: 'Old Parsing', status: 'Parsing' },
|
||||
{ id: 3, name: 'Old Error', status: 'Error' },
|
||||
],
|
||||
});
|
||||
|
||||
await expect(getProjects()).resolves.toEqual([
|
||||
expect.objectContaining({ status: 'ready' }),
|
||||
expect.objectContaining({ status: 'parsing' }),
|
||||
expect.objectContaining({ status: 'error' }),
|
||||
]);
|
||||
});
|
||||
|
||||
it('exports COCO from the backend route shape', async () => {
|
||||
const { exportCoco } = await import('./api');
|
||||
const blob = new Blob(['{}'], { type: 'application/json' });
|
||||
axiosMock.client.get.mockResolvedValueOnce({ data: blob });
|
||||
|
||||
await expect(exportCoco('9')).resolves.toBe(blob);
|
||||
expect(axiosMock.client.get).toHaveBeenCalledWith('/api/export/9/coco', {
|
||||
responseType: 'blob',
|
||||
});
|
||||
});
|
||||
|
||||
it('loads dashboard overview from the backend summary endpoint', async () => {
|
||||
const { getDashboardOverview } = await import('./api');
|
||||
const overview = {
|
||||
summary: {
|
||||
project_count: 2,
|
||||
parsing_task_count: 1,
|
||||
annotation_count: 5,
|
||||
frame_count: 100,
|
||||
template_count: 3,
|
||||
system_load_percent: 12,
|
||||
},
|
||||
tasks: [
|
||||
{ id: 'project-1', project_id: 1, name: 'Demo', progress: 60, status: 'pending', frame_count: 10, updated_at: 'now' },
|
||||
],
|
||||
activity: [
|
||||
{ id: 'project-1', kind: 'project', time: 'now', message: '项目状态: pending', project: 'Demo' },
|
||||
],
|
||||
};
|
||||
axiosMock.client.get.mockResolvedValueOnce({ data: overview });
|
||||
|
||||
await expect(getDashboardOverview()).resolves.toEqual(overview);
|
||||
expect(axiosMock.client.get).toHaveBeenCalledWith('/api/dashboard/overview');
|
||||
});
|
||||
|
||||
it('queues media parsing and reads processing task status', async () => {
|
||||
const { getTask, parseMedia } = await import('./api');
|
||||
const task = {
|
||||
id: 12,
|
||||
task_type: 'parse_video',
|
||||
status: 'queued',
|
||||
progress: 0,
|
||||
message: '解析任务已入队',
|
||||
project_id: 9,
|
||||
celery_task_id: 'celery-12',
|
||||
payload: { source_type: 'video' },
|
||||
result: null,
|
||||
error: null,
|
||||
created_at: 'created',
|
||||
started_at: null,
|
||||
finished_at: null,
|
||||
updated_at: 'updated',
|
||||
};
|
||||
axiosMock.client.post.mockResolvedValueOnce({ data: task });
|
||||
axiosMock.client.get.mockResolvedValueOnce({ data: { ...task, status: 'success', progress: 100 } });
|
||||
|
||||
await expect(parseMedia('9')).resolves.toEqual(task);
|
||||
expect(axiosMock.client.post).toHaveBeenCalledWith('/api/media/parse', null, {
|
||||
params: { project_id: '9' },
|
||||
});
|
||||
|
||||
await expect(getTask(12)).resolves.toEqual(expect.objectContaining({ status: 'success', progress: 100 }));
|
||||
expect(axiosMock.client.get).toHaveBeenCalledWith('/api/tasks/12');
|
||||
});
|
||||
|
||||
it('lists, saves, updates, and deletes annotations with the backend annotation contract', async () => {
|
||||
const { deleteAnnotation, getProjectAnnotations, saveAnnotation, updateAnnotation } = await import('./api');
|
||||
const saved = {
|
||||
id: 1,
|
||||
project_id: 9,
|
||||
frame_id: 5,
|
||||
template_id: 2,
|
||||
mask_data: { polygons: [[[0, 0], [1, 0], [1, 1]]] },
|
||||
points: null,
|
||||
bbox: null,
|
||||
created_at: 'created',
|
||||
updated_at: 'updated',
|
||||
};
|
||||
axiosMock.client.get.mockResolvedValueOnce({ data: [saved] });
|
||||
axiosMock.client.post.mockResolvedValueOnce({ data: saved });
|
||||
|
||||
await expect(getProjectAnnotations('9', '5')).resolves.toEqual([saved]);
|
||||
expect(axiosMock.client.get).toHaveBeenCalledWith('/api/ai/annotations', {
|
||||
params: { project_id: 9, frame_id: 5 },
|
||||
});
|
||||
|
||||
await expect(saveAnnotation({
|
||||
project_id: 9,
|
||||
frame_id: 5,
|
||||
template_id: 2,
|
||||
mask_data: { polygons: [[[0, 0], [1, 0], [1, 1]]], label: 'mask' },
|
||||
})).resolves.toEqual(saved);
|
||||
expect(axiosMock.client.post).toHaveBeenCalledWith('/api/ai/annotate', {
|
||||
project_id: 9,
|
||||
frame_id: 5,
|
||||
template_id: 2,
|
||||
mask_data: { polygons: [[[0, 0], [1, 0], [1, 1]]], label: 'mask' },
|
||||
});
|
||||
|
||||
axiosMock.client.patch.mockResolvedValueOnce({ data: { ...saved, mask_data: { ...saved.mask_data, label: 'updated' } } });
|
||||
await expect(updateAnnotation('1', {
|
||||
template_id: 2,
|
||||
mask_data: { polygons: [[[0, 0], [1, 0], [1, 1]]], label: 'updated' },
|
||||
})).resolves.toEqual(expect.objectContaining({ mask_data: expect.objectContaining({ label: 'updated' }) }));
|
||||
expect(axiosMock.client.patch).toHaveBeenCalledWith('/api/ai/annotations/1', {
|
||||
template_id: 2,
|
||||
mask_data: { polygons: [[[0, 0], [1, 0], [1, 1]]], label: 'updated' },
|
||||
});
|
||||
|
||||
axiosMock.client.delete.mockResolvedValueOnce({ data: null });
|
||||
await expect(deleteAnnotation('1')).resolves.toBeUndefined();
|
||||
expect(axiosMock.client.delete).toHaveBeenCalledWith('/api/ai/annotations/1');
|
||||
});
|
||||
|
||||
it('builds annotation payloads from frontend masks and restores saved annotations to masks', async () => {
|
||||
const { annotationToMask, buildAnnotationPayload } = await import('./api');
|
||||
const frame = { id: '5', projectId: '9', index: 0, url: '/frame.jpg', width: 100, height: 50 };
|
||||
const payload = buildAnnotationPayload('9', {
|
||||
id: 'm1',
|
||||
frameId: '5',
|
||||
pathData: 'M 10 10 L 90 10 L 90 40 Z',
|
||||
label: '胆囊',
|
||||
color: '#ff0000',
|
||||
classId: 'c1',
|
||||
className: '胆囊',
|
||||
classZIndex: 20,
|
||||
segmentation: [[10, 10, 90, 10, 90, 40]],
|
||||
bbox: [10, 10, 80, 30],
|
||||
}, frame, '2');
|
||||
|
||||
expect(payload).toEqual({
|
||||
project_id: 9,
|
||||
frame_id: 5,
|
||||
template_id: 2,
|
||||
mask_data: {
|
||||
polygons: [[[0.1, 0.2], [0.9, 0.2], [0.9, 0.8]]],
|
||||
label: '胆囊',
|
||||
color: '#ff0000',
|
||||
class: { id: 'c1', name: '胆囊', color: '#ff0000', zIndex: 20 },
|
||||
},
|
||||
bbox: [0.1, 0.2, 0.8, 0.6],
|
||||
});
|
||||
|
||||
expect(annotationToMask({
|
||||
id: 3,
|
||||
project_id: 9,
|
||||
frame_id: 5,
|
||||
template_id: 2,
|
||||
mask_data: {
|
||||
polygons: [[[0.1, 0.2], [0.9, 0.2], [0.9, 0.8]]],
|
||||
label: '旧标签',
|
||||
color: '#06b6d4',
|
||||
class: { id: 'c1', name: '胆囊', color: '#ff0000', zIndex: 20 },
|
||||
},
|
||||
points: null,
|
||||
bbox: null,
|
||||
created_at: 'created',
|
||||
updated_at: 'updated',
|
||||
}, frame)).toEqual(expect.objectContaining({
|
||||
id: 'annotation-3',
|
||||
annotationId: '3',
|
||||
frameId: '5',
|
||||
templateId: '2',
|
||||
classId: 'c1',
|
||||
className: '胆囊',
|
||||
classZIndex: 20,
|
||||
label: '胆囊',
|
||||
color: '#ff0000',
|
||||
saveStatus: 'saved',
|
||||
saved: true,
|
||||
pathData: 'M 10 10 L 90 10 L 90 40 Z',
|
||||
bbox: [10, 10, 80, 30],
|
||||
}));
|
||||
});
|
||||
|
||||
it('normalizes positive and negative point prompts for AI prediction', async () => {
|
||||
const { predictMask } = await import('./api');
|
||||
axiosMock.client.post.mockResolvedValueOnce({
|
||||
data: {
|
||||
polygons: [[[0.25, 0.25], [0.75, 0.25], [0.75, 0.75], [0.25, 0.75]]],
|
||||
scores: [0.9],
|
||||
},
|
||||
});
|
||||
|
||||
const result = await predictMask({
|
||||
imageId: '42',
|
||||
imageWidth: 400,
|
||||
imageHeight: 200,
|
||||
points: [
|
||||
{ x: 200, y: 100, type: 'pos' },
|
||||
{ x: 40, y: 20, type: 'neg' },
|
||||
],
|
||||
});
|
||||
|
||||
expect(axiosMock.client.post).toHaveBeenCalledWith('/api/ai/predict', {
|
||||
image_id: 42,
|
||||
prompt_type: 'point',
|
||||
prompt_data: {
|
||||
points: [[0.5, 0.5], [0.1, 0.1]],
|
||||
labels: [1, 0],
|
||||
},
|
||||
model: 'sam2',
|
||||
});
|
||||
expect(result.masks[0]).toEqual(expect.objectContaining({
|
||||
pathData: 'M 100 50 L 300 50 L 300 150 L 100 150 Z',
|
||||
segmentation: [[100, 50, 300, 50, 300, 150, 100, 150]],
|
||||
bbox: [100, 50, 200, 100],
|
||||
area: 20000,
|
||||
confidence: 0.9,
|
||||
}));
|
||||
});
|
||||
|
||||
it('normalizes box prompts for AI prediction', async () => {
|
||||
const { predictMask } = await import('./api');
|
||||
axiosMock.client.post.mockResolvedValueOnce({ data: { polygons: [], scores: [] } });
|
||||
|
||||
await predictMask({
|
||||
imageId: '5',
|
||||
imageWidth: 640,
|
||||
imageHeight: 320,
|
||||
box: { x1: 64, y1: 32, x2: 320, y2: 160 },
|
||||
});
|
||||
|
||||
expect(axiosMock.client.post).toHaveBeenCalledWith('/api/ai/predict', {
|
||||
image_id: 5,
|
||||
prompt_type: 'box',
|
||||
prompt_data: [0.1, 0.1, 0.5, 0.5],
|
||||
model: 'sam2',
|
||||
});
|
||||
});
|
||||
|
||||
it('uses semantic prompt type for text-only AI prediction', async () => {
|
||||
const { predictMask } = await import('./api');
|
||||
axiosMock.client.post.mockResolvedValueOnce({ data: { polygons: [], scores: [] } });
|
||||
|
||||
await predictMask({
|
||||
imageId: '6',
|
||||
imageWidth: 640,
|
||||
imageHeight: 360,
|
||||
model: 'sam3',
|
||||
text: '分割胆囊',
|
||||
});
|
||||
|
||||
expect(axiosMock.client.post).toHaveBeenCalledWith('/api/ai/predict', {
|
||||
image_id: 6,
|
||||
prompt_type: 'semantic',
|
||||
prompt_data: '分割胆囊',
|
||||
model: 'sam3',
|
||||
});
|
||||
});
|
||||
|
||||
it('loads AI model and GPU runtime status', async () => {
|
||||
const { getAiModelStatus } = await import('./api');
|
||||
const status = {
|
||||
selected_model: 'sam2',
|
||||
gpu: { available: false, device: 'cpu', name: null, torch_available: true, torch_version: '2.x', cuda_version: null },
|
||||
models: [
|
||||
{ id: 'sam2', label: 'SAM 2', available: true, loaded: false, device: 'cpu', supports: ['point'], message: 'ready', package_available: true, checkpoint_exists: true, checkpoint_path: 'model.pt', python_ok: true, torch_ok: true, cuda_required: false },
|
||||
{ id: 'sam3', label: 'SAM 3', available: false, loaded: false, device: 'unavailable', supports: ['semantic'], message: 'missing runtime', package_available: false, checkpoint_exists: false, checkpoint_path: null, python_ok: false, torch_ok: true, cuda_required: true },
|
||||
],
|
||||
};
|
||||
axiosMock.client.get.mockResolvedValueOnce({ data: status });
|
||||
|
||||
await expect(getAiModelStatus('sam3')).resolves.toEqual(status);
|
||||
expect(axiosMock.client.get).toHaveBeenCalledWith('/api/ai/models/status', {
|
||||
params: { selected_model: 'sam3' },
|
||||
});
|
||||
});
|
||||
});
|
||||
409
src/lib/api.ts
409
src/lib/api.ts
@@ -1,8 +1,9 @@
|
||||
import axios, { AxiosError } from 'axios';
|
||||
import type { Project, Template } from '../store/useStore';
|
||||
import type { AiModelId, Frame, Mask, Project, Template } from '../store/useStore';
|
||||
import { API_BASE_URL } from './config';
|
||||
|
||||
const apiClient = axios.create({
|
||||
baseURL: 'http://192.168.3.11:8000',
|
||||
baseURL: API_BASE_URL,
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
@@ -40,37 +41,20 @@ export async function login(username: string, password: string): Promise<{ token
|
||||
}
|
||||
|
||||
// Projects
|
||||
export async function getProjects(): Promise<Project[]> {
|
||||
const response = await apiClient.get('/api/projects');
|
||||
return response.data.map((p: any) => ({
|
||||
id: String(p.id),
|
||||
name: p.name,
|
||||
description: p.description,
|
||||
status: p.status,
|
||||
frames: p.frame_count ?? 0,
|
||||
fps: p.original_fps ? `${Math.round(p.original_fps)}FPS` : '30FPS',
|
||||
thumbnail_url: p.thumbnail_url,
|
||||
video_path: p.video_path,
|
||||
source_type: p.source_type,
|
||||
original_fps: p.original_fps,
|
||||
parse_fps: p.parse_fps,
|
||||
createdAt: p.created_at,
|
||||
updatedAt: p.updated_at,
|
||||
}));
|
||||
function normalizeProjectStatus(status?: string): Project['status'] {
|
||||
const value = (status || 'pending').toLowerCase();
|
||||
if (value === 'ready') return 'ready';
|
||||
if (value === 'parsing' || value === 'queued' || value === 'running') return 'parsing';
|
||||
if (value === 'error' || value === 'failed') return 'error';
|
||||
return 'pending';
|
||||
}
|
||||
|
||||
export async function createProject(payload: {
|
||||
name: string;
|
||||
description?: string;
|
||||
parse_fps?: number;
|
||||
}): Promise<Project> {
|
||||
const response = await apiClient.post('/api/projects', payload);
|
||||
const p = response.data;
|
||||
function mapProject(p: any): Project {
|
||||
return {
|
||||
id: String(p.id),
|
||||
name: p.name,
|
||||
description: p.description,
|
||||
status: p.status,
|
||||
status: normalizeProjectStatus(p.status),
|
||||
frames: p.frame_count ?? 0,
|
||||
fps: p.original_fps ? `${Math.round(p.original_fps)}FPS` : '30FPS',
|
||||
thumbnail_url: p.thumbnail_url,
|
||||
@@ -83,9 +67,23 @@ export async function createProject(payload: {
|
||||
};
|
||||
}
|
||||
|
||||
export async function getProjects(): Promise<Project[]> {
|
||||
const response = await apiClient.get('/api/projects');
|
||||
return response.data.map(mapProject);
|
||||
}
|
||||
|
||||
export async function createProject(payload: {
|
||||
name: string;
|
||||
description?: string;
|
||||
parse_fps?: number;
|
||||
}): Promise<Project> {
|
||||
const response = await apiClient.post('/api/projects', payload);
|
||||
return mapProject(response.data);
|
||||
}
|
||||
|
||||
export async function updateProject(id: string, payload: Partial<Project>): Promise<Project> {
|
||||
const response = await apiClient.put(`/api/projects/${id}`, payload);
|
||||
return response.data;
|
||||
const response = await apiClient.patch(`/api/projects/${id}`, payload);
|
||||
return mapProject(response.data);
|
||||
}
|
||||
|
||||
export async function deleteProject(id: string): Promise<void> {
|
||||
@@ -170,26 +168,46 @@ export async function uploadDicomBatch(files: File[], projectId?: string): Promi
|
||||
return response.data;
|
||||
}
|
||||
|
||||
export async function parseMedia(projectId: string): Promise<{
|
||||
project_id: number;
|
||||
frames_extracted: number;
|
||||
status: string;
|
||||
message: string;
|
||||
}> {
|
||||
export interface ProcessingTask {
|
||||
id: number;
|
||||
task_type: string;
|
||||
status: 'queued' | 'running' | 'success' | 'failed' | string;
|
||||
progress: number;
|
||||
message?: string | null;
|
||||
project_id?: number | null;
|
||||
celery_task_id?: string | null;
|
||||
payload?: Record<string, unknown> | null;
|
||||
result?: Record<string, unknown> | null;
|
||||
error?: string | null;
|
||||
created_at: string;
|
||||
started_at?: string | null;
|
||||
finished_at?: string | null;
|
||||
updated_at: string;
|
||||
}
|
||||
|
||||
export async function parseMedia(projectId: string): Promise<ProcessingTask> {
|
||||
const response = await apiClient.post('/api/media/parse', null, {
|
||||
params: { project_id: projectId },
|
||||
});
|
||||
return response.data;
|
||||
}
|
||||
|
||||
// AI Prediction
|
||||
export async function predictMask(payload: {
|
||||
imageUrl: string;
|
||||
export async function getTask(taskId: string | number): Promise<ProcessingTask> {
|
||||
const response = await apiClient.get(`/api/tasks/${taskId}`);
|
||||
return response.data;
|
||||
}
|
||||
|
||||
interface PredictMaskPayload {
|
||||
imageId: string;
|
||||
imageWidth: number;
|
||||
imageHeight: number;
|
||||
model?: AiModelId;
|
||||
points?: { x: number; y: number; type: 'pos' | 'neg' }[];
|
||||
box?: { x1: number; y1: number; x2: number; y2: number };
|
||||
text?: string;
|
||||
modelSize?: string;
|
||||
}): Promise<{
|
||||
}
|
||||
|
||||
interface PredictMaskResult {
|
||||
masks: Array<{
|
||||
id: string;
|
||||
pathData: string;
|
||||
@@ -200,14 +218,319 @@ export async function predictMask(payload: {
|
||||
area: number;
|
||||
confidence: number;
|
||||
}>;
|
||||
}> {
|
||||
const response = await apiClient.post('/api/ai/predict', payload);
|
||||
}
|
||||
|
||||
export interface AiModelStatus {
|
||||
id: AiModelId;
|
||||
label: string;
|
||||
available: boolean;
|
||||
loaded: boolean;
|
||||
device: string;
|
||||
supports: string[];
|
||||
message: string;
|
||||
package_available: boolean;
|
||||
checkpoint_exists: boolean;
|
||||
checkpoint_path?: string | null;
|
||||
python_ok: boolean;
|
||||
torch_ok: boolean;
|
||||
cuda_required: boolean;
|
||||
}
|
||||
|
||||
export interface AiRuntimeStatus {
|
||||
selected_model: AiModelId;
|
||||
gpu: {
|
||||
available: boolean;
|
||||
device: string;
|
||||
name?: string | null;
|
||||
torch_available: boolean;
|
||||
torch_version?: string | null;
|
||||
cuda_version?: string | null;
|
||||
};
|
||||
models: AiModelStatus[];
|
||||
}
|
||||
|
||||
export interface SavedAnnotation {
|
||||
id: number;
|
||||
project_id: number;
|
||||
frame_id: number | null;
|
||||
template_id: number | null;
|
||||
mask_data: {
|
||||
polygons?: number[][][];
|
||||
label?: string;
|
||||
color?: string;
|
||||
class?: {
|
||||
id?: string;
|
||||
name?: string;
|
||||
color?: string;
|
||||
zIndex?: number;
|
||||
category?: string;
|
||||
};
|
||||
} | null;
|
||||
points: number[][] | null;
|
||||
bbox: number[] | null;
|
||||
created_at: string;
|
||||
updated_at: string;
|
||||
}
|
||||
|
||||
export interface SaveAnnotationPayload {
|
||||
project_id: number;
|
||||
frame_id?: number;
|
||||
template_id?: number;
|
||||
mask_data?: {
|
||||
polygons: number[][][];
|
||||
label?: string;
|
||||
color?: string;
|
||||
class?: {
|
||||
id?: string;
|
||||
name?: string;
|
||||
color?: string;
|
||||
zIndex?: number;
|
||||
category?: string;
|
||||
};
|
||||
};
|
||||
points?: number[][];
|
||||
bbox?: number[];
|
||||
}
|
||||
|
||||
export type UpdateAnnotationPayload = Omit<SaveAnnotationPayload, 'project_id' | 'frame_id'>;
|
||||
|
||||
export interface DashboardTask {
|
||||
id: string;
|
||||
task_id?: number;
|
||||
project_id: number;
|
||||
name: string;
|
||||
progress: number;
|
||||
status: string;
|
||||
frame_count: number;
|
||||
updated_at: string | null;
|
||||
}
|
||||
|
||||
export interface DashboardActivity {
|
||||
id: string;
|
||||
kind: 'project' | 'annotation' | 'template' | string;
|
||||
time: string | null;
|
||||
message: string;
|
||||
project: string;
|
||||
}
|
||||
|
||||
export interface DashboardOverview {
|
||||
summary: {
|
||||
project_count: number;
|
||||
parsing_task_count: number;
|
||||
annotation_count: number;
|
||||
frame_count: number;
|
||||
template_count: number;
|
||||
system_load_percent: number;
|
||||
};
|
||||
tasks: DashboardTask[];
|
||||
activity: DashboardActivity[];
|
||||
}
|
||||
|
||||
function clamp01(value: number): number {
|
||||
return Math.min(Math.max(value, 0), 1);
|
||||
}
|
||||
|
||||
function normalizePoint(point: { x: number; y: number }, width: number, height: number): [number, number] {
|
||||
return [
|
||||
clamp01(point.x / Math.max(width, 1)),
|
||||
clamp01(point.y / Math.max(height, 1)),
|
||||
];
|
||||
}
|
||||
|
||||
function polygonToPath(points: number[][], width: number, height: number): string {
|
||||
if (points.length === 0) return '';
|
||||
return points
|
||||
.map(([x, y], index) => {
|
||||
const px = x * width;
|
||||
const py = y * height;
|
||||
return `${index === 0 ? 'M' : 'L'} ${px} ${py}`;
|
||||
})
|
||||
.join(' ')
|
||||
.concat(' Z');
|
||||
}
|
||||
|
||||
function polygonToBbox(points: number[][], width: number, height: number): [number, number, number, number] {
|
||||
const xs = points.map(([x]) => x * width);
|
||||
const ys = points.map(([, y]) => y * height);
|
||||
const minX = Math.min(...xs);
|
||||
const minY = Math.min(...ys);
|
||||
const maxX = Math.max(...xs);
|
||||
const maxY = Math.max(...ys);
|
||||
return [minX, minY, maxX - minX, maxY - minY];
|
||||
}
|
||||
|
||||
function pixelSegmentationToNormalizedPolygons(
|
||||
segmentation: number[][] | undefined,
|
||||
width: number,
|
||||
height: number,
|
||||
): number[][][] {
|
||||
if (!segmentation) return [];
|
||||
return segmentation
|
||||
.map((poly) => {
|
||||
const points: number[][] = [];
|
||||
for (let i = 0; i < poly.length - 1; i += 2) {
|
||||
points.push([
|
||||
clamp01(poly[i] / Math.max(width, 1)),
|
||||
clamp01(poly[i + 1] / Math.max(height, 1)),
|
||||
]);
|
||||
}
|
||||
return points;
|
||||
})
|
||||
.filter((points) => points.length > 0);
|
||||
}
|
||||
|
||||
export function buildAnnotationPayload(
|
||||
projectId: string,
|
||||
mask: Mask,
|
||||
frame: Frame,
|
||||
templateId?: string | null,
|
||||
): SaveAnnotationPayload | null {
|
||||
const polygons = pixelSegmentationToNormalizedPolygons(mask.segmentation, frame.width, frame.height);
|
||||
if (polygons.length === 0) return null;
|
||||
const effectiveTemplateId = mask.templateId || templateId || undefined;
|
||||
const classMetadata = mask.classId || mask.className || mask.classZIndex !== undefined
|
||||
? {
|
||||
id: mask.classId,
|
||||
name: mask.className || mask.label,
|
||||
color: mask.color,
|
||||
zIndex: mask.classZIndex,
|
||||
}
|
||||
: undefined;
|
||||
|
||||
return {
|
||||
project_id: Number(projectId),
|
||||
frame_id: Number(frame.id),
|
||||
template_id: effectiveTemplateId ? Number(effectiveTemplateId) : undefined,
|
||||
mask_data: {
|
||||
polygons,
|
||||
label: mask.label,
|
||||
color: mask.color,
|
||||
...(classMetadata ? { class: classMetadata } : {}),
|
||||
},
|
||||
bbox: mask.bbox
|
||||
? [
|
||||
clamp01(mask.bbox[0] / Math.max(frame.width, 1)),
|
||||
clamp01(mask.bbox[1] / Math.max(frame.height, 1)),
|
||||
clamp01(mask.bbox[2] / Math.max(frame.width, 1)),
|
||||
clamp01(mask.bbox[3] / Math.max(frame.height, 1)),
|
||||
]
|
||||
: undefined,
|
||||
};
|
||||
}
|
||||
|
||||
export function annotationToMask(annotation: SavedAnnotation, frame: Frame): Mask | null {
|
||||
const polygons = annotation.mask_data?.polygons || [];
|
||||
const firstPolygon = polygons[0];
|
||||
if (!firstPolygon || firstPolygon.length === 0) return null;
|
||||
const bbox = polygonToBbox(firstPolygon, frame.width, frame.height);
|
||||
const classMetadata = annotation.mask_data?.class;
|
||||
return {
|
||||
id: `annotation-${annotation.id}`,
|
||||
annotationId: String(annotation.id),
|
||||
frameId: String(annotation.frame_id),
|
||||
templateId: annotation.template_id ? String(annotation.template_id) : undefined,
|
||||
classId: classMetadata?.id,
|
||||
className: classMetadata?.name,
|
||||
classZIndex: classMetadata?.zIndex,
|
||||
saveStatus: 'saved',
|
||||
saved: true,
|
||||
pathData: polygonToPath(firstPolygon, frame.width, frame.height),
|
||||
label: classMetadata?.name || annotation.mask_data?.label || `Annotation ${annotation.id}`,
|
||||
color: classMetadata?.color || annotation.mask_data?.color || '#06b6d4',
|
||||
segmentation: polygons.map((polygon) => polygon.flatMap(([x, y]) => [x * frame.width, y * frame.height])),
|
||||
bbox,
|
||||
area: bbox[2] * bbox[3],
|
||||
};
|
||||
}
|
||||
|
||||
export async function predictMask(payload: PredictMaskPayload): Promise<PredictMaskResult> {
|
||||
let prompt_type: 'point' | 'box' | 'semantic';
|
||||
let prompt_data: unknown;
|
||||
|
||||
if (payload.box) {
|
||||
prompt_type = 'box';
|
||||
prompt_data = [
|
||||
clamp01(payload.box.x1 / Math.max(payload.imageWidth, 1)),
|
||||
clamp01(payload.box.y1 / Math.max(payload.imageHeight, 1)),
|
||||
clamp01(payload.box.x2 / Math.max(payload.imageWidth, 1)),
|
||||
clamp01(payload.box.y2 / Math.max(payload.imageHeight, 1)),
|
||||
];
|
||||
} else if (payload.points && payload.points.length > 0) {
|
||||
prompt_type = 'point';
|
||||
prompt_data = {
|
||||
points: payload.points.map((point) => normalizePoint(point, payload.imageWidth, payload.imageHeight)),
|
||||
labels: payload.points.map((point) => (point.type === 'neg' ? 0 : 1)),
|
||||
};
|
||||
} else {
|
||||
prompt_type = 'semantic';
|
||||
prompt_data = payload.text?.trim() || '';
|
||||
}
|
||||
|
||||
const response = await apiClient.post('/api/ai/predict', {
|
||||
image_id: Number(payload.imageId),
|
||||
prompt_type,
|
||||
prompt_data,
|
||||
model: payload.model || 'sam2',
|
||||
});
|
||||
|
||||
const polygons: number[][][] = response.data.polygons || [];
|
||||
const scores: number[] = response.data.scores || [];
|
||||
return {
|
||||
masks: polygons.map((polygon, index) => {
|
||||
const bbox = polygonToBbox(polygon, payload.imageWidth, payload.imageHeight);
|
||||
return {
|
||||
id: `mask-${payload.imageId}-${Date.now()}-${index}`,
|
||||
pathData: polygonToPath(polygon, payload.imageWidth, payload.imageHeight),
|
||||
label: prompt_type === 'semantic' ? (payload.text?.trim() || 'AI Mask') : 'AI Mask',
|
||||
color: '#06b6d4',
|
||||
segmentation: [polygon.flatMap(([x, y]) => [x * payload.imageWidth, y * payload.imageHeight])],
|
||||
bbox,
|
||||
area: bbox[2] * bbox[3],
|
||||
confidence: scores[index] ?? 0,
|
||||
};
|
||||
}),
|
||||
};
|
||||
}
|
||||
|
||||
export async function getAiModelStatus(selectedModel?: AiModelId): Promise<AiRuntimeStatus> {
|
||||
const response = await apiClient.get('/api/ai/models/status', {
|
||||
params: selectedModel ? { selected_model: selectedModel } : undefined,
|
||||
});
|
||||
return response.data;
|
||||
}
|
||||
|
||||
export async function getProjectAnnotations(projectId: string, frameId?: string): Promise<SavedAnnotation[]> {
|
||||
const response = await apiClient.get('/api/ai/annotations', {
|
||||
params: {
|
||||
project_id: Number(projectId),
|
||||
...(frameId ? { frame_id: Number(frameId) } : {}),
|
||||
},
|
||||
});
|
||||
return response.data;
|
||||
}
|
||||
|
||||
export async function saveAnnotation(payload: SaveAnnotationPayload): Promise<SavedAnnotation> {
|
||||
const response = await apiClient.post('/api/ai/annotate', payload);
|
||||
return response.data;
|
||||
}
|
||||
|
||||
export async function updateAnnotation(annotationId: string, payload: UpdateAnnotationPayload): Promise<SavedAnnotation> {
|
||||
const response = await apiClient.patch(`/api/ai/annotations/${annotationId}`, payload);
|
||||
return response.data;
|
||||
}
|
||||
|
||||
export async function deleteAnnotation(annotationId: string): Promise<void> {
|
||||
await apiClient.delete(`/api/ai/annotations/${annotationId}`);
|
||||
}
|
||||
|
||||
export async function getDashboardOverview(): Promise<DashboardOverview> {
|
||||
const response = await apiClient.get('/api/dashboard/overview');
|
||||
return response.data;
|
||||
}
|
||||
|
||||
// Export
|
||||
export async function exportCoco(projectId: string): Promise<Blob> {
|
||||
const response = await apiClient.get(`/api/export/coco/${projectId}`, {
|
||||
const response = await apiClient.get(`/api/export/${projectId}/coco`, {
|
||||
responseType: 'blob',
|
||||
});
|
||||
return response.data;
|
||||
|
||||
38
src/lib/config.test.ts
Normal file
38
src/lib/config.test.ts
Normal file
@@ -0,0 +1,38 @@
|
||||
import { afterEach, describe, expect, it, vi } from 'vitest';
|
||||
|
||||
describe('frontend runtime config', () => {
|
||||
afterEach(() => {
|
||||
vi.unstubAllEnvs();
|
||||
vi.resetModules();
|
||||
});
|
||||
|
||||
it('prefers explicit VITE_API_BASE_URL and trims trailing slashes', async () => {
|
||||
vi.stubEnv('VITE_API_BASE_URL', 'http://api.example.test:8000///');
|
||||
|
||||
const config = await import('./config');
|
||||
|
||||
expect(config.API_BASE_URL).toBe('http://api.example.test:8000');
|
||||
});
|
||||
|
||||
it('infers the API host from the current browser hostname', async () => {
|
||||
const config = await import('./config');
|
||||
|
||||
expect(config.API_BASE_URL).toBe('http://seg.local:8000');
|
||||
});
|
||||
|
||||
it('derives websocket URL from API URL unless explicitly configured', async () => {
|
||||
vi.stubEnv('VITE_API_BASE_URL', 'https://seg.example.test');
|
||||
|
||||
const config = await import('./config');
|
||||
|
||||
expect(config.WS_PROGRESS_URL).toBe('wss://seg.example.test/ws/progress');
|
||||
});
|
||||
|
||||
it('prefers explicit VITE_WS_PROGRESS_URL', async () => {
|
||||
vi.stubEnv('VITE_WS_PROGRESS_URL', 'ws://custom/ws/progress');
|
||||
|
||||
const config = await import('./config');
|
||||
|
||||
expect(config.WS_PROGRESS_URL).toBe('ws://custom/ws/progress');
|
||||
});
|
||||
});
|
||||
29
src/lib/config.ts
Normal file
29
src/lib/config.ts
Normal file
@@ -0,0 +1,29 @@
|
||||
const DEFAULT_API_BASE_URL = 'http://192.168.3.11:8000';
|
||||
|
||||
function trimTrailingSlash(value: string): string {
|
||||
return value.replace(/\/+$/, '');
|
||||
}
|
||||
|
||||
function inferApiBaseUrl(): string {
|
||||
const envUrl = import.meta.env.VITE_API_BASE_URL;
|
||||
if (envUrl) return trimTrailingSlash(envUrl);
|
||||
|
||||
if (typeof window !== 'undefined' && window.location.hostname) {
|
||||
return `${window.location.protocol}//${window.location.hostname}:8000`;
|
||||
}
|
||||
|
||||
return DEFAULT_API_BASE_URL;
|
||||
}
|
||||
|
||||
export const API_BASE_URL = inferApiBaseUrl();
|
||||
|
||||
function inferWsProgressUrl(): string {
|
||||
const envUrl = import.meta.env.VITE_WS_PROGRESS_URL;
|
||||
if (envUrl) return envUrl;
|
||||
|
||||
const url = new URL('/ws/progress', API_BASE_URL);
|
||||
url.protocol = url.protocol === 'https:' ? 'wss:' : 'ws:';
|
||||
return url.toString();
|
||||
}
|
||||
|
||||
export const WS_PROGRESS_URL = inferWsProgressUrl();
|
||||
15
src/lib/templateSelection.ts
Normal file
15
src/lib/templateSelection.ts
Normal file
@@ -0,0 +1,15 @@
|
||||
import type { Template, TemplateClass } from '../store/useStore';
|
||||
|
||||
export function getActiveTemplate(templates: Template[], activeTemplateId: string | null): Template | null {
|
||||
return templates.find((template) => template.id === activeTemplateId) || templates[0] || null;
|
||||
}
|
||||
|
||||
export function getActiveClass(
|
||||
templates: Template[],
|
||||
activeTemplateId: string | null,
|
||||
activeClassId: string | null,
|
||||
): TemplateClass | null {
|
||||
const template = getActiveTemplate(templates, activeTemplateId);
|
||||
if (!template) return null;
|
||||
return template.classes.find((templateClass) => templateClass.id === activeClassId) || null;
|
||||
}
|
||||
46
src/lib/websocket.test.ts
Normal file
46
src/lib/websocket.test.ts
Normal file
@@ -0,0 +1,46 @@
|
||||
import { afterEach, describe, expect, it, vi } from 'vitest';
|
||||
|
||||
describe('progress websocket client', () => {
|
||||
afterEach(() => {
|
||||
vi.restoreAllMocks();
|
||||
vi.resetModules();
|
||||
vi.unstubAllGlobals();
|
||||
});
|
||||
|
||||
it('connects using the configured URL and reports open state', async () => {
|
||||
const instances: any[] = [];
|
||||
class FakeWebSocket {
|
||||
static CONNECTING = 0;
|
||||
static OPEN = 1;
|
||||
readyState = FakeWebSocket.OPEN;
|
||||
onopen?: () => void;
|
||||
onmessage?: (event: MessageEvent) => void;
|
||||
onclose?: () => void;
|
||||
onerror?: () => void;
|
||||
constructor(public url: string) {
|
||||
instances.push(this);
|
||||
}
|
||||
close = vi.fn();
|
||||
}
|
||||
vi.stubGlobal('WebSocket', FakeWebSocket);
|
||||
|
||||
const { progressWS } = await import('./websocket');
|
||||
progressWS.connect();
|
||||
|
||||
expect(instances[0].url).toContain('/ws/progress');
|
||||
expect(progressWS.isConnected()).toBe(true);
|
||||
});
|
||||
|
||||
it('subscribes and unsubscribes progress callbacks', async () => {
|
||||
const { progressWS } = await import('./websocket');
|
||||
const callback = vi.fn();
|
||||
|
||||
const unsubscribe = progressWS.onProgress(callback);
|
||||
(progressWS as any).callbacks.forEach((cb: any) => cb({ type: 'status', message: 'ok' }));
|
||||
unsubscribe();
|
||||
(progressWS as any).callbacks.forEach((cb: any) => cb({ type: 'status', message: 'again' }));
|
||||
|
||||
expect(callback).toHaveBeenCalledTimes(1);
|
||||
expect(callback).toHaveBeenCalledWith({ type: 'status', message: 'ok' });
|
||||
});
|
||||
});
|
||||
@@ -1,12 +1,18 @@
|
||||
import { WS_PROGRESS_URL } from './config';
|
||||
|
||||
type ProgressCallback = (data: ProgressMessage) => void;
|
||||
|
||||
interface ProgressMessage {
|
||||
type: 'progress' | 'status' | 'error' | 'complete';
|
||||
taskId?: string;
|
||||
task_id?: number;
|
||||
project_id?: number;
|
||||
projectName?: string;
|
||||
filename?: string;
|
||||
progress?: number;
|
||||
status?: string;
|
||||
message?: string;
|
||||
error?: string;
|
||||
timestamp?: string;
|
||||
}
|
||||
|
||||
@@ -21,7 +27,7 @@ class ProgressWebSocket {
|
||||
private shouldCloseAfterOpen = false;
|
||||
private currentInterval = 3000;
|
||||
|
||||
constructor(url = 'ws://192.168.3.11:8000/ws/progress') {
|
||||
constructor(url = WS_PROGRESS_URL) {
|
||||
this.url = url;
|
||||
}
|
||||
|
||||
|
||||
56
src/store/useStore.test.ts
Normal file
56
src/store/useStore.test.ts
Normal file
@@ -0,0 +1,56 @@
|
||||
import { beforeEach, describe, expect, it } from 'vitest';
|
||||
import { resetStore } from '../test/storeTestUtils';
|
||||
import { useStore } from './useStore';
|
||||
|
||||
describe('useStore', () => {
|
||||
beforeEach(() => {
|
||||
resetStore();
|
||||
});
|
||||
|
||||
it('stores and clears auth state with localStorage', () => {
|
||||
useStore.getState().login('token-1');
|
||||
|
||||
expect(useStore.getState().isAuthenticated).toBe(true);
|
||||
expect(useStore.getState().token).toBe('token-1');
|
||||
expect(localStorage.getItem('token')).toBe('token-1');
|
||||
|
||||
useStore.getState().logout();
|
||||
|
||||
expect(useStore.getState().isAuthenticated).toBe(false);
|
||||
expect(useStore.getState().projects).toEqual([]);
|
||||
expect(useStore.getState().frames).toEqual([]);
|
||||
expect(localStorage.getItem('token')).toBeNull();
|
||||
});
|
||||
|
||||
it('manages projects, frames, masks, annotations and templates', () => {
|
||||
const project = { id: '1', name: 'Project', status: 'ready' as const };
|
||||
useStore.getState().addProject(project);
|
||||
useStore.getState().updateProject({ ...project, name: 'Updated' });
|
||||
useStore.getState().setCurrentProject(project);
|
||||
useStore.getState().setFrames([{ id: 'f1', projectId: '1', index: 0, url: '/f1.jpg', width: 640, height: 360 }]);
|
||||
useStore.getState().setCurrentFrame(0);
|
||||
useStore.getState().addMask({ id: 'm1', frameId: 'f1', pathData: 'M 0 0 Z', label: 'mask', color: '#fff' });
|
||||
useStore.getState().updateMask('m1', { label: 'updated mask', saveStatus: 'dirty' });
|
||||
useStore.getState().addAnnotation({ id: 'a1', frameId: 'f1', type: 'mask', points: [], label: 'ann', color: '#fff' });
|
||||
useStore.getState().addTemplate({ id: 't1', name: 'Template', classes: [], rules: [] });
|
||||
useStore.getState().updateTemplate({ id: 't1', name: 'Template 2', classes: [], rules: [] });
|
||||
useStore.getState().setActiveClass({ id: 'c1', name: '胆囊', color: '#ff0000', zIndex: 10 });
|
||||
|
||||
expect(useStore.getState().projects[0].name).toBe('Updated');
|
||||
expect(useStore.getState().currentProject?.id).toBe('1');
|
||||
expect(useStore.getState().frames).toHaveLength(1);
|
||||
expect(useStore.getState().currentFrameIndex).toBe(0);
|
||||
expect(useStore.getState().masks[0]).toEqual(expect.objectContaining({ label: 'updated mask', saveStatus: 'dirty' }));
|
||||
expect(useStore.getState().annotations).toHaveLength(1);
|
||||
expect(useStore.getState().templates[0].name).toBe('Template 2');
|
||||
expect(useStore.getState().activeClassId).toBe('c1');
|
||||
|
||||
useStore.getState().removeAnnotation('a1');
|
||||
useStore.getState().clearMasks();
|
||||
useStore.getState().removeTemplate('t1');
|
||||
|
||||
expect(useStore.getState().annotations).toEqual([]);
|
||||
expect(useStore.getState().masks).toEqual([]);
|
||||
expect(useStore.getState().templates).toEqual([]);
|
||||
});
|
||||
});
|
||||
@@ -4,7 +4,7 @@ export interface Project {
|
||||
id: string;
|
||||
name: string;
|
||||
description?: string;
|
||||
status: 'Ready' | 'Parsing' | 'Error';
|
||||
status: 'pending' | 'parsing' | 'ready' | 'error';
|
||||
fps?: string;
|
||||
frames?: number;
|
||||
thumbnail?: string;
|
||||
@@ -17,6 +17,8 @@ export interface Project {
|
||||
updatedAt?: string;
|
||||
}
|
||||
|
||||
export type AiModelId = 'sam2' | 'sam3';
|
||||
|
||||
export interface Frame {
|
||||
id: string;
|
||||
projectId: string;
|
||||
@@ -42,6 +44,13 @@ export interface Annotation {
|
||||
export interface Mask {
|
||||
id: string;
|
||||
frameId: string;
|
||||
annotationId?: string;
|
||||
templateId?: string;
|
||||
classId?: string;
|
||||
className?: string;
|
||||
classZIndex?: number;
|
||||
saveStatus?: 'draft' | 'saved' | 'dirty' | 'saving' | 'error';
|
||||
saved?: boolean;
|
||||
pathData: string;
|
||||
label: string;
|
||||
color: string;
|
||||
@@ -96,24 +105,32 @@ export interface AppState {
|
||||
// Workspace
|
||||
activeModule: string;
|
||||
activeTool: string;
|
||||
aiModel: AiModelId;
|
||||
frames: Frame[];
|
||||
currentFrameIndex: number;
|
||||
annotations: Annotation[];
|
||||
masks: Mask[];
|
||||
setActiveModule: (module: string) => void;
|
||||
setActiveTool: (tool: string) => void;
|
||||
setAiModel: (model: AiModelId) => void;
|
||||
setFrames: (frames: Frame[]) => void;
|
||||
setCurrentFrame: (index: number) => void;
|
||||
addAnnotation: (annotation: Annotation) => void;
|
||||
addMask: (mask: Mask) => void;
|
||||
updateMask: (id: string, updates: Partial<Mask>) => void;
|
||||
setMasks: (masks: Mask[]) => void;
|
||||
clearMasks: () => void;
|
||||
removeAnnotation: (id: string) => void;
|
||||
|
||||
// Templates
|
||||
templates: Template[];
|
||||
activeTemplateId: string | null;
|
||||
activeClassId: string | null;
|
||||
activeClass: TemplateClass | null;
|
||||
setTemplates: (templates: Template[]) => void;
|
||||
setActiveTemplateId: (id: string | null) => void;
|
||||
setActiveClassId: (id: string | null) => void;
|
||||
setActiveClass: (templateClass: TemplateClass | null) => void;
|
||||
addTemplate: (template: Template) => void;
|
||||
updateTemplate: (template: Template) => void;
|
||||
removeTemplate: (id: string) => void;
|
||||
@@ -144,6 +161,9 @@ export const useStore = create<AppState>((set) => ({
|
||||
frames: [],
|
||||
annotations: [],
|
||||
masks: [],
|
||||
activeTemplateId: null,
|
||||
activeClassId: null,
|
||||
activeClass: null,
|
||||
});
|
||||
},
|
||||
|
||||
@@ -162,18 +182,25 @@ export const useStore = create<AppState>((set) => ({
|
||||
// Workspace
|
||||
activeModule: 'workspace',
|
||||
activeTool: 'move',
|
||||
aiModel: 'sam2',
|
||||
frames: [],
|
||||
currentFrameIndex: 0,
|
||||
annotations: [],
|
||||
masks: [],
|
||||
setActiveModule: (activeModule: string) => set({ activeModule }),
|
||||
setActiveTool: (activeTool: string) => set({ activeTool }),
|
||||
setAiModel: (aiModel: AiModelId) => set({ aiModel }),
|
||||
setFrames: (frames: Frame[]) => set({ frames }),
|
||||
setCurrentFrame: (currentFrameIndex: number) => set({ currentFrameIndex }),
|
||||
addAnnotation: (annotation: Annotation) =>
|
||||
set((state) => ({ annotations: [...state.annotations, annotation] })),
|
||||
addMask: (mask: Mask) =>
|
||||
set((state) => ({ masks: [...state.masks, mask] })),
|
||||
updateMask: (id: string, updates: Partial<Mask>) =>
|
||||
set((state) => ({
|
||||
masks: state.masks.map((mask) => (mask.id === id ? { ...mask, ...updates } : mask)),
|
||||
})),
|
||||
setMasks: (masks: Mask[]) => set({ masks }),
|
||||
clearMasks: () => set({ masks: [] }),
|
||||
removeAnnotation: (id: string) =>
|
||||
set((state) => ({
|
||||
@@ -183,8 +210,15 @@ export const useStore = create<AppState>((set) => ({
|
||||
// Templates
|
||||
templates: [],
|
||||
activeTemplateId: null,
|
||||
activeClassId: null,
|
||||
activeClass: null,
|
||||
setTemplates: (templates: Template[]) => set({ templates }),
|
||||
setActiveTemplateId: (activeTemplateId: string | null) => set({ activeTemplateId }),
|
||||
setActiveClassId: (activeClassId: string | null) => set({ activeClassId }),
|
||||
setActiveClass: (activeClass: TemplateClass | null) => set({
|
||||
activeClass,
|
||||
activeClassId: activeClass?.id || null,
|
||||
}),
|
||||
addTemplate: (template: Template) =>
|
||||
set((state) => ({ templates: [...state.templates, template] })),
|
||||
updateTemplate: (template: Template) =>
|
||||
|
||||
66
src/test/setup.tsx
Normal file
66
src/test/setup.tsx
Normal file
@@ -0,0 +1,66 @@
|
||||
import React from 'react';
|
||||
import { afterEach, vi } from 'vitest';
|
||||
import { cleanup } from '@testing-library/react';
|
||||
import '@testing-library/jest-dom/vitest';
|
||||
|
||||
afterEach(() => {
|
||||
cleanup();
|
||||
localStorage.clear();
|
||||
});
|
||||
|
||||
vi.stubGlobal('alert', vi.fn());
|
||||
vi.stubGlobal('confirm', vi.fn(() => true));
|
||||
URL.createObjectURL = vi.fn(() => 'blob:mock-url');
|
||||
URL.revokeObjectURL = vi.fn();
|
||||
HTMLAnchorElement.prototype.click = vi.fn();
|
||||
|
||||
function makeStageEvent(x = 120, y = 80) {
|
||||
const stage = {
|
||||
getPointerPosition: () => ({ x, y }),
|
||||
getRelativePointerPosition: () => ({ x, y }),
|
||||
scaleX: () => 1,
|
||||
x: () => 0,
|
||||
y: () => 0,
|
||||
};
|
||||
|
||||
return {
|
||||
evt: { preventDefault: vi.fn(), deltaY: -1 },
|
||||
target: {
|
||||
getStage: () => stage,
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
vi.mock('react-konva', () => ({
|
||||
Stage: ({ children, onClick, onMouseDown, onMouseUp, onMouseMove, onWheel }: any) => (
|
||||
<div
|
||||
data-testid="konva-stage"
|
||||
onClick={() => onClick?.(makeStageEvent())}
|
||||
onMouseDown={() => onMouseDown?.(makeStageEvent())}
|
||||
onMouseUp={() => onMouseUp?.(makeStageEvent(260, 200))}
|
||||
onMouseMove={() => onMouseMove?.(makeStageEvent(180, 120))}
|
||||
onWheel={() => onWheel?.(makeStageEvent())}
|
||||
>
|
||||
{children}
|
||||
</div>
|
||||
),
|
||||
Layer: ({ children }: any) => <div data-testid="konva-layer">{children}</div>,
|
||||
Group: ({ children }: any) => <div data-testid="konva-group">{children}</div>,
|
||||
Image: ({ image }: any) => <img data-testid="konva-image" alt="" src={image?.src || ''} />,
|
||||
Circle: (props: any) => <span data-testid="konva-circle" data-fill={props.fill} />,
|
||||
Rect: (props: any) => <span data-testid="konva-rect" data-width={props.width} />,
|
||||
Path: (props: any) => <span data-testid="konva-path" data-path={props.data} data-fill={props.fill} />,
|
||||
}));
|
||||
|
||||
vi.mock('use-image', () => ({
|
||||
default: (src: string) => [
|
||||
{
|
||||
src,
|
||||
width: 640,
|
||||
height: 360,
|
||||
naturalWidth: 640,
|
||||
naturalHeight: 360,
|
||||
},
|
||||
'loaded',
|
||||
],
|
||||
}));
|
||||
23
src/test/storeTestUtils.ts
Normal file
23
src/test/storeTestUtils.ts
Normal file
@@ -0,0 +1,23 @@
|
||||
import { useStore } from '../store/useStore';
|
||||
|
||||
export function resetStore() {
|
||||
useStore.setState({
|
||||
isAuthenticated: false,
|
||||
token: null,
|
||||
projects: [],
|
||||
currentProject: null,
|
||||
activeModule: 'workspace',
|
||||
activeTool: 'move',
|
||||
aiModel: 'sam2',
|
||||
frames: [],
|
||||
currentFrameIndex: 0,
|
||||
annotations: [],
|
||||
masks: [],
|
||||
templates: [],
|
||||
activeTemplateId: null,
|
||||
activeClassId: null,
|
||||
activeClass: null,
|
||||
isLoading: false,
|
||||
error: null,
|
||||
});
|
||||
}
|
||||
6
src/vite-env.d.ts
vendored
Normal file
6
src/vite-env.d.ts
vendored
Normal file
@@ -0,0 +1,6 @@
|
||||
/// <reference types="vite/client" />
|
||||
|
||||
interface ImportMetaEnv {
|
||||
readonly VITE_API_BASE_URL?: string;
|
||||
readonly VITE_WS_PROGRESS_URL?: string;
|
||||
}
|
||||
Reference in New Issue
Block a user