feat: 打通全栈标注闭环、异步拆帧与模型状态

后端能力:

- 新增 Celery app、worker task、ProcessingTask 模型、/api/tasks 查询接口和 media_task_runner,将 /api/media/parse 改为创建后台任务并由 worker 执行 FFmpeg/OpenCV/pydicom 拆帧。

- 新增 Redis 进度事件模块和 FastAPI Redis pub/sub 订阅,将 worker 任务进度广播到 /ws/progress;Dashboard 后端概览接口改为聚合 projects/frames/annotations/templates/processing_tasks。

- 统一项目状态为 pending/parsing/ready/error,新增共享 status 常量,并让前端兼容归一化旧状态值。

- 扩展 AI 后端:新增 SAM registry、SAM2 真实运行状态、SAM3 状态检测与文本语义推理适配入口,以及 /api/ai/models/status GPU/模型状态接口。

- 补齐标注保存/更新/删除、COCO/PNG mask 导出相关后端契约和模板 mapping_rules 打包/解包行为。

前端能力:

- 新增运行时 API/WS 地址推导配置,前端 API 封装对齐 FastAPI 路由、字段映射、任务轮询、标注归档、导出下载和 AI 预测响应转换。

- Dashboard 改为读取 /api/dashboard/overview,并订阅 WebSocket progress/complete/error/status 更新解析队列和实时流转记录。

- 项目库导入视频/DICOM 后创建项目、上传媒体、触发异步解析并刷新真实项目列表。

- 工作区加载真实帧、无帧时触发解析任务、回显已保存标注、保存未归档 mask、更新 dirty mask、清空当前帧后端标注、导出 COCO JSON。

- Canvas 支持当前帧点/框提示调用后端 AI、渲染推理/已保存 mask、应用模板分类并维护保存状态计数;时间轴按项目 fps 播放。

- AI 页面新增 SAM2/SAM3 模型选择,预测请求携带 model;侧边栏和工作区新增真实 GPU/SAM 状态徽标。

- 模板库和本体面板接入真实模板 CRUD、分类编辑、拖拽排序、JSON 导入、默认腹腔镜分类和本地自定义分类选择。

测试与文档:

- 新增 Vitest 配置、前端测试 setup、API/config/websocket/store/组件测试,覆盖登录、项目库、Dashboard、Canvas、工作区、模型状态、时间轴、本体和模板库。

- 新增 pytest 后端测试夹具和 auth/projects/templates/media/AI/export/dashboard/tasks/progress 测试,使用 SQLite、fake MinIO、fake SAM registry 和 Redis monkeypatch 隔离外部服务。

- 新增 doc/ 文档结构,冻结当前需求、设计、接口契约、测试计划、前端逐元素审计、实现地图和后续实施计划,并同步更新 README 与 AGENTS。

验证:

- conda run -n seg_server pytest backend/tests:27 passed。

- npm run test:run:54 passed。

- npm run lint、npm run build、compileall、git diff --check 均通过;Vite 仅提示大 chunk 警告。
This commit is contained in:
2026-05-01 13:29:14 +08:00
parent 4d65c37c73
commit f020ff3b4f
78 changed files with 7089 additions and 456 deletions

View File

@@ -0,0 +1,43 @@
import { fireEvent, render, screen } from '@testing-library/react';
import { beforeEach, describe, expect, it, vi } from 'vitest';
import { resetStore } from '../test/storeTestUtils';
import { useStore } from '../store/useStore';
import { AISegmentation } from './AISegmentation';
const apiMock = vi.hoisted(() => ({
getAiModelStatus: vi.fn(),
predictMask: vi.fn(),
}));
vi.mock('../lib/api', () => ({
getAiModelStatus: apiMock.getAiModelStatus,
predictMask: apiMock.predictMask,
}));
describe('AISegmentation', () => {
beforeEach(() => {
resetStore();
vi.clearAllMocks();
useStore.setState({
frames: [{ id: 'frame-1', projectId: 'project-1', index: 0, url: '/frame.jpg', width: 640, height: 360 }],
});
apiMock.getAiModelStatus.mockResolvedValue({
selected_model: 'sam2',
gpu: { available: true, device: 'cuda', name: 'RTX 4090', torch_available: true },
models: [
{ id: 'sam2', label: 'SAM 2', available: true, loaded: false, device: 'cuda', supports: ['point', 'box'], message: 'SAM 2 ready', package_available: true, checkpoint_exists: true, python_ok: true, torch_ok: true, cuda_required: false },
{ id: 'sam3', label: 'SAM 3', available: false, loaded: false, device: 'cuda', supports: ['semantic'], message: 'SAM 3 missing runtime', package_available: false, checkpoint_exists: false, python_ok: false, torch_ok: true, cuda_required: true },
],
});
});
it('lets the user choose SAM3 for subsequent predictions', async () => {
render(<AISegmentation onSendToWorkspace={vi.fn()} />);
const sam3Button = (await screen.findByText('SAM3')).closest('button')!;
fireEvent.click(sam3Button);
expect(useStore.getState().aiModel).toBe('sam3');
expect(await screen.findByText('SAM 3 missing runtime')).toBeInTheDocument();
});
});

View File

@@ -1,11 +1,11 @@
import React, { useState, useCallback } from 'react';
import React, { useState, useCallback, useEffect } from 'react';
import { Target, PlusCircle, MinusCircle, SquareDashed, Sparkles, SendToBack, Image as ImageIcon, Undo, Redo, Loader2 } from 'lucide-react';
import { cn } from '../lib/utils';
import { Stage, Layer, Image as KonvaImage, Circle, Path, Group } from 'react-konva';
import useImage from 'use-image';
import { OntologyInspector } from './OntologyInspector';
import { useStore } from '../store/useStore';
import { predictMask } from '../lib/api';
import { getAiModelStatus, predictMask, type AiRuntimeStatus } from '../lib/api';
interface AISegmentationProps {
onSendToWorkspace: () => void;
@@ -17,9 +17,15 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
const masks = useStore((state) => state.masks);
const addMask = useStore((state) => state.addMask);
const clearMasks = useStore((state) => state.clearMasks);
const frames = useStore((state) => state.frames);
const currentFrameIndex = useStore((state) => state.currentFrameIndex);
const activeTemplateId = useStore((state) => state.activeTemplateId);
const activeClass = useStore((state) => state.activeClass);
const aiModel = useStore((state) => state.aiModel);
const setAiModel = useStore((state) => state.setAiModel);
const [modelSize, setModelSize] = useState('vit_l');
const [semanticText, setSemanticText] = useState('');
const [modelStatus, setModelStatus] = useState<AiRuntimeStatus | null>(null);
const [autoDeleteBg, setAutoDeleteBg] = useState(true);
const [cropMode, setCropMode] = useState(false);
const [isInferencing, setIsInferencing] = useState(false);
@@ -29,10 +35,29 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
const [position, setPosition] = useState({ x: 0, y: 0 });
const [points, setPoints] = useState<{ x: number, y: number, type: 'pos'|'neg' }[]>([]);
const [cursorPos, setCursorPos] = useState({ x: 0, y: 0 });
const [image] = useImage('https://images.unsplash.com/photo-1549317661-bd32c8ce0be2?q=80&w=2070&auto=format&fit=crop');
const currentFrame = frames[currentFrameIndex] || null;
const previewUrl = currentFrame?.url || 'https://images.unsplash.com/photo-1549317661-bd32c8ce0be2?q=80&w=2070&auto=format&fit=crop';
const [image] = useImage(previewUrl);
const frameMasks = currentFrame ? masks.filter((mask) => mask.frameId === currentFrame.id) : masks;
const selectedModelStatus = modelStatus?.models.find((model) => model.id === aiModel);
const modelCanInfer = selectedModelStatus?.available ?? true;
const effectiveTool = storeActiveTool;
useEffect(() => {
let cancelled = false;
getAiModelStatus(aiModel)
.then((status) => {
if (!cancelled) setModelStatus(status);
})
.catch(() => {
if (!cancelled) setModelStatus(null);
});
return () => {
cancelled = true;
};
}, [aiModel]);
const handleWheel = (e: any) => {
e.evt.preventDefault();
const scaleBy = 1.1;
@@ -63,22 +88,44 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
const runInference = useCallback(async () => {
if (points.length === 0 && !semanticText.trim()) return;
if (!currentFrame?.id) {
console.warn('AI inference skipped: no project frame is selected');
return;
}
const imageWidth = currentFrame.width || image?.naturalWidth || image?.width || 0;
const imageHeight = currentFrame.height || image?.naturalHeight || image?.height || 0;
if (imageWidth <= 0 || imageHeight <= 0) {
console.warn('AI inference skipped: active frame dimensions are unavailable');
return;
}
setIsInferencing(true);
try {
const result = await predictMask({
imageUrl: 'https://images.unsplash.com/photo-1549317661-bd32c8ce0be2?q=80&w=2070&auto=format&fit=crop',
imageId: currentFrame.id,
imageWidth,
imageHeight,
model: aiModel,
points: points.map((p) => ({ x: p.x, y: p.y, type: p.type })),
text: semanticText.trim() || undefined,
modelSize,
});
result.masks.forEach((m) => {
const label = activeClass?.name || m.label;
const color = activeClass?.color || m.color;
addMask({
id: m.id,
frameId: 'frame-ai-1',
frameId: currentFrame.id,
templateId: activeTemplateId || undefined,
classId: activeClass?.id,
className: activeClass?.name,
classZIndex: activeClass?.zIndex,
saveStatus: 'draft',
saved: false,
pathData: m.pathData,
label: m.label,
color: m.color,
label,
color,
segmentation: m.segmentation,
bbox: m.bbox,
area: m.area,
@@ -89,7 +136,7 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
} finally {
setIsInferencing(false);
}
}, [points, semanticText, modelSize, addMask]);
}, [activeClass, activeTemplateId, addMask, aiModel, currentFrame?.height, currentFrame?.id, currentFrame?.width, image?.height, image?.naturalHeight, image?.naturalWidth, image?.width, points, semanticText]);
const handleStageClick = (e: any) => {
if (effectiveTool === 'move') return;
@@ -117,17 +164,26 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
{/* Model Select */}
<div>
<h3 className="text-[10px] font-bold text-gray-500 uppercase tracking-widest mb-3"></h3>
<div className="bg-[#111] border border-white/5 flex p-1 rounded-lg">
{['vit_b', 'vit_l', 'vit_h'].map(m => (
<div className="bg-[#111] border border-white/5 grid grid-cols-2 gap-1 p-1 rounded-lg">
{(modelStatus?.models || [
{ id: 'sam2' as const, label: 'SAM 2', available: true, message: '正在读取 SAM 2 状态' },
{ id: 'sam3' as const, label: 'SAM 3', available: false, message: '正在读取 SAM 3 状态' },
]).map((m) => (
<button
key={m}
className={cn("flex-1 text-xs py-2 rounded-md transition-colors text-center uppercase tracking-wider font-mono", modelSize === m ? "bg-white/10 text-white font-medium shadow-sm" : "text-gray-500 hover:text-gray-300 hover:bg-white/5")}
onClick={() => setModelSize(m)}
key={m.id}
className={cn("text-xs py-2 rounded-md transition-colors text-center uppercase tracking-wider font-mono", aiModel === m.id ? "bg-white/10 text-white font-medium shadow-sm" : "text-gray-500 hover:text-gray-300 hover:bg-white/5")}
onClick={() => setAiModel(m.id)}
title={m.message}
>
{m.split('_')[1]}
{m.label.replace(' ', '')}
<span className={cn("ml-1", m.available ? "text-emerald-400" : "text-amber-400")}></span>
</button>
))}
</div>
<div className="mt-2 text-[10px] text-gray-500 leading-relaxed">
<div>{selectedModelStatus?.message || '正在读取模型状态...'}</div>
<div>GPU: {modelStatus?.gpu.available ? `${modelStatus.gpu.name || 'CUDA'} 可用` : '不可用或未检测到 CUDA'}</div>
</div>
</div>
{/* Prompt Tools */}
@@ -206,16 +262,16 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
<div className="p-6 bg-[#0a0a0a] border-t border-white/5 shrink-0 flex flex-col gap-3">
<button
onClick={runInference}
disabled={isInferencing}
disabled={isInferencing || !currentFrame || !modelCanInfer}
className={cn(
"w-full py-3.5 rounded-lg flex items-center justify-center gap-2 transition-all shadow-lg font-medium tracking-wide text-xs uppercase",
isInferencing
isInferencing || !currentFrame || !modelCanInfer
? "bg-cyan-500/50 text-black/70 cursor-not-allowed"
: "bg-cyan-500 hover:bg-cyan-400 text-black shadow-cyan-500/20 hover:shadow-cyan-500/40"
)}
>
{isInferencing ? <Loader2 size={16} className="animate-spin" /> : <Sparkles size={16} />}
{isInferencing ? '推理中...' : '执行高精度语义分割'}
{isInferencing ? '推理中...' : modelCanInfer ? '执行高精度语义分割' : '当前模型不可用'}
</button>
<button
onClick={onSendToWorkspace}
@@ -231,7 +287,7 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
<header className="h-16 border-b border-white/5 bg-[#111] flex items-center justify-between px-6 shrink-0">
<div className="flex flex-col">
<h2 className="text-sm font-semibold tracking-wide text-white"> (Visualizer)</h2>
<span className="text-[10px] text-gray-500 uppercase tracking-widest font-mono">SAM 3 </span>
<span className="text-[10px] text-gray-500 uppercase tracking-widest font-mono">{aiModel.toUpperCase()} </span>
</div>
<div className="flex items-center gap-4">
<button className="w-8 h-8 rounded text-gray-400 hover:bg-white/5 hover:text-white flex items-center justify-center transition-colors" title="撤销操作 (Ctrl+Z)">
@@ -276,7 +332,7 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
)}
{/* AI Returned Masks */}
{masks.map((mask) => (
{frameMasks.map((mask) => (
<Group key={mask.id} opacity={0.45}>
<Path
data={mask.pathData}
@@ -309,7 +365,7 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
<div className="absolute bottom-4 left-4 flex gap-4 text-[10px] font-mono text-gray-500 pointer-events-none">
<span>: {cursorPos.x.toFixed(2)}, {cursorPos.y.toFixed(2)}</span>
<span>: {(scale * 100).toFixed(0)}%</span>
<span>: {masks.length}</span>
<span>: {frameMasks.length}</span>
</div>
</div>
</div>

View File

@@ -0,0 +1,130 @@
import { fireEvent, render, screen, waitFor } from '@testing-library/react';
import { beforeEach, describe, expect, it, vi } from 'vitest';
import { resetStore } from '../test/storeTestUtils';
import { useStore } from '../store/useStore';
import { CanvasArea } from './CanvasArea';
const apiMock = vi.hoisted(() => ({
predictMask: vi.fn(),
}));
vi.mock('../lib/api', () => ({
predictMask: apiMock.predictMask,
}));
describe('CanvasArea', () => {
const frame = { id: 'frame-1', projectId: 'project-1', index: 0, url: '/frame.jpg', width: 640, height: 360 };
beforeEach(() => {
resetStore();
vi.clearAllMocks();
});
it('calls AI prediction with the active frame when a point prompt is placed', async () => {
useStore.setState({
activeTemplateId: '2',
activeClass: { id: 'c1', name: '胆囊', color: '#ff0000', zIndex: 20 },
activeClassId: 'c1',
});
apiMock.predictMask.mockResolvedValueOnce({
masks: [
{
id: 'mask-1',
pathData: 'M 0 0 L 10 0 L 10 10 Z',
label: 'AI Mask',
color: '#06b6d4',
segmentation: [[0, 0, 10, 0, 10, 10]],
bbox: [0, 0, 10, 10],
area: 100,
},
],
});
render(<CanvasArea activeTool="point_pos" frame={frame} />);
fireEvent.click(screen.getByTestId('konva-stage'));
await waitFor(() => expect(apiMock.predictMask).toHaveBeenCalledWith({
imageId: 'frame-1',
imageWidth: 640,
imageHeight: 360,
model: 'sam2',
points: [{ x: 120, y: 80, type: 'pos' }],
box: undefined,
}));
expect(useStore.getState().masks[0]).toEqual(expect.objectContaining({
id: 'mask-1',
frameId: 'frame-1',
pathData: 'M 0 0 L 10 0 L 10 10 Z',
templateId: '2',
classId: 'c1',
className: '胆囊',
classZIndex: 20,
label: '胆囊',
color: '#ff0000',
saveStatus: 'draft',
}));
});
it('renders only masks that belong to the current frame', () => {
useStore.setState({
masks: [
{ id: 'm1', frameId: 'frame-1', pathData: 'M 0 0 Z', label: 'A', color: '#fff' },
{ id: 'm2', frameId: 'frame-2', pathData: 'M 1 1 Z', label: 'B', color: '#000' },
],
});
render(<CanvasArea activeTool="move" frame={frame} />);
expect(screen.getAllByTestId('konva-path')).toHaveLength(1);
expect(screen.getByText('遮罩数: 1')).toBeInTheDocument();
});
it('applies the selected class to current-frame masks and marks saved masks dirty', () => {
useStore.setState({
activeTemplateId: '2',
activeClass: { id: 'c1', name: '胆囊', color: '#ff0000', zIndex: 20 },
activeClassId: 'c1',
masks: [
{
id: 'm1',
frameId: 'frame-1',
annotationId: '99',
pathData: 'M 0 0 Z',
label: '旧标签',
color: '#06b6d4',
saved: true,
saveStatus: 'saved',
},
],
});
render(<CanvasArea activeTool="move" frame={frame} />);
fireEvent.click(screen.getByRole('button', { name: '应用分类' }));
expect(useStore.getState().masks[0]).toEqual(expect.objectContaining({
templateId: '2',
classId: 'c1',
className: '胆囊',
classZIndex: 20,
label: '胆囊',
color: '#ff0000',
saveStatus: 'dirty',
saved: false,
}));
});
it('delegates clear to the workspace handler so saved annotations can be deleted', () => {
const onClearMasks = vi.fn();
useStore.setState({
masks: [
{ id: 'm1', frameId: 'frame-1', pathData: 'M 0 0 Z', label: 'A', color: '#fff' },
],
});
render(<CanvasArea activeTool="move" frame={frame} onClearMasks={onClearMasks} />);
fireEvent.click(screen.getByRole('button', { name: '清空遮罩' }));
expect(onClearMasks).toHaveBeenCalled();
expect(useStore.getState().masks).toHaveLength(1);
});
});

View File

@@ -3,14 +3,15 @@ import { Stage, Layer, Image as KonvaImage, Circle, Rect, Path, Group } from 're
import useImage from 'use-image';
import { useStore } from '../store/useStore';
import { predictMask } from '../lib/api';
import { cn } from '../lib/utils';
import type { Frame } from '../store/useStore';
interface CanvasAreaProps {
activeTool: string;
frameUrl: string;
frame: Frame | null;
onClearMasks?: () => void;
}
export function CanvasArea({ activeTool, frameUrl }: CanvasAreaProps) {
export function CanvasArea({ activeTool, frame, onClearMasks }: CanvasAreaProps) {
const containerRef = useRef<HTMLDivElement>(null);
const [stageSize, setStageSize] = useState({ width: 800, height: 600 });
const [scale, setScale] = useState(1);
@@ -24,13 +25,20 @@ export function CanvasArea({ activeTool, frameUrl }: CanvasAreaProps) {
const masks = useStore((state) => state.masks);
const addMask = useStore((state) => state.addMask);
const clearMasks = useStore((state) => state.clearMasks);
const setMasks = useStore((state) => state.setMasks);
const storeActiveTool = useStore((state) => state.activeTool);
const setActiveTool = useStore((state) => state.setActiveTool);
const aiModel = useStore((state) => state.aiModel);
const activeTemplateId = useStore((state) => state.activeTemplateId);
const activeClass = useStore((state) => state.activeClass);
const effectiveTool = activeTool || storeActiveTool;
// Load the actual frame image
const [image] = useImage(frameUrl || '');
const [image] = useImage(frame?.url || '');
const frameMasks = masks.filter((mask) => mask.frameId === frame?.id);
const savedMaskCount = frameMasks.filter((mask) => mask.saveStatus === 'saved' || mask.saved).length;
const draftMaskCount = frameMasks.filter((mask) => !mask.annotationId).length;
const dirtyMaskCount = frameMasks.filter((mask) => mask.saveStatus === 'dirty').length;
useEffect(() => {
const handleResize = () => {
@@ -85,21 +93,44 @@ export function CanvasArea({ activeTool, frameUrl }: CanvasAreaProps) {
};
const runInference = useCallback(async (promptPoints?: typeof points, promptBox?: { x1: number, y1: number, x2: number, y2: number }) => {
if (!frame?.id) {
console.warn('Inference skipped: no active frame');
return;
}
const imageWidth = frame.width || image?.naturalWidth || image?.width || 0;
const imageHeight = frame.height || image?.naturalHeight || image?.height || 0;
if (imageWidth <= 0 || imageHeight <= 0) {
console.warn('Inference skipped: active frame dimensions are unavailable');
return;
}
setIsInferencing(true);
try {
const result = await predictMask({
imageUrl: frameUrl || '',
imageId: frame.id,
imageWidth,
imageHeight,
model: aiModel,
points: promptPoints?.map((p) => ({ x: p.x, y: p.y, type: p.type })),
box: promptBox,
});
result.masks.forEach((m) => {
const label = activeClass?.name || m.label;
const color = activeClass?.color || m.color;
addMask({
id: m.id,
frameId: 'frame-1',
frameId: frame.id,
templateId: activeTemplateId || undefined,
classId: activeClass?.id,
className: activeClass?.name,
classZIndex: activeClass?.zIndex,
saveStatus: 'draft',
saved: false,
pathData: m.pathData,
label: m.label,
color: m.color,
label,
color,
segmentation: m.segmentation,
bbox: m.bbox,
area: m.area,
@@ -110,7 +141,33 @@ export function CanvasArea({ activeTool, frameUrl }: CanvasAreaProps) {
} finally {
setIsInferencing(false);
}
}, [addMask]);
}, [activeClass, activeTemplateId, addMask, aiModel, frame?.height, frame?.id, frame?.width, image?.height, image?.naturalHeight, image?.naturalWidth, image?.width]);
const handleApplyActiveClass = () => {
if (!frame?.id || !activeClass) return;
setMasks(masks.map((mask) => {
if (mask.frameId !== frame.id) return mask;
return {
...mask,
templateId: activeTemplateId || mask.templateId,
classId: activeClass.id,
className: activeClass.name,
classZIndex: activeClass.zIndex,
label: activeClass.name,
color: activeClass.color,
saveStatus: mask.annotationId ? 'dirty' : 'draft',
saved: Boolean(mask.annotationId) ? false : mask.saved,
};
}));
};
const handleClearMasks = () => {
if (onClearMasks) {
onClearMasks();
return;
}
clearMasks();
};
const handleStageMouseDown = (e: any) => {
if (effectiveTool === 'box_select') {
@@ -199,7 +256,7 @@ export function CanvasArea({ activeTool, frameUrl }: CanvasAreaProps) {
)}
{/* AI Returned Masks */}
{masks.map((mask) => (
{frameMasks.map((mask) => (
<Group key={mask.id} opacity={0.5}>
<Path
data={mask.pathData}
@@ -248,16 +305,29 @@ export function CanvasArea({ activeTool, frameUrl }: CanvasAreaProps) {
<span>: {cursorPos.x.toFixed(2)}, {cursorPos.y.toFixed(2)}</span>
<span>当前图层树: OBJECT_VEHICLE_01</span>
<span>: {(scale * 100).toFixed(0)}%</span>
<span>: {masks.length}</span>
<span>: {frameMasks.length}</span>
<span>: {savedMaskCount}</span>
<span>: {draftMaskCount}</span>
<span>: {dirtyMaskCount}</span>
</div>
{masks.length > 0 && (
<button
onClick={clearMasks}
className="absolute bottom-4 right-4 text-xs bg-red-500/10 hover:bg-red-500/20 text-red-400 border border-red-500/20 px-3 py-1.5 rounded transition-colors"
>
</button>
{frameMasks.length > 0 && (
<div className="absolute bottom-4 right-4 flex gap-2">
{activeClass && (
<button
onClick={handleApplyActiveClass}
className="text-xs bg-cyan-500/10 hover:bg-cyan-500/20 text-cyan-300 border border-cyan-500/20 px-3 py-1.5 rounded transition-colors"
>
</button>
)}
<button
onClick={handleClearMasks}
className="text-xs bg-red-500/10 hover:bg-red-500/20 text-red-400 border border-red-500/20 px-3 py-1.5 rounded transition-colors"
>
</button>
</div>
)}
</div>
);

View File

@@ -0,0 +1,115 @@
import { act, render, screen, waitFor } from '@testing-library/react';
import { beforeEach, describe, expect, it, vi } from 'vitest';
import { Dashboard } from './Dashboard';
const apiMock = vi.hoisted(() => ({
getDashboardOverview: vi.fn(),
}));
const wsMock = vi.hoisted(() => {
const state = {
callback: undefined as undefined | ((data: any) => void),
connected: false,
};
return {
state,
progressWS: {
connect: vi.fn(() => { state.connected = true; }),
disconnect: vi.fn(() => { state.connected = false; }),
isConnected: vi.fn(() => state.connected),
onProgress: vi.fn((cb: (data: any) => void) => {
state.callback = cb;
return vi.fn();
}),
},
};
});
vi.mock('../lib/websocket', () => ({
progressWS: wsMock.progressWS,
}));
vi.mock('../lib/api', () => ({
getDashboardOverview: apiMock.getDashboardOverview,
}));
describe('Dashboard', () => {
beforeEach(() => {
vi.useRealTimers();
vi.clearAllMocks();
wsMock.state.connected = false;
wsMock.state.callback = undefined;
apiMock.getDashboardOverview.mockResolvedValue({
summary: {
project_count: 2,
parsing_task_count: 1,
annotation_count: 5,
frame_count: 100,
template_count: 3,
system_load_percent: 12,
},
tasks: [
{
id: 'project-1',
project_id: 1,
name: '真实项目.mp4',
progress: 60,
status: 'pending',
frame_count: 10,
updated_at: '2026-05-01T00:00:00Z',
},
],
activity: [
{
id: 'activity-1',
kind: 'project',
time: '2026-05-01T00:00:00Z',
message: '项目状态: pending',
project: '真实项目.mp4',
},
],
});
});
it('loads dashboard stats, tasks, and activity from the backend overview endpoint', async () => {
render(<Dashboard />);
await waitFor(() => expect(apiMock.getDashboardOverview).toHaveBeenCalled());
expect(screen.getByText('项目总数')).toBeInTheDocument();
expect(screen.getByText('已存标注')).toBeInTheDocument();
expect(screen.getByText('真实项目.mp4')).toBeInTheDocument();
expect(screen.getByText('项目状态: pending')).toBeInTheDocument();
expect(screen.queryByText('City_Driving_Dataset_004.mp4')).not.toBeInTheDocument();
});
it('connects to the progress stream and updates progress tasks', async () => {
render(<Dashboard />);
await waitFor(() => expect(wsMock.progressWS.connect).toHaveBeenCalled());
act(() => {
wsMock.state.callback?.({
type: 'progress',
taskId: 'task-1',
projectName: 'demo.mp4',
progress: 44,
status: '正在截取帧',
});
});
expect(await screen.findByText('demo.mp4')).toBeInTheDocument();
expect(screen.getByText('44%')).toBeInTheDocument();
});
it('adds activity logs for complete and status messages', async () => {
render(<Dashboard />);
act(() => {
wsMock.state.callback?.({ type: 'status', message: 'Progress stream active' });
wsMock.state.callback?.({ type: 'complete', taskId: '1', filename: 'done.mp4' });
});
await waitFor(() => expect(screen.getByText('Progress stream active')).toBeInTheDocument());
expect(screen.getByText('解析完成: done.mp4')).toBeInTheDocument();
});
});

View File

@@ -2,30 +2,68 @@ import React, { useState, useEffect } from 'react';
import { Activity, Clock, Folders, CheckCircle2, Loader2 } from 'lucide-react';
import { progressWS, type ProgressMessage } from '../lib/websocket';
import { cn } from '../lib/utils';
import { getDashboardOverview, type DashboardActivity, type DashboardOverview, type DashboardTask } from '../lib/api';
interface QueueTask {
id: string;
name: string;
progress: number;
status: string;
}
const emptySummary: DashboardOverview['summary'] = {
project_count: 0,
parsing_task_count: 0,
annotation_count: 0,
frame_count: 0,
template_count: 0,
system_load_percent: 0,
};
export function Dashboard() {
const [tasks, setTasks] = useState<QueueTask[]>([
{ id: '1', name: 'City_Driving_Dataset_004.mp4', progress: 85, status: '正在截取帧 (30fps)' },
{ id: '2', name: 'Pedestrian_Night_Vision_02.mkv', progress: 32, status: '正在截取帧 (60fps)' },
{ id: '3', name: 'Drone_Mapping_Sector_7.avi', progress: 0, status: '队列排队等待中' },
]);
const [summary, setSummary] = useState<DashboardOverview['summary']>(emptySummary);
const [tasks, setTasks] = useState<DashboardTask[]>([]);
const [isConnected, setIsConnected] = useState(false);
const [activityLog, setActivityLog] = useState<Array<{ time: string; message: string; project?: string }>>([
{ time: '10 分钟前', message: '语义归档完成 54 帧', project: 'Highway_Data' },
{ time: '25 分钟前', message: '项目解析开始', project: 'City_Driving_Dataset_004' },
{ time: '1 小时前', message: '模板库更新: Cityscapes_v2', project: '系统' },
{ time: '2 小时前', message: 'AI 推理完成 12 个实例', project: 'Nav_Cam_Left' },
]);
const [activityLog, setActivityLog] = useState<DashboardActivity[]>([]);
const [isLoading, setIsLoading] = useState(true);
const [loadError, setLoadError] = useState('');
useEffect(() => {
let cancelled = false;
const loadOverview = () => {
getDashboardOverview()
.then((overview) => {
if (cancelled) return;
setSummary(overview.summary);
setTasks((prev) => {
if (prev.length === 0) return overview.tasks;
const overviewIds = new Set(overview.tasks.map((task) => task.id));
const wsOnly = prev.filter((task) => !task.id.startsWith('task-') && !overviewIds.has(task.id) && task.progress < 100);
return [...overview.tasks, ...wsOnly];
});
setActivityLog((prev) => {
if (prev.length === 0) return overview.activity;
const byId = new Map(prev.map((item) => [item.id, item]));
overview.activity.forEach((item) => byId.set(item.id, item));
return Array.from(byId.values()).slice(0, 10);
});
setLoadError('');
})
.catch((err) => {
console.error('Failed to load dashboard overview:', err);
if (!cancelled) setLoadError('Dashboard 数据加载失败');
})
.finally(() => {
if (!cancelled) setIsLoading(false);
});
};
loadOverview();
const overviewInterval = setInterval(loadOverview, 5000);
return () => {
cancelled = true;
clearInterval(overviewInterval);
};
}, []);
useEffect(() => {
let mounted = true;
const taskTitle = (data: ProgressMessage) => data.filename || data.projectName || data.taskId || '后台任务';
const timer = setTimeout(() => {
if (mounted) progressWS.connect();
}, 500);
@@ -34,7 +72,7 @@ export function Dashboard() {
if (!mounted) return;
setIsConnected(progressWS.isConnected());
if (data.type === 'progress' && data.taskId && data.filename) {
if (data.type === 'progress' && data.taskId) {
setTasks((prev) => {
const exists = prev.find((t) => t.id === data.taskId);
if (exists) {
@@ -48,9 +86,12 @@ export function Dashboard() {
...prev,
{
id: data.taskId!,
name: data.filename!,
project_id: data.project_id ?? Number(data.task_id || 0),
name: taskTitle(data),
progress: data.progress ?? 0,
status: data.status ?? '处理中',
frame_count: 0,
updated_at: new Date().toISOString(),
},
];
});
@@ -63,7 +104,7 @@ export function Dashboard() {
)
);
setActivityLog((prev) => [
{ time: '刚刚', message: `解析完成: ${data.filename || data.taskId}`, project: '系统' },
{ id: `ws-complete-${Date.now()}`, kind: 'websocket', time: new Date().toISOString(), message: data.message || `解析完成: ${taskTitle(data)}`, project: data.projectName || '系统' },
...prev.slice(0, 9),
]);
}
@@ -71,14 +112,18 @@ export function Dashboard() {
if (data.type === 'error' && data.taskId) {
setTasks((prev) =>
prev.map((t) =>
t.id === data.taskId ? { ...t, status: `错误: ${data.message || '未知错误'}` } : t
t.id === data.taskId ? { ...t, progress: data.progress ?? t.progress, status: `错误: ${data.error || data.message || '未知错误'}` } : t
)
);
setActivityLog((prev) => [
{ id: `ws-error-${Date.now()}`, kind: 'websocket', time: new Date().toISOString(), message: data.message || `解析失败: ${taskTitle(data)}`, project: data.projectName || '系统' },
...prev.slice(0, 9),
]);
}
if (data.type === 'status') {
setActivityLog((prev) => [
{ time: '刚刚', message: data.message || '状态更新', project: '系统' },
{ id: `ws-status-${Date.now()}`, kind: 'websocket', time: new Date().toISOString(), message: data.message || '状态更新', project: '系统' },
...prev.slice(0, 9),
]);
}
@@ -97,12 +142,24 @@ export function Dashboard() {
}, []);
const stats = [
{ label: '运行中项目', value: '14', icon: Folders, color: 'text-blue-400', bg: 'bg-blue-400/10' },
{ label: '排队处理任务', value: tasks.length.toString(), icon: Clock, color: 'text-orange-400', bg: 'bg-orange-400/10' },
{ label: '已归档批次', value: '128', icon: CheckCircle2, color: 'text-emerald-400', bg: 'bg-emerald-400/10' },
{ label: '系统负载', value: '78%', icon: Activity, color: 'text-cyan-400', bg: 'bg-cyan-400/10' },
{ label: '项目总数', value: summary.project_count.toString(), icon: Folders, color: 'text-blue-400', bg: 'bg-blue-400/10' },
{ label: '处理任务', value: summary.parsing_task_count.toString(), icon: Clock, color: 'text-orange-400', bg: 'bg-orange-400/10' },
{ label: '已存标注', value: summary.annotation_count.toString(), icon: CheckCircle2, color: 'text-emerald-400', bg: 'bg-emerald-400/10' },
{ label: '系统负载', value: `${summary.system_load_percent}%`, icon: Activity, color: 'text-cyan-400', bg: 'bg-cyan-400/10' },
];
function formatActivityTime(value: string | null): string {
if (!value) return '未知时间';
const date = new Date(value);
if (Number.isNaN(date.getTime())) return value;
return date.toLocaleString('zh-CN', {
month: '2-digit',
day: '2-digit',
hour: '2-digit',
minute: '2-digit',
});
}
return (
<div className="p-8 w-full h-full overflow-y-auto bg-[#0a0a0a]">
<header className="mb-8">
@@ -119,6 +176,7 @@ export function Dashboard() {
</div>
</div>
<p className="text-gray-400 text-sm mt-1"></p>
{loadError && <p className="text-red-400 text-xs mt-2">{loadError}</p>}
</header>
<div className="grid grid-cols-1 md:grid-cols-2 lg:grid-cols-4 gap-4 mb-8">
@@ -140,8 +198,11 @@ export function Dashboard() {
<div className="grid grid-cols-1 lg:grid-cols-3 gap-6">
<div className="lg:col-span-2 bg-[#111] border border-white/5 rounded-xl p-6 min-h-[400px]">
<h2 className="text-sm font-medium text-gray-400 uppercase tracking-widest mb-6"> (FFmpeg )</h2>
<h2 className="text-sm font-medium text-gray-400 uppercase tracking-widest mb-6"> ()</h2>
<div className="space-y-4">
{isLoading && (
<div className="text-sm text-gray-500 text-center py-12"> Dashboard ...</div>
)}
{tasks.map((task) => (
<div key={task.id} className="bg-[#0d0d0d] border border-white/5 p-4 rounded-lg">
<div className="flex justify-between items-center mb-2">
@@ -152,7 +213,7 @@ export function Dashboard() {
<div className="h-full bg-gradient-to-r from-cyan-600 to-cyan-400 rounded-full transition-all duration-500" style={{ width: `${task.progress}%` }} />
</div>
<div className="text-xs text-gray-500 flex items-center gap-2">
{task.status === '已完成' ? (
{task.status === '已完成' || task.progress >= 100 ? (
<CheckCircle2 size={12} className="text-emerald-400" />
) : task.status.includes('错误') ? (
<span className="text-red-400"></span>
@@ -160,10 +221,11 @@ export function Dashboard() {
<Loader2 size={12} className="text-cyan-400 animate-spin" />
)}
{task.status}
<span className="text-gray-600">: {task.frame_count}</span>
</div>
</div>
))}
{tasks.length === 0 && (
{!isLoading && tasks.length === 0 && (
<div className="text-sm text-gray-500 text-center py-12"></div>
)}
</div>
@@ -172,16 +234,22 @@ export function Dashboard() {
<div className="bg-[#111] border border-white/5 rounded-xl p-6 min-h-[400px]">
<h2 className="text-sm font-medium text-gray-400 uppercase tracking-widest mb-6"></h2>
<div className="space-y-6 relative before:absolute before:inset-0 before:ml-[11px] before:-translate-x-px md:before:mx-auto md:before:translate-x-0 before:h-full before:w-0.5 before:bg-gradient-to-b before:from-transparent before:via-white/10 before:to-transparent">
{activityLog.map((log, i) => (
<div key={i} className="relative flex items-center justify-between md:justify-normal md:odd:flex-row-reverse group is-active">
{isLoading && (
<div className="text-sm text-gray-500 text-center py-12">...</div>
)}
{activityLog.map((log) => (
<div key={log.id} className="relative flex items-center justify-between md:justify-normal md:odd:flex-row-reverse group is-active">
<div className="flex items-center justify-center w-6 h-6 rounded-full border border-white/10 bg-[#111] group-[.is-active]:bg-cyan-500 group-[.is-active]:border-cyan-400 text-slate-500 group-[.is-active]:text-black shadow shrink-0 md:order-1 md:group-odd:-translate-x-1/2 md:group-even:translate-x-1/2 z-10" />
<div className="w-[calc(100%-4rem)] md:w-[calc(50%-2.5rem)] bg-[#0d0d0d] p-3 rounded border border-white/5">
<div className="text-xs text-gray-400 mb-1">{log.time}</div>
<div className="text-xs text-gray-400 mb-1">{formatActivityTime(log.time)}</div>
<div className="text-sm font-medium text-gray-200">{log.message}</div>
<div className="text-xs text-gray-500">: {log.project}</div>
</div>
</div>
))}
{!isLoading && activityLog.length === 0 && (
<div className="text-sm text-gray-500 text-center py-12"></div>
)}
</div>
</div>
</div>

View File

@@ -0,0 +1,62 @@
import { act, fireEvent, render, screen } from '@testing-library/react';
import { beforeEach, describe, expect, it, vi } from 'vitest';
import { resetStore } from '../test/storeTestUtils';
import { useStore } from '../store/useStore';
import { FrameTimeline } from './FrameTimeline';
describe('FrameTimeline', () => {
beforeEach(() => {
resetStore();
vi.useRealTimers();
});
it('renders empty state when no frames are loaded', () => {
render(<FrameTimeline />);
expect(screen.getByText('暂无帧数据')).toBeInTheDocument();
expect(screen.getByText('0')).toBeInTheDocument();
});
it('changes the current frame through thumbnails and range input', () => {
useStore.setState({
frames: [
{ id: 'f1', projectId: 'p1', index: 0, url: '/1.jpg', width: 640, height: 360 },
{ id: 'f2', projectId: 'p1', index: 1, url: '/2.jpg', width: 640, height: 360 },
{ id: 'f3', projectId: 'p1', index: 2, url: '/3.jpg', width: 640, height: 360 },
],
});
render(<FrameTimeline />);
fireEvent.click(screen.getByAltText('frame-1'));
expect(useStore.getState().currentFrameIndex).toBe(1);
fireEvent.change(screen.getByRole('slider'), { target: { value: '3' } });
expect(useStore.getState().currentFrameIndex).toBe(2);
});
it('plays forward using the project parse fps and stops at the end', () => {
vi.useFakeTimers();
useStore.setState({
currentProject: { id: 'p1', name: 'P', status: 'ready', parse_fps: 10 },
frames: [
{ id: 'f1', projectId: 'p1', index: 0, url: '/1.jpg', width: 640, height: 360 },
{ id: 'f2', projectId: 'p1', index: 1, url: '/2.jpg', width: 640, height: 360 },
],
});
const { container } = render(<FrameTimeline />);
fireEvent.click(container.querySelector('button')!);
act(() => {
vi.advanceTimersByTime(100);
});
expect(useStore.getState().currentFrameIndex).toBe(1);
act(() => {
vi.advanceTimersByTime(100);
});
expect(screen.getByText('播放序列 (F5)')).toBeInTheDocument();
});
});

View File

@@ -1,16 +1,42 @@
import React, { useState } from 'react';
import React, { useEffect, useMemo, useState } from 'react';
import { Play, Pause } from 'lucide-react';
import { cn } from '../lib/utils';
import { useStore } from '../store/useStore';
export function FrameTimeline() {
const frames = useStore((state) => state.frames);
const currentProject = useStore((state) => state.currentProject);
const currentFrameIndex = useStore((state) => state.currentFrameIndex);
const setCurrentFrame = useStore((state) => state.setCurrentFrame);
const [isPlaying, setIsPlaying] = useState(false);
const totalFrames = frames.length;
const currentFrame = totalFrames > 0 ? currentFrameIndex + 1 : 0;
const playbackFps = useMemo(() => {
const fps = currentProject?.parse_fps || currentProject?.original_fps || 12;
return Math.min(Math.max(fps, 1), 30);
}, [currentProject?.original_fps, currentProject?.parse_fps]);
useEffect(() => {
if (!isPlaying || totalFrames <= 1) return;
const timer = window.setTimeout(() => {
if (currentFrameIndex >= totalFrames - 1) {
setIsPlaying(false);
return;
}
setCurrentFrame(currentFrameIndex + 1);
}, 1000 / playbackFps);
return () => window.clearTimeout(timer);
}, [currentFrameIndex, isPlaying, playbackFps, setCurrentFrame, totalFrames]);
useEffect(() => {
if (totalFrames === 0) {
setIsPlaying(false);
}
}, [totalFrames]);
// show frames around current frame
const frameWindow = 20;
@@ -45,8 +71,14 @@ export function FrameTimeline() {
<div className="flex-1 flex items-center px-4 gap-6">
<div className="flex flex-col items-center gap-2 px-4 border-r border-white/10 shrink-0">
<button
className="p-2 rounded-full bg-white/5 text-white hover:bg-white/10"
onClick={() => setIsPlaying(!isPlaying)}
className="p-2 rounded-full bg-white/5 text-white hover:bg-white/10 disabled:opacity-40 disabled:cursor-not-allowed"
disabled={totalFrames <= 1}
onClick={() => {
if (currentFrameIndex >= totalFrames - 1) {
setCurrentFrame(0);
}
setIsPlaying(!isPlaying);
}}
>
{isPlaying ? <Pause size={20} fill="currentColor" /> : <Play size={20} fill="currentColor" />}
</button>

View File

@@ -0,0 +1,42 @@
import { fireEvent, render, screen, waitFor } from '@testing-library/react';
import { beforeEach, describe, expect, it, vi } from 'vitest';
import { resetStore } from '../test/storeTestUtils';
import { useStore } from '../store/useStore';
import { Login } from './Login';
const apiMock = vi.hoisted(() => ({
login: vi.fn(),
}));
vi.mock('../lib/api', () => ({
login: apiMock.login,
}));
describe('Login', () => {
beforeEach(() => {
resetStore();
vi.clearAllMocks();
});
it('logs in with the development credentials and stores the token', async () => {
apiMock.login.mockResolvedValueOnce({ token: 'fake-jwt-token-for-admin' });
render(<Login />);
fireEvent.click(screen.getByRole('button', { name: '安全登录' }));
await waitFor(() => expect(apiMock.login).toHaveBeenCalledWith('admin', '123456'));
expect(useStore.getState().isAuthenticated).toBe(true);
expect(localStorage.getItem('token')).toBe('fake-jwt-token-for-admin');
});
it('shows backend login errors', async () => {
apiMock.login.mockRejectedValueOnce({ response: { data: { detail: 'Invalid credentials' } } });
render(<Login />);
fireEvent.change(screen.getByDisplayValue('admin'), { target: { value: 'bad' } });
fireEvent.click(screen.getByRole('button', { name: '安全登录' }));
expect(await screen.findByText('Invalid credentials')).toBeInTheDocument();
expect(useStore.getState().isAuthenticated).toBe(false);
});
});

View File

@@ -0,0 +1,45 @@
import { render, screen, waitFor } from '@testing-library/react';
import { beforeEach, describe, expect, it, vi } from 'vitest';
import { resetStore } from '../test/storeTestUtils';
import { useStore } from '../store/useStore';
import { ModelStatusBadge } from './ModelStatusBadge';
const apiMock = vi.hoisted(() => ({
getAiModelStatus: vi.fn(),
}));
vi.mock('../lib/api', () => ({
getAiModelStatus: apiMock.getAiModelStatus,
}));
describe('ModelStatusBadge', () => {
beforeEach(() => {
resetStore();
vi.clearAllMocks();
apiMock.getAiModelStatus.mockResolvedValue({
selected_model: 'sam2',
gpu: { available: true, device: 'cuda', name: 'RTX 4090', torch_available: true },
models: [
{ id: 'sam2', label: 'SAM 2', available: true, loaded: false, device: 'cuda', supports: ['point', 'box'], message: 'SAM 2 ready', package_available: true, checkpoint_exists: true, python_ok: true, torch_ok: true, cuda_required: false },
{ id: 'sam3', label: 'SAM 3', available: false, loaded: false, device: 'cuda', supports: ['semantic'], message: 'SAM 3 missing runtime', package_available: false, checkpoint_exists: false, python_ok: false, torch_ok: true, cuda_required: true },
],
});
});
it('loads real model status for the selected model', async () => {
render(<ModelStatusBadge />);
expect(await screen.findByText('SAM 2 可用')).toBeInTheDocument();
expect(apiMock.getAiModelStatus).toHaveBeenCalledWith('sam2');
});
it('shows unavailable state when SAM3 is selected but not runnable', async () => {
useStore.getState().setAiModel('sam3');
render(<ModelStatusBadge />);
await waitFor(() => expect(apiMock.getAiModelStatus).toHaveBeenCalledWith('sam3'));
expect(await screen.findByText('SAM 3 不可用')).toBeInTheDocument();
expect(screen.getByTitle('SAM 3 missing runtime')).toBeInTheDocument();
});
});

View File

@@ -0,0 +1,56 @@
import React, { useEffect, useState } from 'react';
import { Cpu, Loader2 } from 'lucide-react';
import { getAiModelStatus, type AiRuntimeStatus } from '../lib/api';
import { cn } from '../lib/utils';
import { useStore } from '../store/useStore';
interface ModelStatusBadgeProps {
compact?: boolean;
}
export function ModelStatusBadge({ compact = false }: ModelStatusBadgeProps) {
const aiModel = useStore((state) => state.aiModel);
const [status, setStatus] = useState<AiRuntimeStatus | null>(null);
const [isLoading, setIsLoading] = useState(true);
useEffect(() => {
let cancelled = false;
setIsLoading(true);
getAiModelStatus(aiModel)
.then((data) => {
if (!cancelled) setStatus(data);
})
.catch(() => {
if (!cancelled) setStatus(null);
})
.finally(() => {
if (!cancelled) setIsLoading(false);
});
return () => {
cancelled = true;
};
}, [aiModel]);
const model = status?.models.find((item) => item.id === aiModel);
const ready = Boolean(model?.available);
const gpuReady = Boolean(status?.gpu.available);
const label = compact
? (gpuReady ? 'GPU' : 'CPU')
: `${model?.label || aiModel.toUpperCase()} ${ready ? '可用' : '不可用'}`;
return (
<div
className={cn(
"inline-flex items-center gap-1.5 rounded border font-mono uppercase",
compact ? "w-8 h-8 justify-center text-[9px]" : "px-2 py-0.5 text-[10px]",
ready
? "bg-emerald-500/10 text-emerald-400 border-emerald-500/20"
: "bg-amber-500/10 text-amber-400 border-amber-500/20"
)}
title={model?.message || 'AI 模型状态读取中'}
>
{isLoading ? <Loader2 size={compact ? 12 : 10} className="animate-spin" /> : <Cpu size={compact ? 12 : 10} />}
<span>{label}</span>
</div>
);
}

View File

@@ -0,0 +1,60 @@
import { fireEvent, render, screen, within } from '@testing-library/react';
import { beforeEach, describe, expect, it } from 'vitest';
import { resetStore } from '../test/storeTestUtils';
import { useStore } from '../store/useStore';
import { OntologyInspector } from './OntologyInspector';
describe('OntologyInspector', () => {
beforeEach(() => {
resetStore();
useStore.setState({
templates: [
{
id: 't1',
name: '腹腔镜模板',
classes: [
{ id: 'c1', name: '胆囊', color: '#ff0000', zIndex: 20, category: '器官' },
{ id: 'c2', name: '肝脏', color: '#00ff00', zIndex: 10, category: '器官' },
],
rules: [],
},
],
});
});
it('shows template classes and changes the active template', () => {
render(<OntologyInspector />);
fireEvent.change(screen.getByRole('combobox'), { target: { value: 't1' } });
expect(useStore.getState().activeTemplateId).toBe('t1');
expect(screen.getByText('胆囊')).toBeInTheDocument();
expect(screen.getByText('肝脏')).toBeInTheDocument();
});
it('selects a concrete class for subsequent masks', () => {
render(<OntologyInspector />);
fireEvent.click(screen.getByText('胆囊'));
expect(useStore.getState().activeClassId).toBe('c1');
expect(useStore.getState().activeClass).toEqual(expect.objectContaining({
id: 'c1',
name: '胆囊',
zIndex: 20,
}));
});
it('adds custom classes locally without backend persistence', () => {
const { container } = render(<OntologyInspector />);
const customSection = screen.getByText('自定义分类').parentElement!;
fireEvent.click(within(customSection).getByRole('button'));
fireEvent.change(screen.getByPlaceholderText('分类名称'), { target: { value: '新局部分类' } });
fireEvent.keyDown(screen.getByPlaceholderText('分类名称'), { key: 'Enter' });
expect(screen.getAllByText('新局部分类')).toHaveLength(2);
expect(useStore.getState().activeClass).toEqual(expect.objectContaining({ name: '新局部分类' }));
expect(useStore.getState().templates[0].classes).toHaveLength(2);
expect(container).toHaveTextContent('2 个分类来自模板 + 1 个自定义');
});
});

View File

@@ -2,11 +2,16 @@ import React, { useState } from 'react';
import { Layers, ChevronDown, Tag, Eye, Plus, X } from 'lucide-react';
import { useStore } from '../store/useStore';
import type { TemplateClass } from '../store/useStore';
import { cn } from '../lib/utils';
import { getActiveTemplate } from '../lib/templateSelection';
export function OntologyInspector() {
const templates = useStore((state) => state.templates);
const activeTemplateId = useStore((state) => state.activeTemplateId);
const activeClassId = useStore((state) => state.activeClassId);
const activeClass = useStore((state) => state.activeClass);
const setActiveTemplateId = useStore((state) => state.setActiveTemplateId);
const setActiveClass = useStore((state) => state.setActiveClass);
// Project-level custom classes (in addition to template classes)
const [customClasses, setCustomClasses] = useState<TemplateClass[]>([]);
@@ -14,10 +19,17 @@ export function OntologyInspector() {
const [newClassName, setNewClassName] = useState('');
const [newClassColor, setNewClassColor] = useState('#06b6d4');
const activeTemplate = templates.find((t) => t.id === activeTemplateId) || templates[0] || null;
const activeTemplate = getActiveTemplate(templates, activeTemplateId);
const templateClasses = activeTemplate?.classes || [];
const allClasses = [...templateClasses, ...customClasses].sort((a, b) => b.zIndex - a.zIndex);
const handleSelectClass = (templateClass: TemplateClass) => {
if (activeTemplate && !activeTemplateId) {
setActiveTemplateId(activeTemplate.id);
}
setActiveClass(templateClass);
};
const handleAddCustom = () => {
if (!newClassName.trim()) return;
const maxZ = allClasses.length > 0 ? Math.max(...allClasses.map((c) => c.zIndex)) : 0;
@@ -29,6 +41,7 @@ export function OntologyInspector() {
category: '自定义',
};
setCustomClasses([...customClasses, newClass]);
handleSelectClass(newClass);
setNewClassName('');
setShowAddForm(false);
};
@@ -47,7 +60,10 @@ export function OntologyInspector() {
<div className="relative">
<select
value={activeTemplate?.id || ''}
onChange={(e) => setActiveTemplateId(e.target.value || null)}
onChange={(e) => {
setActiveTemplateId(e.target.value || null);
setActiveClass(null);
}}
className="w-full bg-[#1a1a1a] border border-white/10 rounded-lg px-3 py-2 text-xs text-gray-300 appearance-none cursor-pointer focus:outline-none focus:border-cyan-500/50"
>
<option value="">-- --</option>
@@ -73,7 +89,14 @@ export function OntologyInspector() {
<div className="space-y-2">
{allClasses.map(cls => (
<div key={cls.id} className="flex flex-col gap-1">
<div className="flex items-center justify-between p-2 rounded bg-white/5 hover:bg-white/10 cursor-pointer group transition-colors">
<button
type="button"
onClick={() => handleSelectClass(cls)}
className={cn(
'flex items-center justify-between p-2 rounded bg-white/5 hover:bg-white/10 cursor-pointer group transition-colors text-left border',
activeClassId === cls.id ? 'border-cyan-500/50 bg-cyan-500/10' : 'border-transparent',
)}
>
<div className="flex items-center gap-2">
<span className="w-2.5 h-2.5 rounded-sm" style={{ backgroundColor: cls.color }} />
<span className="text-xs font-medium text-gray-200">{cls.name}</span>
@@ -82,7 +105,7 @@ export function OntologyInspector() {
<span className="text-[10px] text-gray-500 font-mono">z:{cls.zIndex}</span>
<Eye size={14} className="text-gray-500 group-hover:text-gray-300" />
</div>
</div>
</button>
</div>
))}
{allClasses.length === 0 && (
@@ -136,7 +159,9 @@ export function OntologyInspector() {
<div className="bg-white/5 rounded-lg p-3">
<div className="flex items-center gap-2 mb-3">
<Tag size={12} className="text-cyan-400" />
<span className="text-xs font-semibold text-gray-200">{activeTemplate?.name || '未选择'}</span>
<span className="text-xs font-semibold text-gray-200">
{activeClass?.name || activeTemplate?.name || '未选择'}
</span>
</div>
<div className="space-y-3">
<div className="space-y-1">

View File

@@ -0,0 +1,92 @@
import { fireEvent, render, screen, waitFor } from '@testing-library/react';
import { beforeEach, describe, expect, it, vi } from 'vitest';
import { resetStore } from '../test/storeTestUtils';
import { useStore } from '../store/useStore';
import { ProjectLibrary } from './ProjectLibrary';
const apiMock = vi.hoisted(() => ({
getProjects: vi.fn(),
createProject: vi.fn(),
uploadMedia: vi.fn(),
parseMedia: vi.fn(),
uploadDicomBatch: vi.fn(),
}));
vi.mock('../lib/api', () => ({
getProjects: apiMock.getProjects,
createProject: apiMock.createProject,
uploadMedia: apiMock.uploadMedia,
parseMedia: apiMock.parseMedia,
uploadDicomBatch: apiMock.uploadDicomBatch,
}));
describe('ProjectLibrary', () => {
beforeEach(() => {
resetStore();
vi.clearAllMocks();
apiMock.getProjects.mockResolvedValue([]);
});
it('loads projects and selects one into the workspace', async () => {
const onProjectSelect = vi.fn();
apiMock.getProjects.mockResolvedValueOnce([
{ id: 'p1', name: 'Demo Project', status: 'ready', frames: 3, fps: '30FPS' },
]);
render(<ProjectLibrary onProjectSelect={onProjectSelect} />);
fireEvent.click(await screen.findByText('Demo Project'));
expect(useStore.getState().currentProject?.id).toBe('p1');
expect(onProjectSelect).toHaveBeenCalled();
});
it('creates a new project from the modal', async () => {
apiMock.createProject.mockResolvedValueOnce({ id: 'p2', name: 'New Project', status: 'pending' });
render(<ProjectLibrary onProjectSelect={vi.fn()} />);
fireEvent.click(screen.getByText('新建项目'));
fireEvent.change(screen.getByPlaceholderText('输入项目名称'), { target: { value: 'New Project' } });
fireEvent.change(screen.getByPlaceholderText('输入项目描述'), { target: { value: 'desc' } });
fireEvent.click(screen.getByRole('button', { name: '创建' }));
await waitFor(() => expect(apiMock.createProject).toHaveBeenCalledWith({
name: 'New Project',
description: 'desc',
}));
expect(useStore.getState().projects[0]).toEqual(expect.objectContaining({ id: 'p2' }));
});
it('imports video by creating a project, uploading media, parsing frames and refreshing projects', async () => {
apiMock.createProject.mockResolvedValueOnce({ id: 'p3', name: 'clip.mp4', status: 'pending' });
apiMock.uploadMedia.mockResolvedValueOnce({ url: 'http://file', id: 'object' });
apiMock.parseMedia.mockResolvedValueOnce({ frames_extracted: 1 });
apiMock.getProjects.mockResolvedValue([]);
const { container } = render(<ProjectLibrary onProjectSelect={vi.fn()} />);
const input = container.querySelector('input[accept="video/*"]') as HTMLInputElement;
const file = new File(['video'], 'clip.mp4', { type: 'video/mp4' });
fireEvent.change(input, { target: { files: [file] } });
fireEvent.click(await screen.findByRole('button', { name: '开始导入' }));
await waitFor(() => expect(apiMock.createProject).toHaveBeenCalledWith(expect.objectContaining({
name: 'clip.mp4',
parse_fps: 30,
})));
expect(apiMock.uploadMedia).toHaveBeenCalledWith(file, 'p3');
expect(apiMock.parseMedia).toHaveBeenCalledWith('p3');
});
it('imports only valid DICOM files and parses the returned project', async () => {
apiMock.uploadDicomBatch.mockResolvedValueOnce({ project_id: 77, uploaded_count: 1, message: 'ok' });
apiMock.parseMedia.mockResolvedValueOnce({ frames_extracted: 1 });
const { container } = render(<ProjectLibrary onProjectSelect={vi.fn()} />);
const input = container.querySelector('input[accept=".dcm"]') as HTMLInputElement;
const dcm = new File(['dcm'], 'scan.dcm', { type: 'application/dicom' });
const ignored = new File(['txt'], 'notes.txt', { type: 'text/plain' });
fireEvent.change(input, { target: { files: [dcm, ignored] } });
await waitFor(() => expect(apiMock.uploadDicomBatch).toHaveBeenCalledWith([dcm]));
expect(apiMock.parseMedia).toHaveBeenCalledWith('77');
});
});

View File

@@ -212,11 +212,11 @@ export function ProjectLibrary({ onProjectSelect }: ProjectLibraryProps) {
{proj.source_type === 'dicom' ? 'DICOM' : (proj.fps || '30FPS')}
</span>
<span className="backdrop-blur-md bg-black/40 text-gray-200 text-[10px] px-2 py-1 rounded border border-white/10 flex items-center gap-1 uppercase tracking-widest">
{proj.status === 'Ready' ? (
{proj.status === 'ready' ? (
<><div className="w-1.5 h-1.5 bg-emerald-500 rounded-full" /> </>
) : proj.status === 'Parsing' ? (
) : proj.status === 'parsing' ? (
<><div className="w-1.5 h-1.5 bg-amber-500 rounded-full animate-pulse" /> </>
) : proj.status === 'Error' ? (
) : proj.status === 'error' ? (
<><div className="w-1.5 h-1.5 bg-red-500 rounded-full" /> </>
) : (
<><div className="w-1.5 h-1.5 bg-blue-500 rounded-full" /> </>

View File

@@ -2,6 +2,7 @@ import React from 'react';
import { Home, FolderOpen, Edit3, LayoutTemplate, BrainCircuit } from 'lucide-react';
import { cn } from '../lib/utils';
import type { ActiveModule } from '../App';
import { ModelStatusBadge } from './ModelStatusBadge';
interface SidebarProps {
activeModule: ActiveModule;
@@ -47,9 +48,7 @@ export function Sidebar({ activeModule, setActiveModule }: SidebarProps) {
})}
</nav>
<div className="mt-auto mb-4 flex flex-col gap-4">
<div className="w-8 h-8 rounded-full border border-cyan-500/50 flex items-center justify-center text-[10px] text-cyan-400 font-bold cursor-pointer transition-all hover:bg-cyan-500/10">
GPU
</div>
<ModelStatusBadge compact />
</div>
</aside>
);

View File

@@ -0,0 +1,85 @@
import { fireEvent, render, screen, waitFor } from '@testing-library/react';
import { beforeEach, describe, expect, it, vi } from 'vitest';
import { resetStore } from '../test/storeTestUtils';
import { useStore } from '../store/useStore';
import { TemplateRegistry } from './TemplateRegistry';
const apiMock = vi.hoisted(() => ({
getTemplates: vi.fn(),
createTemplate: vi.fn(),
updateTemplate: vi.fn(),
deleteTemplate: vi.fn(),
}));
vi.mock('../lib/api', () => ({
getTemplates: apiMock.getTemplates,
createTemplate: apiMock.createTemplate,
updateTemplate: apiMock.updateTemplate,
deleteTemplate: apiMock.deleteTemplate,
}));
describe('TemplateRegistry', () => {
beforeEach(() => {
resetStore();
vi.clearAllMocks();
});
it('loads and displays templates with unpacked classes', async () => {
apiMock.getTemplates.mockResolvedValueOnce([
{
id: 't1',
name: '腹腔镜胆囊切除术',
description: 'desc',
classes: [{ id: 'c1', name: '胆囊', color: '#ff0000', zIndex: 10, category: '器官' }],
rules: [],
},
]);
render(<TemplateRegistry />);
expect(await screen.findAllByText('腹腔镜胆囊切除术')).toHaveLength(2);
expect(screen.getByText('胆囊')).toBeInTheDocument();
});
it('creates a template and stores it globally', async () => {
apiMock.getTemplates.mockResolvedValueOnce([]);
apiMock.createTemplate.mockResolvedValueOnce({
id: 't2',
name: 'New Template',
description: 'desc',
classes: [],
rules: [],
});
render(<TemplateRegistry />);
fireEvent.click(screen.getByText('新建方案'));
fireEvent.change(screen.getAllByRole('textbox')[0], { target: { value: 'New Template' } });
fireEvent.change(screen.getAllByRole('textbox')[1], { target: { value: 'desc' } });
fireEvent.click(screen.getByRole('button', { name: '保存' }));
await waitFor(() => expect(apiMock.createTemplate).toHaveBeenCalledWith(expect.objectContaining({
name: 'New Template',
description: 'desc',
classes: [],
rules: [],
color: '#06b6d4',
z_index: 0,
})));
expect(useStore.getState().templates[0]).toEqual(expect.objectContaining({ id: 't2' }));
});
it('imports JSON classes into the edit modal before saving', async () => {
apiMock.getTemplates.mockResolvedValueOnce([]);
render(<TemplateRegistry />);
fireEvent.click(screen.getByText('新建方案'));
fireEvent.change(screen.getAllByRole('textbox')[0], { target: { value: 'With Classes' } });
fireEvent.click(screen.getByText('批量导入'));
fireEvent.change(screen.getByPlaceholderText('[[[255,0,0], [0,255,0]], ["分类A", "分类B"]]'), {
target: { value: '{"colors":[[255,0,0]],"names":["分类A"]}' },
});
fireEvent.click(screen.getByRole('button', { name: '导入' }));
expect(screen.getByText('分类A')).toBeInTheDocument();
});
});

View File

@@ -0,0 +1,30 @@
import { fireEvent, render, screen } from '@testing-library/react';
import { describe, expect, it, vi } from 'vitest';
import { ToolsPalette } from './ToolsPalette';
describe('ToolsPalette', () => {
it('switches tools and exposes UI-only placeholder buttons', () => {
const setActiveTool = vi.fn();
render(<ToolsPalette activeTool="move" setActiveTool={setActiveTool} />);
fireEvent.click(screen.getByTitle('创建多边形 (P)'));
fireEvent.click(screen.getByTitle('正向选点 (SAM)'));
expect(setActiveTool).toHaveBeenNthCalledWith(1, 'create_polygon');
expect(setActiveTool).toHaveBeenNthCalledWith(2, 'point_pos');
expect(screen.getByTitle('撤销操作 (Ctrl+Z)')).toBeInTheDocument();
expect(screen.getByTitle('重做操作 (Ctrl+Shift+Z)')).toBeInTheDocument();
});
it('switches to SAM trigger and calls the AI navigation hook', () => {
const setActiveTool = vi.fn();
const onTriggerAI = vi.fn();
render(<ToolsPalette activeTool="move" setActiveTool={setActiveTool} onTriggerAI={onTriggerAI} />);
fireEvent.click(screen.getByTitle('触发 SAM 推理 (Enter)'));
expect(setActiveTool).toHaveBeenCalledWith('sam_trigger');
expect(onTriggerAI).toHaveBeenCalled();
});
});

View File

@@ -78,7 +78,7 @@ export function ToolsPalette({ activeTool, setActiveTool, onTriggerAI }: ToolsPa
setActiveTool('sam_trigger');
if (onTriggerAI) onTriggerAI();
}}
title="触发 SAM 3 推理 (Enter)"
title="触发 SAM 推理 (Enter)"
className={cn(
"w-10 h-10 rounded-lg flex items-center justify-center transition-all",
activeTool === 'sam_trigger'

View File

@@ -0,0 +1,259 @@
import { act, fireEvent, render, screen, waitFor } from '@testing-library/react';
import { beforeEach, describe, expect, it, vi } from 'vitest';
import { resetStore } from '../test/storeTestUtils';
import { useStore } from '../store/useStore';
import { VideoWorkspace } from './VideoWorkspace';
const apiMock = vi.hoisted(() => ({
getProjectFrames: vi.fn(),
parseMedia: vi.fn(),
getTask: vi.fn(),
getTemplates: vi.fn(),
getProjectAnnotations: vi.fn(),
saveAnnotation: vi.fn(),
updateAnnotation: vi.fn(),
deleteAnnotation: vi.fn(),
exportCoco: vi.fn(),
annotationToMask: vi.fn(),
buildAnnotationPayload: vi.fn(),
getAiModelStatus: vi.fn(),
}));
vi.mock('../lib/api', () => ({
getProjectFrames: apiMock.getProjectFrames,
parseMedia: apiMock.parseMedia,
getTask: apiMock.getTask,
getTemplates: apiMock.getTemplates,
getProjectAnnotations: apiMock.getProjectAnnotations,
saveAnnotation: apiMock.saveAnnotation,
updateAnnotation: apiMock.updateAnnotation,
deleteAnnotation: apiMock.deleteAnnotation,
exportCoco: apiMock.exportCoco,
annotationToMask: apiMock.annotationToMask,
buildAnnotationPayload: apiMock.buildAnnotationPayload,
getAiModelStatus: apiMock.getAiModelStatus,
}));
describe('VideoWorkspace', () => {
beforeEach(() => {
resetStore();
vi.clearAllMocks();
useStore.setState({ currentProject: { id: '1', name: 'Demo', status: 'ready', video_path: 'uploads/demo.mp4' } });
apiMock.getTemplates.mockResolvedValue([]);
apiMock.getProjectAnnotations.mockResolvedValue([]);
apiMock.annotationToMask.mockReturnValue(null);
apiMock.getTask.mockResolvedValue({ id: 1, status: 'success', progress: 100, message: '解析完成' });
apiMock.getAiModelStatus.mockResolvedValue({
selected_model: 'sam2',
gpu: { available: false, device: 'cpu', name: null, torch_available: true },
models: [
{ id: 'sam2', label: 'SAM 2', available: true, loaded: false, device: 'cpu', supports: [], message: 'ready', package_available: true, checkpoint_exists: true, python_ok: true, torch_ok: true, cuda_required: false },
{ id: 'sam3', label: 'SAM 3', available: false, loaded: false, device: 'unavailable', supports: [], message: 'missing', package_available: false, checkpoint_exists: false, python_ok: false, torch_ok: true, cuda_required: true },
],
});
});
it('loads project frames into the workspace store', async () => {
apiMock.getProjectFrames.mockResolvedValueOnce([
{ id: 10, project_id: 1, frame_index: 0, image_url: '/frame.jpg', width: 640, height: 360 },
]);
render(<VideoWorkspace />);
await waitFor(() => expect(useStore.getState().frames).toEqual([
{ id: '10', projectId: '1', index: 0, url: '/frame.jpg', width: 640, height: 360 },
]));
expect(screen.getByText('Demo')).toBeInTheDocument();
expect(apiMock.parseMedia).not.toHaveBeenCalled();
expect(apiMock.getProjectAnnotations).toHaveBeenCalledWith('1');
});
it('triggers parsing when a media project has no frames yet', async () => {
apiMock.getProjectFrames
.mockResolvedValueOnce([])
.mockResolvedValueOnce([
{ id: 11, project_id: 1, frame_index: 0, image_url: '/parsed.jpg', width: 320, height: 240 },
]);
apiMock.parseMedia.mockResolvedValueOnce({ id: 7, status: 'queued', progress: 0 });
apiMock.getTask.mockResolvedValueOnce({ id: 7, status: 'success', progress: 100, message: '解析完成' });
render(<VideoWorkspace />);
await waitFor(() => expect(apiMock.parseMedia).toHaveBeenCalledWith('1'));
expect(apiMock.getTask).toHaveBeenCalledWith(7);
await waitFor(() => expect(useStore.getState().frames[0]).toEqual(expect.objectContaining({
id: '11',
url: '/parsed.jpg',
})));
});
it('hydrates saved annotations after loading frames', async () => {
apiMock.getProjectFrames.mockResolvedValueOnce([
{ id: 10, project_id: 1, frame_index: 0, image_url: '/frame.jpg', width: 640, height: 360 },
]);
apiMock.getProjectAnnotations.mockResolvedValueOnce([{ id: 99, frame_id: 10 }]);
apiMock.annotationToMask.mockReturnValueOnce({
id: 'annotation-99',
annotationId: '99',
frameId: '10',
saved: true,
pathData: 'M 0 0 Z',
label: 'Saved',
color: '#06b6d4',
});
render(<VideoWorkspace />);
await waitFor(() => expect(useStore.getState().masks).toEqual([
expect.objectContaining({ id: 'annotation-99', saved: true }),
]));
});
it('saves pending masks through the archive button', async () => {
apiMock.getProjectFrames.mockResolvedValueOnce([
{ id: 10, project_id: 1, frame_index: 0, image_url: '/frame.jpg', width: 640, height: 360 },
]);
apiMock.buildAnnotationPayload.mockReturnValueOnce({ project_id: 1, frame_id: 10, mask_data: { polygons: [] } });
apiMock.saveAnnotation.mockResolvedValueOnce({ id: 5 });
render(<VideoWorkspace />);
await waitFor(() => expect(useStore.getState().frames).toHaveLength(1));
act(() => {
useStore.setState({
activeTemplateId: '2',
masks: [{
id: 'mask-1',
frameId: '10',
pathData: 'M 0 0 Z',
label: 'AI Mask',
color: '#06b6d4',
segmentation: [[0, 0, 10, 0, 10, 10]],
bbox: [0, 0, 10, 10],
}],
});
});
fireEvent.click(screen.getByRole('button', { name: '结构化归档保存' }));
await waitFor(() => expect(apiMock.saveAnnotation).toHaveBeenCalledWith({
project_id: 1,
frame_id: 10,
mask_data: { polygons: [] },
}));
expect(apiMock.buildAnnotationPayload).toHaveBeenCalledWith(
'1',
expect.objectContaining({ id: 'mask-1' }),
expect.objectContaining({ id: '10' }),
'2',
);
});
it('updates dirty saved masks through the archive button', async () => {
apiMock.getProjectFrames.mockResolvedValueOnce([
{ id: 10, project_id: 1, frame_index: 0, image_url: '/frame.jpg', width: 640, height: 360 },
]);
apiMock.buildAnnotationPayload.mockReturnValueOnce({
project_id: 1,
frame_id: 10,
template_id: 2,
mask_data: { polygons: [], label: '胆囊' },
});
apiMock.updateAnnotation.mockResolvedValueOnce({ id: 99 });
render(<VideoWorkspace />);
await waitFor(() => expect(useStore.getState().frames).toHaveLength(1));
act(() => {
useStore.setState({
activeTemplateId: '2',
masks: [{
id: 'annotation-99',
annotationId: '99',
frameId: '10',
pathData: 'M 0 0 Z',
label: '胆囊',
color: '#ff0000',
saveStatus: 'dirty',
segmentation: [[0, 0, 10, 0, 10, 10]],
bbox: [0, 0, 10, 10],
}],
});
});
fireEvent.click(screen.getByRole('button', { name: '结构化归档保存' }));
await waitFor(() => expect(apiMock.updateAnnotation).toHaveBeenCalledWith('99', {
template_id: 2,
mask_data: { polygons: [], label: '胆囊' },
points: undefined,
bbox: undefined,
}));
expect(apiMock.saveAnnotation).not.toHaveBeenCalled();
});
it('deletes saved annotations when clearing current-frame masks', async () => {
apiMock.getProjectFrames.mockResolvedValueOnce([
{ id: 10, project_id: 1, frame_index: 0, image_url: '/frame.jpg', width: 640, height: 360 },
]);
apiMock.deleteAnnotation.mockResolvedValueOnce(undefined);
render(<VideoWorkspace />);
await waitFor(() => expect(useStore.getState().frames).toHaveLength(1));
act(() => {
useStore.setState({
masks: [
{
id: 'annotation-99',
annotationId: '99',
frameId: '10',
pathData: 'M 0 0 Z',
label: 'Saved',
color: '#06b6d4',
saved: true,
saveStatus: 'saved',
},
{
id: 'draft-1',
frameId: '10',
pathData: 'M 1 1 Z',
label: 'Draft',
color: '#ff0000',
},
],
});
});
fireEvent.click(screen.getByRole('button', { name: '清空遮罩' }));
await waitFor(() => expect(apiMock.deleteAnnotation).toHaveBeenCalledWith('99'));
expect(useStore.getState().masks).toEqual([]);
});
it('auto-saves pending masks before exporting COCO', async () => {
apiMock.getProjectFrames.mockResolvedValueOnce([
{ id: 10, project_id: 1, frame_index: 0, image_url: '/frame.jpg', width: 640, height: 360 },
]);
apiMock.buildAnnotationPayload.mockReturnValueOnce({ project_id: 1, frame_id: 10, mask_data: { polygons: [] } });
apiMock.saveAnnotation.mockResolvedValueOnce({ id: 5 });
apiMock.exportCoco.mockResolvedValueOnce(new Blob(['{}'], { type: 'application/json' }));
render(<VideoWorkspace />);
await waitFor(() => expect(useStore.getState().frames).toHaveLength(1));
act(() => {
useStore.setState({
masks: [{
id: 'mask-1',
frameId: '10',
pathData: 'M 0 0 Z',
label: 'AI Mask',
color: '#06b6d4',
segmentation: [[0, 0, 10, 0, 10, 10]],
}],
});
});
fireEvent.click(screen.getByRole('button', { name: '导出 JSON 标注集' }));
await waitFor(() => expect(apiMock.saveAnnotation).toHaveBeenCalled());
expect(apiMock.exportCoco).toHaveBeenCalledWith('1');
});
});

View File

@@ -1,10 +1,28 @@
import React, { useEffect } from 'react';
import React, { useCallback, useEffect, useMemo, useState } from 'react';
import { useStore } from '../store/useStore';
import { getProjectFrames, parseMedia, getTemplates } from '../lib/api';
import {
annotationToMask,
buildAnnotationPayload,
deleteAnnotation,
exportCoco,
getProjectAnnotations,
getProjectFrames,
getTask,
getTemplates,
parseMedia,
saveAnnotation,
updateAnnotation,
} from '../lib/api';
import { CanvasArea } from './CanvasArea';
import { ToolsPalette } from './ToolsPalette';
import { OntologyInspector } from './OntologyInspector';
import { FrameTimeline } from './FrameTimeline';
import { ModelStatusBadge } from './ModelStatusBadge';
import type { Frame } from '../store/useStore';
function sleep(ms: number) {
return new Promise((resolve) => setTimeout(resolve, ms));
}
export function VideoWorkspace({ onNavigateToAI }: { onNavigateToAI?: () => void }) {
const activeTool = useStore((state) => state.activeTool);
@@ -12,8 +30,26 @@ export function VideoWorkspace({ onNavigateToAI }: { onNavigateToAI?: () => void
const currentProject = useStore((state) => state.currentProject);
const frames = useStore((state) => state.frames);
const currentFrameIndex = useStore((state) => state.currentFrameIndex);
const masks = useStore((state) => state.masks);
const activeTemplateId = useStore((state) => state.activeTemplateId);
const setFrames = useStore((state) => state.setFrames);
const setCurrentFrame = useStore((state) => state.setCurrentFrame);
const setMasks = useStore((state) => state.setMasks);
const [isSaving, setIsSaving] = useState(false);
const [isExporting, setIsExporting] = useState(false);
const [statusMessage, setStatusMessage] = useState('');
const hydrateSavedAnnotations = useCallback(async (projectId: string, projectFrames: Frame[]) => {
const frameById = new Map(projectFrames.map((frame) => [frame.id, frame]));
const annotations = await getProjectAnnotations(projectId);
const savedMasks = annotations
.map((annotation) => {
const frame = annotation.frame_id ? frameById.get(String(annotation.frame_id)) : null;
return frame ? annotationToMask(annotation, frame) : null;
})
.filter((mask): mask is NonNullable<typeof mask> => Boolean(mask));
setMasks(savedMasks);
}, [setMasks]);
useEffect(() => {
if (!currentProject?.id) return;
@@ -25,34 +61,58 @@ export function VideoWorkspace({ onNavigateToAI }: { onNavigateToAI?: () => void
if (cancelled) return;
if (data.length === 0 && currentProject.video_path) {
// No frames yet but video exists → trigger parsing
// No frames yet but video exists -> queue parsing and poll the task.
try {
await parseMedia(String(currentProject.id));
const task = await parseMedia(String(currentProject.id));
if (cancelled) return;
setStatusMessage(`解析任务已入队 #${task.id}`);
let completed = false;
for (let attempt = 0; attempt < 60; attempt += 1) {
const freshTask = await getTask(task.id);
if (cancelled) return;
setStatusMessage(freshTask.message || `解析进度 ${freshTask.progress}%`);
if (freshTask.status === 'success') {
completed = true;
break;
}
if (freshTask.status === 'failed') {
setStatusMessage(freshTask.error || '解析任务失败');
return;
}
await sleep(2000);
}
if (!completed) {
setStatusMessage('解析仍在后台运行,可稍后刷新工作区');
return;
}
const fresh = await getProjectFrames(String(currentProject.id));
if (cancelled) return;
setFrames(fresh.map((f) => ({
const mappedFrames = fresh.map((f) => ({
id: String(f.id),
projectId: String(f.project_id),
index: f.frame_index,
url: f.image_url,
width: f.width ?? 0,
height: f.height ?? 0,
})));
}));
setFrames(mappedFrames);
setCurrentFrame(0);
await hydrateSavedAnnotations(String(currentProject.id), mappedFrames);
} catch (err) {
console.error('Parse failed:', err);
}
} else {
setFrames(data.map((f) => ({
const mappedFrames = data.map((f) => ({
id: String(f.id),
projectId: String(f.project_id),
index: f.frame_index,
url: f.image_url,
width: f.width ?? 0,
height: f.height ?? 0,
})));
}));
setFrames(mappedFrames);
setCurrentFrame(0);
await hydrateSavedAnnotations(String(currentProject.id), mappedFrames);
}
} catch (err) {
console.error('Failed to load frames:', err);
@@ -61,7 +121,7 @@ export function VideoWorkspace({ onNavigateToAI }: { onNavigateToAI?: () => void
loadFrames();
return () => { cancelled = true; };
}, [currentProject?.id, setFrames, setCurrentFrame]);
}, [currentProject?.id, currentProject?.video_path, hydrateSavedAnnotations, setFrames, setCurrentFrame]);
const templates = useStore((state) => state.templates);
const setTemplates = useStore((state) => state.setTemplates);
@@ -72,7 +132,121 @@ export function VideoWorkspace({ onNavigateToAI }: { onNavigateToAI?: () => void
}
}, [templates.length, setTemplates]);
const currentFrameUrl = frames[currentFrameIndex]?.url || '';
const currentFrame = frames[currentFrameIndex] || null;
const frameById = useMemo(() => new Map(frames.map((frame) => [frame.id, frame])), [frames]);
const projectFrameIds = useMemo(() => new Set(frames.map((frame) => frame.id)), [frames]);
const savePendingAnnotations = useCallback(async ({ silent = false } = {}) => {
if (!currentProject?.id) return 0;
const projectMasks = masks.filter((mask) => projectFrameIds.has(mask.frameId));
const pendingMasks = projectMasks.filter((mask) => !mask.annotationId);
const dirtyMasks = projectMasks.filter((mask) => mask.annotationId && mask.saveStatus === 'dirty');
if (pendingMasks.length === 0 && dirtyMasks.length === 0) {
if (!silent) setStatusMessage('没有待保存标注');
return 0;
}
setIsSaving(true);
setStatusMessage('正在保存标注...');
try {
const createPayloads = pendingMasks
.map((mask) => {
const frame = frameById.get(mask.frameId);
return frame ? buildAnnotationPayload(currentProject.id, mask, frame, activeTemplateId) : null;
})
.filter((payload): payload is NonNullable<typeof payload> => Boolean(payload));
const updatePayloads = dirtyMasks
.map((mask) => {
const frame = frameById.get(mask.frameId);
const payload = frame ? buildAnnotationPayload(currentProject.id, mask, frame, activeTemplateId) : null;
if (!payload || !mask.annotationId) return null;
const updatePayload = {
template_id: payload.template_id,
mask_data: payload.mask_data,
points: payload.points,
bbox: payload.bbox,
};
return { annotationId: mask.annotationId, payload: updatePayload };
})
.filter((item): item is NonNullable<typeof item> => Boolean(item));
if (createPayloads.length === 0 && updatePayloads.length === 0) {
setStatusMessage('没有可保存的标注数据');
return 0;
}
await Promise.all([
...createPayloads.map((payload) => saveAnnotation(payload)),
...updatePayloads.map(({ annotationId, payload }) => updateAnnotation(annotationId, payload)),
]);
await hydrateSavedAnnotations(currentProject.id, frames);
const savedCount = createPayloads.length + updatePayloads.length;
setStatusMessage(`已保存 ${savedCount} 个标注`);
return savedCount;
} catch (err) {
console.error('Save annotations failed:', err);
setStatusMessage('保存失败,请检查后端服务');
throw err;
} finally {
setIsSaving(false);
}
}, [activeTemplateId, currentProject?.id, frameById, frames, hydrateSavedAnnotations, masks, projectFrameIds]);
const handleClearCurrentFrameMasks = useCallback(async () => {
if (!currentFrame) return;
const frameMasks = masks.filter((mask) => mask.frameId === currentFrame.id);
const annotationIds = frameMasks
.map((mask) => mask.annotationId)
.filter((annotationId): annotationId is string => Boolean(annotationId));
setIsSaving(true);
setStatusMessage(annotationIds.length > 0 ? '正在删除已保存标注...' : '正在清空本帧遮罩...');
try {
await Promise.all(annotationIds.map((annotationId) => deleteAnnotation(annotationId)));
setMasks(masks.filter((mask) => mask.frameId !== currentFrame.id));
setStatusMessage(annotationIds.length > 0
? `已删除 ${annotationIds.length} 个后端标注`
: '已清空本帧未保存遮罩');
} catch (err) {
console.error('Delete annotations failed:', err);
setStatusMessage('删除失败,请检查后端服务');
} finally {
setIsSaving(false);
}
}, [currentFrame, masks, setMasks]);
const handleSave = async () => {
try {
await savePendingAnnotations();
} catch {
// status message is set in savePendingAnnotations
}
};
const handleExport = async () => {
if (!currentProject?.id) return;
setIsExporting(true);
setStatusMessage('正在准备导出...');
try {
await savePendingAnnotations({ silent: true });
const blob = await exportCoco(currentProject.id);
const url = URL.createObjectURL(blob);
const link = document.createElement('a');
link.href = url;
link.download = `project_${currentProject.id}_coco.json`;
document.body.appendChild(link);
link.click();
link.remove();
URL.revokeObjectURL(url);
setStatusMessage('COCO JSON 已导出');
} catch (err) {
console.error('Export failed:', err);
setStatusMessage('导出失败,请检查后端服务');
} finally {
setIsExporting(false);
}
};
return (
<div className="w-full h-full flex flex-col bg-[#0a0a0a]">
@@ -84,14 +258,25 @@ export function VideoWorkspace({ onNavigateToAI }: { onNavigateToAI?: () => void
<span className="text-sm text-white font-mono">{currentProject?.name || '未选择项目'}</span>
</div>
<div className="flex items-center gap-3">
<div className="flex items-center gap-1.5 text-[10px] uppercase font-medium">
<span className="px-2 py-0.5 rounded bg-green-500/10 text-green-400 border border-green-500/20">SAM 3 </span>
</div>
<button className="px-4 py-1.5 bg-white/5 hover:bg-white/10 border border-white/10 rounded-md text-xs transition-colors text-white">
JSON
{statusMessage && (
<span className="text-[10px] text-gray-500 font-mono max-w-48 truncate" title={statusMessage}>
{statusMessage}
</span>
)}
<ModelStatusBadge />
<button
onClick={handleExport}
disabled={!currentProject?.id || isExporting || isSaving}
className="px-4 py-1.5 bg-white/5 hover:bg-white/10 border border-white/10 rounded-md text-xs transition-colors text-white disabled:opacity-40 disabled:cursor-not-allowed"
>
{isExporting ? '导出中...' : '导出 JSON 标注集'}
</button>
<button className="px-4 py-1.5 bg-cyan-600 hover:bg-cyan-500 text-white text-xs font-medium rounded-md transition-shadow shadow-lg shadow-cyan-900/20">
<button
onClick={handleSave}
disabled={!currentProject?.id || isSaving || isExporting}
className="px-4 py-1.5 bg-cyan-600 hover:bg-cyan-500 text-white text-xs font-medium rounded-md transition-shadow shadow-lg shadow-cyan-900/20 disabled:opacity-40 disabled:cursor-not-allowed"
>
{isSaving ? '保存中...' : '结构化归档保存'}
</button>
</div>
</div>
@@ -102,7 +287,7 @@ export function VideoWorkspace({ onNavigateToAI }: { onNavigateToAI?: () => void
<div className="flex-1 relative flex items-center justify-center p-8 bg-[#151515] overflow-hidden">
<div className="relative w-full h-full bg-[#1e1e1e] border border-white/5 shadow-2xl rounded-sm">
<CanvasArea activeTool={activeTool} frameUrl={currentFrameUrl} />
<CanvasArea activeTool={activeTool} frame={currentFrame} onClearMasks={handleClearCurrentFrameMasks} />
</div>
</div>

361
src/lib/api.test.ts Normal file
View File

@@ -0,0 +1,361 @@
import { beforeEach, describe, expect, it, vi } from 'vitest';
const axiosMock = vi.hoisted(() => {
const client = {
get: vi.fn(),
post: vi.fn(),
patch: vi.fn(),
delete: vi.fn(),
interceptors: {
request: { use: vi.fn() },
response: { use: vi.fn() },
},
};
return { client, create: vi.fn(() => client) };
});
vi.mock('axios', () => ({
default: {
create: axiosMock.create,
},
}));
describe('api client contracts', () => {
beforeEach(() => {
vi.clearAllMocks();
vi.setSystemTime(new Date('2026-05-01T00:00:00Z'));
});
it('maps backend project fields into frontend project fields', async () => {
const { getProjects } = await import('./api');
axiosMock.client.get.mockResolvedValueOnce({
data: [
{
id: 7,
name: 'Demo',
description: 'desc',
status: 'ready',
frame_count: 12,
original_fps: 29.97,
parse_fps: 10,
thumbnail_url: 'thumb',
video_path: 'uploads/demo.mp4',
source_type: 'video',
created_at: 'created',
updated_at: 'updated',
},
],
});
await expect(getProjects()).resolves.toEqual([
expect.objectContaining({
id: '7',
name: 'Demo',
status: 'ready',
frames: 12,
fps: '30FPS',
thumbnail_url: 'thumb',
video_path: 'uploads/demo.mp4',
source_type: 'video',
createdAt: 'created',
updatedAt: 'updated',
}),
]);
expect(axiosMock.client.get).toHaveBeenCalledWith('/api/projects');
});
it('updates projects with PATCH instead of the old PUT contract', async () => {
const { updateProject } = await import('./api');
axiosMock.client.patch.mockResolvedValueOnce({ data: { id: 3, name: 'Renamed', status: 'ready' } });
await updateProject('3', { name: 'Renamed' } as any);
expect(axiosMock.client.patch).toHaveBeenCalledWith('/api/projects/3', { name: 'Renamed' });
});
it('normalizes legacy project status values returned by existing databases', async () => {
const { getProjects } = await import('./api');
axiosMock.client.get.mockResolvedValueOnce({
data: [
{ id: 1, name: 'Old Ready', status: 'Ready' },
{ id: 2, name: 'Old Parsing', status: 'Parsing' },
{ id: 3, name: 'Old Error', status: 'Error' },
],
});
await expect(getProjects()).resolves.toEqual([
expect.objectContaining({ status: 'ready' }),
expect.objectContaining({ status: 'parsing' }),
expect.objectContaining({ status: 'error' }),
]);
});
it('exports COCO from the backend route shape', async () => {
const { exportCoco } = await import('./api');
const blob = new Blob(['{}'], { type: 'application/json' });
axiosMock.client.get.mockResolvedValueOnce({ data: blob });
await expect(exportCoco('9')).resolves.toBe(blob);
expect(axiosMock.client.get).toHaveBeenCalledWith('/api/export/9/coco', {
responseType: 'blob',
});
});
it('loads dashboard overview from the backend summary endpoint', async () => {
const { getDashboardOverview } = await import('./api');
const overview = {
summary: {
project_count: 2,
parsing_task_count: 1,
annotation_count: 5,
frame_count: 100,
template_count: 3,
system_load_percent: 12,
},
tasks: [
{ id: 'project-1', project_id: 1, name: 'Demo', progress: 60, status: 'pending', frame_count: 10, updated_at: 'now' },
],
activity: [
{ id: 'project-1', kind: 'project', time: 'now', message: '项目状态: pending', project: 'Demo' },
],
};
axiosMock.client.get.mockResolvedValueOnce({ data: overview });
await expect(getDashboardOverview()).resolves.toEqual(overview);
expect(axiosMock.client.get).toHaveBeenCalledWith('/api/dashboard/overview');
});
it('queues media parsing and reads processing task status', async () => {
const { getTask, parseMedia } = await import('./api');
const task = {
id: 12,
task_type: 'parse_video',
status: 'queued',
progress: 0,
message: '解析任务已入队',
project_id: 9,
celery_task_id: 'celery-12',
payload: { source_type: 'video' },
result: null,
error: null,
created_at: 'created',
started_at: null,
finished_at: null,
updated_at: 'updated',
};
axiosMock.client.post.mockResolvedValueOnce({ data: task });
axiosMock.client.get.mockResolvedValueOnce({ data: { ...task, status: 'success', progress: 100 } });
await expect(parseMedia('9')).resolves.toEqual(task);
expect(axiosMock.client.post).toHaveBeenCalledWith('/api/media/parse', null, {
params: { project_id: '9' },
});
await expect(getTask(12)).resolves.toEqual(expect.objectContaining({ status: 'success', progress: 100 }));
expect(axiosMock.client.get).toHaveBeenCalledWith('/api/tasks/12');
});
it('lists, saves, updates, and deletes annotations with the backend annotation contract', async () => {
const { deleteAnnotation, getProjectAnnotations, saveAnnotation, updateAnnotation } = await import('./api');
const saved = {
id: 1,
project_id: 9,
frame_id: 5,
template_id: 2,
mask_data: { polygons: [[[0, 0], [1, 0], [1, 1]]] },
points: null,
bbox: null,
created_at: 'created',
updated_at: 'updated',
};
axiosMock.client.get.mockResolvedValueOnce({ data: [saved] });
axiosMock.client.post.mockResolvedValueOnce({ data: saved });
await expect(getProjectAnnotations('9', '5')).resolves.toEqual([saved]);
expect(axiosMock.client.get).toHaveBeenCalledWith('/api/ai/annotations', {
params: { project_id: 9, frame_id: 5 },
});
await expect(saveAnnotation({
project_id: 9,
frame_id: 5,
template_id: 2,
mask_data: { polygons: [[[0, 0], [1, 0], [1, 1]]], label: 'mask' },
})).resolves.toEqual(saved);
expect(axiosMock.client.post).toHaveBeenCalledWith('/api/ai/annotate', {
project_id: 9,
frame_id: 5,
template_id: 2,
mask_data: { polygons: [[[0, 0], [1, 0], [1, 1]]], label: 'mask' },
});
axiosMock.client.patch.mockResolvedValueOnce({ data: { ...saved, mask_data: { ...saved.mask_data, label: 'updated' } } });
await expect(updateAnnotation('1', {
template_id: 2,
mask_data: { polygons: [[[0, 0], [1, 0], [1, 1]]], label: 'updated' },
})).resolves.toEqual(expect.objectContaining({ mask_data: expect.objectContaining({ label: 'updated' }) }));
expect(axiosMock.client.patch).toHaveBeenCalledWith('/api/ai/annotations/1', {
template_id: 2,
mask_data: { polygons: [[[0, 0], [1, 0], [1, 1]]], label: 'updated' },
});
axiosMock.client.delete.mockResolvedValueOnce({ data: null });
await expect(deleteAnnotation('1')).resolves.toBeUndefined();
expect(axiosMock.client.delete).toHaveBeenCalledWith('/api/ai/annotations/1');
});
it('builds annotation payloads from frontend masks and restores saved annotations to masks', async () => {
const { annotationToMask, buildAnnotationPayload } = await import('./api');
const frame = { id: '5', projectId: '9', index: 0, url: '/frame.jpg', width: 100, height: 50 };
const payload = buildAnnotationPayload('9', {
id: 'm1',
frameId: '5',
pathData: 'M 10 10 L 90 10 L 90 40 Z',
label: '胆囊',
color: '#ff0000',
classId: 'c1',
className: '胆囊',
classZIndex: 20,
segmentation: [[10, 10, 90, 10, 90, 40]],
bbox: [10, 10, 80, 30],
}, frame, '2');
expect(payload).toEqual({
project_id: 9,
frame_id: 5,
template_id: 2,
mask_data: {
polygons: [[[0.1, 0.2], [0.9, 0.2], [0.9, 0.8]]],
label: '胆囊',
color: '#ff0000',
class: { id: 'c1', name: '胆囊', color: '#ff0000', zIndex: 20 },
},
bbox: [0.1, 0.2, 0.8, 0.6],
});
expect(annotationToMask({
id: 3,
project_id: 9,
frame_id: 5,
template_id: 2,
mask_data: {
polygons: [[[0.1, 0.2], [0.9, 0.2], [0.9, 0.8]]],
label: '旧标签',
color: '#06b6d4',
class: { id: 'c1', name: '胆囊', color: '#ff0000', zIndex: 20 },
},
points: null,
bbox: null,
created_at: 'created',
updated_at: 'updated',
}, frame)).toEqual(expect.objectContaining({
id: 'annotation-3',
annotationId: '3',
frameId: '5',
templateId: '2',
classId: 'c1',
className: '胆囊',
classZIndex: 20,
label: '胆囊',
color: '#ff0000',
saveStatus: 'saved',
saved: true,
pathData: 'M 10 10 L 90 10 L 90 40 Z',
bbox: [10, 10, 80, 30],
}));
});
it('normalizes positive and negative point prompts for AI prediction', async () => {
const { predictMask } = await import('./api');
axiosMock.client.post.mockResolvedValueOnce({
data: {
polygons: [[[0.25, 0.25], [0.75, 0.25], [0.75, 0.75], [0.25, 0.75]]],
scores: [0.9],
},
});
const result = await predictMask({
imageId: '42',
imageWidth: 400,
imageHeight: 200,
points: [
{ x: 200, y: 100, type: 'pos' },
{ x: 40, y: 20, type: 'neg' },
],
});
expect(axiosMock.client.post).toHaveBeenCalledWith('/api/ai/predict', {
image_id: 42,
prompt_type: 'point',
prompt_data: {
points: [[0.5, 0.5], [0.1, 0.1]],
labels: [1, 0],
},
model: 'sam2',
});
expect(result.masks[0]).toEqual(expect.objectContaining({
pathData: 'M 100 50 L 300 50 L 300 150 L 100 150 Z',
segmentation: [[100, 50, 300, 50, 300, 150, 100, 150]],
bbox: [100, 50, 200, 100],
area: 20000,
confidence: 0.9,
}));
});
it('normalizes box prompts for AI prediction', async () => {
const { predictMask } = await import('./api');
axiosMock.client.post.mockResolvedValueOnce({ data: { polygons: [], scores: [] } });
await predictMask({
imageId: '5',
imageWidth: 640,
imageHeight: 320,
box: { x1: 64, y1: 32, x2: 320, y2: 160 },
});
expect(axiosMock.client.post).toHaveBeenCalledWith('/api/ai/predict', {
image_id: 5,
prompt_type: 'box',
prompt_data: [0.1, 0.1, 0.5, 0.5],
model: 'sam2',
});
});
it('uses semantic prompt type for text-only AI prediction', async () => {
const { predictMask } = await import('./api');
axiosMock.client.post.mockResolvedValueOnce({ data: { polygons: [], scores: [] } });
await predictMask({
imageId: '6',
imageWidth: 640,
imageHeight: 360,
model: 'sam3',
text: '分割胆囊',
});
expect(axiosMock.client.post).toHaveBeenCalledWith('/api/ai/predict', {
image_id: 6,
prompt_type: 'semantic',
prompt_data: '分割胆囊',
model: 'sam3',
});
});
it('loads AI model and GPU runtime status', async () => {
const { getAiModelStatus } = await import('./api');
const status = {
selected_model: 'sam2',
gpu: { available: false, device: 'cpu', name: null, torch_available: true, torch_version: '2.x', cuda_version: null },
models: [
{ id: 'sam2', label: 'SAM 2', available: true, loaded: false, device: 'cpu', supports: ['point'], message: 'ready', package_available: true, checkpoint_exists: true, checkpoint_path: 'model.pt', python_ok: true, torch_ok: true, cuda_required: false },
{ id: 'sam3', label: 'SAM 3', available: false, loaded: false, device: 'unavailable', supports: ['semantic'], message: 'missing runtime', package_available: false, checkpoint_exists: false, checkpoint_path: null, python_ok: false, torch_ok: true, cuda_required: true },
],
};
axiosMock.client.get.mockResolvedValueOnce({ data: status });
await expect(getAiModelStatus('sam3')).resolves.toEqual(status);
expect(axiosMock.client.get).toHaveBeenCalledWith('/api/ai/models/status', {
params: { selected_model: 'sam3' },
});
});
});

View File

@@ -1,8 +1,9 @@
import axios, { AxiosError } from 'axios';
import type { Project, Template } from '../store/useStore';
import type { AiModelId, Frame, Mask, Project, Template } from '../store/useStore';
import { API_BASE_URL } from './config';
const apiClient = axios.create({
baseURL: 'http://192.168.3.11:8000',
baseURL: API_BASE_URL,
headers: {
'Content-Type': 'application/json',
},
@@ -40,37 +41,20 @@ export async function login(username: string, password: string): Promise<{ token
}
// Projects
export async function getProjects(): Promise<Project[]> {
const response = await apiClient.get('/api/projects');
return response.data.map((p: any) => ({
id: String(p.id),
name: p.name,
description: p.description,
status: p.status,
frames: p.frame_count ?? 0,
fps: p.original_fps ? `${Math.round(p.original_fps)}FPS` : '30FPS',
thumbnail_url: p.thumbnail_url,
video_path: p.video_path,
source_type: p.source_type,
original_fps: p.original_fps,
parse_fps: p.parse_fps,
createdAt: p.created_at,
updatedAt: p.updated_at,
}));
function normalizeProjectStatus(status?: string): Project['status'] {
const value = (status || 'pending').toLowerCase();
if (value === 'ready') return 'ready';
if (value === 'parsing' || value === 'queued' || value === 'running') return 'parsing';
if (value === 'error' || value === 'failed') return 'error';
return 'pending';
}
export async function createProject(payload: {
name: string;
description?: string;
parse_fps?: number;
}): Promise<Project> {
const response = await apiClient.post('/api/projects', payload);
const p = response.data;
function mapProject(p: any): Project {
return {
id: String(p.id),
name: p.name,
description: p.description,
status: p.status,
status: normalizeProjectStatus(p.status),
frames: p.frame_count ?? 0,
fps: p.original_fps ? `${Math.round(p.original_fps)}FPS` : '30FPS',
thumbnail_url: p.thumbnail_url,
@@ -83,9 +67,23 @@ export async function createProject(payload: {
};
}
export async function getProjects(): Promise<Project[]> {
const response = await apiClient.get('/api/projects');
return response.data.map(mapProject);
}
export async function createProject(payload: {
name: string;
description?: string;
parse_fps?: number;
}): Promise<Project> {
const response = await apiClient.post('/api/projects', payload);
return mapProject(response.data);
}
export async function updateProject(id: string, payload: Partial<Project>): Promise<Project> {
const response = await apiClient.put(`/api/projects/${id}`, payload);
return response.data;
const response = await apiClient.patch(`/api/projects/${id}`, payload);
return mapProject(response.data);
}
export async function deleteProject(id: string): Promise<void> {
@@ -170,26 +168,46 @@ export async function uploadDicomBatch(files: File[], projectId?: string): Promi
return response.data;
}
export async function parseMedia(projectId: string): Promise<{
project_id: number;
frames_extracted: number;
status: string;
message: string;
}> {
export interface ProcessingTask {
id: number;
task_type: string;
status: 'queued' | 'running' | 'success' | 'failed' | string;
progress: number;
message?: string | null;
project_id?: number | null;
celery_task_id?: string | null;
payload?: Record<string, unknown> | null;
result?: Record<string, unknown> | null;
error?: string | null;
created_at: string;
started_at?: string | null;
finished_at?: string | null;
updated_at: string;
}
export async function parseMedia(projectId: string): Promise<ProcessingTask> {
const response = await apiClient.post('/api/media/parse', null, {
params: { project_id: projectId },
});
return response.data;
}
// AI Prediction
export async function predictMask(payload: {
imageUrl: string;
export async function getTask(taskId: string | number): Promise<ProcessingTask> {
const response = await apiClient.get(`/api/tasks/${taskId}`);
return response.data;
}
interface PredictMaskPayload {
imageId: string;
imageWidth: number;
imageHeight: number;
model?: AiModelId;
points?: { x: number; y: number; type: 'pos' | 'neg' }[];
box?: { x1: number; y1: number; x2: number; y2: number };
text?: string;
modelSize?: string;
}): Promise<{
}
interface PredictMaskResult {
masks: Array<{
id: string;
pathData: string;
@@ -200,14 +218,319 @@ export async function predictMask(payload: {
area: number;
confidence: number;
}>;
}> {
const response = await apiClient.post('/api/ai/predict', payload);
}
export interface AiModelStatus {
id: AiModelId;
label: string;
available: boolean;
loaded: boolean;
device: string;
supports: string[];
message: string;
package_available: boolean;
checkpoint_exists: boolean;
checkpoint_path?: string | null;
python_ok: boolean;
torch_ok: boolean;
cuda_required: boolean;
}
export interface AiRuntimeStatus {
selected_model: AiModelId;
gpu: {
available: boolean;
device: string;
name?: string | null;
torch_available: boolean;
torch_version?: string | null;
cuda_version?: string | null;
};
models: AiModelStatus[];
}
export interface SavedAnnotation {
id: number;
project_id: number;
frame_id: number | null;
template_id: number | null;
mask_data: {
polygons?: number[][][];
label?: string;
color?: string;
class?: {
id?: string;
name?: string;
color?: string;
zIndex?: number;
category?: string;
};
} | null;
points: number[][] | null;
bbox: number[] | null;
created_at: string;
updated_at: string;
}
export interface SaveAnnotationPayload {
project_id: number;
frame_id?: number;
template_id?: number;
mask_data?: {
polygons: number[][][];
label?: string;
color?: string;
class?: {
id?: string;
name?: string;
color?: string;
zIndex?: number;
category?: string;
};
};
points?: number[][];
bbox?: number[];
}
export type UpdateAnnotationPayload = Omit<SaveAnnotationPayload, 'project_id' | 'frame_id'>;
export interface DashboardTask {
id: string;
task_id?: number;
project_id: number;
name: string;
progress: number;
status: string;
frame_count: number;
updated_at: string | null;
}
export interface DashboardActivity {
id: string;
kind: 'project' | 'annotation' | 'template' | string;
time: string | null;
message: string;
project: string;
}
export interface DashboardOverview {
summary: {
project_count: number;
parsing_task_count: number;
annotation_count: number;
frame_count: number;
template_count: number;
system_load_percent: number;
};
tasks: DashboardTask[];
activity: DashboardActivity[];
}
function clamp01(value: number): number {
return Math.min(Math.max(value, 0), 1);
}
function normalizePoint(point: { x: number; y: number }, width: number, height: number): [number, number] {
return [
clamp01(point.x / Math.max(width, 1)),
clamp01(point.y / Math.max(height, 1)),
];
}
function polygonToPath(points: number[][], width: number, height: number): string {
if (points.length === 0) return '';
return points
.map(([x, y], index) => {
const px = x * width;
const py = y * height;
return `${index === 0 ? 'M' : 'L'} ${px} ${py}`;
})
.join(' ')
.concat(' Z');
}
function polygonToBbox(points: number[][], width: number, height: number): [number, number, number, number] {
const xs = points.map(([x]) => x * width);
const ys = points.map(([, y]) => y * height);
const minX = Math.min(...xs);
const minY = Math.min(...ys);
const maxX = Math.max(...xs);
const maxY = Math.max(...ys);
return [minX, minY, maxX - minX, maxY - minY];
}
function pixelSegmentationToNormalizedPolygons(
segmentation: number[][] | undefined,
width: number,
height: number,
): number[][][] {
if (!segmentation) return [];
return segmentation
.map((poly) => {
const points: number[][] = [];
for (let i = 0; i < poly.length - 1; i += 2) {
points.push([
clamp01(poly[i] / Math.max(width, 1)),
clamp01(poly[i + 1] / Math.max(height, 1)),
]);
}
return points;
})
.filter((points) => points.length > 0);
}
export function buildAnnotationPayload(
projectId: string,
mask: Mask,
frame: Frame,
templateId?: string | null,
): SaveAnnotationPayload | null {
const polygons = pixelSegmentationToNormalizedPolygons(mask.segmentation, frame.width, frame.height);
if (polygons.length === 0) return null;
const effectiveTemplateId = mask.templateId || templateId || undefined;
const classMetadata = mask.classId || mask.className || mask.classZIndex !== undefined
? {
id: mask.classId,
name: mask.className || mask.label,
color: mask.color,
zIndex: mask.classZIndex,
}
: undefined;
return {
project_id: Number(projectId),
frame_id: Number(frame.id),
template_id: effectiveTemplateId ? Number(effectiveTemplateId) : undefined,
mask_data: {
polygons,
label: mask.label,
color: mask.color,
...(classMetadata ? { class: classMetadata } : {}),
},
bbox: mask.bbox
? [
clamp01(mask.bbox[0] / Math.max(frame.width, 1)),
clamp01(mask.bbox[1] / Math.max(frame.height, 1)),
clamp01(mask.bbox[2] / Math.max(frame.width, 1)),
clamp01(mask.bbox[3] / Math.max(frame.height, 1)),
]
: undefined,
};
}
export function annotationToMask(annotation: SavedAnnotation, frame: Frame): Mask | null {
const polygons = annotation.mask_data?.polygons || [];
const firstPolygon = polygons[0];
if (!firstPolygon || firstPolygon.length === 0) return null;
const bbox = polygonToBbox(firstPolygon, frame.width, frame.height);
const classMetadata = annotation.mask_data?.class;
return {
id: `annotation-${annotation.id}`,
annotationId: String(annotation.id),
frameId: String(annotation.frame_id),
templateId: annotation.template_id ? String(annotation.template_id) : undefined,
classId: classMetadata?.id,
className: classMetadata?.name,
classZIndex: classMetadata?.zIndex,
saveStatus: 'saved',
saved: true,
pathData: polygonToPath(firstPolygon, frame.width, frame.height),
label: classMetadata?.name || annotation.mask_data?.label || `Annotation ${annotation.id}`,
color: classMetadata?.color || annotation.mask_data?.color || '#06b6d4',
segmentation: polygons.map((polygon) => polygon.flatMap(([x, y]) => [x * frame.width, y * frame.height])),
bbox,
area: bbox[2] * bbox[3],
};
}
export async function predictMask(payload: PredictMaskPayload): Promise<PredictMaskResult> {
let prompt_type: 'point' | 'box' | 'semantic';
let prompt_data: unknown;
if (payload.box) {
prompt_type = 'box';
prompt_data = [
clamp01(payload.box.x1 / Math.max(payload.imageWidth, 1)),
clamp01(payload.box.y1 / Math.max(payload.imageHeight, 1)),
clamp01(payload.box.x2 / Math.max(payload.imageWidth, 1)),
clamp01(payload.box.y2 / Math.max(payload.imageHeight, 1)),
];
} else if (payload.points && payload.points.length > 0) {
prompt_type = 'point';
prompt_data = {
points: payload.points.map((point) => normalizePoint(point, payload.imageWidth, payload.imageHeight)),
labels: payload.points.map((point) => (point.type === 'neg' ? 0 : 1)),
};
} else {
prompt_type = 'semantic';
prompt_data = payload.text?.trim() || '';
}
const response = await apiClient.post('/api/ai/predict', {
image_id: Number(payload.imageId),
prompt_type,
prompt_data,
model: payload.model || 'sam2',
});
const polygons: number[][][] = response.data.polygons || [];
const scores: number[] = response.data.scores || [];
return {
masks: polygons.map((polygon, index) => {
const bbox = polygonToBbox(polygon, payload.imageWidth, payload.imageHeight);
return {
id: `mask-${payload.imageId}-${Date.now()}-${index}`,
pathData: polygonToPath(polygon, payload.imageWidth, payload.imageHeight),
label: prompt_type === 'semantic' ? (payload.text?.trim() || 'AI Mask') : 'AI Mask',
color: '#06b6d4',
segmentation: [polygon.flatMap(([x, y]) => [x * payload.imageWidth, y * payload.imageHeight])],
bbox,
area: bbox[2] * bbox[3],
confidence: scores[index] ?? 0,
};
}),
};
}
export async function getAiModelStatus(selectedModel?: AiModelId): Promise<AiRuntimeStatus> {
const response = await apiClient.get('/api/ai/models/status', {
params: selectedModel ? { selected_model: selectedModel } : undefined,
});
return response.data;
}
export async function getProjectAnnotations(projectId: string, frameId?: string): Promise<SavedAnnotation[]> {
const response = await apiClient.get('/api/ai/annotations', {
params: {
project_id: Number(projectId),
...(frameId ? { frame_id: Number(frameId) } : {}),
},
});
return response.data;
}
export async function saveAnnotation(payload: SaveAnnotationPayload): Promise<SavedAnnotation> {
const response = await apiClient.post('/api/ai/annotate', payload);
return response.data;
}
export async function updateAnnotation(annotationId: string, payload: UpdateAnnotationPayload): Promise<SavedAnnotation> {
const response = await apiClient.patch(`/api/ai/annotations/${annotationId}`, payload);
return response.data;
}
export async function deleteAnnotation(annotationId: string): Promise<void> {
await apiClient.delete(`/api/ai/annotations/${annotationId}`);
}
export async function getDashboardOverview(): Promise<DashboardOverview> {
const response = await apiClient.get('/api/dashboard/overview');
return response.data;
}
// Export
export async function exportCoco(projectId: string): Promise<Blob> {
const response = await apiClient.get(`/api/export/coco/${projectId}`, {
const response = await apiClient.get(`/api/export/${projectId}/coco`, {
responseType: 'blob',
});
return response.data;

38
src/lib/config.test.ts Normal file
View File

@@ -0,0 +1,38 @@
import { afterEach, describe, expect, it, vi } from 'vitest';
describe('frontend runtime config', () => {
afterEach(() => {
vi.unstubAllEnvs();
vi.resetModules();
});
it('prefers explicit VITE_API_BASE_URL and trims trailing slashes', async () => {
vi.stubEnv('VITE_API_BASE_URL', 'http://api.example.test:8000///');
const config = await import('./config');
expect(config.API_BASE_URL).toBe('http://api.example.test:8000');
});
it('infers the API host from the current browser hostname', async () => {
const config = await import('./config');
expect(config.API_BASE_URL).toBe('http://seg.local:8000');
});
it('derives websocket URL from API URL unless explicitly configured', async () => {
vi.stubEnv('VITE_API_BASE_URL', 'https://seg.example.test');
const config = await import('./config');
expect(config.WS_PROGRESS_URL).toBe('wss://seg.example.test/ws/progress');
});
it('prefers explicit VITE_WS_PROGRESS_URL', async () => {
vi.stubEnv('VITE_WS_PROGRESS_URL', 'ws://custom/ws/progress');
const config = await import('./config');
expect(config.WS_PROGRESS_URL).toBe('ws://custom/ws/progress');
});
});

29
src/lib/config.ts Normal file
View File

@@ -0,0 +1,29 @@
const DEFAULT_API_BASE_URL = 'http://192.168.3.11:8000';
function trimTrailingSlash(value: string): string {
return value.replace(/\/+$/, '');
}
function inferApiBaseUrl(): string {
const envUrl = import.meta.env.VITE_API_BASE_URL;
if (envUrl) return trimTrailingSlash(envUrl);
if (typeof window !== 'undefined' && window.location.hostname) {
return `${window.location.protocol}//${window.location.hostname}:8000`;
}
return DEFAULT_API_BASE_URL;
}
export const API_BASE_URL = inferApiBaseUrl();
function inferWsProgressUrl(): string {
const envUrl = import.meta.env.VITE_WS_PROGRESS_URL;
if (envUrl) return envUrl;
const url = new URL('/ws/progress', API_BASE_URL);
url.protocol = url.protocol === 'https:' ? 'wss:' : 'ws:';
return url.toString();
}
export const WS_PROGRESS_URL = inferWsProgressUrl();

View File

@@ -0,0 +1,15 @@
import type { Template, TemplateClass } from '../store/useStore';
export function getActiveTemplate(templates: Template[], activeTemplateId: string | null): Template | null {
return templates.find((template) => template.id === activeTemplateId) || templates[0] || null;
}
export function getActiveClass(
templates: Template[],
activeTemplateId: string | null,
activeClassId: string | null,
): TemplateClass | null {
const template = getActiveTemplate(templates, activeTemplateId);
if (!template) return null;
return template.classes.find((templateClass) => templateClass.id === activeClassId) || null;
}

46
src/lib/websocket.test.ts Normal file
View File

@@ -0,0 +1,46 @@
import { afterEach, describe, expect, it, vi } from 'vitest';
describe('progress websocket client', () => {
afterEach(() => {
vi.restoreAllMocks();
vi.resetModules();
vi.unstubAllGlobals();
});
it('connects using the configured URL and reports open state', async () => {
const instances: any[] = [];
class FakeWebSocket {
static CONNECTING = 0;
static OPEN = 1;
readyState = FakeWebSocket.OPEN;
onopen?: () => void;
onmessage?: (event: MessageEvent) => void;
onclose?: () => void;
onerror?: () => void;
constructor(public url: string) {
instances.push(this);
}
close = vi.fn();
}
vi.stubGlobal('WebSocket', FakeWebSocket);
const { progressWS } = await import('./websocket');
progressWS.connect();
expect(instances[0].url).toContain('/ws/progress');
expect(progressWS.isConnected()).toBe(true);
});
it('subscribes and unsubscribes progress callbacks', async () => {
const { progressWS } = await import('./websocket');
const callback = vi.fn();
const unsubscribe = progressWS.onProgress(callback);
(progressWS as any).callbacks.forEach((cb: any) => cb({ type: 'status', message: 'ok' }));
unsubscribe();
(progressWS as any).callbacks.forEach((cb: any) => cb({ type: 'status', message: 'again' }));
expect(callback).toHaveBeenCalledTimes(1);
expect(callback).toHaveBeenCalledWith({ type: 'status', message: 'ok' });
});
});

View File

@@ -1,12 +1,18 @@
import { WS_PROGRESS_URL } from './config';
type ProgressCallback = (data: ProgressMessage) => void;
interface ProgressMessage {
type: 'progress' | 'status' | 'error' | 'complete';
taskId?: string;
task_id?: number;
project_id?: number;
projectName?: string;
filename?: string;
progress?: number;
status?: string;
message?: string;
error?: string;
timestamp?: string;
}
@@ -21,7 +27,7 @@ class ProgressWebSocket {
private shouldCloseAfterOpen = false;
private currentInterval = 3000;
constructor(url = 'ws://192.168.3.11:8000/ws/progress') {
constructor(url = WS_PROGRESS_URL) {
this.url = url;
}

View File

@@ -0,0 +1,56 @@
import { beforeEach, describe, expect, it } from 'vitest';
import { resetStore } from '../test/storeTestUtils';
import { useStore } from './useStore';
describe('useStore', () => {
beforeEach(() => {
resetStore();
});
it('stores and clears auth state with localStorage', () => {
useStore.getState().login('token-1');
expect(useStore.getState().isAuthenticated).toBe(true);
expect(useStore.getState().token).toBe('token-1');
expect(localStorage.getItem('token')).toBe('token-1');
useStore.getState().logout();
expect(useStore.getState().isAuthenticated).toBe(false);
expect(useStore.getState().projects).toEqual([]);
expect(useStore.getState().frames).toEqual([]);
expect(localStorage.getItem('token')).toBeNull();
});
it('manages projects, frames, masks, annotations and templates', () => {
const project = { id: '1', name: 'Project', status: 'ready' as const };
useStore.getState().addProject(project);
useStore.getState().updateProject({ ...project, name: 'Updated' });
useStore.getState().setCurrentProject(project);
useStore.getState().setFrames([{ id: 'f1', projectId: '1', index: 0, url: '/f1.jpg', width: 640, height: 360 }]);
useStore.getState().setCurrentFrame(0);
useStore.getState().addMask({ id: 'm1', frameId: 'f1', pathData: 'M 0 0 Z', label: 'mask', color: '#fff' });
useStore.getState().updateMask('m1', { label: 'updated mask', saveStatus: 'dirty' });
useStore.getState().addAnnotation({ id: 'a1', frameId: 'f1', type: 'mask', points: [], label: 'ann', color: '#fff' });
useStore.getState().addTemplate({ id: 't1', name: 'Template', classes: [], rules: [] });
useStore.getState().updateTemplate({ id: 't1', name: 'Template 2', classes: [], rules: [] });
useStore.getState().setActiveClass({ id: 'c1', name: '胆囊', color: '#ff0000', zIndex: 10 });
expect(useStore.getState().projects[0].name).toBe('Updated');
expect(useStore.getState().currentProject?.id).toBe('1');
expect(useStore.getState().frames).toHaveLength(1);
expect(useStore.getState().currentFrameIndex).toBe(0);
expect(useStore.getState().masks[0]).toEqual(expect.objectContaining({ label: 'updated mask', saveStatus: 'dirty' }));
expect(useStore.getState().annotations).toHaveLength(1);
expect(useStore.getState().templates[0].name).toBe('Template 2');
expect(useStore.getState().activeClassId).toBe('c1');
useStore.getState().removeAnnotation('a1');
useStore.getState().clearMasks();
useStore.getState().removeTemplate('t1');
expect(useStore.getState().annotations).toEqual([]);
expect(useStore.getState().masks).toEqual([]);
expect(useStore.getState().templates).toEqual([]);
});
});

View File

@@ -4,7 +4,7 @@ export interface Project {
id: string;
name: string;
description?: string;
status: 'Ready' | 'Parsing' | 'Error';
status: 'pending' | 'parsing' | 'ready' | 'error';
fps?: string;
frames?: number;
thumbnail?: string;
@@ -17,6 +17,8 @@ export interface Project {
updatedAt?: string;
}
export type AiModelId = 'sam2' | 'sam3';
export interface Frame {
id: string;
projectId: string;
@@ -42,6 +44,13 @@ export interface Annotation {
export interface Mask {
id: string;
frameId: string;
annotationId?: string;
templateId?: string;
classId?: string;
className?: string;
classZIndex?: number;
saveStatus?: 'draft' | 'saved' | 'dirty' | 'saving' | 'error';
saved?: boolean;
pathData: string;
label: string;
color: string;
@@ -96,24 +105,32 @@ export interface AppState {
// Workspace
activeModule: string;
activeTool: string;
aiModel: AiModelId;
frames: Frame[];
currentFrameIndex: number;
annotations: Annotation[];
masks: Mask[];
setActiveModule: (module: string) => void;
setActiveTool: (tool: string) => void;
setAiModel: (model: AiModelId) => void;
setFrames: (frames: Frame[]) => void;
setCurrentFrame: (index: number) => void;
addAnnotation: (annotation: Annotation) => void;
addMask: (mask: Mask) => void;
updateMask: (id: string, updates: Partial<Mask>) => void;
setMasks: (masks: Mask[]) => void;
clearMasks: () => void;
removeAnnotation: (id: string) => void;
// Templates
templates: Template[];
activeTemplateId: string | null;
activeClassId: string | null;
activeClass: TemplateClass | null;
setTemplates: (templates: Template[]) => void;
setActiveTemplateId: (id: string | null) => void;
setActiveClassId: (id: string | null) => void;
setActiveClass: (templateClass: TemplateClass | null) => void;
addTemplate: (template: Template) => void;
updateTemplate: (template: Template) => void;
removeTemplate: (id: string) => void;
@@ -144,6 +161,9 @@ export const useStore = create<AppState>((set) => ({
frames: [],
annotations: [],
masks: [],
activeTemplateId: null,
activeClassId: null,
activeClass: null,
});
},
@@ -162,18 +182,25 @@ export const useStore = create<AppState>((set) => ({
// Workspace
activeModule: 'workspace',
activeTool: 'move',
aiModel: 'sam2',
frames: [],
currentFrameIndex: 0,
annotations: [],
masks: [],
setActiveModule: (activeModule: string) => set({ activeModule }),
setActiveTool: (activeTool: string) => set({ activeTool }),
setAiModel: (aiModel: AiModelId) => set({ aiModel }),
setFrames: (frames: Frame[]) => set({ frames }),
setCurrentFrame: (currentFrameIndex: number) => set({ currentFrameIndex }),
addAnnotation: (annotation: Annotation) =>
set((state) => ({ annotations: [...state.annotations, annotation] })),
addMask: (mask: Mask) =>
set((state) => ({ masks: [...state.masks, mask] })),
updateMask: (id: string, updates: Partial<Mask>) =>
set((state) => ({
masks: state.masks.map((mask) => (mask.id === id ? { ...mask, ...updates } : mask)),
})),
setMasks: (masks: Mask[]) => set({ masks }),
clearMasks: () => set({ masks: [] }),
removeAnnotation: (id: string) =>
set((state) => ({
@@ -183,8 +210,15 @@ export const useStore = create<AppState>((set) => ({
// Templates
templates: [],
activeTemplateId: null,
activeClassId: null,
activeClass: null,
setTemplates: (templates: Template[]) => set({ templates }),
setActiveTemplateId: (activeTemplateId: string | null) => set({ activeTemplateId }),
setActiveClassId: (activeClassId: string | null) => set({ activeClassId }),
setActiveClass: (activeClass: TemplateClass | null) => set({
activeClass,
activeClassId: activeClass?.id || null,
}),
addTemplate: (template: Template) =>
set((state) => ({ templates: [...state.templates, template] })),
updateTemplate: (template: Template) =>

66
src/test/setup.tsx Normal file
View File

@@ -0,0 +1,66 @@
import React from 'react';
import { afterEach, vi } from 'vitest';
import { cleanup } from '@testing-library/react';
import '@testing-library/jest-dom/vitest';
afterEach(() => {
cleanup();
localStorage.clear();
});
vi.stubGlobal('alert', vi.fn());
vi.stubGlobal('confirm', vi.fn(() => true));
URL.createObjectURL = vi.fn(() => 'blob:mock-url');
URL.revokeObjectURL = vi.fn();
HTMLAnchorElement.prototype.click = vi.fn();
function makeStageEvent(x = 120, y = 80) {
const stage = {
getPointerPosition: () => ({ x, y }),
getRelativePointerPosition: () => ({ x, y }),
scaleX: () => 1,
x: () => 0,
y: () => 0,
};
return {
evt: { preventDefault: vi.fn(), deltaY: -1 },
target: {
getStage: () => stage,
},
};
}
vi.mock('react-konva', () => ({
Stage: ({ children, onClick, onMouseDown, onMouseUp, onMouseMove, onWheel }: any) => (
<div
data-testid="konva-stage"
onClick={() => onClick?.(makeStageEvent())}
onMouseDown={() => onMouseDown?.(makeStageEvent())}
onMouseUp={() => onMouseUp?.(makeStageEvent(260, 200))}
onMouseMove={() => onMouseMove?.(makeStageEvent(180, 120))}
onWheel={() => onWheel?.(makeStageEvent())}
>
{children}
</div>
),
Layer: ({ children }: any) => <div data-testid="konva-layer">{children}</div>,
Group: ({ children }: any) => <div data-testid="konva-group">{children}</div>,
Image: ({ image }: any) => <img data-testid="konva-image" alt="" src={image?.src || ''} />,
Circle: (props: any) => <span data-testid="konva-circle" data-fill={props.fill} />,
Rect: (props: any) => <span data-testid="konva-rect" data-width={props.width} />,
Path: (props: any) => <span data-testid="konva-path" data-path={props.data} data-fill={props.fill} />,
}));
vi.mock('use-image', () => ({
default: (src: string) => [
{
src,
width: 640,
height: 360,
naturalWidth: 640,
naturalHeight: 360,
},
'loaded',
],
}));

View File

@@ -0,0 +1,23 @@
import { useStore } from '../store/useStore';
export function resetStore() {
useStore.setState({
isAuthenticated: false,
token: null,
projects: [],
currentProject: null,
activeModule: 'workspace',
activeTool: 'move',
aiModel: 'sam2',
frames: [],
currentFrameIndex: 0,
annotations: [],
masks: [],
templates: [],
activeTemplateId: null,
activeClassId: null,
activeClass: null,
isLoading: false,
error: null,
});
}

6
src/vite-env.d.ts vendored Normal file
View File

@@ -0,0 +1,6 @@
/// <reference types="vite/client" />
interface ImportMetaEnv {
readonly VITE_API_BASE_URL?: string;
readonly VITE_WS_PROGRESS_URL?: string;
}