feat: 完善分割工作区导入导出与管理流程

- 新增基于 JWT 当前用户的登录恢复、角色权限、用户管理、审计日志和演示出厂重置后台接口与前端管理页。

- 重串 GT_label 导出和 GT Mask 导入逻辑:导出保留类别真实 maskid,导入仅接受灰度或 RGB 等通道 maskid 图,支持未知 maskid 策略、尺寸最近邻拉伸和导入预览。

- 统一分割结果导出体验:默认当前帧,按项目抽帧顺序和 XhXXmXXsXXXms 时间戳命名 ZIP 与图片,补齐 GT/Pro/Mix/分开 Mask 输出和映射 JSON。

- 调整工作区左侧工具栏:移除创建点/线段入口,新增画笔、橡皮擦及尺寸控制,并按绘制、布尔、导入/AI 工具分组分隔。

- 扩展 Canvas 编辑能力:画笔按语义分类绘制并可自动并入连通选中 mask,橡皮擦对选中区域扣除,优化布尔操作、选区、撤销重做和保存状态联动。

- 优化自动传播时间轴显示:同一蓝色系按传播新旧递进变暗,老传播记录达到阈值后统一旧记录色,并维护范围选择与清空后的历史显示。

- 将 AI 智能分割入口替换为更明确的 AI 元素图标,并同步侧栏、工作区和 AI 页面入口表现。

- 完善模板分类、maskid 工具函数、分类树联动、遮罩透明度、边缘平滑和传播链同步相关前端状态。

- 扩展后端项目、媒体、任务、Dashboard、模板和传播 runner 的用户隔离、任务控制、进度事件与兼容处理。

- 补充前后端测试,覆盖用户管理、GT_label 往返导入导出、GT Mask 校验和预览、画笔/橡皮擦、时间轴传播历史、导出范围、WebSocket 与 API 封装。

- 更新 AGENTS、README 和 doc 文档,记录当前接口契约、实现状态、测试计划、安装说明和 maskid/GT_label 规则。
This commit is contained in:
2026-05-03 03:52:32 +08:00
parent 4c1d3dba73
commit afcddfaeb9
62 changed files with 6572 additions and 849 deletions

View File

@@ -1,6 +1,6 @@
import React, { useEffect } from 'react';
import { useStore } from './store/useStore';
import { getProjects } from './lib/api';
import { getCurrentUser, getProjects } from './lib/api';
import { Sidebar } from './components/Sidebar';
import { Dashboard } from './components/Dashboard';
import { ProjectLibrary } from './components/ProjectLibrary';
@@ -8,8 +8,9 @@ import { VideoWorkspace } from './components/VideoWorkspace';
import { TemplateRegistry } from './components/TemplateRegistry';
import { AISegmentation } from './components/AISegmentation';
import { Login } from './components/Login';
import { UserAdmin } from './components/UserAdmin';
export type ActiveModule = 'dashboard' | 'projects' | 'ai' | 'workspace' | 'templates';
export type ActiveModule = 'dashboard' | 'projects' | 'ai' | 'workspace' | 'templates' | 'admin';
export default function App() {
const isAuthenticated = useStore((state) => state.isAuthenticated);
@@ -17,17 +18,27 @@ export default function App() {
const setActiveModule = useStore((state) => state.setActiveModule);
const setProjects = useStore((state) => state.setProjects);
const setError = useStore((state) => state.setError);
const setCurrentUser = useStore((state) => state.setCurrentUser);
const logout = useStore((state) => state.logout);
const currentUser = useStore((state) => state.currentUser);
useEffect(() => {
if (isAuthenticated) {
getProjects()
.then((data) => setProjects(data))
Promise.all([getCurrentUser(), getProjects()])
.then(([user, projects]) => {
setCurrentUser(user);
setProjects(projects);
})
.catch((err) => {
console.error('Failed to fetch projects:', err);
if (err?.response?.status === 401) {
logout();
return;
}
setError('获取项目列表失败');
});
}
}, [isAuthenticated, setProjects, setError]);
}, [isAuthenticated, logout, setCurrentUser, setProjects, setError]);
if (!isAuthenticated) {
return <Login />;
@@ -42,6 +53,7 @@ export default function App() {
{activeModule === 'ai' && <AISegmentation onSendToWorkspace={() => setActiveModule('workspace')} />}
{activeModule === 'workspace' && <VideoWorkspace onNavigateToAI={() => setActiveModule('ai')} />}
{activeModule === 'templates' && <TemplateRegistry />}
{activeModule === 'admin' && currentUser?.role === 'admin' && <UserAdmin />}
</main>
</div>
);

View File

@@ -356,13 +356,41 @@ describe('AISegmentation', () => {
await waitFor(() => expect(screen.getByTestId('konva-path')).toBeInTheDocument());
const maskGroup = () => screen.getAllByTestId('konva-group').find((group) => group.getAttribute('data-opacity'));
expect(maskGroup()).toHaveAttribute('data-opacity', '0.72');
fireEvent.change(screen.getByLabelText('遮罩清晰度'), { target: { value: '35' } });
expect(maskGroup()).toHaveAttribute('data-opacity', '0.5');
fireEvent.change(screen.getByLabelText('AI 遮罩透明度'), { target: { value: '35' } });
expect(maskGroup()).toHaveAttribute('data-opacity', '0.35');
expect(useStore.getState().maskPreviewOpacity).toBe(35);
expect(useStore.getState().masks[0].segmentation).toEqual([[10, 10, 40, 10, 40, 40]]);
});
it('updates AI candidate opacity when the shared ontology opacity slider changes', async () => {
apiMock.predictMask.mockResolvedValueOnce({
masks: [
{
id: 'sam2-mask',
pathData: 'M 10 10 L 40 10 L 40 40 Z',
label: 'AI Mask',
color: '#06b6d4',
segmentation: [[10, 10, 40, 10, 40, 40]],
bbox: [10, 10, 30, 30],
area: 900,
},
],
});
render(<AISegmentation onSendToWorkspace={vi.fn()} />);
fireEvent.click(screen.getByText('正向选点'));
fireEvent.click(screen.getByTestId('konva-stage'));
fireEvent.click(await screen.findByText('执行高精度语义分割'));
await waitFor(() => expect(screen.getByTestId('konva-path')).toBeInTheDocument());
const maskGroup = () => screen.getAllByTestId('konva-group').find((group) => group.getAttribute('data-opacity'));
fireEvent.change(screen.getByLabelText('遮罩透明度'), { target: { value: '80' } });
expect(maskGroup()).toHaveAttribute('data-opacity', '0.8');
});
it('lets positive and negative prompt points be added on top of an AI mask', async () => {
apiMock.predictMask
.mockResolvedValueOnce({
@@ -558,6 +586,11 @@ describe('AISegmentation', () => {
it('keeps the generated SAM2 mask selected when sending it to the workspace editor', async () => {
const onSendToWorkspace = vi.fn();
useStore.setState({
activeTemplateId: 'template-1',
activeClass: { id: 'class-1', name: '胆囊', color: '#ff0000', zIndex: 30 },
activeClassId: 'class-1',
});
apiMock.predictMask.mockResolvedValueOnce({
masks: [
{
@@ -585,4 +618,94 @@ describe('AISegmentation', () => {
expect(onSendToWorkspace).toHaveBeenCalled();
});
it('blocks sending an AI candidate to the workspace until a semantic class is selected', async () => {
const onSendToWorkspace = vi.fn();
apiMock.predictMask.mockResolvedValueOnce({
masks: [
{
id: 'sam2-mask',
pathData: 'M 10 10 L 40 10 L 40 40 Z',
label: 'AI Mask',
color: '#06b6d4',
segmentation: [[10, 10, 40, 10, 40, 40]],
bbox: [10, 10, 30, 30],
area: 900,
},
],
});
render(<AISegmentation onSendToWorkspace={onSendToWorkspace} />);
fireEvent.click(screen.getByText('正向选点'));
fireEvent.click(screen.getByTestId('konva-stage'));
fireEvent.click(await screen.findByText('执行高精度语义分割'));
await waitFor(() => expect(useStore.getState().selectedMaskIds).toEqual(['sam2-mask']));
fireEvent.click(screen.getByText('推送至工作区编辑'));
const toast = screen.getByRole('status');
expect(toast).toHaveTextContent('请先在右侧语义分类树为 AI 候选区域选择语义分类,再推送至工作区。');
expect(toast.className).toContain('bg-red-950');
expect(useStore.getState().activeTool).toBe('point_pos');
expect(onSendToWorkspace).not.toHaveBeenCalled();
});
it('removes unclassified AI candidates when leaving the AI page', async () => {
apiMock.predictMask.mockResolvedValueOnce({
masks: [
{
id: 'sam2-mask',
pathData: 'M 10 10 L 40 10 L 40 40 Z',
label: 'AI Mask',
color: '#06b6d4',
segmentation: [[10, 10, 40, 10, 40, 40]],
bbox: [10, 10, 30, 30],
area: 900,
},
],
});
const { unmount } = render(<AISegmentation onSendToWorkspace={vi.fn()} />);
fireEvent.click(screen.getByText('正向选点'));
fireEvent.click(screen.getByTestId('konva-stage'));
fireEvent.click(await screen.findByText('执行高精度语义分割'));
await waitFor(() => expect(useStore.getState().masks).toHaveLength(1));
unmount();
expect(useStore.getState().masks).toEqual([]);
expect(useStore.getState().selectedMaskIds).toEqual([]);
});
it('keeps classified AI candidates when leaving the AI page', async () => {
useStore.setState({
activeTemplateId: 'template-1',
activeClass: { id: 'class-1', name: '胆囊', color: '#ff0000', zIndex: 30 },
activeClassId: 'class-1',
});
apiMock.predictMask.mockResolvedValueOnce({
masks: [
{
id: 'sam2-mask',
pathData: 'M 10 10 L 40 10 L 40 40 Z',
label: 'AI Mask',
color: '#06b6d4',
segmentation: [[10, 10, 40, 10, 40, 40]],
bbox: [10, 10, 30, 30],
area: 900,
},
],
});
const { unmount } = render(<AISegmentation onSendToWorkspace={vi.fn()} />);
fireEvent.click(screen.getByText('正向选点'));
fireEvent.click(screen.getByTestId('konva-stage'));
fireEvent.click(await screen.findByText('执行高精度语义分割'));
await waitFor(() => expect(useStore.getState().masks[0]?.classId).toBe('class-1'));
unmount();
expect(useStore.getState().masks).toHaveLength(1);
expect(useStore.getState().selectedMaskIds).toEqual(['sam2-mask']);
});
});

View File

@@ -4,6 +4,7 @@ import { cn } from '../lib/utils';
import { Stage, Layer, Image as KonvaImage, Circle, Path, Group, Rect } from 'react-konva';
import useImage from 'use-image';
import { OntologyInspector } from './OntologyInspector';
import { TransientNotice, type NoticeState } from './TransientNotice';
import { SAM2_MODEL_OPTIONS, useStore, type Mask } from '../store/useStore';
import { getAiModelStatus, predictMask, type AiRuntimeStatus } from '../lib/api';
@@ -34,13 +35,15 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
const setAiModel = useStore((state) => state.setAiModel);
const activeTemplateId = useStore((state) => state.activeTemplateId);
const activeClass = useStore((state) => state.activeClass);
const maskPreviewOpacity = useStore((state) => state.maskPreviewOpacity);
const setMaskPreviewOpacity = useStore((state) => state.setMaskPreviewOpacity);
const [modelStatus, setModelStatus] = useState<AiRuntimeStatus | null>(null);
const [autoDeleteBg, setAutoDeleteBg] = useState(true);
const [cropMode, setCropMode] = useState(false);
const [maskOpacity, setMaskOpacity] = useState(72);
const [isInferencing, setIsInferencing] = useState(false);
const [inferenceMessage, setInferenceMessage] = useState('');
const [notice, setNotice] = useState<NoticeState | null>(null);
const [aiMaskIds, setAiMaskIds] = useState<string[]>([]);
// Canvas state
@@ -59,11 +62,32 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
const frameMasks = currentFrame
? masks.filter((mask) => mask.frameId === currentFrame.id && aiMaskIdSet.has(mask.id))
: masks.filter((mask) => aiMaskIdSet.has(mask.id));
const selectedAiMasks = frameMasks.filter((mask) => selectedMaskIds.includes(mask.id));
const aiMasksToSend = selectedAiMasks.length > 0 ? selectedAiMasks : frameMasks;
const selectedModelStatus = modelStatus?.models.find((model) => model.id === aiModel);
const modelCanInfer = selectedModelStatus?.available ?? true;
const effectiveTool = storeActiveTool;
useEffect(() => {
return () => {
if (aiMaskIds.length === 0) return;
const state = useStore.getState();
const aiIds = new Set(aiMaskIds);
const unclassifiedAiIds = new Set(
state.masks
.filter((mask) => aiIds.has(mask.id) && !mask.classId && !mask.className)
.map((mask) => mask.id),
);
if (unclassifiedAiIds.size === 0) return;
useStore.setState({
masks: state.masks.filter((mask) => !unclassifiedAiIds.has(mask.id)),
selectedMaskIds: state.selectedMaskIds.filter((id) => !unclassifiedAiIds.has(id)),
});
};
}, [aiMaskIds]);
useEffect(() => {
const handleResize = () => {
if (!canvasContainerRef.current) return;
@@ -266,6 +290,7 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
classId: activeClass?.id,
className: activeClass?.name,
classZIndex: activeClass?.zIndex,
classMaskId: activeClass?.maskId,
saveStatus: 'draft',
saved: false,
pathData: m.pathData,
@@ -329,6 +354,27 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
deleteAiMasksById(selectedMaskIds);
}, [deleteAiMasksById, selectedMaskIds]);
const handleSendToWorkspace = useCallback(() => {
if (aiMasksToSend.length === 0) {
setInferenceMessage('请先执行分割并选择一个 AI 候选区域。');
return;
}
const hasMissingSemantic = aiMasksToSend.some((mask) => !mask.classId && !mask.className);
if (hasMissingSemantic) {
setInferenceMessage('');
setNotice({
id: Date.now(),
message: '请先在右侧语义分类树为 AI 候选区域选择语义分类,再推送至工作区。',
tone: 'error',
});
return;
}
setInferenceMessage('');
setActiveTool('edit_polygon');
onSendToWorkspace();
}, [aiMasksToSend, onSendToWorkspace, setActiveTool]);
const removePromptPoint = useCallback((pointIndex: number) => {
setPoints((currentPoints) => currentPoints.filter((_, index) => index !== pointIndex));
}, []);
@@ -398,6 +444,7 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
return (
<div className="w-full h-full flex bg-[#0a0a0a]">
<TransientNotice notice={notice} onDismiss={() => setNotice(null)} />
{/* Left AI Controller Panel */}
<aside className="w-80 bg-[#0d0d0d] flex flex-col border-r border-white/5 shrink-0 z-10 overflow-hidden">
<div className="h-16 border-b border-white/5 flex items-center px-6 shrink-0 justify-between">
@@ -506,17 +553,18 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
<div className="space-y-2">
<div className="flex items-center justify-between">
<label htmlFor="ai-mask-opacity" className="text-[11px] text-gray-400 uppercase tracking-wider font-medium"></label>
<span className="text-[10px] font-mono text-cyan-400">{maskOpacity}%</span>
<label htmlFor="ai-mask-opacity" className="text-[11px] text-gray-400 uppercase tracking-wider font-medium">AI </label>
<span className="text-[10px] font-mono text-cyan-400">{maskPreviewOpacity}%</span>
</div>
<input
id="ai-mask-opacity"
aria-label="AI 遮罩透明度"
type="range"
min="20"
min="10"
max="100"
step="5"
value={maskOpacity}
onChange={(event) => setMaskOpacity(Number(event.target.value))}
value={maskPreviewOpacity}
onChange={(event) => setMaskPreviewOpacity(Number(event.target.value))}
className="w-full accent-cyan-400"
/>
</div>
@@ -544,10 +592,8 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
</div>
)}
<button
onClick={() => {
setActiveTool('edit_polygon');
onSendToWorkspace();
}}
onClick={handleSendToWorkspace}
title="AI 候选区域必须先选择语义分类,才能推送到工作区"
className="w-full py-3.5 rounded-lg flex items-center justify-center gap-2 transition-all font-medium tracking-wide text-xs uppercase bg-white/5 hover:bg-white/10 text-gray-300 border border-white/5 hover:border-white/10"
>
<SendToBack size={16} />
@@ -659,9 +705,10 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
{/* AI Returned Masks */}
{frameMasks.map((mask) => {
const isSelected = selectedMaskIds.includes(mask.id);
const baseOpacity = Math.min(Math.max(maskPreviewOpacity / 100, 0.1), 1);
const previewOpacity = isSelected
? maskOpacity / 100
: Math.max(0.18, (maskOpacity / 100) * 0.62);
? baseOpacity
: Math.max(0.12, baseOpacity * 0.62);
return (
<Group key={mask.id} opacity={previewOpacity}>
<Path

View File

@@ -0,0 +1,25 @@
import React from 'react';
import { Bot, Sparkles } from 'lucide-react';
interface AiSegmentationIconProps {
size?: number;
strokeWidth?: number;
}
export function AiSegmentationIcon({ size = 20, strokeWidth = 2 }: AiSegmentationIconProps) {
const sparkleSize = Math.max(9, Math.round(size * 0.48));
return (
<span
data-testid="ai-segmentation-icon"
className="relative inline-flex items-center justify-center"
style={{ width: size, height: size }}
>
<Bot size={size} strokeWidth={strokeWidth} />
<Sparkles
size={sparkleSize}
strokeWidth={Math.max(strokeWidth, 2.2)}
className="absolute -right-1 -top-1 text-cyan-300 drop-shadow-[0_0_4px_rgba(34,211,238,0.75)]"
/>
</span>
);
}

View File

@@ -934,69 +934,102 @@ describe('CanvasArea', () => {
expect(useStore.getState().masks[0].segmentation?.[0]).toHaveLength(64);
});
it('creates a manual line region from a drag gesture', () => {
render(<CanvasArea activeTool="create_line" frame={frame} />);
it('creates a brush mask when a semantic class is selected', () => {
useStore.setState({
activeTemplateId: '2',
activeClass: { id: 'c1', name: '胆囊', color: '#ff0000', zIndex: 20, maskId: 1 },
activeClassId: 'c1',
});
render(<CanvasArea activeTool="brush" frame={frame} />);
const stage = screen.getByTestId('konva-stage');
fireEvent.mouseDown(stage, { clientX: 120, clientY: 80 });
fireEvent.mouseMove(stage, { clientX: 260, clientY: 200 });
fireEvent.mouseMove(stage, { clientX: 180, clientY: 120 });
fireEvent.mouseUp(stage, { clientX: 260, clientY: 200 });
expect(useStore.getState().masks).toHaveLength(1);
expect(useStore.getState().masks[0]).toEqual(expect.objectContaining({
frameId: 'frame-1',
label: '手工线段',
color: '#06b6d4',
label: '胆囊',
color: '#ff0000',
classId: 'c1',
classMaskId: 1,
saveStatus: 'draft',
metadata: expect.objectContaining({
source: 'manual',
shape: '线段',
shape: '画笔',
}),
}));
expect(useStore.getState().masks[0].segmentation?.[0]).toHaveLength(8);
expect(useStore.getState().masks[0].segmentation?.length).toBeGreaterThan(0);
expect(useStore.getState().masks[0].area).toBeGreaterThan(1000);
});
it('creates an editable point region on click', () => {
render(<CanvasArea activeTool="create_point" frame={frame} />);
fireEvent.click(screen.getByTestId('konva-stage'), { clientX: 120, clientY: 80 });
expect(useStore.getState().masks).toHaveLength(1);
expect(useStore.getState().masks[0]).toEqual(expect.objectContaining({
frameId: 'frame-1',
label: '手工点区域',
color: '#06b6d4',
saveStatus: 'draft',
points: [[120, 80]],
bbox: expect.arrayContaining([115, 75]),
metadata: expect.objectContaining({
source: 'manual',
shape: '点区域',
}),
}));
});
it('creates a point region when clicking over an existing mask', () => {
it('merges a connected brush stroke into the selected mask', () => {
useStore.setState({
activeTemplateId: '2',
activeClass: { id: 'c1', name: '胆囊', color: '#ff0000', zIndex: 20 },
activeClassId: 'c1',
selectedMaskIds: ['m1'],
masks: [
{
id: 'm1',
frameId: 'frame-1',
pathData: 'M 10 10 L 200 10 L 200 200 Z',
label: 'Existing',
color: '#06b6d4',
segmentation: [[10, 10, 200, 10, 200, 200]],
pathData: 'M 100 70 L 150 70 L 150 120 L 100 120 Z',
label: '胆囊',
color: '#ff0000',
classId: 'c1',
segmentation: [[100, 70, 150, 70, 150, 120, 100, 120]],
area: 2500,
},
],
});
render(<CanvasArea activeTool="create_point" frame={frame} />);
fireEvent.click(screen.getByTestId('konva-path'), { clientX: 120, clientY: 80 });
render(<CanvasArea activeTool="brush" frame={frame} />);
const stage = screen.getByTestId('konva-stage');
fireEvent.mouseDown(stage, { clientX: 130, clientY: 90 });
fireEvent.mouseMove(stage, { clientX: 170, clientY: 100 });
fireEvent.mouseUp(stage, { clientX: 210, clientY: 110 });
expect(useStore.getState().masks).toHaveLength(2);
expect(useStore.getState().masks[1]).toEqual(expect.objectContaining({
metadata: expect.objectContaining({ shape: '点区域' }),
points: [[120, 80]],
expect(useStore.getState().masks).toHaveLength(1);
expect(useStore.getState().masks[0]).toEqual(expect.objectContaining({
id: 'm1',
label: '胆囊',
color: '#ff0000',
saveStatus: 'draft',
}));
expect(useStore.getState().masks[0].area).toBeGreaterThan(2500);
expect(useStore.getState().selectedMaskIds).toEqual(['m1']);
});
it('erases from the selected mask with a sampled stroke', () => {
useStore.setState({
selectedMaskIds: ['m1'],
masks: [
{
id: 'm1',
frameId: 'frame-1',
pathData: 'M 10 10 L 300 10 L 300 220 L 10 220 Z',
label: 'Existing',
color: '#06b6d4',
segmentation: [[10, 10, 300, 10, 300, 220, 10, 220]],
area: 60900,
},
],
});
render(<CanvasArea activeTool="eraser" frame={frame} />);
const stage = screen.getByTestId('konva-stage');
fireEvent.mouseDown(stage, { clientX: 120, clientY: 80 });
fireEvent.mouseMove(stage, { clientX: 180, clientY: 120 });
fireEvent.mouseUp(stage, { clientX: 260, clientY: 200 });
expect(useStore.getState().masks).toHaveLength(1);
expect(useStore.getState().masks[0]).toEqual(expect.objectContaining({
id: 'm1',
saveStatus: 'draft',
}));
expect(useStore.getState().masks[0].area).toBeLessThan(60900);
expect(useStore.getState().selectedMaskIds).toEqual(['m1']);
});
it('finalizes a clicked polygon with Enter', () => {
@@ -1082,10 +1115,10 @@ describe('CanvasArea', () => {
vi.useRealTimers();
});
it('applies the selected class to current-frame masks and marks saved masks dirty', () => {
it('applies the selected class to current-frame masks and linked propagation masks', () => {
useStore.setState({
activeTemplateId: '2',
activeClass: { id: 'c1', name: '胆囊', color: '#ff0000', zIndex: 20 },
activeClass: { id: 'c1', name: '胆囊', color: '#ff0000', zIndex: 20, maskId: 1 },
activeClassId: 'c1',
masks: [
{
@@ -1098,6 +1131,28 @@ describe('CanvasArea', () => {
saved: true,
saveStatus: 'saved',
},
{
id: 'm2',
frameId: 'frame-2',
annotationId: '100',
pathData: 'M 1 1 Z',
label: '旧传播标签',
color: '#06b6d4',
metadata: { source_annotation_id: 99, source_mask_id: 'annotation-99' },
saved: true,
saveStatus: 'saved',
},
{
id: 'm3',
frameId: 'frame-2',
annotationId: '101',
pathData: 'M 2 2 Z',
label: '无关区域',
color: '#ffffff',
metadata: { source_annotation_id: 101 },
saved: true,
saveStatus: 'saved',
},
],
});
@@ -1109,11 +1164,56 @@ describe('CanvasArea', () => {
classId: 'c1',
className: '胆囊',
classZIndex: 20,
classMaskId: 1,
label: '胆囊',
color: '#ff0000',
saveStatus: 'dirty',
saved: false,
}));
expect(useStore.getState().masks[1]).toEqual(expect.objectContaining({
classId: 'c1',
className: '胆囊',
classMaskId: 1,
label: '胆囊',
color: '#ff0000',
saveStatus: 'dirty',
saved: false,
}));
expect(useStore.getState().masks[2]).toEqual(expect.objectContaining({
label: '无关区域',
color: '#ffffff',
saveStatus: 'saved',
saved: true,
}));
});
it('renders unselected masks by semantic tree layer priority', () => {
useStore.setState({
selectedMaskIds: [],
masks: [
{
id: 'high',
frameId: 'frame-1',
pathData: 'M 0 0 Z',
label: '高优先级',
color: '#ef4444',
classZIndex: 30,
},
{
id: 'low',
frameId: 'frame-1',
pathData: 'M 1 1 Z',
label: '低优先级',
color: '#22c55e',
classZIndex: 10,
},
],
});
render(<CanvasArea activeTool="move" frame={frame} />);
const paths = screen.getAllByTestId('konva-path');
expect(paths.map((path) => path.getAttribute('data-fill'))).toEqual(['#22c55e', '#ef4444']);
});
it('delegates clear to the workspace handler so saved annotations can be deleted', () => {

View File

@@ -18,14 +18,18 @@ type PromptPoint = CanvasPoint & { type: 'pos' | 'neg' };
type PromptBox = { x1: number; y1: number; x2: number; y2: number };
type ToolHint = { title: string; body: string };
const DRAG_MANUAL_TOOLS = new Set(['create_rectangle', 'create_circle', 'create_line']);
const DRAG_MANUAL_TOOLS = new Set(['create_rectangle', 'create_circle']);
const POLYGON_TOOL = 'create_polygon';
const EDIT_POLYGON_TOOL = 'edit_polygon';
const POINT_TOOL = 'create_point';
const BRUSH_TOOL = 'brush';
const ERASER_TOOL = 'eraser';
const PAINT_TOOLS = new Set([BRUSH_TOOL, ERASER_TOOL]);
const BOOLEAN_TOOLS = new Set(['area_merge', 'area_remove']);
const POLYGON_CLOSE_RADIUS = 8;
const DEFAULT_IMAGE_FIT_RATIO = 0.86;
const TOOL_HINT_TTL_MS = 3600;
const PAINT_STAMP_SEGMENTS = 16;
const MAX_PAINT_STROKE_POINTS = 128;
function clamp(value: number, min: number, max: number): number {
return Math.min(Math.max(value, min), max);
@@ -97,6 +101,31 @@ function findLinkedMasksOnFrame(selectedIds: string[], allMasks: Mask[], targetF
.map((mask) => mask.id);
}
function findPropagationChainMaskIds(selectedIds: string[], allMasks: Mask[]): Set<string> {
const selectedMasks = selectedIds
.map((id) => allMasks.find((mask) => mask.id === id))
.filter((mask): mask is Mask => Boolean(mask));
const selectedTokens = new Set<string>();
selectedMasks.forEach((mask) => {
propagationLineageTokens(mask).forEach((token) => selectedTokens.add(token));
});
if (selectedTokens.size === 0) return new Set(selectedIds);
return new Set(
allMasks
.filter((mask) => {
const candidateTokens = propagationLineageTokens(mask);
return [...candidateTokens].some((token) => selectedTokens.has(token));
})
.map((mask) => mask.id),
);
}
function maskLayerPriority(mask: Mask): number {
const parsed = Number(mask.classZIndex ?? mask.metadata?.classZIndex ?? 0);
return Number.isFinite(parsed) ? parsed : 0;
}
function polygonPath(points: CanvasPoint[]): string {
if (points.length === 0) return '';
return points
@@ -165,6 +194,29 @@ function pointDistance(a: CanvasPoint, b: CanvasPoint): number {
return Math.hypot(a.x - b.x, a.y - b.y);
}
function extendStrokePoints(
current: CanvasPoint[],
nextPoint: CanvasPoint,
spacing: number,
maxPoints = MAX_PAINT_STROKE_POINTS,
): CanvasPoint[] {
const previous = current[current.length - 1];
if (!previous) return [nextPoint];
const distance = pointDistance(previous, nextPoint);
if (distance < spacing) return current;
const steps = Math.max(1, Math.floor(distance / spacing));
const additions: CanvasPoint[] = [];
for (let step = 1; step <= steps; step += 1) {
if (current.length + additions.length >= maxPoints) break;
const ratio = step / steps;
additions.push({
x: previous.x + (nextPoint.x - previous.x) * ratio,
y: previous.y + (nextPoint.y - previous.y) * ratio,
});
}
return [...current, ...additions];
}
function distanceToSegmentSquared(point: CanvasPoint, start: CanvasPoint, end: CanvasPoint): number {
const dx = end.x - start.x;
const dy = end.y - start.y;
@@ -218,6 +270,13 @@ function maskToMultiPolygon(mask: Mask): MultiPolygon | null {
return polygons.length > 0 ? polygons : null;
}
function polygonsToMultiPolygon(polygons: CanvasPoint[][]): MultiPolygon | null {
const geometry = polygons
.filter((points) => points.length >= 3)
.map((points) => [closeRing(points)]);
return geometry.length > 0 ? geometry : null;
}
function openRingPoints(ring: Pair[]): CanvasPoint[] {
const openRing = ring.length > 1
&& ring[0][0] === ring[ring.length - 1][0]
@@ -247,6 +306,27 @@ function multiPolygonHasHoles(geometry: MultiPolygon): boolean {
return geometry.some((polygon) => polygon.length > 1);
}
function maskWithSegmentation(
mask: Mask,
segmentation: number[][],
options: { area?: number; hasHoles?: boolean } = {},
): Mask {
const bbox = segmentationBbox(segmentation);
const metadata = { ...(mask.metadata || {}) };
if (options.hasHoles === true) metadata.hasHoles = true;
if (options.hasHoles === false) delete metadata.hasHoles;
return {
...mask,
pathData: segmentationPath(segmentation),
segmentation,
bbox,
area: options.area ?? segmentationArea(segmentation),
metadata,
saveStatus: mask.annotationId ? 'dirty' : 'draft',
saved: mask.annotationId ? false : mask.saved,
};
}
function rectanglePoints(start: CanvasPoint, end: CanvasPoint): CanvasPoint[] {
const x1 = Math.min(start.x, end.x);
const y1 = Math.min(start.y, end.y);
@@ -271,25 +351,26 @@ function circlePoints(start: CanvasPoint, end: CanvasPoint): CanvasPoint[] {
});
}
function pointRegion(point: CanvasPoint, radius = 5): CanvasPoint[] {
return Array.from({ length: 12 }, (_, index) => {
const angle = (Math.PI * 2 * index) / 12;
return { x: point.x + Math.cos(angle) * radius, y: point.y + Math.sin(angle) * radius };
function circleStampPoints(center: CanvasPoint, radius: number, segments = PAINT_STAMP_SEGMENTS): CanvasPoint[] {
return Array.from({ length: segments }, (_, index) => {
const angle = (Math.PI * 2 * index) / segments;
return { x: center.x + Math.cos(angle) * radius, y: center.y + Math.sin(angle) * radius };
});
}
function lineRegion(start: CanvasPoint, end: CanvasPoint, halfWidth = 4): CanvasPoint[] {
const dx = end.x - start.x;
const dy = end.y - start.y;
const length = Math.hypot(dx, dy) || 1;
const nx = (-dy / length) * halfWidth;
const ny = (dx / length) * halfWidth;
return [
{ x: start.x + nx, y: start.y + ny },
{ x: end.x + nx, y: end.y + ny },
{ x: end.x - nx, y: end.y - ny },
{ x: start.x - nx, y: start.y - ny },
];
function paintStrokeToGeometry(strokePoints: CanvasPoint[], radius: number): MultiPolygon | null {
const geometries = strokePoints
.map((point) => polygonsToMultiPolygon([circleStampPoints(point, radius)]))
.filter((geometry): geometry is MultiPolygon => Boolean(geometry));
if (geometries.length === 0) return null;
const [firstGeometry, ...restGeometries] = geometries;
return restGeometries.length === 0
? firstGeometry
: polygonClipping.union(firstGeometry, ...restGeometries);
}
function geometriesOverlap(first: MultiPolygon, second: MultiPolygon): boolean {
return polygonClipping.intersection(first, second).length > 0;
}
export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnotations }: CanvasAreaProps) {
@@ -305,6 +386,7 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
const [samCandidateMaskId, setSamCandidateMaskId] = useState<string | null>(null);
const [manualStart, setManualStart] = useState<CanvasPoint | null>(null);
const [manualCurrent, setManualCurrent] = useState<CanvasPoint | null>(null);
const [paintStrokePoints, setPaintStrokePointsState] = useState<CanvasPoint[]>([]);
const [polygonPoints, setPolygonPoints] = useState<CanvasPoint[]>([]);
const [selectedMaskId, setSelectedMaskId] = useState<string | null>(() => useStore.getState().selectedMaskIds[0] || null);
const [selectedMaskIds, setSelectedMaskIds] = useState<string[]>(() => useStore.getState().selectedMaskIds);
@@ -315,6 +397,9 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
const [inferenceMessage, setInferenceMessage] = useState('');
const [isToolHintVisible, setIsToolHintVisible] = useState(false);
const lastAutoFitKeyRef = useRef('');
const paintStrokeRef = useRef<CanvasPoint[]>([]);
const paintToolRef = useRef<string | null>(null);
const lastPaintPointRef = useRef<CanvasPoint | null>(null);
const masks = useStore((state) => state.masks);
const addMask = useStore((state) => state.addMask);
@@ -323,6 +408,8 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
const setMasks = useStore((state) => state.setMasks);
const setGlobalSelectedMaskIds = useStore((state) => state.setSelectedMaskIds);
const maskPreviewOpacity = useStore((state) => state.maskPreviewOpacity);
const brushSize = useStore((state) => state.brushSize);
const eraserSize = useStore((state) => state.eraserSize);
const storeActiveTool = useStore((state) => state.activeTool);
const aiModel = useStore((state) => state.aiModel);
const activeTemplateId = useStore((state) => state.activeTemplateId);
@@ -333,6 +420,16 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
// Load the actual frame image
const [image] = useImage(frame?.url || '');
const frameMasks = masks.filter((mask) => mask.frameId === frame?.id);
const displayFrameMasks = React.useMemo(() => {
if (selectedMaskIds.length > 0) return frameMasks;
return frameMasks
.map((mask, index) => ({ mask, index }))
.sort((a, b) => {
const priorityDiff = maskLayerPriority(a.mask) - maskLayerPriority(b.mask);
return priorityDiff === 0 ? a.index - b.index : priorityDiff;
})
.map((item) => item.mask);
}, [frameMasks, selectedMaskIds.length]);
const selectedMask = React.useMemo(
() => frameMasks.find((mask) => mask.id === selectedMaskId) || null,
[frameMasks, selectedMaskId],
@@ -351,7 +448,14 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
const draftMaskCount = frameMasks.filter((mask) => !mask.annotationId).length;
const dirtyMaskCount = frameMasks.filter((mask) => mask.saveStatus === 'dirty').length;
const isBooleanTool = BOOLEAN_TOOLS.has(effectiveTool);
const isPaintTool = PAINT_TOOLS.has(effectiveTool);
const isPolygonEditTool = effectiveTool === 'move' || effectiveTool === EDIT_POLYGON_TOOL;
const activePaintSize = effectiveTool === ERASER_TOOL ? eraserSize : brushSize;
const activePaintRadius = Math.max(2, activePaintSize / 2);
const setPaintStrokePoints = useCallback((nextPoints: CanvasPoint[]) => {
paintStrokeRef.current = nextPoints;
setPaintStrokePointsState(nextPoints);
}, []);
const currentLayerLabel = selectedMask
? `${selectedMask.className || selectedMask.label}${selectedMask.annotationId ? ` #${selectedMask.annotationId}` : ' (未保存)'}`
: '未选择';
@@ -381,11 +485,21 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
if (effectiveTool === 'create_circle') {
return { title: '创建圆形', body: '按住并拖拽确定外接范围,松开鼠标后生成椭圆 mask。' };
}
if (effectiveTool === 'create_line') {
return { title: '创建线段', body: '按住并拖拽画出线段,松开后生成有宽度的线状 mask。' };
if (effectiveTool === BRUSH_TOOL) {
return {
title: '画笔',
body: activeClass
? '按住并拖动画出连续区域;若与当前选中 mask 连通,会自动合并到该 mask。'
: '先在右侧语义分类树选择类别,然后按住并拖动画出连续区域。',
};
}
if (effectiveTool === POINT_TOOL) {
return { title: '创建点区域', body: '点击画布创建一个小型点区域;也可以在已有 mask 上继续落点。' };
if (effectiveTool === ERASER_TOOL) {
return {
title: '橡皮擦',
body: selectedMask
? '按住并拖动,从当前选中 mask 中扣除经过的区域。'
: '先选择一个 mask然后按住并拖动擦除区域。',
};
}
if (effectiveTool === 'box_select') {
return {
@@ -426,7 +540,7 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
};
}
return null;
}, [booleanSelectedMasks.length, effectiveTool, frame, polygonPoints.length, samPromptBox, selectedMask]);
}, [activeClass, booleanSelectedMasks.length, effectiveTool, frame, polygonPoints.length, samPromptBox, selectedMask]);
useEffect(() => {
if (!toolHint) {
@@ -479,14 +593,17 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
useEffect(() => {
setManualStart(null);
setManualCurrent(null);
setPaintStrokePoints([]);
paintToolRef.current = null;
lastPaintPointRef.current = null;
setPolygonPoints([]);
setSelectedVertexIndex(null);
if (!isPolygonEditTool && !isBooleanTool) {
if (!isPolygonEditTool && !isBooleanTool && !isPaintTool) {
setSelectedMaskId(null);
setSelectedMaskIds([]);
setSelectedPolygonIndex(0);
}
}, [effectiveTool, isBooleanTool, isPolygonEditTool]);
}, [effectiveTool, isBooleanTool, isPaintTool, isPolygonEditTool, setPaintStrokePoints]);
useEffect(() => {
if (previousFrameIdRef.current === frame?.id) return;
@@ -617,18 +734,13 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
classId: activeClass?.id,
className: activeClass?.name,
classZIndex: activeClass?.zIndex,
classMaskId: activeClass?.maskId,
saveStatus: 'draft',
saved: false,
pathData: polygonPath(polygon),
label,
color,
segmentation: polygonSegmentation(polygon),
points: shape === '点区域'
? [[
polygon.reduce((sum, point) => sum + point.x, 0) / polygon.length,
polygon.reduce((sum, point) => sum + point.y, 0) / polygon.length,
]]
: undefined,
bbox: polygonBbox(polygon),
area,
metadata: { source: 'manual', shape },
@@ -636,6 +748,38 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
addMask(mask);
}, [activeClass, activeTemplateId, addMask, frame?.id]);
const createManualMaskFromGeometry = useCallback((shape: string, geometry: MultiPolygon): Mask | null => {
if (!frame?.id || !activeClass) return null;
const segmentation = multiPolygonToSegmentation(geometry);
if (segmentation.length === 0) return null;
const area = multiPolygonArea(geometry);
if (area <= 1) return null;
const mask: Mask = {
id: `manual-${frame.id}-${shape}-${Date.now()}`,
frameId: frame.id,
templateId: activeTemplateId || undefined,
classId: activeClass.id,
className: activeClass.name,
classZIndex: activeClass.zIndex,
classMaskId: activeClass.maskId,
saveStatus: 'draft',
saved: false,
pathData: segmentationPath(segmentation),
label: activeClass.name,
color: activeClass.color,
segmentation,
bbox: segmentationBbox(segmentation),
area,
metadata: {
source: 'manual',
shape,
...(multiPolygonHasHoles(geometry) ? { hasHoles: true } : {}),
},
};
addMask(mask);
return mask;
}, [activeClass, activeTemplateId, addMask, frame?.id]);
const finishPolygon = useCallback(() => {
if (polygonPoints.length < 3) return;
createManualMask('多边形', polygonPoints);
@@ -665,6 +809,20 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
setManualCurrent({ x: pos.x, y: pos.y });
}
}
if (paintToolRef.current && PAINT_TOOLS.has(effectiveTool)) {
const pos = stagePoint(e);
const previous = lastPaintPointRef.current;
if (!pos || !previous) return;
const radius = Math.max(2, (paintToolRef.current === ERASER_TOOL ? eraserSize : brushSize) / 2);
const minDistance = Math.max(3, radius * 0.55);
if (pointDistance(previous, pos) < minDistance) return;
const currentStroke = paintStrokeRef.current;
if (currentStroke.length >= MAX_PAINT_STROKE_POINTS) return;
const nextStroke = extendStrokePoints(currentStroke, pos, minDistance);
lastPaintPointRef.current = nextStroke[nextStroke.length - 1] || pos;
setPaintStrokePoints(nextStroke);
}
};
const runInference = useCallback(async (
@@ -721,6 +879,7 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
classId: activeClass?.id || existingCandidate?.classId,
className: activeClass?.name || existingCandidate?.className,
classZIndex: activeClass?.zIndex ?? existingCandidate?.classZIndex,
classMaskId: activeClass?.maskId ?? existingCandidate?.classMaskId,
saveStatus: existingCandidate?.annotationId ? 'dirty' as const : 'draft' as const,
saved: false,
pathData: m.pathData,
@@ -768,14 +927,19 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
const handleApplyActiveClass = () => {
if (!frame?.id || !activeClass) return;
const seedIds = selectedMaskIds.length > 0
? selectedMaskIds
: frameMasks.map((mask) => mask.id);
const targetIds = findPropagationChainMaskIds(seedIds, masks);
setMasks(masks.map((mask) => {
if (mask.frameId !== frame.id) return mask;
if (!targetIds.has(mask.id)) return mask;
return {
...mask,
templateId: activeTemplateId || mask.templateId,
classId: activeClass.id,
className: activeClass.name,
classZIndex: activeClass.zIndex,
classMaskId: activeClass.maskId,
label: activeClass.name,
color: activeClass.color,
saveStatus: mask.annotationId ? 'dirty' : 'draft',
@@ -815,7 +979,102 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
setSelectedVertexIndex(null);
}, [masks, onDeleteMaskAnnotations, samCandidateMaskId, setMasks]);
const applyPaintStroke = useCallback((tool: string | null, strokePoints: CanvasPoint[]) => {
if (!frame?.id || strokePoints.length === 0) return;
const radius = Math.max(2, (tool === ERASER_TOOL ? eraserSize : brushSize) / 2);
const strokeGeometry = paintStrokeToGeometry(strokePoints, radius);
if (!strokeGeometry) return;
if (tool === BRUSH_TOOL) {
if (!activeClass) {
setInferenceMessage('请先在右侧语义分类树选择类别,再使用画笔。');
return;
}
const targetGeometry = selectedMask ? maskToMultiPolygon(selectedMask) : null;
const shouldMerge = Boolean(targetGeometry && geometriesOverlap(targetGeometry, strokeGeometry));
if (selectedMask && targetGeometry && shouldMerge) {
const resultGeometry = polygonClipping.union(targetGeometry, strokeGeometry);
const resultSegmentation = multiPolygonToSegmentation(resultGeometry);
if (resultSegmentation.length === 0) return;
const nextMask = {
...maskWithSegmentation(selectedMask, resultSegmentation, {
area: multiPolygonArea(resultGeometry),
hasHoles: multiPolygonHasHoles(resultGeometry),
}),
templateId: activeTemplateId || selectedMask.templateId,
classId: activeClass.id,
className: activeClass.name,
classZIndex: activeClass.zIndex,
classMaskId: activeClass.maskId,
label: activeClass.name,
color: activeClass.color,
};
setMasks(masks.map((mask) => (mask.id === selectedMask.id ? nextMask : mask)));
setSelectedMaskId(selectedMask.id);
setSelectedMaskIds([selectedMask.id]);
setSelectedVertexIndex(null);
return;
}
const nextMask = createManualMaskFromGeometry('画笔', strokeGeometry);
if (nextMask) {
setSelectedMaskId(nextMask.id);
setSelectedMaskIds([nextMask.id]);
setSelectedPolygonIndex(0);
setSelectedVertexIndex(null);
}
return;
}
if (tool === ERASER_TOOL) {
if (!selectedMask) {
setInferenceMessage('请先选择一个 mask再使用橡皮擦。');
return;
}
const targetGeometry = maskToMultiPolygon(selectedMask);
if (!targetGeometry) return;
const resultGeometry = polygonClipping.difference(targetGeometry, strokeGeometry);
const resultSegmentation = multiPolygonToSegmentation(resultGeometry);
if (resultSegmentation.length === 0) {
deleteMasksById([selectedMask.id]);
return;
}
const nextMask = maskWithSegmentation(selectedMask, resultSegmentation, {
area: multiPolygonArea(resultGeometry),
hasHoles: multiPolygonHasHoles(resultGeometry),
});
setMasks(masks.map((mask) => (mask.id === selectedMask.id ? nextMask : mask)));
setSelectedMaskId(selectedMask.id);
setSelectedMaskIds([selectedMask.id]);
setSelectedVertexIndex(null);
}
}, [
activeClass,
activeTemplateId,
brushSize,
createManualMaskFromGeometry,
deleteMasksById,
eraserSize,
frame?.id,
masks,
selectedMask,
setMasks,
]);
const handleStageMouseDown = (e: any) => {
if (PAINT_TOOLS.has(effectiveTool)) {
const canStart = effectiveTool === BRUSH_TOOL ? Boolean(activeClass) : Boolean(selectedMask);
if (!canStart) return;
const pos = stagePoint(e);
if (pos) {
paintToolRef.current = effectiveTool;
lastPaintPointRef.current = pos;
setPaintStrokePoints([pos]);
}
return;
}
if (DRAG_MANUAL_TOOLS.has(effectiveTool)) {
const pos = stagePoint(e);
if (pos) {
@@ -836,11 +1095,28 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
};
const handleStageMouseUp = (e: any) => {
if (paintToolRef.current && PAINT_TOOLS.has(effectiveTool)) {
const finalPoint = stagePoint(e);
const currentStroke = paintStrokeRef.current;
const spacing = Math.max(3, activePaintRadius * 0.55);
const nextStroke = finalPoint
&& currentStroke.length > 0
&& pointDistance(currentStroke[currentStroke.length - 1], finalPoint) >= spacing
&& currentStroke.length < MAX_PAINT_STROKE_POINTS
? extendStrokePoints(currentStroke, finalPoint, spacing)
: currentStroke;
const tool = paintToolRef.current;
setPaintStrokePoints([]);
paintToolRef.current = null;
lastPaintPointRef.current = null;
applyPaintStroke(tool, nextStroke);
return;
}
if (DRAG_MANUAL_TOOLS.has(effectiveTool) && manualStart) {
const end = stagePoint(e) || manualCurrent || manualStart;
const width = Math.abs(end.x - manualStart.x);
const height = Math.abs(end.y - manualStart.y);
const distance = Math.hypot(width, height);
if (effectiveTool === 'create_rectangle' && width > 4 && height > 4) {
createManualMask('矩形', rectanglePoints(manualStart, end));
@@ -848,9 +1124,6 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
if (effectiveTool === 'create_circle' && width > 4 && height > 4) {
createManualMask('圆形', circlePoints(manualStart, end));
}
if (effectiveTool === 'create_line' && distance > 4) {
createManualMask('线段', lineRegion(manualStart, end));
}
setManualStart(null);
setManualCurrent(null);
@@ -880,14 +1153,7 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
if (isPolygonEditTool) return;
if (effectiveTool === 'box_select') return; // handled by mouseup
if (DRAG_MANUAL_TOOLS.has(effectiveTool)) return;
if (effectiveTool === POINT_TOOL) {
const pos = stagePoint(e);
if (pos) {
createManualMask('点区域', pointRegion(pos));
}
return;
}
if (PAINT_TOOLS.has(effectiveTool)) return;
if (effectiveTool === POLYGON_TOOL) {
const pos = stagePoint(e);
@@ -955,20 +1221,7 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
segmentation: number[][],
options: { area?: number; hasHoles?: boolean } = {},
): Mask => {
const bbox = segmentationBbox(segmentation);
const metadata = { ...(mask.metadata || {}) };
if (options.hasHoles === true) metadata.hasHoles = true;
if (options.hasHoles === false) delete metadata.hasHoles;
return {
...mask,
pathData: segmentationPath(segmentation),
segmentation,
bbox,
area: options.area ?? segmentationArea(segmentation),
metadata,
saveStatus: mask.annotationId ? 'dirty' : 'draft',
saved: mask.annotationId ? false : mask.saved,
};
return maskWithSegmentation(mask, segmentation, options);
}, []);
useEffect(() => {
@@ -1017,7 +1270,6 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
if (manualStart && manualCurrent) {
if (effectiveTool === 'create_rectangle') return polygonPath(rectanglePoints(manualStart, manualCurrent));
if (effectiveTool === 'create_circle') return polygonPath(circlePoints(manualStart, manualCurrent));
if (effectiveTool === 'create_line') return polygonPath(lineRegion(manualStart, manualCurrent));
}
if (effectiveTool === POLYGON_TOOL && polygonPoints.length > 0) {
const previewPoints = [...polygonPoints, cursorPos];
@@ -1217,7 +1469,7 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
)}
{/* AI Returned Masks */}
{frameMasks.map((mask) => {
{displayFrameMasks.map((mask) => {
const selectedIndex = selectedMaskIds.indexOf(mask.id);
const isMaskSelected = selectedIndex >= 0;
const isBooleanPrimary = isBooleanTool && selectedIndex === 0;
@@ -1282,6 +1534,34 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
/>
)}
{paintStrokePoints.length > 0 && (
<Group opacity={effectiveTool === ERASER_TOOL ? 0.28 : 0.22}>
{paintStrokePoints.map((point, index) => (
<Circle
key={`paint-stroke-${index}`}
x={point.x}
y={point.y}
radius={activePaintRadius}
fill={effectiveTool === ERASER_TOOL ? '#ef4444' : activeClass?.color || '#22d3ee'}
stroke={effectiveTool === ERASER_TOOL ? '#fecaca' : '#ffffff'}
strokeWidth={1 / scale}
/>
))}
</Group>
)}
{isPaintTool && (effectiveTool === BRUSH_TOOL ? activeClass : selectedMask) && paintStrokePoints.length === 0 && (
<Circle
x={cursorPos.x}
y={cursorPos.y}
radius={activePaintRadius}
fill="rgba(255,255,255,0.02)"
stroke={effectiveTool === ERASER_TOOL ? '#f87171' : activeClass?.color || '#22d3ee'}
strokeWidth={1.5 / scale}
dash={[4 / scale, 4 / scale]}
/>
)}
{polygonPoints.map((point, index) => (
<Circle
key={`poly-point-${index}`}

View File

@@ -188,6 +188,7 @@ export function Dashboard() {
return () => {
mounted = false;
clearTimeout(timer);
unsubscribe();
unsubscribeStatus();
clearInterval(checkConnection);

View File

@@ -78,7 +78,8 @@ describe('FrameTimeline', () => {
expect(screen.getByLabelText('视频处理进度条')).toBeInTheDocument();
expect(screen.getByText('人工/AI 1 帧 · 自动传播 1 帧')).toBeInTheDocument();
expect(screen.queryByTestId('current-frame-line')).not.toBeInTheDocument();
expect(screen.getByTestId('current-frame-line')).toHaveStyle({ left: '50%' });
expect(screen.getByTestId('current-frame-line').className).toContain('bg-white');
expect(screen.getAllByTestId('propagated-frame-segment')).toHaveLength(1);
expect(screen.getByTestId('propagated-frame-segment').className).toContain('bg-blue-500');
expect(screen.getAllByTestId('annotated-frame-marker')).toHaveLength(1);
@@ -86,31 +87,50 @@ describe('FrameTimeline', () => {
expect(screen.queryByLabelText('跳转到已编辑帧 3')).not.toBeInTheDocument();
});
it('renders recent propagation history segments with distinct gradient colors', () => {
it('renders propagation history with newest bright and old records capped to one blue threshold', () => {
useStore.setState({
frames: [
{ id: 'f1', projectId: 'p1', index: 0, url: '/1.jpg', width: 640, height: 360 },
{ id: 'f2', projectId: 'p1', index: 1, url: '/2.jpg', width: 640, height: 360 },
{ id: 'f3', projectId: 'p1', index: 2, url: '/3.jpg', width: 640, height: 360 },
{ id: 'f4', projectId: 'p1', index: 3, url: '/4.jpg', width: 640, height: 360 },
{ id: 'f5', projectId: 'p1', index: 4, url: '/5.jpg', width: 640, height: 360 },
{ id: 'f6', projectId: 'p1', index: 5, url: '/6.jpg', width: 640, height: 360 },
{ id: 'f7', projectId: 'p1', index: 6, url: '/7.jpg', width: 640, height: 360 },
],
});
render(
<FrameTimeline
propagationHistory={[
{ id: 'history-1', startFrame: 1, endFrame: 2, colorIndex: 0, label: '第一次传播' },
{ id: 'history-2', startFrame: 3, endFrame: 4, colorIndex: 1, label: '第二次传播' },
{ id: 'history-1', startFrame: 1, endFrame: 1, colorIndex: 0, label: '第一次传播' },
{ id: 'history-2', startFrame: 2, endFrame: 2, colorIndex: 1, label: '第二次传播' },
{ id: 'history-3', startFrame: 3, endFrame: 3, colorIndex: 2, label: '第三次传播' },
{ id: 'history-4', startFrame: 4, endFrame: 4, colorIndex: 3, label: '第四次传播' },
{ id: 'history-5', startFrame: 5, endFrame: 5, colorIndex: 4, label: '第五次传播' },
{ id: 'history-6', startFrame: 6, endFrame: 6, colorIndex: 5, label: '第六次传播' },
{ id: 'history-7', startFrame: 7, endFrame: 7, colorIndex: 6, label: '第七次传播' },
]}
/>,
);
const segments = screen.getAllByTestId('propagation-history-segment');
expect(segments).toHaveLength(2);
expect(segments).toHaveLength(7);
expect(segments[0]).toHaveAttribute('title', '第一次传播');
expect(segments[0]).toHaveStyle({ left: '0%', width: '50%' });
expect(segments[0].getAttribute('style')).toContain('linear-gradient');
expect(segments[1].getAttribute('style')).toContain('124, 58, 237');
expect(segments[0]).toHaveStyle({ left: '0%' });
expect(segments[0]).toHaveAttribute('data-recency-level', '4');
expect(segments[1]).toHaveAttribute('data-recency-level', '4');
expect(segments[2]).toHaveAttribute('data-recency-level', '4');
expect(segments[3]).toHaveAttribute('data-recency-level', '3');
expect(segments[4]).toHaveAttribute('data-recency-level', '2');
expect(segments[5]).toHaveAttribute('data-recency-level', '1');
expect(segments[6]).toHaveAttribute('data-recency-level', '0');
const oldestStyle = segments[0].getAttribute('style') || '';
const newestStyle = segments[6].getAttribute('style') || '';
expect(oldestStyle).not.toContain('linear-gradient');
expect(newestStyle).not.toContain('linear-gradient');
expect(segments[0].style.backgroundColor).toBe(segments[1].style.backgroundColor);
expect(segments[6].style.backgroundColor).not.toBe(segments[0].style.backgroundColor);
});
it('jumps from the processing progress bar and frame status markers', () => {
@@ -180,6 +200,7 @@ describe('FrameTimeline', () => {
{ id: 'f2', projectId: 'p1', index: 1, url: '/2.jpg', width: 640, height: 360 },
{ id: 'f3', projectId: 'p1', index: 2, url: '/3.jpg', width: 640, height: 360 },
{ id: 'f4', projectId: 'p1', index: 3, url: '/4.jpg', width: 640, height: 360 },
{ id: 'f5', projectId: 'p1', index: 4, url: '/5.jpg', width: 640, height: 360 },
],
masks: [
{ id: 'm1', frameId: 'f2', pathData: 'M 0 0 Z', label: 'Draft', color: '#ef4444' },
@@ -200,6 +221,14 @@ describe('FrameTimeline', () => {
color: '#3b82f6',
metadata: { source: 'sam2.1_hiera_tiny_propagation' },
},
{
id: 'm5',
frameId: 'f5',
pathData: 'M 3 3 Z',
label: 'Tracked after smoothing',
color: '#3b82f6',
metadata: { source_annotation_id: 7, source_mask_id: 'annotation-7' },
},
],
});
@@ -211,6 +240,8 @@ describe('FrameTimeline', () => {
const manuallyAdjustedPropagatedTile = screen.getByAltText('frame-3').closest('div');
expect(manuallyAdjustedPropagatedTile?.className).toContain('border-red-500');
expect(manuallyAdjustedPropagatedTile?.className).toContain('inset_0_0_0_2px_rgba(59,130,246,0.85)');
expect(screen.getByAltText('frame-4').closest('div')?.className).toContain('border-blue-500');
expect(screen.getByAltText('frame-4').closest('div')?.className).not.toContain('border-red-500');
});
it('keeps the current frame blue border while showing an inner red ring for annotated frames', () => {
@@ -278,6 +309,12 @@ describe('FrameTimeline', () => {
expect(onPropagationRangeChange).toHaveBeenLastCalledWith(2, 4);
expect(screen.getAllByTestId('propagation-range-overlay')).toHaveLength(2);
const boundaryLines = screen.getAllByTestId('range-boundary-line');
expect(boundaryLines).toHaveLength(2);
expect(boundaryLines[0]).toHaveStyle({ left: '25%' });
expect(boundaryLines[0].className).toContain('bg-fuchsia-400');
expect(boundaryLines[1]).toHaveStyle({ left: '75%' });
expect(boundaryLines[1].className).toContain('bg-lime-300');
});
it('changes frames with left and right arrow keys without leaving bounds', () => {

View File

@@ -49,7 +49,11 @@ export function FrameTimeline({
const totalSeconds = totalFrames > 0 ? Math.max(totalFrames - 1, 0) / timeBaseFps : 0;
const isPropagatedMask = (mask: (typeof masks)[number]) => {
const source = typeof mask.metadata?.source === 'string' ? mask.metadata.source : '';
return source.includes('_propagation') || mask.metadata?.propagated_from_frame_id !== undefined;
return source.includes('_propagation')
|| mask.metadata?.propagated_from_frame_id !== undefined
|| mask.metadata?.source_annotation_id !== undefined
|| mask.metadata?.source_mask_id !== undefined
|| mask.metadata?.propagation_seed_key !== undefined;
};
const propagatedFrameMarkers = useMemo(() => {
const frameIds = new Set(frames.map((frame) => frame.id));
@@ -105,18 +109,29 @@ export function FrameTimeline({
const rangeWidth = visibleSelectedRange && totalFrames > 0
? ((visibleSelectedRange.endFrame - visibleSelectedRange.startFrame + 1) / totalFrames) * 100
: 0;
const propagationHistoryColors = [
{ dark: 'rgba(8, 145, 178, 0.68)', light: 'rgba(103, 232, 249, 0.9)', glow: 'rgba(34, 211, 238, 0.38)' },
{ dark: 'rgba(124, 58, 237, 0.66)', light: 'rgba(196, 181, 253, 0.9)', glow: 'rgba(167, 139, 250, 0.34)' },
{ dark: 'rgba(5, 150, 105, 0.66)', light: 'rgba(110, 231, 183, 0.9)', glow: 'rgba(52, 211, 153, 0.34)' },
{ dark: 'rgba(217, 119, 6, 0.66)', light: 'rgba(253, 186, 116, 0.9)', glow: 'rgba(251, 146, 60, 0.34)' },
{ dark: 'rgba(219, 39, 119, 0.66)', light: 'rgba(251, 113, 133, 0.9)', glow: 'rgba(244, 114, 182, 0.34)' },
];
const frameLineLeft = (frame: number) => {
if (totalFrames <= 1) return 0;
return ((clampFrame(frame) - 1) / (totalFrames - 1)) * 100;
};
const currentFrameLineLeft = totalFrames > 0 ? frameLineLeft(currentFrame) : 0;
const rangeStartLineLeft = visibleSelectedRange ? frameLineLeft(visibleSelectedRange.startFrame) : 0;
const rangeEndLineLeft = visibleSelectedRange ? frameLineLeft(visibleSelectedRange.endFrame) : 0;
const propagationHistoryColor = (ageFromNewest: number) => {
const step = Math.min(Math.max(ageFromNewest, 0), 4);
const lightness = 58 - step * 7;
const alpha = 0.88 - step * 0.085;
return {
fill: `hsla(212, 88%, ${lightness}%, ${Math.max(alpha, 0.52)})`,
glow: `hsla(212, 88%, ${Math.min(lightness + 10, 76)}%, ${0.38 - step * 0.045})`,
border: `hsla(212, 90%, ${Math.min(lightness + 18, 84)}%, ${0.72 - step * 0.045})`,
};
};
const visiblePropagationHistory = useMemo(() => (
propagationHistory
.map((segment, order) => {
const range = normalizeRange(segment.startFrame, segment.endFrame);
return { ...segment, ...range, order };
const ageFromNewest = Math.min(Math.max(propagationHistory.length - 1 - order, 0), 4);
return { ...segment, ...range, order, ageFromNewest };
})
.filter((segment) => totalFrames > 0 && segment.endFrame >= 1 && segment.startFrame <= totalFrames)
), [propagationHistory, totalFrames]);
@@ -282,6 +297,32 @@ export function FrameTimeline({
{formatTime(currentSeconds)}
</div>
</div>
{totalFrames > 0 && (
<div
data-testid="current-frame-line"
aria-hidden="true"
className="pointer-events-none absolute top-[18px] bottom-[8px] z-[60] w-[2px] -translate-x-1/2 rounded-full bg-white shadow-[0_0_10px_rgba(255,255,255,0.85)]"
style={{ left: `${currentFrameLineLeft}%` }}
/>
)}
{visibleSelectedRange && (
<>
<div
data-testid="range-boundary-line"
aria-hidden="true"
title={`范围开始帧 ${visibleSelectedRange.startFrame}`}
className="pointer-events-none absolute top-[16px] bottom-[7px] z-[65] w-[2px] -translate-x-1/2 rounded-full bg-fuchsia-400 shadow-[0_0_12px_rgba(244,114,182,0.9)]"
style={{ left: `${rangeStartLineLeft}%` }}
/>
<div
data-testid="range-boundary-line"
aria-hidden="true"
title={`范围结束帧 ${visibleSelectedRange.endFrame}`}
className="pointer-events-none absolute top-[16px] bottom-[7px] z-[65] w-[2px] -translate-x-1/2 rounded-full bg-lime-300 shadow-[0_0_12px_rgba(190,242,100,0.9)]"
style={{ left: `${rangeEndLineLeft}%` }}
/>
</>
)}
<div
className={cn(
"mt-2 h-2.5 w-full relative bg-zinc-700/80 border-y border-white/10 shadow-inner",
@@ -321,21 +362,21 @@ export function FrameTimeline({
);
})}
{visiblePropagationHistory.map((segment) => {
const color = propagationHistoryColors[segment.colorIndex % propagationHistoryColors.length];
const color = propagationHistoryColor(segment.ageFromNewest);
const left = totalFrames > 0 ? ((segment.startFrame - 1) / totalFrames) * 100 : 0;
const width = totalFrames > 0 ? ((segment.endFrame - segment.startFrame + 1) / totalFrames) * 100 : 0;
const opacity = Math.max(0.48, 0.92 - (visiblePropagationHistory.length - 1 - segment.order) * 0.12);
return (
<div
key={segment.id}
data-testid="propagation-history-segment"
data-recency-level={segment.ageFromNewest}
title={segment.label || `自动传播记录:第 ${segment.startFrame}-${segment.endFrame}`}
className="pointer-events-none absolute inset-y-0 z-[15] rounded-[2px] border-x border-white/25"
className="pointer-events-none absolute inset-y-0 z-[15] rounded-[2px] border-x"
style={{
left: `${left}%`,
width: `${width}%`,
opacity,
background: `linear-gradient(to right, ${color.dark}, ${color.light})`,
backgroundColor: color.fill,
borderColor: color.border,
boxShadow: `0 0 10px ${color.glow}`,
}}
/>

View File

@@ -19,14 +19,19 @@ describe('Login', () => {
});
it('logs in with the development credentials and stores the token', async () => {
apiMock.login.mockResolvedValueOnce({ token: 'fake-jwt-token-for-admin' });
apiMock.login.mockResolvedValueOnce({
token: 'jwt-token',
username: 'admin',
user: { id: 1, username: 'admin', role: 'admin' },
});
render(<Login />);
fireEvent.click(screen.getByRole('button', { name: '安全登录' }));
await waitFor(() => expect(apiMock.login).toHaveBeenCalledWith('admin', '123456'));
expect(useStore.getState().isAuthenticated).toBe(true);
expect(localStorage.getItem('token')).toBe('fake-jwt-token-for-admin');
expect(useStore.getState().currentUser?.username).toBe('admin');
expect(localStorage.getItem('token')).toBe('jwt-token');
});
it('shows backend login errors', async () => {
@@ -39,4 +44,11 @@ describe('Login', () => {
expect(await screen.findByText('Invalid credentials')).toBeInTheDocument();
expect(useStore.getState().isAuthenticated).toBe(false);
});
it('marks login fields with browser autocomplete hints', () => {
render(<Login />);
expect(screen.getByDisplayValue('admin')).toHaveAttribute('autocomplete', 'username');
expect(screen.getByDisplayValue('123456')).toHaveAttribute('autocomplete', 'current-password');
});
});

View File

@@ -18,7 +18,7 @@ export function Login() {
try {
const data = await loginApi(username, password);
storeLogin(data.token);
storeLogin(data.token, data.user);
} catch (err: any) {
const msg = err?.response?.data?.detail || err?.response?.data?.error || '登录失败,请检查网络或凭证';
setError(msg);
@@ -47,6 +47,7 @@ export function Login() {
type="text"
value={username}
onChange={(e) => setUsername(e.target.value)}
autoComplete="username"
className="w-full bg-[#1a1a1a] border border-white/10 rounded-lg px-4 py-3 text-sm focus:outline-none focus:border-cyan-500/50 focus:ring-1 focus:ring-cyan-500/50 transition-all font-mono"
placeholder="输入账号"
/>
@@ -58,6 +59,7 @@ export function Login() {
type="password"
value={password}
onChange={(e) => setPassword(e.target.value)}
autoComplete="current-password"
className="w-full bg-[#1a1a1a] border border-white/10 rounded-lg px-4 py-3 text-sm focus:outline-none focus:border-cyan-500/50 focus:ring-1 focus:ring-cyan-500/50 transition-all font-mono"
placeholder="输入密码"
/>

View File

@@ -1,4 +1,4 @@
import { fireEvent, render, screen, waitFor, within } from '@testing-library/react';
import { act, fireEvent, render, screen, waitFor, within } from '@testing-library/react';
import { beforeEach, describe, expect, it, vi } from 'vitest';
import { resetStore } from '../test/storeTestUtils';
import { useStore } from '../store/useStore';
@@ -67,6 +67,9 @@ describe('OntologyInspector', () => {
expect(useStore.getState().activeTemplateId).toBe('t1');
expect(screen.getByText('胆囊')).toBeInTheDocument();
expect(screen.getByText('肝脏')).toBeInTheDocument();
expect(screen.getByText('maskid:1')).toBeInTheDocument();
expect(screen.getByText('maskid:2')).toBeInTheDocument();
expect(screen.queryByText(/z:/)).not.toBeInTheDocument();
});
it('adjusts workspace mask opacity from above the semantic tree', () => {
@@ -151,13 +154,86 @@ describe('OntologyInspector', () => {
classId: 'c2',
className: '肝脏',
classZIndex: 10,
classMaskId: 2,
label: '肝脏',
color: '#00ff00',
saveStatus: 'dirty',
saved: false,
}));
expect(screen.getByText('当前选中区域:')).toBeInTheDocument();
expect(screen.getByText('1')).toBeInTheDocument();
expect(screen.queryByText('当前选中区域:')).not.toBeInTheDocument();
});
it('applies class changes to the same propagation chain across frames', () => {
useStore.setState({
selectedMaskIds: ['annotation-10'],
masks: [
{
id: 'annotation-10',
annotationId: '10',
frameId: 'frame-1',
pathData: 'M 0 0 Z',
label: '旧标签',
color: '#06b6d4',
saveStatus: 'saved',
saved: true,
},
{
id: 'annotation-11',
annotationId: '11',
frameId: 'frame-2',
pathData: 'M 1 1 Z',
label: '旧传播标签',
color: '#06b6d4',
metadata: {
source_annotation_id: 10,
source_mask_id: 'annotation-10',
propagation_seed_key: 'annotation:10',
},
saveStatus: 'saved',
saved: true,
},
{
id: 'annotation-99',
annotationId: '99',
frameId: 'frame-3',
pathData: 'M 2 2 Z',
label: '无关区域',
color: '#ffffff',
metadata: { source_annotation_id: 99 },
saveStatus: 'saved',
saved: true,
},
],
});
render(<OntologyInspector />);
fireEvent.click(screen.getByText('肝脏'));
const updated = useStore.getState().masks;
expect(updated.find((mask) => mask.id === 'annotation-10')).toEqual(expect.objectContaining({
classId: 'c2',
className: '肝脏',
classMaskId: 2,
label: '肝脏',
color: '#00ff00',
saveStatus: 'dirty',
saved: false,
}));
expect(updated.find((mask) => mask.id === 'annotation-11')).toEqual(expect.objectContaining({
classId: 'c2',
className: '肝脏',
classMaskId: 2,
label: '肝脏',
color: '#00ff00',
saveStatus: 'dirty',
saved: false,
}));
expect(updated.find((mask) => mask.id === 'annotation-99')).toEqual(expect.objectContaining({
label: '无关区域',
color: '#ffffff',
saveStatus: 'saved',
saved: true,
}));
});
it('persists custom classes to the active backend template', async () => {
@@ -187,6 +263,59 @@ describe('OntologyInspector', () => {
expect(useStore.getState().templates[0].classes).toHaveLength(3);
});
it('persists dragged semantic class order as layer priority without changing maskid', async () => {
apiMock.updateTemplate.mockResolvedValueOnce({
id: 't1',
name: '腹腔镜模板',
classes: [
{ id: 'c2', name: '肝脏', color: '#00ff00', zIndex: 20, maskId: 2, category: '器官' },
{ id: 'c1', name: '胆囊', color: '#ff0000', zIndex: 10, maskId: 1, category: '器官' },
],
rules: [],
});
useStore.setState({
masks: [{
id: 'm-liver',
annotationId: '42',
frameId: 'frame-1',
classId: 'c2',
className: '肝脏',
classZIndex: 10,
pathData: 'M 0 0 Z',
label: '肝脏',
color: '#00ff00',
saveStatus: 'saved',
saved: true,
}],
});
render(<OntologyInspector />);
const liverButton = screen.getByRole('button', { name: /肝脏/ });
const gallbladderButton = screen.getByRole('button', { name: /胆囊/ });
const dataTransfer = {
effectAllowed: '',
dropEffect: '',
setData: vi.fn(),
getData: vi.fn(() => 'c2'),
};
fireEvent.dragStart(liverButton, { dataTransfer });
fireEvent.dragOver(gallbladderButton, { dataTransfer });
fireEvent.drop(gallbladderButton, { dataTransfer });
await waitFor(() => expect(apiMock.updateTemplate).toHaveBeenCalledWith('t1', expect.objectContaining({
classes: [
expect.objectContaining({ id: 'c2', zIndex: 20, maskId: 2 }),
expect.objectContaining({ id: 'c1', zIndex: 10, maskId: 1 }),
],
})));
expect(useStore.getState().masks[0]).toEqual(expect.objectContaining({
classZIndex: 20,
saveStatus: 'dirty',
saved: false,
}));
});
it('loads selected mask properties from the backend analyzer', async () => {
useStore.setState({
frames: [{ id: 'frame-1', projectId: 'p1', index: 0, url: '/1.jpg', width: 100, height: 100 }],
@@ -214,15 +343,40 @@ describe('OntologyInspector', () => {
expect(screen.queryByText('后端模型置信度')).not.toBeInTheDocument();
expect(screen.queryByText('0.8200')).not.toBeInTheDocument();
expect(screen.getByText('4 节点')).toBeInTheDocument();
fireEvent.click(screen.getByRole('button', { name: '重新提取拓扑锚点' }));
expect(screen.queryByRole('button', { name: '重新提取拓扑锚点' })).not.toBeInTheDocument();
expect(apiMock.analyzeMask).toHaveBeenLastCalledWith(
expect.objectContaining({ id: 'm1' }),
expect.objectContaining({ id: 'frame-1' }),
{ extractSkeleton: true },
);
});
it('applies backend edge smoothing to the selected mask and marks it dirty', async () => {
it('ignores aborted mask analysis requests without showing an error', async () => {
const consoleError = vi.spyOn(console, 'error').mockImplementation(() => {});
apiMock.analyzeMask.mockRejectedValueOnce({ code: 'ECONNABORTED', message: 'Request aborted' });
useStore.setState({
frames: [{ id: 'frame-1', projectId: 'p1', index: 0, url: '/1.jpg', width: 100, height: 100 }],
selectedMaskIds: ['m1'],
masks: [
{
id: 'm1',
frameId: 'frame-1',
pathData: 'M 0 0 Z',
label: '胆囊',
color: '#ff0000',
segmentation: [[10, 10, 20, 10, 20, 20]],
},
],
});
render(<OntologyInspector />);
await waitFor(() => expect(apiMock.analyzeMask).toHaveBeenCalled());
await waitFor(() => expect(screen.queryByText('后端属性读取失败')).not.toBeInTheDocument());
expect(consoleError).not.toHaveBeenCalled();
consoleError.mockRestore();
});
it('previews backend edge smoothing while moving the slider without marking the mask dirty', async () => {
useStore.setState({
frames: [{ id: 'frame-1', projectId: 'p1', index: 0, url: '/1.jpg', width: 100, height: 100 }],
selectedMaskIds: ['m1'],
@@ -244,13 +398,100 @@ describe('OntologyInspector', () => {
render(<OntologyInspector />);
fireEvent.change(screen.getByLabelText('边缘平滑强度'), { target: { value: '35' } });
fireEvent.click(screen.getByRole('button', { name: '应用边缘平滑' }));
await waitFor(() => expect(apiMock.smoothMaskGeometry).toHaveBeenCalledWith(
expect.objectContaining({ id: 'm1' }),
expect.objectContaining({ id: 'frame-1' }),
35,
));
await waitFor(() => expect(useStore.getState().masks[0]).toEqual(expect.objectContaining({
pathData: 'M 12 12 L 28 12 L 28 28 L 12 28 Z',
segmentation: [[12, 12, 28, 12, 28, 28, 12, 28]],
bbox: [12, 12, 16, 16],
area: 256,
saveStatus: 'saved',
saved: true,
metadata: { geometry_smoothing_preview: { strength: 35, method: 'chaikin' } },
})));
expect(screen.getByText('已应用边缘平滑强度 35预览中点击应用后写入当前 mask。')).toBeInTheDocument();
});
it('debounces backend edge smoothing preview while dragging the slider', async () => {
vi.useFakeTimers();
try {
useStore.setState({
frames: [{ id: 'frame-1', projectId: 'p1', index: 0, url: '/1.jpg', width: 100, height: 100 }],
selectedMaskIds: ['m1'],
masks: [
{
id: 'm1',
annotationId: '10',
frameId: 'frame-1',
pathData: 'M 10 10 L 30 10 L 30 30 Z',
label: '胆囊',
color: '#ff0000',
segmentation: [[10, 10, 30, 10, 30, 30]],
saveStatus: 'saved',
saved: true,
},
],
});
render(<OntologyInspector />);
fireEvent.change(screen.getByLabelText('边缘平滑强度'), { target: { value: '15' } });
fireEvent.change(screen.getByLabelText('边缘平滑强度'), { target: { value: '25' } });
fireEvent.change(screen.getByLabelText('边缘平滑强度'), { target: { value: '35' } });
expect(screen.getByText('正在等待停止拖动后生成边缘平滑预览...')).toBeInTheDocument();
expect(apiMock.smoothMaskGeometry).not.toHaveBeenCalled();
act(() => {
vi.advanceTimersByTime(219);
});
expect(apiMock.smoothMaskGeometry).not.toHaveBeenCalled();
await act(async () => {
vi.advanceTimersByTime(1);
await Promise.resolve();
});
expect(apiMock.smoothMaskGeometry).toHaveBeenCalledTimes(1);
expect(apiMock.smoothMaskGeometry).toHaveBeenCalledWith(
expect.objectContaining({ id: 'm1' }),
expect.objectContaining({ id: 'frame-1' }),
35,
);
} finally {
vi.useRealTimers();
}
});
it('applies a previewed edge smoothing result to the selected mask and marks it dirty', async () => {
useStore.setState({
frames: [{ id: 'frame-1', projectId: 'p1', index: 0, url: '/1.jpg', width: 100, height: 100 }],
selectedMaskIds: ['m1'],
masks: [
{
id: 'm1',
annotationId: '10',
frameId: 'frame-1',
pathData: 'M 10 10 L 30 10 L 30 30 Z',
label: '胆囊',
color: '#ff0000',
segmentation: [[10, 10, 30, 10, 30, 30]],
saveStatus: 'saved',
saved: true,
},
],
});
render(<OntologyInspector />);
fireEvent.change(screen.getByLabelText('边缘平滑强度'), { target: { value: '35' } });
await waitFor(() => expect(screen.getByRole('button', { name: '应用边缘平滑' })).not.toBeDisabled());
fireEvent.click(screen.getByRole('button', { name: '应用边缘平滑' }));
await waitFor(() => expect(useStore.getState().masks[0]).toEqual(expect.objectContaining({
pathData: 'M 12 12 L 28 12 L 28 28 L 12 28 Z',
segmentation: [[12, 12, 28, 12, 28, 28, 12, 28]],
@@ -258,8 +499,87 @@ describe('OntologyInspector', () => {
area: 256,
saveStatus: 'dirty',
saved: false,
metadata: { geometry_smoothing: { strength: 35, method: 'chaikin' } },
})));
expect(screen.getByText('已应用边缘平滑强度 35请保存后生效')).toBeInTheDocument();
expect(useStore.getState().masks[0].metadata?.geometry_smoothing).toBeUndefined();
expect(apiMock.smoothMaskGeometry).toHaveBeenCalledTimes(1);
expect(screen.getByText('0%')).toBeInTheDocument();
expect(screen.getByText('已应用边缘平滑强度 35已变为新的 mask强度已重置为 0请保存后生效')).toBeInTheDocument();
});
it('applies smoothing to linked propagation masks as one undoable geometry edit', async () => {
useStore.setState({
frames: [
{ id: 'frame-0', projectId: 'p1', index: 0, url: '/0.jpg', width: 100, height: 100 },
{ id: 'frame-1', projectId: 'p1', index: 1, url: '/1.jpg', width: 100, height: 100 },
{ id: 'frame-2', projectId: 'p1', index: 2, url: '/2.jpg', width: 100, height: 100 },
],
selectedMaskIds: ['seed-mask'],
masks: [
{
id: 'seed-mask',
annotationId: '10',
frameId: 'frame-1',
pathData: 'M 10 10 L 30 10 L 30 30 Z',
label: '胆囊',
color: '#ff0000',
segmentation: [[10, 10, 30, 10, 30, 30]],
saveStatus: 'saved',
saved: true,
},
{
id: 'prop-backward',
annotationId: '11',
frameId: 'frame-0',
pathData: 'M 11 11 L 31 11 L 31 31 Z',
label: '胆囊',
color: '#ff0000',
segmentation: [[11, 11, 31, 11, 31, 31]],
saveStatus: 'saved',
saved: true,
metadata: { source_annotation_id: 10, source_mask_id: 'annotation-10', propagated_from_frame_id: 10 },
},
{
id: 'prop-forward',
annotationId: '12',
frameId: 'frame-2',
pathData: 'M 12 12 L 32 12 L 32 32 Z',
label: '胆囊',
color: '#ff0000',
segmentation: [[12, 12, 32, 12, 32, 32]],
saveStatus: 'saved',
saved: true,
metadata: { source_annotation_id: 10, source_mask_id: 'annotation-10', propagated_from_frame_id: 10 },
},
],
});
render(<OntologyInspector />);
fireEvent.change(screen.getByLabelText('边缘平滑强度'), { target: { value: '35' } });
await waitFor(() => expect(screen.getByRole('button', { name: '应用边缘平滑' })).not.toBeDisabled());
fireEvent.click(screen.getByRole('button', { name: '应用边缘平滑' }));
await waitFor(() => expect(apiMock.smoothMaskGeometry).toHaveBeenCalledTimes(3));
await waitFor(() => expect(useStore.getState().masks).toEqual([
expect.objectContaining({ id: 'seed-mask', pathData: 'M 12 12 L 28 12 L 28 28 L 12 28 Z', saveStatus: 'dirty', saved: false }),
expect.objectContaining({ id: 'prop-backward', pathData: 'M 12 12 L 28 12 L 28 28 L 12 28 Z', saveStatus: 'dirty', saved: false }),
expect.objectContaining({ id: 'prop-forward', pathData: 'M 12 12 L 28 12 L 28 28 L 12 28 Z', saveStatus: 'dirty', saved: false }),
]));
expect(useStore.getState().masks.every((mask) => !mask.metadata?.geometry_smoothing)).toBe(true);
expect(screen.getByText('已应用边缘平滑强度 35已同步应用到传播链 3 个对应 mask强度已重置为 0请保存后生效')).toBeInTheDocument();
act(() => {
useStore.getState().undoMasks();
});
expect(useStore.getState().masks.map((mask) => mask.pathData)).toEqual([
'M 10 10 L 30 10 L 30 30 Z',
'M 11 11 L 31 11 L 31 31 Z',
'M 12 12 L 32 12 L 32 32 Z',
]);
act(() => {
useStore.getState().redoMasks();
});
expect(useStore.getState().masks.every((mask) => mask.pathData === 'M 12 12 L 28 12 L 28 28 L 12 28 Z')).toBe(true);
});
});

View File

@@ -1,10 +1,63 @@
import React, { useEffect, useMemo, useRef, useState } from 'react';
import { ChevronDown, Tag, Eye, Plus, X, Loader2 } from 'lucide-react';
import { ChevronDown, Tag, Eye, Plus, X, Loader2, GripVertical } from 'lucide-react';
import { useStore } from '../store/useStore';
import type { TemplateClass } from '../store/useStore';
import type { Mask, TemplateClass } from '../store/useStore';
import { cn } from '../lib/utils';
import { getActiveTemplate } from '../lib/templateSelection';
import { analyzeMask, smoothMaskGeometry, updateTemplate, type MaskAnalysisResult } from '../lib/api';
import { analyzeMask, smoothMaskGeometry, updateTemplate, type MaskAnalysisResult, type SmoothMaskGeometryResult } from '../lib/api';
import { nextClassMaskId, normalizeClassMaskIds } from '../lib/maskIds';
const SMOOTHING_PREVIEW_DEBOUNCE_MS = 220;
const isRequestAbortError = (err: unknown) => {
const error = err as { code?: string; message?: string; name?: string } | null;
const message = error?.message || '';
return error?.code === 'ERR_CANCELED'
|| error?.code === 'ECONNABORTED'
|| error?.name === 'AbortError'
|| /request aborted|aborted|cancell?ed/i.test(message);
};
function metadataNumber(value: unknown): number | null {
const parsed = Number(value);
return Number.isFinite(parsed) && parsed > 0 ? parsed : null;
}
function propagationSourceMaskTokens(value: unknown): string[] {
if (typeof value !== 'string' || value.length === 0) return [];
const tokens = [`mask:${value}`];
const annotationMatch = value.match(/^annotation-(\d+)$/);
if (annotationMatch) {
tokens.push(`annotation:${annotationMatch[1]}`);
}
return tokens;
}
function propagationLineageTokens(mask: { id: string; annotationId?: string; metadata?: Record<string, unknown> }): Set<string> {
const metadata = mask.metadata || {};
const tokens = new Set<string>([`mask:${mask.id}`]);
if (mask.annotationId) {
tokens.add(`annotation:${mask.annotationId}`);
}
const sourceAnnotationId = metadataNumber(metadata.source_annotation_id);
if (sourceAnnotationId !== null) {
tokens.add(`annotation:${sourceAnnotationId}`);
}
propagationSourceMaskTokens(metadata.source_mask_id).forEach((token) => tokens.add(token));
if (typeof metadata.propagation_seed_key === 'string' && metadata.propagation_seed_key.length > 0) {
tokens.add(`seed-key:${metadata.propagation_seed_key}`);
}
return tokens;
}
function findPropagationChainMaskIds(selectedMask: Pick<Mask, 'id' | 'annotationId' | 'metadata'>, masks: Mask[]): Set<string> {
const selectedTokens = propagationLineageTokens(selectedMask);
return new Set(
masks
.filter((mask) => Array.from(selectedTokens).some((token) => propagationLineageTokens(mask).has(token)))
.map((mask) => mask.id),
);
}
export function OntologyInspector() {
const templates = useStore((state) => state.templates);
@@ -27,20 +80,38 @@ export function OntologyInspector() {
const [newClassColor, setNewClassColor] = useState('#06b6d4');
const [isSavingClass, setIsSavingClass] = useState(false);
const [classSaveMessage, setClassSaveMessage] = useState('');
const [dragClassId, setDragClassId] = useState<string | null>(null);
const [maskAnalysis, setMaskAnalysis] = useState<MaskAnalysisResult | null>(null);
const [isAnalyzingMask, setIsAnalyzingMask] = useState(false);
const [analysisMessage, setAnalysisMessage] = useState('');
const [smoothingStrength, setSmoothingStrength] = useState(0);
const [isPreviewingSmoothing, setIsPreviewingSmoothing] = useState(false);
const [isSmoothingMask, setIsSmoothingMask] = useState(false);
const activeTemplate = getActiveTemplate(templates, activeTemplateId);
const templateClasses = activeTemplate?.classes || [];
const templateClasses = normalizeClassMaskIds(activeTemplate?.classes || []);
const allClasses = [...templateClasses].sort((a, b) => b.zIndex - a.zIndex);
const selectedMask = masks.find((mask) => selectedMaskIds.includes(mask.id)) || null;
const selectedMaskLabel = selectedMask?.className || selectedMask?.label || '未选择';
const currentFrame = frames[currentFrameIndex] || null;
const classButtonRefs = useRef(new Map<string, HTMLButtonElement>());
const skipNextAutoAnalysisRef = useRef(false);
const analysisRequestIdRef = useRef(0);
const smoothingPreviewRef = useRef<{
maskId: string;
baseMask: NonNullable<typeof selectedMask>;
strength: number;
result: SmoothMaskGeometryResult | null;
applied: boolean;
requestId: number;
} | null>(null);
const smoothingRequestIdRef = useRef(0);
const smoothingPreviewTimerRef = useRef<number | null>(null);
const clearSmoothingPreviewTimer = React.useCallback(() => {
if (smoothingPreviewTimerRef.current === null) return;
window.clearTimeout(smoothingPreviewTimerRef.current);
smoothingPreviewTimerRef.current = null;
}, []);
const selectedMaskClass = useMemo(() => {
if (!selectedMask) return null;
@@ -78,14 +149,21 @@ export function OntologyInspector() {
if (!hasSelectedMasks) return;
const templateId = activeTemplate?.id || activeTemplateId || undefined;
const targetIdSet = new Set<string>();
masks
.filter((mask) => selectedIdSet.has(mask.id))
.forEach((mask) => {
findPropagationChainMaskIds(mask, masks).forEach((maskId) => targetIdSet.add(maskId));
});
const updatedMasks = masks.map((mask) => {
if (!selectedIdSet.has(mask.id)) return mask;
if (!targetIdSet.has(mask.id)) return mask;
return {
...mask,
templateId: templateId || mask.templateId,
classId: templateClass.id,
className: templateClass.name,
classZIndex: templateClass.zIndex,
classMaskId: templateClass.maskId,
label: templateClass.name,
color: templateClass.color,
saveStatus: mask.annotationId ? 'dirty' as const : 'draft' as const,
@@ -101,33 +179,63 @@ export function OntologyInspector() {
]);
};
const refreshMaskAnalysis = async (extractSkeleton = false) => {
const refreshMaskAnalysis = async () => {
const requestId = analysisRequestIdRef.current + 1;
analysisRequestIdRef.current = requestId;
if (!selectedMask || !currentFrame) {
setMaskAnalysis(null);
setAnalysisMessage(selectedMask ? '当前帧信息不可用,无法读取后端属性' : '请选择一个 mask 查看后端属性');
return;
}
setIsAnalyzingMask(true);
setAnalysisMessage('');
try {
const result = await analyzeMask(selectedMask, currentFrame, { extractSkeleton });
const result = await analyzeMask(selectedMask, currentFrame);
if (analysisRequestIdRef.current !== requestId) return;
setMaskAnalysis(result);
setAnalysisMessage(result.message);
} catch (err) {
if (analysisRequestIdRef.current !== requestId || isRequestAbortError(err)) return;
console.error('Mask analysis failed:', err);
setMaskAnalysis(null);
setAnalysisMessage('后端属性读取失败');
} finally {
setIsAnalyzingMask(false);
}
};
const restoreSmoothingPreview = React.useCallback(() => {
const preview = smoothingPreviewRef.current;
if (!preview || preview.applied) {
smoothingPreviewRef.current = null;
return;
}
const state = useStore.getState();
useStore.setState({
masks: state.masks.map((mask) => (mask.id === preview.maskId ? preview.baseMask : mask)),
selectedMaskIds: state.selectedMaskIds,
});
smoothingPreviewRef.current = null;
}, []);
React.useEffect(() => {
return () => {
analysisRequestIdRef.current += 1;
clearSmoothingPreviewTimer();
restoreSmoothingPreview();
};
}, [clearSmoothingPreviewTimer, restoreSmoothingPreview]);
React.useEffect(() => {
const preview = smoothingPreviewRef.current;
if (preview && preview.maskId !== selectedMask?.id) {
restoreSmoothingPreview();
}
}, [restoreSmoothingPreview, selectedMask?.id]);
React.useEffect(() => {
if (skipNextAutoAnalysisRef.current) {
skipNextAutoAnalysisRef.current = false;
return;
}
void refreshMaskAnalysis(false);
void refreshMaskAnalysis();
// selectedMask is intentionally tracked by id and geometry fields to avoid
// re-running analysis for unrelated store changes.
}, [selectedMask?.id, selectedMask?.segmentation, selectedMask?.points, currentFrame?.id]);
@@ -140,43 +248,202 @@ export function OntologyInspector() {
setSmoothingStrength(Number.isFinite(strength) ? Math.min(Math.max(strength, 0), 100) : 0);
}, [selectedMask?.id]);
const applySmoothingResultToMask = React.useCallback((
mask: Mask,
result: SmoothMaskGeometryResult,
options: { commit: boolean },
): Mask => {
const metadata = { ...(mask.metadata || {}) };
delete metadata.geometry_smoothing_preview;
if (options.commit) {
delete metadata.geometry_smoothing;
} else {
metadata.geometry_smoothing_preview = result.smoothing;
}
return {
...mask,
pathData: result.pathData,
segmentation: result.segmentation,
bbox: result.bbox,
area: result.area,
metadata,
...(options.commit
? {
saveStatus: mask.annotationId ? 'dirty' as const : 'draft' as const,
saved: mask.annotationId ? false : mask.saved,
}
: {}),
};
}, []);
const updateMaskWithSmoothingResult = React.useCallback((
maskId: string,
result: SmoothMaskGeometryResult,
options: { commit: boolean },
) => {
const state = useStore.getState();
const nextMasks = state.masks.map((mask) => (
mask.id === maskId ? applySmoothingResultToMask(mask, result, options) : mask
));
if (options.commit) {
setMasks(nextMasks);
} else {
useStore.setState({ masks: nextMasks });
}
}, [applySmoothingResultToMask, setMasks]);
const applySmoothingResultToAnalysis = React.useCallback((
result: SmoothMaskGeometryResult,
sourceMask: NonNullable<typeof selectedMask>,
suffix: string,
) => {
setMaskAnalysis({
confidence: null,
confidence_source: 'manual_or_imported',
topology_anchor_count: result.topology_anchor_count,
topology_anchors: result.topology_anchors,
area: result.area,
bbox: result.bbox,
source: sourceMask.metadata?.source as string | undefined,
message: result.message,
});
setAnalysisMessage(`${result.message}${suffix}`);
}, []);
const runSmoothingPreview = React.useCallback(async (nextStrength: number) => {
if (!selectedMask || !currentFrame) return;
const existingPreview = smoothingPreviewRef.current?.maskId === selectedMask.id
? smoothingPreviewRef.current
: null;
const baseMask = existingPreview?.baseMask || selectedMask;
const requestId = smoothingRequestIdRef.current + 1;
smoothingRequestIdRef.current = requestId;
if (nextStrength <= 0) {
clearSmoothingPreviewTimer();
smoothingPreviewRef.current = {
maskId: selectedMask.id,
baseMask,
strength: 0,
result: null,
applied: false,
requestId,
};
skipNextAutoAnalysisRef.current = true;
useStore.setState({
masks: useStore.getState().masks.map((mask) => (mask.id === selectedMask.id ? baseMask : mask)),
});
setAnalysisMessage('已预览恢复原始边缘,点击应用后写入当前 mask。');
setIsPreviewingSmoothing(false);
return;
}
setAnalysisMessage('正在生成边缘平滑预览...');
try {
const result = await smoothMaskGeometry(baseMask, currentFrame, nextStrength);
if (smoothingRequestIdRef.current !== requestId) return;
smoothingPreviewRef.current = {
maskId: selectedMask.id,
baseMask,
strength: nextStrength,
result,
applied: false,
requestId,
};
skipNextAutoAnalysisRef.current = true;
updateMaskWithSmoothingResult(selectedMask.id, result, { commit: false });
applySmoothingResultToAnalysis(result, baseMask, ',预览中,点击应用后写入当前 mask。');
} catch (err) {
if (smoothingRequestIdRef.current !== requestId) return;
console.error('Mask smoothing preview failed:', err);
setAnalysisMessage('边缘平滑预览失败,请检查后端服务');
} finally {
if (smoothingRequestIdRef.current === requestId) {
setIsPreviewingSmoothing(false);
}
}
}, [applySmoothingResultToAnalysis, clearSmoothingPreviewTimer, currentFrame, selectedMask, updateMaskWithSmoothingResult]);
const previewSmoothing = React.useCallback((nextStrength: number) => {
setSmoothingStrength(nextStrength);
clearSmoothingPreviewTimer();
if (!selectedMask || !currentFrame) return;
if (nextStrength <= 0) {
void runSmoothingPreview(nextStrength);
return;
}
setIsPreviewingSmoothing(true);
setAnalysisMessage('正在等待停止拖动后生成边缘平滑预览...');
smoothingPreviewTimerRef.current = window.setTimeout(() => {
smoothingPreviewTimerRef.current = null;
void runSmoothingPreview(nextStrength);
}, SMOOTHING_PREVIEW_DEBOUNCE_MS);
}, [clearSmoothingPreviewTimer, currentFrame, runSmoothingPreview, selectedMask]);
const handleApplySmoothing = async () => {
if (!selectedMask || !currentFrame) {
setAnalysisMessage('请选择一个 mask 后再应用边缘平滑');
return;
}
clearSmoothingPreviewTimer();
smoothingRequestIdRef.current += 1;
setIsSmoothingMask(true);
setAnalysisMessage('');
try {
const result = await smoothMaskGeometry(selectedMask, currentFrame, smoothingStrength);
skipNextAutoAnalysisRef.current = true;
setMasks(masks.map((mask) => {
if (mask.id !== selectedMask.id) return mask;
return {
...mask,
pathData: result.pathData,
segmentation: result.segmentation,
bbox: result.bbox,
area: result.area,
metadata: {
...(mask.metadata || {}),
geometry_smoothing: result.smoothing,
},
saveStatus: mask.annotationId ? 'dirty' as const : 'draft' as const,
saved: mask.annotationId ? false : mask.saved,
};
}));
setMaskAnalysis({
confidence: null,
confidence_source: 'manual_or_imported',
topology_anchor_count: result.topology_anchor_count,
topology_anchors: result.topology_anchors,
area: result.area,
bbox: result.bbox,
source: selectedMask.metadata?.source as string | undefined,
message: result.message,
const existingPreview = smoothingPreviewRef.current?.maskId === selectedMask.id
&& smoothingPreviewRef.current.strength === smoothingStrength
? smoothingPreviewRef.current
: null;
const baseMask = existingPreview?.baseMask || selectedMask;
if (smoothingStrength <= 0) {
smoothingPreviewRef.current = null;
setSmoothingStrength(0);
setAnalysisMessage('边缘平滑强度为 0当前 mask 保持原始边缘。');
return;
}
const state = useStore.getState();
const frameById = new Map(state.frames.map((frame) => [String(frame.id), frame]));
const chainMaskIds = findPropagationChainMaskIds(baseMask, state.masks);
chainMaskIds.add(selectedMask.id);
const selectedResult = existingPreview?.result || await smoothMaskGeometry(baseMask, currentFrame, smoothingStrength);
const resultEntries = new Map<string, SmoothMaskGeometryResult>();
resultEntries.set(selectedMask.id, selectedResult);
await Promise.all(
Array.from(chainMaskIds)
.filter((maskId) => maskId !== selectedMask.id)
.map(async (maskId) => {
const mask = state.masks.find((item) => item.id === maskId);
const frame = mask ? frameById.get(String(mask.frameId)) : null;
if (!mask || !frame) return;
resultEntries.set(maskId, await smoothMaskGeometry(mask, frame, smoothingStrength));
}),
);
const latestMasks = useStore.getState().masks;
const historyBaseMasks = latestMasks.map((mask) => (mask.id === selectedMask.id ? baseMask : mask));
useStore.setState({ masks: historyBaseMasks });
const nextMasks = historyBaseMasks.map((mask) => {
const result = resultEntries.get(mask.id);
if (!result) return mask;
return applySmoothingResultToMask(mask, result, { commit: true });
});
setAnalysisMessage(`${result.message},请保存后生效`);
skipNextAutoAnalysisRef.current = true;
setMasks(nextMasks);
if (smoothingPreviewRef.current) {
smoothingPreviewRef.current.applied = true;
}
smoothingPreviewRef.current = null;
setSmoothingStrength(0);
applySmoothingResultToAnalysis(
selectedResult,
baseMask,
resultEntries.size > 1
? `,已同步应用到传播链 ${resultEntries.size} 个对应 mask强度已重置为 0请保存后生效`
: ',已变为新的 mask强度已重置为 0请保存后生效',
);
} catch (err) {
console.error('Mask smoothing failed:', err);
setAnalysisMessage('边缘平滑失败,请检查后端服务');
@@ -197,6 +464,7 @@ export function OntologyInspector() {
name: newClassName.trim(),
color: newClassColor,
zIndex: maxZ + 10,
maskId: nextClassMaskId(templateClasses),
category: '自定义',
};
setIsSavingClass(true);
@@ -205,7 +473,7 @@ export function OntologyInspector() {
const updated = await updateTemplate(activeTemplate.id, {
name: activeTemplate.name,
description: activeTemplate.description,
classes: [...templateClasses, newClass],
classes: normalizeClassMaskIds([...templateClasses, newClass]),
rules: activeTemplate.rules || [],
});
updateTemplateStore(updated);
@@ -222,6 +490,62 @@ export function OntologyInspector() {
}
};
const handleReorderClass = async (sourceClassId: string, targetClassId: string) => {
if (!activeTemplate || sourceClassId === targetClassId) {
setDragClassId(null);
return;
}
const sourceIndex = allClasses.findIndex((item) => item.id === sourceClassId);
const targetIndex = allClasses.findIndex((item) => item.id === targetClassId);
if (sourceIndex < 0 || targetIndex < 0) {
setDragClassId(null);
return;
}
const reordered = [...allClasses];
const [source] = reordered.splice(sourceIndex, 1);
reordered.splice(targetIndex, 0, source);
const nextClasses = normalizeClassMaskIds(
reordered.map((item, index) => ({
...item,
zIndex: (reordered.length - index) * 10,
})),
);
setIsSavingClass(true);
setClassSaveMessage('正在保存分类覆盖顺序...');
try {
const updated = await updateTemplate(activeTemplate.id, {
name: activeTemplate.name,
description: activeTemplate.description,
classes: nextClasses,
rules: activeTemplate.rules || [],
});
updateTemplateStore(updated);
setActiveTemplateId(updated.id);
const zIndexByClassId = new Map(nextClasses.map((item) => [item.id, item.zIndex]));
setMasks(useStore.getState().masks.map((mask) => (
mask.classId && zIndexByClassId.has(mask.classId)
? {
...mask,
classZIndex: zIndexByClassId.get(mask.classId),
saveStatus: mask.annotationId ? 'dirty' as const : mask.saveStatus,
saved: mask.annotationId ? false : mask.saved,
}
: mask
)));
const nextActiveClass = nextClasses.find((item) => item.id === activeClassId);
if (nextActiveClass) setActiveClass(nextActiveClass);
setClassSaveMessage('分类覆盖顺序已保存');
} catch (err) {
console.error('Reorder class failed:', err);
setClassSaveMessage('分类覆盖顺序保存失败');
} finally {
setIsSavingClass(false);
setDragClassId(null);
}
};
return (
<div className="w-60 bg-[#0d0d0d] flex flex-col border-l border-white/5 shrink-0 z-10 overflow-hidden">
<div className="flex-1 overflow-y-auto seg-scrollbar p-4 flex flex-col gap-6">
@@ -275,13 +599,14 @@ export function OntologyInspector() {
{/* Semantic Classification Tree */}
<div>
<h3 className="text-[10px] font-bold text-gray-500 uppercase tracking-widest mb-3 flex justify-between items-center">
<span> (/Z-Index)</span>
<span></span>
</h3>
<div className="space-y-2">
{allClasses.map(cls => (
<div key={cls.id} className="flex flex-col gap-1">
<button
type="button"
draggable={Boolean(activeTemplate) && !isSavingClass}
ref={(node) => {
if (node) {
classButtonRefs.current.set(cls.id, node);
@@ -290,18 +615,36 @@ export function OntologyInspector() {
}
}}
onClick={() => handleSelectClass(cls)}
onDragStart={(event) => {
setDragClassId(cls.id);
event.dataTransfer.effectAllowed = 'move';
event.dataTransfer.setData('text/plain', cls.id);
}}
onDragOver={(event) => {
if (!dragClassId || dragClassId === cls.id) return;
event.preventDefault();
event.dataTransfer.dropEffect = 'move';
}}
onDrop={(event) => {
event.preventDefault();
const sourceId = event.dataTransfer.getData('text/plain') || dragClassId;
if (sourceId) void handleReorderClass(sourceId, cls.id);
}}
onDragEnd={() => setDragClassId(null)}
aria-current={activeClassId === cls.id ? 'true' : undefined}
className={cn(
'flex items-center justify-between p-2 rounded bg-white/5 hover:bg-white/10 cursor-pointer group transition-colors text-left border',
activeClassId === cls.id ? 'border-cyan-500/50 bg-cyan-500/10' : 'border-transparent',
dragClassId === cls.id && 'opacity-50',
)}
>
<div className="flex items-center gap-2">
<GripVertical size={13} className="text-gray-600 group-hover:text-gray-400" aria-hidden="true" />
<span className="w-2.5 h-2.5 rounded-sm" style={{ backgroundColor: cls.color }} />
<span className="text-xs font-medium text-gray-200">{cls.name}</span>
</div>
<div className="flex items-center gap-3">
<span className="text-[10px] text-gray-500 font-mono">z:{cls.zIndex}</span>
<span className="text-[10px] text-gray-500 font-mono">maskid:{cls.maskId}</span>
<Eye size={14} className="text-gray-500 group-hover:text-gray-300" />
</div>
</button>
@@ -366,10 +709,6 @@ export function OntologyInspector() {
</span>
</div>
<div className="space-y-3">
<div className="flex items-center justify-between">
<span className="text-[10px] text-gray-500 uppercase">:</span>
<span className="text-xs font-mono text-gray-300">{selectedMaskIds.length}</span>
</div>
<div className="flex items-center justify-between">
<span className="text-[10px] text-gray-500 uppercase">:</span>
<span className="text-xs font-mono text-gray-300">{maskAnalysis?.topology_anchor_count ?? 0} </span>
@@ -387,28 +726,21 @@ export function OntologyInspector() {
max={100}
step={5}
value={smoothingStrength}
onChange={(event) => setSmoothingStrength(Number(event.target.value))}
onChange={(event) => void previewSmoothing(Number(event.target.value))}
disabled={!selectedMask || isSmoothingMask}
className="w-full accent-cyan-500 disabled:opacity-40"
/>
<button
onClick={handleApplySmoothing}
disabled={!selectedMask || !currentFrame || isSmoothingMask}
disabled={!selectedMask || !currentFrame || isSmoothingMask || isPreviewingSmoothing}
className="mt-2 w-full bg-cyan-500/10 hover:bg-cyan-500/20 border border-cyan-500/20 text-xs text-cyan-100 py-1.5 rounded transition-colors disabled:opacity-40 disabled:cursor-not-allowed"
>
{isSmoothingMask ? '平滑中...' : '应用边缘平滑'}
{isSmoothingMask ? '平滑中...' : isPreviewingSmoothing ? '预览中...' : '应用边缘平滑'}
</button>
</div>
{analysisMessage && (
<div className="text-[10px] leading-relaxed text-gray-500">{analysisMessage}</div>
)}
<button
onClick={() => refreshMaskAnalysis(true)}
disabled={!selectedMask || isAnalyzingMask}
className="w-full mt-2 bg-white/5 hover:bg-white/10 border border-white/10 text-xs text-gray-300 py-1.5 rounded transition-colors disabled:opacity-40 disabled:cursor-not-allowed"
>
{isAnalyzingMask ? '提取中...' : '重新提取拓扑锚点'}
</button>
</div>
</div>
</div>

View File

@@ -0,0 +1,41 @@
import { fireEvent, render, screen } from '@testing-library/react';
import { beforeEach, describe, expect, it, vi } from 'vitest';
import { resetStore } from '../test/storeTestUtils';
import { useStore } from '../store/useStore';
import { Sidebar } from './Sidebar';
vi.mock('./ModelStatusBadge', () => ({
ModelStatusBadge: () => <div></div>,
}));
describe('Sidebar', () => {
beforeEach(() => {
resetStore();
});
it('shows admin user management only for admin users', () => {
const setActiveModule = vi.fn();
useStore.setState({ currentUser: { id: 1, username: 'admin', role: 'admin' } });
render(<Sidebar activeModule="dashboard" setActiveModule={setActiveModule} />);
fireEvent.click(screen.getByTitle('用户管理'));
expect(setActiveModule).toHaveBeenCalledWith('admin');
});
it('hides admin user management for non-admin users', () => {
useStore.setState({ currentUser: { id: 2, username: 'doctor', role: 'annotator' } });
render(<Sidebar activeModule="dashboard" setActiveModule={vi.fn()} />);
expect(screen.queryByTitle('用户管理')).not.toBeInTheDocument();
});
it('uses an explicit AI-styled icon for AI segmentation', () => {
useStore.setState({ currentUser: { id: 2, username: 'doctor', role: 'annotator' } });
render(<Sidebar activeModule="dashboard" setActiveModule={vi.fn()} />);
expect(screen.getByTitle('AI智能分割').querySelector('[data-testid="ai-segmentation-icon"]')).toBeInTheDocument();
});
});

View File

@@ -1,8 +1,10 @@
import React from 'react';
import { Home, FolderOpen, Edit3, LayoutTemplate, BrainCircuit } from 'lucide-react';
import { Home, FolderOpen, Edit3, LayoutTemplate, LogOut, UserCircle, ShieldCheck } from 'lucide-react';
import { cn } from '../lib/utils';
import type { ActiveModule } from '../App';
import { ModelStatusBadge } from './ModelStatusBadge';
import { useStore } from '../store/useStore';
import { AiSegmentationIcon } from './AiSegmentationIcon';
interface SidebarProps {
activeModule: ActiveModule;
@@ -10,12 +12,15 @@ interface SidebarProps {
}
export function Sidebar({ activeModule, setActiveModule }: SidebarProps) {
const currentUser = useStore((state) => state.currentUser);
const logout = useStore((state) => state.logout);
const navItems = [
{ id: 'dashboard', icon: Home, label: '总体概况' },
{ id: 'projects', icon: FolderOpen, label: '项目库' },
{ id: 'workspace', icon: Edit3, label: '分割工作区' },
{ id: 'ai', icon: BrainCircuit, label: 'AI智能分割' },
{ id: 'ai', icon: AiSegmentationIcon, label: 'AI智能分割' },
{ id: 'templates', icon: LayoutTemplate, label: '模板库' },
...(currentUser?.role === 'admin' ? [{ id: 'admin', icon: ShieldCheck, label: '用户管理' }] : []),
] as const;
return (
@@ -49,6 +54,17 @@ export function Sidebar({ activeModule, setActiveModule }: SidebarProps) {
</nav>
<div className="mt-auto mb-4 flex flex-col gap-4">
<ModelStatusBadge compact />
<button
type="button"
title={currentUser ? `当前用户:${currentUser.username},点击退出` : '退出登录'}
onClick={logout}
className="group relative flex h-9 w-9 items-center justify-center rounded-lg border border-white/10 bg-white/5 text-gray-400 transition-colors hover:border-red-400/40 hover:bg-red-500/10 hover:text-red-200"
>
{currentUser ? <UserCircle size={20} /> : <LogOut size={20} />}
<span className="absolute left-full ml-2 whitespace-nowrap rounded border border-[#333] bg-[#222] px-2 py-1 text-xs text-gray-200 opacity-0 shadow-xl transition-all group-hover:opacity-100">
{currentUser ? `${currentUser.username} / 退出` : '退出登录'}
</span>
</button>
</div>
</aside>
);

View File

@@ -39,6 +39,8 @@ describe('TemplateRegistry', () => {
expect(await screen.findAllByText('腹腔镜胆囊切除术')).toHaveLength(2);
expect(screen.getByText('胆囊')).toBeInTheDocument();
expect(screen.getAllByText(/maskid: ?1/).length).toBeGreaterThan(0);
expect(screen.queryByText(/Z-Level/)).not.toBeInTheDocument();
});
it('creates a template and stores it globally', async () => {

View File

@@ -3,6 +3,7 @@ import { Settings, Database, Trash2, Edit3, Plus, Loader2, X, GripVertical, Impo
import { cn } from '../lib/utils';
import { useStore } from '../store/useStore';
import { getTemplates, createTemplate, updateTemplate, deleteTemplate } from '../lib/api';
import { nextClassMaskId, normalizeClassMaskIds } from '../lib/maskIds';
import type { Template, TemplateClass } from '../store/useStore';
import { TransientNotice, type NoticeState, type NoticeTone } from './TransientNotice';
@@ -86,7 +87,7 @@ export function TemplateRegistry() {
setSelectedTemplate(template);
setEditName(template.name);
setEditDesc(template.description || '');
setEditClasses(template.classes ? [...template.classes] : []);
setEditClasses(normalizeClassMaskIds(template.classes ? [...template.classes] : []));
setShowModal(true);
};
@@ -97,7 +98,7 @@ export function TemplateRegistry() {
const basePayload = {
name: editName.trim(),
description: editDesc.trim() || undefined,
classes: editClasses,
classes: normalizeClassMaskIds(editClasses),
rules: [],
color: selectedTemplate ? (selectedTemplate as any).color || '#06b6d4' : '#06b6d4',
z_index: selectedTemplate ? (selectedTemplate as any).z_index ?? 0 : 0,
@@ -138,6 +139,7 @@ export function TemplateRegistry() {
name: '新类别',
color: generateColor(editClasses.length, Math.max(editClasses.length + 1, 8)),
zIndex: editClasses.length > 0 ? Math.max(...editClasses.map((c) => c.zIndex)) + 10 : 10,
maskId: nextClassMaskId(editClasses),
category: '未分类',
};
setEditClasses([...editClasses, newClass]);
@@ -179,6 +181,7 @@ export function TemplateRegistry() {
return;
}
const firstMaskId = nextClassMaskId(editClasses);
const imported: TemplateClass[] = names.map((name: string, i: number) => {
const rgb = colors[i] || [100, 100, 100];
const hex = `#${rgb[0].toString(16).padStart(2, '0')}${rgb[1].toString(16).padStart(2, '0')}${rgb[2].toString(16).padStart(2, '0')}`;
@@ -187,6 +190,7 @@ export function TemplateRegistry() {
name,
color: hex,
zIndex: (names.length - i) * 10,
maskId: firstMaskId + i,
category: '批量导入',
};
});
@@ -208,6 +212,7 @@ export function TemplateRegistry() {
name,
color: hex,
zIndex: (LAPAROSCOPIC_NAMES.length - i) * 10,
maskId: i + 1,
category: '腹腔镜胆囊切除术',
};
});
@@ -308,13 +313,13 @@ export function TemplateRegistry() {
(Painter's Algorithm Weight)
</h3>
<div className="space-y-2">
{(activeTemplate.classes || []).sort((a, b) => b.zIndex - a.zIndex).map((cls) => (
{normalizeClassMaskIds(activeTemplate.classes || []).sort((a, b) => b.zIndex - a.zIndex).map((cls) => (
<div key={cls.id} className="grid grid-cols-4 gap-4 p-3 bg-[#0d0d0d] border border-white/5 rounded items-center">
<div className="col-span-1 flex items-center gap-2">
<div className="w-3 h-3 rounded" style={{ backgroundColor: cls.color }}></div>
<span className="font-medium text-sm text-gray-300">{cls.name}</span>
</div>
<div className="col-span-1 font-mono text-xs text-gray-500">优先级 Z-Level: {cls.zIndex}</div>
<div className="col-span-1 font-mono text-xs text-gray-500">maskid: {cls.maskId}</div>
<div className="col-span-2 flex justify-end">
<span className="bg-white/5 text-gray-400 text-xs px-2 py-1 rounded border border-white/10">{cls.category || ''}</span>
</div>
@@ -445,7 +450,7 @@ export function TemplateRegistry() {
>
{cls.name}
</span>
<span className="w-16 text-sm text-gray-500 font-mono text-right">z:{cls.zIndex}</span>
<span className="w-24 text-sm text-gray-500 font-mono text-right">maskid:{cls.maskId}</span>
</>
)}
<button onClick={() => removeClass(cls.id)} className="text-gray-500 hover:text-red-400 transition-colors">

View File

@@ -1,35 +1,91 @@
import { fireEvent, render, screen } from '@testing-library/react';
import { describe, expect, it, vi } from 'vitest';
import { beforeEach, describe, expect, it, vi } from 'vitest';
import { useStore } from '../store/useStore';
import { resetStore } from '../test/storeTestUtils';
import { ToolsPalette } from './ToolsPalette';
describe('ToolsPalette', () => {
it('switches tools and dispatches undo/redo actions when available', () => {
beforeEach(() => {
resetStore();
});
it('switches workspace editing tools without showing AI prompt or duplicate undo tools', () => {
const setActiveTool = vi.fn();
const onUndo = vi.fn();
const onRedo = vi.fn();
render(
<ToolsPalette
activeTool="move"
setActiveTool={setActiveTool}
onUndo={onUndo}
onRedo={onRedo}
canUndo
canRedo
/>,
);
fireEvent.click(screen.getByTitle('创建多边形 (P)'));
fireEvent.click(screen.getByTitle('调整多边形 (E)'));
fireEvent.click(screen.getByTitle('正向选点 (SAM)'));
fireEvent.click(screen.getByTitle('撤销操作 (Ctrl+Z)'));
fireEvent.click(screen.getByTitle('重做操作 (Ctrl+Shift+Z)'));
fireEvent.click(screen.getByTitle('画笔 (B)'));
fireEvent.click(screen.getByTitle('橡皮擦 (X)'));
expect(setActiveTool).toHaveBeenNthCalledWith(1, 'create_polygon');
expect(setActiveTool).toHaveBeenNthCalledWith(2, 'edit_polygon');
expect(setActiveTool).toHaveBeenNthCalledWith(3, 'point_pos');
expect(onUndo).toHaveBeenCalled();
expect(onRedo).toHaveBeenCalled();
expect(setActiveTool).toHaveBeenNthCalledWith(3, 'brush');
expect(setActiveTool).toHaveBeenNthCalledWith(4, 'eraser');
expect(screen.queryByTitle('正向选点 (SAM)')).not.toBeInTheDocument();
expect(screen.queryByTitle('反向选点 (SAM)')).not.toBeInTheDocument();
expect(screen.queryByTitle('边界框选 (SAM)')).not.toBeInTheDocument();
expect(screen.queryByTitle('撤销操作 (Ctrl+Z)')).not.toBeInTheDocument();
expect(screen.queryByTitle('重做操作 (Ctrl+Shift+Z)')).not.toBeInTheDocument();
expect(screen.queryByTitle('创建点 (C)')).not.toBeInTheDocument();
expect(screen.queryByTitle('创建线段 (L)')).not.toBeInTheDocument();
});
it('shows size controls for brush and eraser tools', () => {
const { rerender } = render(<ToolsPalette activeTool="brush" setActiveTool={vi.fn()} />);
const brushSize = screen.getByLabelText('画笔大小');
fireEvent.change(brushSize, { target: { value: '36' } });
expect(useStore.getState().brushSize).toBe(36);
rerender(<ToolsPalette activeTool="eraser" setActiveTool={vi.fn()} />);
const eraserSize = screen.getByLabelText('橡皮擦大小');
fireEvent.change(eraserSize, { target: { value: '48' } });
expect(useStore.getState().eraserSize).toBe(48);
});
it('places GT mask import after overlap removal with a distinct violet style', () => {
const onImportGtMask = vi.fn();
render(
<ToolsPalette
activeTool="move"
setActiveTool={vi.fn()}
onImportGtMask={onImportGtMask}
canImportGtMask
/>,
);
const overlapButton = screen.getByTitle('重叠区域去除 (-)');
const importButton = screen.getByTitle('导入 GT Mask');
fireEvent.click(importButton);
expect(onImportGtMask).toHaveBeenCalled();
expect(importButton).toHaveClass('bg-violet-500/10');
expect(overlapButton.compareDocumentPosition(importButton) & Node.DOCUMENT_POSITION_FOLLOWING).toBeTruthy();
});
it('separates drawing, editing, and external action tool groups', () => {
render(<ToolsPalette activeTool="move" setActiveTool={vi.fn()} canImportGtMask />);
const separators = screen.getAllByTestId('tool-group-separator');
const circleButton = screen.getByTitle('创建圆 (O)');
const brushButton = screen.getByTitle('画笔 (B)');
const removeButton = screen.getByTitle('重叠区域去除 (-)');
const importButton = screen.getByTitle('导入 GT Mask');
expect(separators).toHaveLength(2);
expect(circleButton.compareDocumentPosition(separators[0]) & Node.DOCUMENT_POSITION_FOLLOWING).toBeTruthy();
expect(separators[0].compareDocumentPosition(brushButton) & Node.DOCUMENT_POSITION_FOLLOWING).toBeTruthy();
expect(removeButton.compareDocumentPosition(separators[1]) & Node.DOCUMENT_POSITION_FOLLOWING).toBeTruthy();
expect(separators[1].compareDocumentPosition(importButton) & Node.DOCUMENT_POSITION_FOLLOWING).toBeTruthy();
separators.forEach((separator) => {
expect(separator).toHaveClass('bg-white/15');
});
});
it('switches to SAM trigger and calls the AI navigation hook', () => {
@@ -37,7 +93,9 @@ describe('ToolsPalette', () => {
const onTriggerAI = vi.fn();
render(<ToolsPalette activeTool="move" setActiveTool={setActiveTool} onTriggerAI={onTriggerAI} />);
fireEvent.click(screen.getByTitle('打开 AI 智能分割'));
const aiButton = screen.getByTitle('打开 AI 智能分割');
expect(aiButton.querySelector('[data-testid="ai-segmentation-icon"]')).toBeInTheDocument();
fireEvent.click(aiButton);
expect(setActiveTool).toHaveBeenCalledWith('sam_trigger');
expect(onTriggerAI).toHaveBeenCalled();

View File

@@ -1,44 +1,59 @@
import React from 'react';
import { MousePointer2, PencilLine, Hexagon, Square, Circle, Minus, Combine, Scissors, Wand2, Undo, Redo, Crosshair, PlusCircle, MinusCircle, SquareDashed } from 'lucide-react';
import { MousePointer2, PencilLine, Hexagon, Square, Circle, Brush, Eraser, Combine, Scissors, FileUp } from 'lucide-react';
import { cn } from '../lib/utils';
import { AiSegmentationIcon } from './AiSegmentationIcon';
import { useStore } from '../store/useStore';
interface ToolsPaletteProps {
activeTool: string;
setActiveTool: (tool: string) => void;
onTriggerAI?: () => void;
onUndo?: () => void;
onRedo?: () => void;
canUndo?: boolean;
canRedo?: boolean;
onImportGtMask?: () => void;
canImportGtMask?: boolean;
isImportingGtMask?: boolean;
}
export function ToolsPalette({
activeTool,
setActiveTool,
onTriggerAI,
onUndo,
onRedo,
canUndo = false,
canRedo = false,
onImportGtMask,
canImportGtMask = false,
isImportingGtMask = false,
}: ToolsPaletteProps) {
const brushSize = useStore((state) => state.brushSize);
const eraserSize = useStore((state) => state.eraserSize);
const setBrushSize = useStore((state) => state.setBrushSize);
const setEraserSize = useStore((state) => state.setEraserSize);
const sizeControl = activeTool === 'brush'
? {
label: '画笔大小',
value: brushSize,
min: 4,
max: 96,
onChange: setBrushSize,
}
: activeTool === 'eraser'
? {
label: '橡皮擦大小',
value: eraserSize,
min: 4,
max: 128,
onChange: setEraserSize,
}
: null;
const tools = [
{ id: 'move', icon: MousePointer2, label: '拖拽 / 选择 (V)' },
{ id: 'edit_polygon', icon: PencilLine, label: '调整多边形 (E)' },
{ id: 'create_polygon', icon: Hexagon, label: '创建多边形 (P)' },
{ id: 'create_rectangle', icon: Square, label: '创建矩形 (R)' },
{ id: 'create_circle', icon: Circle, label: '创建圆 (O)' },
{ id: 'create_point', icon: Crosshair, label: '创建点 (C)' },
{ id: 'create_line', icon: Minus, label: '创建线段 (L)' },
{ id: 'brush', icon: Brush, label: '画笔 (B)' },
{ id: 'eraser', icon: Eraser, label: '橡皮擦 (X)' },
{ id: 'area_merge', icon: Combine, label: '区域合并 (+)' },
{ id: 'area_remove', icon: Scissors, label: '重叠区域去除 (-)' },
];
const aiTools = [
{ id: 'point_pos', icon: PlusCircle, label: '正向选点 (SAM)', color: 'text-green-400', bg: 'bg-green-500/10', border: 'border-green-500/30' },
{ id: 'point_neg', icon: MinusCircle, label: '反向选点 (SAM)', color: 'text-red-400', bg: 'bg-red-500/10', border: 'border-red-500/30' },
{ id: 'box_select', icon: SquareDashed, label: '边界框选 (SAM)', color: 'text-blue-400', bg: 'bg-blue-500/10', border: 'border-blue-500/30' },
];
return (
<div className="h-full w-14 bg-[#0d0d0d] border-r border-white/5 flex flex-col items-start py-2 shrink-0 z-10 overflow-y-auto overflow-x-hidden overscroll-contain seg-scrollbar">
<div className="flex flex-col gap-1.5 w-12 shrink-0 px-1.5">
@@ -46,45 +61,52 @@ export function ToolsPalette({
const Icon = tool.icon;
const isActive = activeTool === tool.id;
return (
<button
key={tool.id}
onClick={() => setActiveTool(tool.id)}
title={tool.label}
className={cn(
"w-9 h-9 rounded-md flex items-center justify-center transition-all p-1.5",
isActive
? (tool.id.includes('remove') ? "bg-red-500/10 text-red-500"
: tool.id.includes('merge') ? "bg-green-500/10 text-green-500"
: "bg-white/10 text-white")
: "text-gray-500 hover:bg-white/5 hover:text-white"
<React.Fragment key={tool.id}>
<button
onClick={() => setActiveTool(tool.id)}
title={tool.label}
className={cn(
"w-9 h-9 rounded-md flex items-center justify-center transition-all p-1.5",
isActive
? (tool.id.includes('remove') ? "bg-red-500/10 text-red-500"
: tool.id.includes('merge') ? "bg-green-500/10 text-green-500"
: "bg-white/10 text-white")
: "text-gray-500 hover:bg-white/5 hover:text-white"
)}
>
<Icon size={16} strokeWidth={isActive ? 2.5 : 2} />
</button>
{tool.id === 'eraser' && sizeControl && (
<div className="w-9 rounded-md border border-white/10 bg-white/[0.03] px-1 py-2 text-center">
<label htmlFor={`${activeTool}-size`} className="sr-only">{sizeControl.label}</label>
<input
id={`${activeTool}-size`}
aria-label={sizeControl.label}
type="range"
min={sizeControl.min}
max={sizeControl.max}
value={sizeControl.value}
onChange={(event) => sizeControl.onChange(Number(event.target.value))}
className="h-20 w-7 accent-cyan-400 [writing-mode:vertical-rl]"
/>
<div className="mt-1 text-[10px] leading-none text-gray-400">{sizeControl.value}</div>
</div>
)}
>
<Icon size={16} strokeWidth={isActive ? 2.5 : 2} />
</button>
{(tool.id === 'create_circle' || tool.id === 'area_remove') && (
<div data-testid="tool-group-separator" className="my-1 h-px w-9 bg-white/15" />
)}
</React.Fragment>
)
})}
<div className="w-full h-px bg-white/10 my-0.5" />
{aiTools.map(tool => {
const Icon = tool.icon;
const isActive = activeTool === tool.id;
return (
<button
key={tool.id}
onClick={() => setActiveTool(tool.id)}
title={tool.label}
className={cn(
"w-9 h-9 rounded-md flex items-center justify-center transition-all p-1.5 border",
isActive
? `${tool.bg} ${tool.color} ${tool.border} shadow-[0_0_10px_rgba(255,255,255,0.05)]`
: "text-gray-500 hover:bg-white/5 hover:text-white border-transparent"
)}
>
<Icon size={16} strokeWidth={isActive ? 2.5 : 2} />
</button>
)
})}
<button
onClick={onImportGtMask}
disabled={!canImportGtMask || isImportingGtMask}
title={isImportingGtMask ? '正在导入 GT Mask' : '导入 GT Mask'}
className="w-9 h-9 rounded-md flex items-center justify-center transition-all p-1.5 border border-violet-500/30 bg-violet-500/10 text-violet-200 hover:bg-violet-500/20 hover:text-white disabled:opacity-35 disabled:hover:bg-violet-500/10 disabled:hover:text-violet-200 disabled:cursor-not-allowed"
>
<FileUp size={16} strokeWidth={2.2} />
</button>
<button
onClick={() => {
@@ -99,26 +121,7 @@ export function ToolsPalette({
: "text-gray-500 hover:bg-white/5"
)}
>
<Wand2 size={16} strokeWidth={2} />
</button>
<div className="w-full h-px bg-white/10 my-0.5" />
<button
onClick={onUndo}
disabled={!canUndo}
className="w-9 h-9 rounded text-gray-500 hover:bg-white/5 hover:text-white flex items-center justify-center transition-colors disabled:opacity-30 disabled:hover:bg-transparent disabled:hover:text-gray-500 disabled:cursor-not-allowed"
title="撤销操作 (Ctrl+Z)"
>
<Undo size={16} />
</button>
<button
onClick={onRedo}
disabled={!canRedo}
className="w-9 h-9 rounded text-gray-500 hover:bg-white/5 hover:text-white flex items-center justify-center transition-colors disabled:opacity-30 disabled:hover:bg-transparent disabled:hover:text-gray-500 disabled:cursor-not-allowed"
title="重做操作 (Ctrl+Shift+Z)"
>
<Redo size={16} />
<AiSegmentationIcon size={17} strokeWidth={2.1} />
</button>
</div>

View File

@@ -0,0 +1,154 @@
import { fireEvent, render, screen, waitFor } from '@testing-library/react';
import { beforeEach, describe, expect, it, vi } from 'vitest';
import { resetStore } from '../test/storeTestUtils';
import { useStore } from '../store/useStore';
import { UserAdmin } from './UserAdmin';
const apiMock = vi.hoisted(() => ({
getAdminUsers: vi.fn(),
getAuditLogs: vi.fn(),
createAdminUser: vi.fn(),
updateAdminUser: vi.fn(),
deleteAdminUser: vi.fn(),
resetDemoFactory: vi.fn(),
}));
vi.mock('../lib/api', () => ({
getAdminUsers: apiMock.getAdminUsers,
getAuditLogs: apiMock.getAuditLogs,
createAdminUser: apiMock.createAdminUser,
updateAdminUser: apiMock.updateAdminUser,
deleteAdminUser: apiMock.deleteAdminUser,
resetDemoFactory: apiMock.resetDemoFactory,
}));
describe('UserAdmin', () => {
beforeEach(() => {
resetStore();
vi.clearAllMocks();
useStore.setState({ currentUser: { id: 1, username: 'admin', role: 'admin' } });
apiMock.getAdminUsers.mockResolvedValue([
{ id: 1, username: 'admin', role: 'admin', is_active: 1 },
{ id: 2, username: 'doctor', role: 'annotator', is_active: 1 },
]);
apiMock.getAuditLogs.mockResolvedValue([
{
id: 1,
actor_user_id: 1,
action: 'admin.user_created',
target_type: 'user',
target_id: '2',
detail: { username: 'doctor' },
created_at: '2026-05-02T00:00:00Z',
},
]);
});
it('loads users and audit logs', async () => {
render(<UserAdmin />);
expect(await screen.findByText('doctor')).toBeInTheDocument();
expect(screen.getByText('admin.user_created')).toBeInTheDocument();
expect(screen.getByText('当前管理员admin')).toBeInTheDocument();
});
it('creates a user with role and password', async () => {
apiMock.createAdminUser.mockResolvedValueOnce({
id: 3,
username: 'nurse',
role: 'viewer',
is_active: 1,
});
render(<UserAdmin />);
await screen.findByText('doctor');
fireEvent.change(screen.getByPlaceholderText('用户名'), { target: { value: 'nurse' } });
fireEvent.change(screen.getByPlaceholderText('初始密码'), { target: { value: 'secret123' } });
fireEvent.change(screen.getAllByDisplayValue('标注员')[0], { target: { value: 'viewer' } });
fireEvent.click(screen.getByRole('button', { name: /新增用户/ }));
await waitFor(() => expect(apiMock.createAdminUser).toHaveBeenCalledWith({
username: 'nurse',
password: 'secret123',
role: 'viewer',
is_active: true,
}));
expect(await screen.findByText('用户已创建')).toBeInTheDocument();
});
it('updates role, status and password, and deletes users', async () => {
apiMock.updateAdminUser.mockResolvedValueOnce({ id: 2, username: 'doctor', role: 'viewer', is_active: 1 });
apiMock.updateAdminUser.mockResolvedValueOnce({ id: 2, username: 'doctor', role: 'viewer', is_active: 0 });
apiMock.updateAdminUser.mockResolvedValueOnce({ id: 2, username: 'doctor', role: 'viewer', is_active: 0 });
apiMock.deleteAdminUser.mockResolvedValueOnce(undefined);
vi.spyOn(window, 'prompt').mockReturnValueOnce('nextsecret');
vi.spyOn(window, 'confirm').mockReturnValueOnce(true);
render(<UserAdmin />);
await screen.findByText('doctor');
const roleSelects = screen.getAllByDisplayValue('标注员');
fireEvent.change(roleSelects[1], { target: { value: 'viewer' } });
await waitFor(() => expect(apiMock.updateAdminUser).toHaveBeenCalledWith(2, { role: 'viewer' }));
fireEvent.click(screen.getAllByRole('button', { name: '启用' })[1]);
await waitFor(() => expect(apiMock.updateAdminUser).toHaveBeenCalledWith(2, { is_active: false }));
fireEvent.click(screen.getAllByTitle('修改密码')[1]);
await waitFor(() => expect(apiMock.updateAdminUser).toHaveBeenCalledWith(2, { password: 'nextsecret' }));
fireEvent.click(screen.getAllByTitle('删除用户')[1]);
await waitFor(() => expect(apiMock.deleteAdminUser).toHaveBeenCalledWith(2));
});
it('requires two confirmations before resetting demo factory data', async () => {
apiMock.resetDemoFactory.mockResolvedValueOnce({
admin_user: { id: 1, username: 'admin', role: 'admin', is_active: 1 },
project: {
id: '8',
name: 'Data_MyVideo_1',
status: 'pending',
frames: 0,
fps: '30FPS',
video_path: 'uploads/8/Data_MyVideo_1.mp4',
},
deleted_counts: { users: 1 },
message: '演示环境已恢复出厂设置',
});
apiMock.getAuditLogs.mockResolvedValueOnce([
{
id: 2,
actor_user_id: 1,
action: 'admin.demo_factory_reset',
target_type: 'project',
target_id: '8',
detail: {},
created_at: '2026-05-02T00:00:00Z',
},
]);
vi.spyOn(window, 'confirm').mockReturnValueOnce(true);
vi.spyOn(window, 'prompt').mockReturnValueOnce('RESET_DEMO_FACTORY');
render(<UserAdmin />);
await screen.findByText('doctor');
fireEvent.click(screen.getByRole('button', { name: '恢复演示出厂设置' }));
await waitFor(() => expect(apiMock.resetDemoFactory).toHaveBeenCalledWith('RESET_DEMO_FACTORY'));
expect(await screen.findByText('演示环境已恢复出厂设置')).toBeInTheDocument();
expect(useStore.getState().projects).toEqual([expect.objectContaining({ name: 'Data_MyVideo_1' })]);
expect(useStore.getState().frames).toEqual([]);
expect(useStore.getState().masks).toEqual([]);
});
it('does not reset demo data when confirmation text does not match', async () => {
vi.spyOn(window, 'confirm').mockReturnValueOnce(true);
vi.spyOn(window, 'prompt').mockReturnValueOnce('wrong');
render(<UserAdmin />);
await screen.findByText('doctor');
fireEvent.click(screen.getByRole('button', { name: '恢复演示出厂设置' }));
expect(apiMock.resetDemoFactory).not.toHaveBeenCalled();
expect(await screen.findByText('确认文本不匹配,未执行恢复出厂设置')).toBeInTheDocument();
});
});

View File

@@ -0,0 +1,352 @@
import React, { useEffect, useMemo, useState } from 'react';
import { KeyRound, Loader2, Plus, ShieldCheck, Trash2, UserCog } from 'lucide-react';
import {
createAdminUser,
deleteAdminUser,
getAdminUsers,
getAuditLogs,
resetDemoFactory,
updateAdminUser,
type AdminUser,
type AuditLog,
} from '../lib/api';
import { cn } from '../lib/utils';
import { useStore } from '../store/useStore';
import { TransientNotice, type NoticeState, type NoticeTone } from './TransientNotice';
const roleLabels: Record<string, string> = {
admin: '管理员',
annotator: '标注员',
viewer: '观察员',
};
function formatTime(value: string): string {
const date = new Date(value);
if (Number.isNaN(date.getTime())) return value;
return date.toLocaleString('zh-CN', {
month: '2-digit',
day: '2-digit',
hour: '2-digit',
minute: '2-digit',
});
}
export function UserAdmin() {
const currentUser = useStore((state) => state.currentUser);
const setProjects = useStore((state) => state.setProjects);
const setCurrentProject = useStore((state) => state.setCurrentProject);
const setFrames = useStore((state) => state.setFrames);
const setCurrentFrame = useStore((state) => state.setCurrentFrame);
const setMasks = useStore((state) => state.setMasks);
const setSelectedMaskIds = useStore((state) => state.setSelectedMaskIds);
const [users, setUsers] = useState<AdminUser[]>([]);
const [auditLogs, setAuditLogs] = useState<AuditLog[]>([]);
const [isLoading, setIsLoading] = useState(true);
const [isSaving, setIsSaving] = useState(false);
const [isResetting, setIsResetting] = useState(false);
const [notice, setNotice] = useState<NoticeState | null>(null);
const [newUsername, setNewUsername] = useState('');
const [newPassword, setNewPassword] = useState('');
const [newRole, setNewRole] = useState('annotator');
const activeCount = useMemo(() => users.filter((user) => user.is_active).length, [users]);
const showNotice = (message: string, tone: NoticeTone = 'info') => {
setNotice({ id: Date.now(), message, tone });
};
const loadAdminData = async () => {
setIsLoading(true);
try {
const [nextUsers, nextLogs] = await Promise.all([getAdminUsers(), getAuditLogs(100)]);
setUsers(nextUsers);
setAuditLogs(nextLogs);
} catch (err) {
console.error('Failed to load admin data:', err);
showNotice('用户管理数据加载失败', 'error');
} finally {
setIsLoading(false);
}
};
useEffect(() => {
void loadAdminData();
}, []);
const handleCreateUser = async (event: React.FormEvent) => {
event.preventDefault();
if (!newUsername.trim() || newPassword.length < 6) {
showNotice('请输入用户名,并设置至少 6 位密码', 'error');
return;
}
setIsSaving(true);
try {
const created = await createAdminUser({
username: newUsername.trim(),
password: newPassword,
role: newRole,
is_active: true,
});
setUsers((prev) => [...prev, created]);
setNewUsername('');
setNewPassword('');
setNewRole('annotator');
showNotice('用户已创建', 'success');
setAuditLogs(await getAuditLogs(100));
} catch (err: any) {
showNotice(err?.response?.data?.detail || '创建用户失败', 'error');
} finally {
setIsSaving(false);
}
};
const handlePatchUser = async (user: AdminUser, patch: Parameters<typeof updateAdminUser>[1]) => {
setIsSaving(true);
try {
const updated = await updateAdminUser(user.id, patch);
setUsers((prev) => prev.map((item) => (item.id === user.id ? updated : item)));
showNotice('用户已更新', 'success');
setAuditLogs(await getAuditLogs(100));
} catch (err: any) {
showNotice(err?.response?.data?.detail || '更新用户失败', 'error');
} finally {
setIsSaving(false);
}
};
const handleChangePassword = async (user: AdminUser) => {
const password = window.prompt(`${user.username} 设置新密码(至少 6 位)`);
if (password === null) return;
await handlePatchUser(user, { password });
};
const handleDeleteUser = async (user: AdminUser) => {
if (!window.confirm(`确定删除用户 ${user.username} 吗?已有项目的用户建议先停用。`)) return;
setIsSaving(true);
try {
await deleteAdminUser(user.id);
setUsers((prev) => prev.filter((item) => item.id !== user.id));
showNotice('用户已删除', 'success');
setAuditLogs(await getAuditLogs(100));
} catch (err: any) {
showNotice(err?.response?.data?.detail || '删除用户失败', 'error');
} finally {
setIsSaving(false);
}
};
const handleFactoryReset = async () => {
const firstConfirmed = window.confirm(
'恢复演示出厂设置会删除除默认 admin 外的所有用户、项目帧、标注、任务和私有模板,只保留一个未生成帧的演示视频项目。确定继续吗?',
);
if (!firstConfirmed) return;
const typed = window.prompt('请输入 RESET_DEMO_FACTORY 以确认恢复演示出厂设置');
if (typed === null) return;
if (typed !== 'RESET_DEMO_FACTORY') {
showNotice('确认文本不匹配,未执行恢复出厂设置', 'error');
return;
}
setIsResetting(true);
try {
const result = await resetDemoFactory(typed);
setUsers([result.admin_user]);
setProjects([result.project]);
setCurrentProject(null);
setFrames([]);
setCurrentFrame(0);
setMasks([]);
setSelectedMaskIds([]);
setAuditLogs(await getAuditLogs(100));
showNotice(result.message || '演示环境已恢复出厂设置', 'success');
} catch (err: any) {
showNotice(err?.response?.data?.detail || '恢复演示出厂设置失败', 'error');
} finally {
setIsResetting(false);
}
};
return (
<div className="flex h-full flex-col overflow-hidden bg-[#0a0a0a] text-gray-200">
<TransientNotice notice={notice} onDismiss={() => setNotice(null)} />
<header className="border-b border-white/10 bg-[#0d0d0d] px-6 py-4">
<div className="flex items-center justify-between gap-4">
<div>
<h1 className="text-xl font-semibold text-white"></h1>
<p className="mt-1 text-xs text-gray-500"></p>
</div>
<div className="flex items-center gap-3 text-xs text-gray-400">
<span className="rounded border border-cyan-400/20 bg-cyan-400/10 px-3 py-1 text-cyan-100">
{currentUser?.username || 'admin'}
</span>
<span className="rounded border border-white/10 bg-white/5 px-3 py-1"> {activeCount}</span>
</div>
</div>
</header>
<main className="grid min-h-0 flex-1 grid-cols-[minmax(0,1.15fr)_minmax(360px,0.85fr)] gap-4 overflow-hidden p-4">
<section className="flex min-h-0 flex-col overflow-hidden rounded-lg border border-white/10 bg-[#111]">
<div className="flex items-center justify-between border-b border-white/10 px-4 py-3">
<div className="flex items-center gap-2 text-sm font-medium text-white">
<UserCog size={18} className="text-cyan-300" />
</div>
{isLoading && <Loader2 size={16} className="animate-spin text-cyan-300" />}
</div>
<form onSubmit={handleCreateUser} className="grid grid-cols-[1fr_1fr_150px_auto] gap-2 border-b border-white/10 p-4">
<input
value={newUsername}
onChange={(event) => setNewUsername(event.target.value)}
placeholder="用户名"
autoComplete="off"
className="rounded border border-white/10 bg-[#181818] px-3 py-2 text-sm text-white outline-none focus:border-cyan-400/50"
/>
<input
value={newPassword}
type="password"
onChange={(event) => setNewPassword(event.target.value)}
placeholder="初始密码"
autoComplete="new-password"
className="rounded border border-white/10 bg-[#181818] px-3 py-2 text-sm text-white outline-none focus:border-cyan-400/50"
/>
<select
value={newRole}
onChange={(event) => setNewRole(event.target.value)}
className="rounded border border-white/10 bg-[#181818] px-3 py-2 text-sm text-white outline-none focus:border-cyan-400/50"
>
<option value="annotator"></option>
<option value="viewer"></option>
<option value="admin"></option>
</select>
<button
type="submit"
disabled={isSaving}
className="inline-flex items-center gap-2 rounded bg-cyan-500 px-4 py-2 text-sm font-semibold text-black transition-colors hover:bg-cyan-400 disabled:opacity-50"
>
<Plus size={16} />
</button>
</form>
<div className="min-h-0 flex-1 overflow-auto">
<table className="w-full text-left text-sm">
<thead className="sticky top-0 bg-[#151515] text-xs uppercase text-gray-500">
<tr>
<th className="px-4 py-3"></th>
<th className="px-4 py-3"></th>
<th className="px-4 py-3"></th>
<th className="px-4 py-3 text-right"></th>
</tr>
</thead>
<tbody className="divide-y divide-white/5">
{users.map((user) => (
<tr key={user.id} className="hover:bg-white/[0.03]">
<td className="px-4 py-3">
<div className="font-medium text-white">{user.username}</div>
<div className="text-xs text-gray-500">ID {user.id}</div>
</td>
<td className="px-4 py-3">
<select
value={user.role}
onChange={(event) => void handlePatchUser(user, { role: event.target.value })}
disabled={isSaving}
className="rounded border border-white/10 bg-[#181818] px-2 py-1 text-xs text-cyan-100"
>
<option value="admin"></option>
<option value="annotator"></option>
<option value="viewer"></option>
</select>
</td>
<td className="px-4 py-3">
<button
type="button"
onClick={() => void handlePatchUser(user, { is_active: !user.is_active })}
disabled={isSaving}
className={cn(
'rounded-full border px-3 py-1 text-xs',
user.is_active
? 'border-emerald-400/30 bg-emerald-400/10 text-emerald-200'
: 'border-gray-500/30 bg-gray-500/10 text-gray-300',
)}
>
{user.is_active ? '启用' : '停用'}
</button>
</td>
<td className="px-4 py-3">
<div className="flex justify-end gap-2">
<button
type="button"
onClick={() => void handleChangePassword(user)}
className="rounded border border-white/10 p-2 text-gray-300 hover:border-cyan-400/40 hover:text-cyan-200"
title="修改密码"
>
<KeyRound size={15} />
</button>
<button
type="button"
onClick={() => void handleDeleteUser(user)}
disabled={user.id === currentUser?.id}
className="rounded border border-white/10 p-2 text-gray-300 hover:border-red-400/40 hover:text-red-200 disabled:cursor-not-allowed disabled:opacity-40"
title="删除用户"
>
<Trash2 size={15} />
</button>
</div>
</td>
</tr>
))}
</tbody>
</table>
</div>
</section>
<section className="flex min-h-0 flex-col overflow-hidden rounded-lg border border-white/10 bg-[#111]">
<div className="flex items-center gap-2 border-b border-white/10 px-4 py-3 text-sm font-medium text-white">
<ShieldCheck size={18} className="text-emerald-300" />
</div>
<div className="min-h-0 flex-1 overflow-auto p-3">
<div className="space-y-2">
{auditLogs.map((log) => (
<div key={log.id} className="rounded border border-white/10 bg-black/20 p-3">
<div className="flex items-center justify-between gap-2">
<span className="text-xs font-medium text-cyan-100">{log.action}</span>
<span className="text-[10px] text-gray-500">{formatTime(log.created_at)}</span>
</div>
<div className="mt-1 text-[11px] text-gray-400">
actor #{log.actor_user_id ?? 'system'} {'->'} {log.target_type || 'target'} #{log.target_id || '-'}
</div>
{log.detail && Object.keys(log.detail).length > 0 && (
<pre className="mt-2 max-h-24 overflow-auto rounded bg-black/30 p-2 text-[10px] leading-relaxed text-gray-500">
{JSON.stringify(log.detail, null, 2)}
</pre>
)}
</div>
))}
{!auditLogs.length && !isLoading && (
<div className="py-10 text-center text-sm text-gray-500"></div>
)}
</div>
</div>
<div className="border-t border-red-400/20 bg-red-950/10 p-4">
<div className="flex items-start justify-between gap-3">
<div>
<div className="text-sm font-semibold text-red-100"></div>
<p className="mt-1 text-xs leading-relaxed text-red-200/70">
admin
</p>
</div>
<button
type="button"
onClick={() => void handleFactoryReset()}
disabled={isResetting || isSaving}
className="shrink-0 rounded border border-red-400/40 bg-red-500/15 px-3 py-2 text-xs font-semibold text-red-100 transition-colors hover:bg-red-500/25 disabled:cursor-wait disabled:opacity-50"
>
{isResetting ? '恢复中...' : '恢复演示出厂设置'}
</button>
</div>
</div>
</section>
</main>
</div>
);
}

View File

@@ -18,6 +18,7 @@ const apiMock = vi.hoisted(() => ({
deleteAnnotation: vi.fn(),
exportCoco: vi.fn(),
exportMasks: vi.fn(),
exportSegmentationResults: vi.fn(),
importGtMask: vi.fn(),
annotationToMask: vi.fn(),
buildAnnotationPayload: vi.fn(),
@@ -39,6 +40,7 @@ vi.mock('../lib/api', () => ({
deleteAnnotation: apiMock.deleteAnnotation,
exportCoco: apiMock.exportCoco,
exportMasks: apiMock.exportMasks,
exportSegmentationResults: apiMock.exportSegmentationResults,
importGtMask: apiMock.importGtMask,
annotationToMask: apiMock.annotationToMask,
buildAnnotationPayload: apiMock.buildAnnotationPayload,
@@ -121,11 +123,16 @@ describe('VideoWorkspace', () => {
});
render(<VideoWorkspace />);
fireEvent.click(screen.getByRole('button', { name: '撤销操作' }));
const undoButton = screen.getByRole('button', { name: '撤销操作' });
const redoButton = screen.getByRole('button', { name: '重做操作' });
expect(undoButton.querySelector('svg')).toHaveClass('text-amber-300');
expect(redoButton.querySelector('svg')).toHaveClass('text-indigo-300');
fireEvent.click(undoButton);
expect(useStore.getState().masks).toEqual([]);
fireEvent.click(screen.getByRole('button', { name: '重做操作' }));
fireEvent.click(redoButton);
expect(useStore.getState().masks).toEqual([mask]);
fireEvent.keyDown(window, { key: 'z', ctrlKey: true });
@@ -147,7 +154,7 @@ describe('VideoWorkspace', () => {
await waitFor(() => expect(useStore.getState().frames).toHaveLength(1));
vi.useFakeTimers();
fireEvent.click(screen.getByRole('button', { name: '结构化归档保存' }));
fireEvent.click(screen.getByRole('button', { name: '已全部保存' }));
expect(screen.getByText('没有待保存标注')).toBeInTheDocument();
act(() => {
@@ -155,7 +162,7 @@ describe('VideoWorkspace', () => {
});
expect(screen.queryByText('没有待保存标注')).not.toBeInTheDocument();
expect(screen.getByRole('button', { name: '结构化归档保存' })).not.toBeDisabled();
expect(screen.getByRole('button', { name: '已全部保存' })).not.toBeDisabled();
vi.useRealTimers();
});
@@ -305,7 +312,8 @@ describe('VideoWorkspace', () => {
});
});
fireEvent.click(screen.getByRole('button', { name: '结构化归档保存' }));
expect(screen.getByRole('button', { name: '保存 1 个改动' })).toBeInTheDocument();
fireEvent.click(screen.getByRole('button', { name: '保存 1 个改动' }));
await waitFor(() => expect(apiMock.saveAnnotation).toHaveBeenCalledWith({
project_id: 1,
@@ -322,6 +330,7 @@ describe('VideoWorkspace', () => {
expect.objectContaining({ id: 'annotation-5', saved: true, saveStatus: 'saved' }),
]));
expect(useStore.getState().masks.some((mask) => mask.id === 'mask-1')).toBe(false);
expect(screen.getByRole('button', { name: '已全部保存' })).toBeInTheDocument();
});
it('updates dirty saved masks through the archive button', async () => {
@@ -360,7 +369,8 @@ describe('VideoWorkspace', () => {
});
});
fireEvent.click(screen.getByRole('button', { name: '结构化归档保存' }));
expect(screen.getByRole('button', { name: '保存 1 个改动' })).toBeInTheDocument();
fireEvent.click(screen.getByRole('button', { name: '保存 1 个改动' }));
await waitFor(() => expect(apiMock.updateAnnotation).toHaveBeenCalledWith('99', {
template_id: 2,
@@ -415,6 +425,7 @@ describe('VideoWorkspace', () => {
});
it('clears masks across the selected frame range', async () => {
const confirmSpy = vi.spyOn(window, 'confirm').mockReturnValue(true);
apiMock.getProjectFrames.mockResolvedValueOnce([
{ id: 10, project_id: 1, frame_index: 0, image_url: '/frame-0.jpg', width: 640, height: 360 },
{ id: 11, project_id: 1, frame_index: 1, image_url: '/frame-1.jpg', width: 640, height: 360 },
@@ -436,7 +447,9 @@ describe('VideoWorkspace', () => {
});
fireEvent.click(screen.getByRole('button', { name: '清空片段遮罩' }));
expect(screen.getByText('请在播放进度条或视频处理进度条上点击/拖拽选择清空起止帧,再点击“确认清空”')).toBeInTheDocument();
expect(screen.getByText('请选择清空模式,并在播放进度条或视频处理进度条上点击/拖拽选择清空起止帧,再点击“确认清空”')).toBeInTheDocument();
expect(screen.getByRole('button', { name: '清空全部' })).toHaveAttribute('aria-pressed', 'true');
expect(screen.getByRole('button', { name: '保留人工/AI' })).toBeInTheDocument();
const processingBar = screen.getByLabelText('视频处理进度条');
vi.spyOn(processingBar, 'getBoundingClientRect').mockReturnValue({
@@ -458,20 +471,129 @@ describe('VideoWorkspace', () => {
fireEvent.click(screen.getByRole('button', { name: '确认清空' }));
expect(confirmSpy).toHaveBeenCalledWith(expect.stringContaining('是否清除“人工/AI标注帧”'));
await waitFor(() => expect(apiMock.deleteAnnotation).toHaveBeenCalledWith('99'));
expect(apiMock.deleteAnnotation).not.toHaveBeenCalledWith('100');
expect(useStore.getState().masks.map((mask) => mask.id)).toEqual(['annotation-100']);
expect(useStore.getState().selectedMaskIds).not.toContain('draft-1');
expect(screen.getByText('已清空第 1-2 帧的 2 个遮罩,其中后端标注 1 个')).toBeInTheDocument();
confirmSpy.mockRestore();
});
it('auto-saves pending masks before exporting COCO', async () => {
it('can clear only propagated masks while preserving manual or AI annotated frames', async () => {
const confirmSpy = vi.spyOn(window, 'confirm').mockReturnValue(true);
apiMock.getProjectFrames.mockResolvedValueOnce([
{ id: 10, project_id: 1, frame_index: 0, image_url: '/frame-0.jpg', width: 640, height: 360 },
{ id: 11, project_id: 1, frame_index: 1, image_url: '/frame-1.jpg', width: 640, height: 360 },
]);
apiMock.deleteAnnotation.mockResolvedValue(undefined);
render(<VideoWorkspace />);
await waitFor(() => expect(useStore.getState().frames).toHaveLength(2));
act(() => {
useStore.setState({
masks: [
{ id: 'manual-1', annotationId: '98', frameId: '10', pathData: 'M 0 0 Z', label: 'Manual', color: '#ef4444', saved: true, saveStatus: 'saved' },
{
id: 'propagated-1',
annotationId: '99',
frameId: '11',
pathData: 'M 1 1 Z',
label: 'Tracked',
color: '#3b82f6',
saved: true,
saveStatus: 'saved',
metadata: { source_annotation_id: 7, source_mask_id: 'annotation-7' },
},
],
selectedMaskIds: ['manual-1', 'propagated-1'],
});
});
fireEvent.click(screen.getByRole('button', { name: '清空片段遮罩' }));
fireEvent.click(screen.getByRole('button', { name: '保留人工/AI' }));
expect(screen.getByRole('button', { name: '保留人工/AI' })).toHaveAttribute('aria-pressed', 'true');
fireEvent.click(screen.getByRole('button', { name: '确认清空' }));
expect(confirmSpy).not.toHaveBeenCalled();
await waitFor(() => expect(apiMock.deleteAnnotation).toHaveBeenCalledWith('99'));
expect(apiMock.deleteAnnotation).not.toHaveBeenCalledWith('98');
expect(useStore.getState().masks.map((mask) => mask.id)).toEqual(['manual-1']);
expect(useStore.getState().selectedMaskIds).toEqual(['manual-1']);
expect(screen.getByText('已清空第 1-2 帧的 1 个自动传播遮罩,其中后端标注 1 个,人工/AI 标注帧已保留')).toBeInTheDocument();
confirmSpy.mockRestore();
});
it('cancels range clearing when manual or AI annotated frames are not confirmed', async () => {
const confirmSpy = vi.spyOn(window, 'confirm').mockReturnValue(false);
apiMock.getProjectFrames.mockResolvedValueOnce([
{ id: 10, project_id: 1, frame_index: 0, image_url: '/frame-0.jpg', width: 640, height: 360 },
{ id: 11, project_id: 1, frame_index: 1, image_url: '/frame-1.jpg', width: 640, height: 360 },
]);
render(<VideoWorkspace />);
await waitFor(() => expect(useStore.getState().frames).toHaveLength(2));
act(() => {
useStore.setState({
masks: [
{ id: 'annotation-99', annotationId: '99', frameId: '10', pathData: 'M 0 0 Z', label: 'Manual', color: '#06b6d4', saved: true, saveStatus: 'saved' },
],
});
});
fireEvent.click(screen.getByRole('button', { name: '清空片段遮罩' }));
fireEvent.click(screen.getByRole('button', { name: '确认清空' }));
expect(confirmSpy).toHaveBeenCalledWith(expect.stringContaining('是否清除“人工/AI标注帧”'));
expect(apiMock.deleteAnnotation).not.toHaveBeenCalled();
expect(useStore.getState().masks.map((mask) => mask.id)).toEqual(['annotation-99']);
expect(screen.getByText('已取消清空片段遮罩')).toBeInTheDocument();
confirmSpy.mockRestore();
});
it('does not ask for manual-frame confirmation when clearing propagated-only frames', async () => {
const confirmSpy = vi.spyOn(window, 'confirm').mockReturnValue(true);
apiMock.getProjectFrames.mockResolvedValueOnce([
{ id: 10, project_id: 1, frame_index: 0, image_url: '/frame-0.jpg', width: 640, height: 360 },
{ id: 11, project_id: 1, frame_index: 1, image_url: '/frame-1.jpg', width: 640, height: 360 },
]);
apiMock.deleteAnnotation.mockResolvedValue(undefined);
render(<VideoWorkspace />);
await waitFor(() => expect(useStore.getState().frames).toHaveLength(2));
act(() => {
useStore.setState({
masks: [
{
id: 'annotation-99',
annotationId: '99',
frameId: '10',
pathData: 'M 0 0 Z',
label: 'Propagated',
color: '#06b6d4',
saved: true,
saveStatus: 'saved',
metadata: { source: 'sam2_propagation', propagated_from_frame_id: 1 },
},
],
});
});
fireEvent.click(screen.getByRole('button', { name: '清空片段遮罩' }));
fireEvent.click(screen.getByRole('button', { name: '确认清空' }));
expect(confirmSpy).not.toHaveBeenCalled();
await waitFor(() => expect(apiMock.deleteAnnotation).toHaveBeenCalledWith('99'));
confirmSpy.mockRestore();
});
it('auto-saves pending masks before exporting segmentation results', async () => {
apiMock.getProjectFrames.mockResolvedValueOnce([
{ id: 10, project_id: 1, frame_index: 0, image_url: '/frame.jpg', width: 640, height: 360 },
]);
apiMock.buildAnnotationPayload.mockReturnValueOnce({ project_id: 1, frame_id: 10, mask_data: { polygons: [] } });
apiMock.saveAnnotation.mockResolvedValueOnce({ id: 5 });
apiMock.exportCoco.mockResolvedValueOnce(new Blob(['{}'], { type: 'application/json' }));
apiMock.exportSegmentationResults.mockResolvedValueOnce(new Blob(['zip'], { type: 'application/zip' }));
render(<VideoWorkspace />);
await waitFor(() => expect(useStore.getState().frames).toHaveLength(1));
@@ -488,39 +610,167 @@ describe('VideoWorkspace', () => {
});
});
fireEvent.click(screen.getByRole('button', { name: '导出 JSON 标注集' }));
fireEvent.click(screen.getByRole('button', { name: '分割结果导出' }));
fireEvent.change(screen.getByLabelText('Mix_label 遮罩透明度'), { target: { value: '0.45' } });
fireEvent.click(screen.getByRole('button', { name: '开始导出' }));
await waitFor(() => expect(apiMock.saveAnnotation).toHaveBeenCalled());
expect(apiMock.exportCoco).toHaveBeenCalledWith('1');
expect(apiMock.exportSegmentationResults).toHaveBeenCalledWith('1', {
scope: 'current',
outputs: ['separate', 'gt_label', 'pro_label', 'mix_label'],
mixOpacity: 0.45,
startFrame: undefined,
endFrame: undefined,
frameId: '10',
});
});
it('auto-saves pending masks before exporting PNG masks', async () => {
it('exports a selected frame range with GT label masks', async () => {
apiMock.getProjectFrames.mockResolvedValueOnce([
{ id: 10, project_id: 1, frame_index: 0, image_url: '/frame.jpg', width: 640, height: 360 },
{ id: 10, project_id: 1, frame_index: 0, image_url: '/frame-0.jpg', width: 640, height: 360, timestamp_ms: 0, source_frame_number: 0 },
{ id: 11, project_id: 1, frame_index: 1, image_url: '/frame-1.jpg', width: 640, height: 360, timestamp_ms: 500, source_frame_number: 15 },
{ id: 12, project_id: 1, frame_index: 2, image_url: '/frame-2.jpg', width: 640, height: 360, timestamp_ms: 1000, source_frame_number: 30 },
]);
apiMock.buildAnnotationPayload.mockReturnValueOnce({ project_id: 1, frame_id: 10, mask_data: { polygons: [] } });
apiMock.saveAnnotation.mockResolvedValueOnce({ id: 5 });
apiMock.exportMasks.mockResolvedValueOnce(new Blob(['zip'], { type: 'application/zip' }));
apiMock.exportSegmentationResults.mockResolvedValueOnce(new Blob(['zip'], { type: 'application/zip' }));
const downloads: string[] = [];
const clickSpy = vi.spyOn(HTMLAnchorElement.prototype, 'click').mockImplementation(function mockClick(this: HTMLAnchorElement) {
downloads.push(this.download);
});
useStore.setState({ currentProject: { id: '1', name: '病例 A/1', status: 'ready', video_path: 'uploads/demo.mp4' } });
render(<VideoWorkspace />);
await waitFor(() => expect(useStore.getState().frames).toHaveLength(1));
await waitFor(() => expect(useStore.getState().frames).toHaveLength(3));
fireEvent.click(screen.getByRole('button', { name: '分割结果导出' }));
fireEvent.click(screen.getByRole('button', { name: '特定范围帧' }));
fireEvent.change(screen.getByLabelText('导出起始帧'), { target: { value: '2' } });
fireEvent.change(screen.getByLabelText('导出结束帧'), { target: { value: '3' } });
fireEvent.click(screen.getByRole('button', { name: '分开 Mask' }));
fireEvent.click(screen.getByRole('button', { name: 'Pro_label 彩色' }));
fireEvent.click(screen.getByRole('button', { name: 'Mix_label 叠加' }));
fireEvent.click(screen.getByRole('button', { name: '开始导出' }));
await waitFor(() => expect(apiMock.exportSegmentationResults).toHaveBeenCalledWith('1', {
scope: 'range',
outputs: ['gt_label'],
mixOpacity: 0.3,
startFrame: 2,
endFrame: 3,
frameId: undefined,
}));
expect(downloads[0]).toBe('病例_A_1_seg_T_0h00m00s500ms-0h00m01s000ms_P_2-3.zip');
clickSpy.mockRestore();
});
it('lets the timeline range picker update selected frame export bounds', async () => {
apiMock.getProjectFrames.mockResolvedValueOnce([
{ id: 10, project_id: 1, frame_index: 0, image_url: '/frame-0.jpg', width: 640, height: 360 },
{ id: 11, project_id: 1, frame_index: 1, image_url: '/frame-1.jpg', width: 640, height: 360 },
{ id: 12, project_id: 1, frame_index: 2, image_url: '/frame-2.jpg', width: 640, height: 360 },
{ id: 13, project_id: 1, frame_index: 3, image_url: '/frame-3.jpg', width: 640, height: 360 },
{ id: 14, project_id: 1, frame_index: 4, image_url: '/frame-4.jpg', width: 640, height: 360 },
]);
apiMock.exportSegmentationResults.mockResolvedValueOnce(new Blob(['zip'], { type: 'application/zip' }));
render(<VideoWorkspace />);
await waitFor(() => expect(useStore.getState().frames).toHaveLength(5));
fireEvent.click(screen.getByRole('button', { name: '分割结果导出' }));
fireEvent.click(screen.getByRole('button', { name: '特定范围帧' }));
expect(screen.getByText('请在播放进度条或视频处理进度条上点击/拖拽选择导出起止帧,也可直接修改导出范围')).toBeInTheDocument();
const processingBar = screen.getByLabelText('视频处理进度条');
vi.spyOn(processingBar, 'getBoundingClientRect').mockReturnValue({
left: 0,
right: 100,
top: 0,
bottom: 10,
width: 100,
height: 10,
x: 0,
y: 0,
toJSON: () => ({}),
});
fireEvent.pointerDown(processingBar, { clientX: 25, pointerId: 1 });
fireEvent.pointerMove(processingBar, { clientX: 100, pointerId: 1 });
fireEvent.pointerUp(processingBar, { clientX: 100, pointerId: 1 });
expect(screen.getByLabelText('导出起始帧')).toHaveValue(2);
expect(screen.getByLabelText('导出结束帧')).toHaveValue(5);
fireEvent.click(screen.getByRole('button', { name: '开始导出' }));
await waitFor(() => expect(apiMock.exportSegmentationResults).toHaveBeenCalledWith('1', {
scope: 'range',
outputs: ['separate', 'gt_label', 'pro_label', 'mix_label'],
mixOpacity: 0.3,
startFrame: 2,
endFrame: 5,
frameId: undefined,
}));
});
it('switches from export range selection to propagation range selection without starting propagation immediately', async () => {
apiMock.getProjectFrames.mockResolvedValueOnce([
{ id: 10, project_id: 1, frame_index: 0, image_url: '/frame-0.jpg', width: 640, height: 360 },
{ id: 11, project_id: 1, frame_index: 1, image_url: '/frame-1.jpg', width: 640, height: 360 },
{ id: 12, project_id: 1, frame_index: 2, image_url: '/frame-2.jpg', width: 640, height: 360 },
]);
render(<VideoWorkspace />);
await waitFor(() => expect(useStore.getState().frames).toHaveLength(3));
act(() => {
useStore.setState({
masks: [{
id: 'mask-1',
id: 'annotation-8',
annotationId: '8',
frameId: '10',
pathData: 'M 0 0 Z',
label: 'AI Mask',
color: '#06b6d4',
segmentation: [[0, 0, 10, 0, 10, 10]],
label: '胆囊',
color: '#ff0000',
segmentation: [[64, 36, 192, 36, 192, 108]],
bbox: [64, 36, 128, 72],
saveStatus: 'saved',
}],
});
});
fireEvent.click(screen.getByRole('button', { name: '导出 PNG Mask ZIP' }));
fireEvent.click(screen.getByRole('button', { name: '分割结果导出' }));
fireEvent.click(screen.getByRole('button', { name: '特定范围帧' }));
fireEvent.click(screen.getByRole('button', { name: '自动传播' }));
await waitFor(() => expect(apiMock.saveAnnotation).toHaveBeenCalled());
expect(apiMock.exportMasks).toHaveBeenCalledWith('1');
expect(screen.getByText('请在播放进度条或视频处理进度条上点击/拖拽选择传播起止帧,再点击“开始传播”')).toBeInTheDocument();
expect(apiMock.queuePropagationTask).not.toHaveBeenCalled();
});
it('exports only the current frame when current image scope is selected', async () => {
apiMock.getProjectFrames.mockResolvedValueOnce([
{ id: 10, project_id: 1, frame_index: 0, image_url: '/frame-0.jpg', width: 640, height: 360 },
{ id: 11, project_id: 1, frame_index: 1, image_url: '/frame-1.jpg', width: 640, height: 360 },
]);
apiMock.exportSegmentationResults.mockResolvedValueOnce(new Blob(['zip'], { type: 'application/zip' }));
render(<VideoWorkspace />);
await waitFor(() => expect(useStore.getState().frames).toHaveLength(2));
act(() => {
useStore.setState({ currentFrameIndex: 1 });
});
fireEvent.click(screen.getByRole('button', { name: '分割结果导出' }));
fireEvent.click(screen.getByRole('button', { name: '当前图片' }));
fireEvent.click(screen.getByRole('button', { name: 'GT_label 黑白' }));
fireEvent.click(screen.getByRole('button', { name: 'Pro_label 彩色' }));
fireEvent.click(screen.getByRole('button', { name: 'Mix_label 叠加' }));
fireEvent.click(screen.getByRole('button', { name: '开始导出' }));
await waitFor(() => expect(apiMock.exportSegmentationResults).toHaveBeenCalledWith('1', {
scope: 'current',
outputs: ['separate'],
mixOpacity: 0.3,
startFrame: undefined,
endFrame: undefined,
frameId: '11',
}));
});
it('imports a GT mask for the current frame and hydrates saved annotations', async () => {
@@ -547,13 +797,41 @@ describe('VideoWorkspace', () => {
const fileInput = document.querySelector('input[type="file"]') as HTMLInputElement;
const file = new File(['mask'], 'mask.png', { type: 'image/png' });
fireEvent.change(fileInput, { target: { files: [file] } });
expect(screen.getByText('导入结果预览')).toBeInTheDocument();
await waitFor(() => expect(screen.getByRole('button', { name: '导入为未定义' })).not.toBeDisabled());
fireEvent.click(screen.getByRole('button', { name: '导入为未定义' }));
await waitFor(() => expect(apiMock.importGtMask).toHaveBeenCalledWith(file, '1', '10'));
await waitFor(() => expect(apiMock.importGtMask).toHaveBeenCalledWith(file, '1', '10', null, {
unknownColorPolicy: 'undefined',
}));
await waitFor(() => expect(useStore.getState().masks).toEqual([
expect.objectContaining({ id: 'annotation-88', label: 'GT Mask' }),
]));
});
it('lets users discard unknown GT mask classes before importing', async () => {
apiMock.getProjectFrames.mockResolvedValueOnce([
{ id: 10, project_id: 1, frame_index: 0, image_url: '/frame.jpg', width: 640, height: 360 },
]);
apiMock.importGtMask.mockResolvedValueOnce([]);
apiMock.getProjectAnnotations.mockResolvedValueOnce([]).mockResolvedValueOnce([]);
useStore.setState({ activeTemplateId: '2' });
render(<VideoWorkspace />);
await waitFor(() => expect(useStore.getState().frames).toHaveLength(1));
const fileInput = document.querySelector('input[type="file"]') as HTMLInputElement;
const file = new File(['mask'], 'color-mask.png', { type: 'image/png' });
fireEvent.change(fileInput, { target: { files: [file] } });
expect(screen.getByText('导入结果预览')).toBeInTheDocument();
await waitFor(() => expect(screen.getByRole('button', { name: '舍弃未知类别' })).not.toBeDisabled());
fireEvent.click(screen.getByRole('button', { name: '舍弃未知类别' }));
await waitFor(() => expect(apiMock.importGtMask).toHaveBeenCalledWith(file, '1', '10', '2', {
unknownColorPolicy: 'discard',
}));
});
it('auto-propagates reference-frame masks through the configured frame range', async () => {
apiMock.getProjectFrames.mockResolvedValueOnce([
{ id: 10, project_id: 1, frame_index: 0, image_url: '/frame.jpg', width: 640, height: 360 },
@@ -823,6 +1101,80 @@ describe('VideoWorkspace', () => {
})));
});
it('removes propagation history bars when clearing the same frame range', async () => {
apiMock.getProjectFrames.mockResolvedValueOnce([
{ id: 10, project_id: 1, frame_index: 0, image_url: '/frame-0.jpg', width: 640, height: 360 },
{ id: 11, project_id: 1, frame_index: 1, image_url: '/frame-1.jpg', width: 640, height: 360 },
{ id: 12, project_id: 1, frame_index: 2, image_url: '/frame-2.jpg', width: 640, height: 360 },
{ id: 13, project_id: 1, frame_index: 3, image_url: '/frame-3.jpg', width: 640, height: 360 },
{ id: 14, project_id: 1, frame_index: 4, image_url: '/frame-4.jpg', width: 640, height: 360 },
]);
apiMock.buildAnnotationPayload.mockReturnValue({
project_id: 1,
frame_id: 10,
mask_data: {
polygons: [[[0.1, 0.1], [0.3, 0.1], [0.3, 0.3]]],
label: '胆囊',
color: '#ff0000',
},
bbox: [0.1, 0.1, 0.2, 0.2],
});
apiMock.deleteAnnotation.mockResolvedValue(undefined);
render(<VideoWorkspace />);
await waitFor(() => expect(useStore.getState().frames).toHaveLength(5));
act(() => {
useStore.setState({
masks: [{
id: 'annotation-8',
annotationId: '8',
frameId: '10',
pathData: 'M 0 0 Z',
label: '胆囊',
color: '#ff0000',
segmentation: [[64, 36, 192, 36, 192, 108]],
bbox: [64, 36, 128, 72],
}],
});
});
fireEvent.click(screen.getByRole('button', { name: '自动传播' }));
const processingBar = screen.getByLabelText('视频处理进度条');
vi.spyOn(processingBar, 'getBoundingClientRect').mockReturnValue({
left: 0,
right: 100,
top: 0,
bottom: 10,
width: 100,
height: 10,
x: 0,
y: 0,
toJSON: () => ({}),
});
fireEvent.pointerDown(processingBar, { clientX: 25, pointerId: 1 });
fireEvent.pointerMove(processingBar, { clientX: 100, pointerId: 1 });
fireEvent.pointerUp(processingBar, { clientX: 100, pointerId: 1 });
fireEvent.click(screen.getByRole('button', { name: '开始传播' }));
expect(await screen.findByTestId('propagation-history-segment')).toBeInTheDocument();
act(() => {
useStore.setState({
masks: [
{ id: 'annotation-101', annotationId: '101', frameId: '11', pathData: 'M 1 1 Z', label: 'Propagated 1', color: '#ff0000', saved: true, saveStatus: 'saved', metadata: { source: 'sam2_propagation', propagated_from_frame_id: 10 } },
{ id: 'annotation-102', annotationId: '102', frameId: '12', pathData: 'M 2 2 Z', label: 'Propagated 2', color: '#ff0000', saved: true, saveStatus: 'saved', metadata: { source: 'sam2_propagation', propagated_from_frame_id: 10 } },
],
});
});
fireEvent.click(screen.getByRole('button', { name: '清空片段遮罩' }));
fireEvent.click(screen.getByRole('button', { name: '确认清空' }));
await waitFor(() => expect(screen.queryByTestId('propagation-history-segment')).not.toBeInTheDocument());
expect(apiMock.deleteAnnotation).toHaveBeenCalledWith('101');
expect(apiMock.deleteAnnotation).toHaveBeenCalledWith('102');
});
it('auto-propagates all reference-frame masks in both directions inside the selected range', async () => {
apiMock.getProjectFrames.mockResolvedValueOnce([
{ id: 10, project_id: 1, frame_index: 0, image_url: '/frame-0.jpg', width: 640, height: 360 },

File diff suppressed because it is too large Load Diff

View File

@@ -84,6 +84,84 @@ describe('api client contracts', () => {
expect(axiosMock.client.delete).toHaveBeenCalledWith('/api/projects/3');
});
it('normalizes missing template class maskids without using priority as the public id', async () => {
const { getTemplates } = await import('./api');
axiosMock.client.get.mockResolvedValueOnce({
data: [{
id: 2,
name: 'Template',
mapping_rules: {
classes: [
{ id: 'c1', name: 'A', color: '#ff0000', zIndex: 100 },
{ id: 'c2', name: 'B', color: '#00ff00', zIndex: 10, maskId: 7 },
{ id: 'c3', name: 'C', color: '#0000ff', zIndex: 50 },
],
rules: [],
},
}],
});
await expect(getTemplates()).resolves.toEqual([
expect.objectContaining({
classes: [
expect.objectContaining({ id: 'c1', maskId: 1, zIndex: 100 }),
expect.objectContaining({ id: 'c2', maskId: 7, zIndex: 10 }),
expect.objectContaining({ id: 'c3', maskId: 2, zIndex: 50 }),
],
}),
]);
});
it('calls admin user management and audit endpoints', async () => {
const {
getAdminUsers,
createAdminUser,
updateAdminUser,
deleteAdminUser,
getAuditLogs,
resetDemoFactory,
} = await import('./api');
axiosMock.client.get
.mockResolvedValueOnce({ data: [{ id: 1, username: 'admin', role: 'admin', is_active: 1 }] })
.mockResolvedValueOnce({ data: [{ id: 9, action: 'admin.user_created', created_at: 'now' }] });
axiosMock.client.post.mockResolvedValueOnce({ data: { id: 2, username: 'doctor', role: 'annotator', is_active: 1 } });
axiosMock.client.patch.mockResolvedValueOnce({ data: { id: 2, username: 'doctor', role: 'viewer', is_active: 1 } });
axiosMock.client.delete.mockResolvedValueOnce({ data: null });
await expect(getAdminUsers()).resolves.toEqual([expect.objectContaining({ username: 'admin' })]);
await createAdminUser({ username: 'doctor', password: 'secret123', role: 'annotator', is_active: true });
await updateAdminUser(2, { role: 'viewer' });
await deleteAdminUser(2);
await expect(getAuditLogs(50)).resolves.toEqual([expect.objectContaining({ action: 'admin.user_created' })]);
expect(axiosMock.client.get).toHaveBeenNthCalledWith(1, '/api/admin/users');
expect(axiosMock.client.post).toHaveBeenCalledWith('/api/admin/users', {
username: 'doctor',
password: 'secret123',
role: 'annotator',
is_active: true,
});
expect(axiosMock.client.patch).toHaveBeenCalledWith('/api/admin/users/2', { role: 'viewer' });
expect(axiosMock.client.delete).toHaveBeenCalledWith('/api/admin/users/2');
expect(axiosMock.client.get).toHaveBeenNthCalledWith(2, '/api/admin/audit-logs', { params: { limit: 50 } });
axiosMock.client.post.mockResolvedValueOnce({
data: {
admin_user: { id: 1, username: 'admin', role: 'admin', is_active: 1 },
project: { id: 8, name: 'Data_MyVideo_1', status: 'pending', frame_count: 0, video_path: 'uploads/8/Data_MyVideo_1.mp4' },
deleted_counts: { users: 1 },
message: '演示环境已恢复出厂设置',
},
});
await expect(resetDemoFactory('RESET_DEMO_FACTORY')).resolves.toEqual(expect.objectContaining({
admin_user: expect.objectContaining({ username: 'admin' }),
project: expect.objectContaining({ id: '8', name: 'Data_MyVideo_1', frames: 0 }),
}));
expect(axiosMock.client.post).toHaveBeenLastCalledWith('/api/admin/demo-factory-reset', {
confirmation: 'RESET_DEMO_FACTORY',
});
});
it('normalizes legacy project status values returned by existing databases', async () => {
const { getProjects } = await import('./api');
axiosMock.client.get.mockResolvedValueOnce({
@@ -123,6 +201,33 @@ describe('api client contracts', () => {
});
});
it('exports combined segmentation results with scope, outputs, and mix opacity params', async () => {
const { exportSegmentationResults } = await import('./api');
const blob = new Blob(['zip'], { type: 'application/zip' });
axiosMock.client.get.mockResolvedValueOnce({ data: blob });
await expect(exportSegmentationResults('9', {
scope: 'range',
outputs: ['gt_label', 'pro_label', 'mix_label'],
mixOpacity: 0.45,
startFrame: 2,
endFrame: 5,
frameId: '12',
})).resolves.toBe(blob);
expect(axiosMock.client.get).toHaveBeenCalledWith('/api/export/9/results', {
params: {
scope: 'range',
mask_type: undefined,
outputs: 'gt_label,pro_label,mix_label',
mix_opacity: 0.45,
start_frame: 2,
end_frame: 5,
frame_id: 12,
},
responseType: 'blob',
});
});
it('loads dashboard overview from the backend summary endpoint', async () => {
const { getDashboardOverview } = await import('./api');
const overview = {
@@ -319,7 +424,7 @@ describe('api client contracts', () => {
const saved = [{ id: 1, project_id: 9, frame_id: 5, template_id: null, mask_data: null, points: null, bbox: null }];
axiosMock.client.post.mockResolvedValueOnce({ data: saved });
await expect(importGtMask(file, '9', '5', '2')).resolves.toEqual(saved);
await expect(importGtMask(file, '9', '5', '2', { unknownColorPolicy: 'discard' })).resolves.toEqual(saved);
expect(axiosMock.client.post).toHaveBeenCalledWith(
'/api/ai/import-gt-mask',
expect.any(FormData),
@@ -330,6 +435,7 @@ describe('api client contracts', () => {
expect(form.get('project_id')).toBe('9');
expect(form.get('frame_id')).toBe('5');
expect(form.get('template_id')).toBe('2');
expect(form.get('unknown_color_policy')).toBe('discard');
});
it('builds annotation payloads from frontend masks and restores saved annotations to masks', async () => {
@@ -344,6 +450,7 @@ describe('api client contracts', () => {
classId: 'c1',
className: '胆囊',
classZIndex: 20,
classMaskId: 7,
segmentation: [[10, 10, 90, 10, 90, 40]],
bbox: [10, 10, 80, 30],
metadata: { geometry_smoothing: { strength: 35, method: 'chaikin' } },
@@ -357,7 +464,7 @@ describe('api client contracts', () => {
polygons: [[[0.1, 0.2], [0.9, 0.2], [0.9, 0.8]]],
label: '胆囊',
color: '#ff0000',
class: { id: 'c1', name: '胆囊', color: '#ff0000', zIndex: 20 },
class: { id: 'c1', name: '胆囊', color: '#ff0000', zIndex: 20, maskId: 7 },
geometry_smoothing: { strength: 35, method: 'chaikin' },
},
bbox: [0.1, 0.2, 0.8, 0.6],
@@ -372,7 +479,7 @@ describe('api client contracts', () => {
polygons: [[[0.1, 0.2], [0.9, 0.2], [0.9, 0.8]]],
label: '旧标签',
color: '#06b6d4',
class: { id: 'c1', name: '胆囊', color: '#ff0000', zIndex: 20 },
class: { id: 'c1', name: '胆囊', color: '#ff0000', zIndex: 20, maskId: 7 },
source: 'sam2.1_hiera_tiny_propagation',
propagated_from_frame_id: 4,
geometry_smoothing: { strength: 35, method: 'chaikin' },
@@ -389,6 +496,7 @@ describe('api client contracts', () => {
classId: 'c1',
className: '胆囊',
classZIndex: 20,
classMaskId: 7,
label: '胆囊',
color: '#ff0000',
saveStatus: 'saved',
@@ -466,6 +574,45 @@ describe('api client contracts', () => {
}));
});
it('preserves propagation metadata when saving edited geometry without persisting preview-only smoothing fields', async () => {
const { buildAnnotationPayload } = await import('./api');
const frame = { id: '5', projectId: '9', index: 0, url: '/frame.jpg', width: 100, height: 50 };
expect(buildAnnotationPayload('9', {
id: 'm1',
frameId: '5',
pathData: 'M 10 10 L 90 10 L 90 40 Z',
label: 'Tracked',
color: '#22c55e',
segmentation: [[10, 10, 90, 10, 90, 40]],
metadata: {
source: 'sam2_propagation',
propagated_from_frame_id: 1,
source_annotation_id: 7,
source_mask_id: 'annotation-7',
propagation_seed_key: 'annotation:7',
geometry_smoothing_preview: { strength: 35, method: 'chaikin' },
},
}, frame)).toEqual(expect.objectContaining({
mask_data: expect.objectContaining({
source: 'sam2_propagation',
propagated_from_frame_id: 1,
source_annotation_id: 7,
source_mask_id: 'annotation-7',
propagation_seed_key: 'annotation:7',
}),
}));
expect(buildAnnotationPayload('9', {
id: 'm1',
frameId: '5',
pathData: 'M 10 10 L 90 10 L 90 40 Z',
label: 'Tracked',
color: '#22c55e',
segmentation: [[10, 10, 90, 10, 90, 40]],
metadata: { geometry_smoothing_preview: { strength: 35, method: 'chaikin' } },
}, frame)?.mask_data).not.toHaveProperty('geometry_smoothing_preview');
});
it('normalizes positive and negative point prompts for AI prediction', async () => {
const { predictMask } = await import('./api');
axiosMock.client.post.mockResolvedValueOnce({

View File

@@ -1,6 +1,7 @@
import axios, { AxiosError } from 'axios';
import { DEFAULT_AI_MODEL_ID, type AiModelId, type Frame, type Mask, type Project, type Template } from '../store/useStore';
import { DEFAULT_AI_MODEL_ID, type AiModelId, type Frame, type Mask, type Project, type Template, type UserProfile } from '../store/useStore';
import { API_BASE_URL } from './config';
import { normalizeClassMaskIds } from './maskIds';
const apiClient = axios.create({
baseURL: API_BASE_URL,
@@ -28,18 +29,88 @@ apiClient.interceptors.response.use(
(error: AxiosError) => {
if (error.response?.status === 401) {
localStorage.removeItem('token');
window.location.reload();
if (!error.config?.url?.includes('/api/auth/login')) {
window.location.reload();
}
}
return Promise.reject(error);
}
);
// Auth
export async function login(username: string, password: string): Promise<{ token: string }> {
export async function login(username: string, password: string): Promise<{ token: string; username: string; user: UserProfile }> {
const response = await apiClient.post('/api/auth/login', { username, password });
return response.data;
}
export async function getCurrentUser(): Promise<UserProfile> {
const response = await apiClient.get('/api/auth/me');
return response.data;
}
export interface AdminUser extends UserProfile {
is_active: number;
}
export interface AuditLog {
id: number;
actor_user_id?: number | null;
action: string;
target_type?: string | null;
target_id?: string | null;
detail?: Record<string, any> | null;
created_at: string;
}
export interface DemoFactoryResetResult {
admin_user: AdminUser;
project: Project;
deleted_counts: Record<string, number>;
message: string;
}
export async function getAdminUsers(): Promise<AdminUser[]> {
const response = await apiClient.get('/api/admin/users');
return response.data;
}
export async function createAdminUser(payload: {
username: string;
password: string;
role: string;
is_active: boolean;
}): Promise<AdminUser> {
const response = await apiClient.post('/api/admin/users', payload);
return response.data;
}
export async function updateAdminUser(id: number, payload: {
username?: string;
password?: string;
role?: string;
is_active?: boolean;
}): Promise<AdminUser> {
const response = await apiClient.patch(`/api/admin/users/${id}`, payload);
return response.data;
}
export async function deleteAdminUser(id: number): Promise<void> {
await apiClient.delete(`/api/admin/users/${id}`);
}
export async function getAuditLogs(limit = 100): Promise<AuditLog[]> {
const response = await apiClient.get('/api/admin/audit-logs', { params: { limit } });
return response.data;
}
export async function resetDemoFactory(confirmation: string): Promise<DemoFactoryResetResult> {
const response = await apiClient.post('/api/admin/demo-factory-reset', { confirmation });
return {
...response.data,
project: mapProject(response.data.project),
};
}
// Projects
function normalizeProjectStatus(status?: string): Project['status'] {
const value = (status || 'pending').toLowerCase();
@@ -103,7 +174,7 @@ function _mapTemplate(t: any): Template {
id: String(t.id),
name: t.name,
description: t.description,
classes: mapping.classes || [],
classes: normalizeClassMaskIds(mapping.classes || []),
rules: mapping.rules || [],
createdAt: t.created_at,
updatedAt: t.updated_at,
@@ -120,7 +191,7 @@ export async function createTemplate(payload: {
description?: string;
color: string;
z_index: number;
classes?: { name: string; color: string; zIndex: number; category?: string }[];
classes?: { name: string; color: string; zIndex: number; maskId?: number; category?: string }[];
rules?: any[];
}): Promise<Template> {
const response = await apiClient.post('/api/templates', payload);
@@ -298,6 +369,7 @@ export interface SavedAnnotation {
name?: string;
color?: string;
zIndex?: number;
maskId?: number;
category?: string;
};
source?: string;
@@ -326,6 +398,7 @@ export interface SaveAnnotationPayload {
name?: string;
color?: string;
zIndex?: number;
maskId?: number;
category?: string;
};
geometry_smoothing?: GeometrySmoothingOptions;
@@ -505,6 +578,16 @@ function normalizeGeometrySmoothing(value: unknown): GeometrySmoothingOptions |
};
}
function persistableMaskMetadata(metadata?: Record<string, unknown>): Record<string, unknown> {
if (!metadata) return {};
const {
geometry_smoothing: _geometrySmoothing,
geometry_smoothing_preview: _geometrySmoothingPreview,
...rest
} = metadata;
return rest;
}
function pixelSegmentationToNormalizedPolygons(
segmentation: number[][] | undefined,
width: number,
@@ -534,21 +617,24 @@ export function buildAnnotationPayload(
const polygons = pixelSegmentationToNormalizedPolygons(mask.segmentation, frame.width, frame.height);
if (polygons.length === 0) return null;
const effectiveTemplateId = mask.templateId || templateId || undefined;
const classMetadata = mask.classId || mask.className || mask.classZIndex !== undefined
const classMetadata = mask.classId || mask.className || mask.classZIndex !== undefined || mask.classMaskId !== undefined
? {
id: mask.classId,
name: mask.className || mask.label,
color: mask.color,
zIndex: mask.classZIndex,
maskId: mask.classMaskId,
}
: undefined;
const geometrySmoothing = normalizeGeometrySmoothing(mask.metadata?.geometry_smoothing);
const metadata = persistableMaskMetadata(mask.metadata);
const payload: SaveAnnotationPayload = {
project_id: Number(projectId),
frame_id: Number(frame.id),
template_id: effectiveTemplateId ? Number(effectiveTemplateId) : undefined,
mask_data: {
...metadata,
polygons,
label: mask.label,
color: mask.color,
@@ -591,6 +677,7 @@ export function annotationToMask(annotation: SavedAnnotation, frame: Frame): Mas
classId: classMetadata?.id,
className: classMetadata?.name,
classZIndex: classMetadata?.zIndex,
classMaskId: classMetadata?.maskId,
saveStatus: 'saved',
saved: true,
pathData: polygonToPath(firstPolygon, frame.width, frame.height),
@@ -785,12 +872,14 @@ export async function importGtMask(
projectId: string,
frameId: string,
templateId?: string | null,
options: { unknownColorPolicy?: 'discard' | 'undefined' } = {},
): Promise<SavedAnnotation[]> {
const formData = new FormData();
formData.append('file', file);
formData.append('project_id', projectId);
formData.append('frame_id', frameId);
if (templateId) formData.append('template_id', templateId);
if (options.unknownColorPolicy) formData.append('unknown_color_policy', options.unknownColorPolicy);
const response = await apiClient.post('/api/ai/import-gt-mask', formData, {
headers: { 'Content-Type': 'multipart/form-data' },
});
@@ -817,4 +906,37 @@ export async function exportMasks(projectId: string): Promise<Blob> {
return response.data;
}
export type SegmentationExportScope = 'all' | 'range' | 'current';
export type SegmentationMaskType = 'separate' | 'gt_label' | 'both';
export type SegmentationExportOutput = 'separate' | 'gt_label' | 'pro_label' | 'mix_label';
export interface ExportSegmentationResultsOptions {
scope: SegmentationExportScope;
maskType?: SegmentationMaskType;
outputs?: SegmentationExportOutput[];
mixOpacity?: number;
startFrame?: number;
endFrame?: number;
frameId?: string;
}
export async function exportSegmentationResults(
projectId: string,
options: ExportSegmentationResultsOptions,
): Promise<Blob> {
const response = await apiClient.get(`/api/export/${projectId}/results`, {
params: {
scope: options.scope,
mask_type: options.maskType,
outputs: options.outputs?.join(','),
mix_opacity: options.mixOpacity,
start_frame: options.startFrame,
end_frame: options.endFrame,
frame_id: options.frameId ? Number(options.frameId) : undefined,
},
responseType: 'blob',
});
return response.data;
}
export default apiClient;

34
src/lib/maskIds.ts Normal file
View File

@@ -0,0 +1,34 @@
import type { TemplateClass } from '../store/useStore';
export function normalizeClassMaskIds(classes: TemplateClass[] = []): TemplateClass[] {
const used = new Set<number>();
let nextMaskId = 1;
const nextAvailableMaskId = () => {
while (used.has(nextMaskId)) nextMaskId += 1;
const value = nextMaskId;
used.add(value);
nextMaskId += 1;
return value;
};
return classes.map((templateClass) => {
const parsed = Number(templateClass.maskId);
if (Number.isInteger(parsed) && parsed > 0 && !used.has(parsed)) {
used.add(parsed);
return { ...templateClass, maskId: parsed };
}
return { ...templateClass, maskId: nextAvailableMaskId() };
});
}
export function nextClassMaskId(classes: TemplateClass[] = []): number {
const used = new Set(
classes
.map((templateClass) => Number(templateClass.maskId))
.filter((value) => Number.isInteger(value) && value > 0),
);
let value = 1;
while (used.has(value)) value += 1;
return value;
}

View File

@@ -94,4 +94,34 @@ describe('progress websocket client', () => {
unsubscribeStatus();
progressWS.disconnect();
});
it('does not reconnect after an intentional disconnect', async () => {
vi.useFakeTimers();
const instances: any[] = [];
class FakeWebSocket {
static CONNECTING = 0;
static OPEN = 1;
readyState = FakeWebSocket.OPEN;
onopen?: () => void;
onmessage?: (event: MessageEvent) => void;
onclose?: () => void;
onerror?: () => void;
constructor(public url: string) {
instances.push(this);
}
close = vi.fn();
send = vi.fn();
}
vi.stubGlobal('WebSocket', FakeWebSocket);
const { progressWS } = await import('./websocket');
progressWS.connect();
instances[0].onopen?.();
progressWS.disconnect();
instances[0].onclose?.();
vi.advanceTimersByTime(30000);
expect(instances).toHaveLength(1);
});
});

View File

@@ -30,6 +30,7 @@ class ProgressWebSocket {
private heartbeatInterval = 15000;
private shouldReconnect = false;
private shouldCloseAfterOpen = false;
private manualDisconnect = false;
private currentInterval = 3000;
constructor(url = WS_PROGRESS_URL) {
@@ -43,6 +44,7 @@ class ProgressWebSocket {
this.shouldReconnect = true;
this.shouldCloseAfterOpen = false;
this.manualDisconnect = false;
this.notifyStatus('connecting');
try {
@@ -71,7 +73,9 @@ class ProgressWebSocket {
};
this.ws.onclose = () => {
console.log('[WebSocket] Connection closed');
if (!this.manualDisconnect) {
console.log('[WebSocket] Connection closed');
}
this.stopHeartbeat();
this.ws = null;
this.notifyStatus('disconnected');
@@ -97,6 +101,7 @@ class ProgressWebSocket {
disconnect() {
this.shouldReconnect = false;
this.manualDisconnect = true;
this.stopHeartbeat();
if (this.reconnectTimer) {
clearTimeout(this.reconnectTimer);

View File

@@ -8,15 +8,17 @@ describe('useStore', () => {
});
it('stores and clears auth state with localStorage', () => {
useStore.getState().login('token-1');
useStore.getState().login('token-1', { id: 1, username: 'admin', role: 'admin' });
expect(useStore.getState().isAuthenticated).toBe(true);
expect(useStore.getState().token).toBe('token-1');
expect(useStore.getState().currentUser?.username).toBe('admin');
expect(localStorage.getItem('token')).toBe('token-1');
useStore.getState().logout();
expect(useStore.getState().isAuthenticated).toBe(false);
expect(useStore.getState().currentUser).toBeNull();
expect(useStore.getState().projects).toEqual([]);
expect(useStore.getState().frames).toEqual([]);
expect(localStorage.getItem('token')).toBeNull();
@@ -32,6 +34,8 @@ describe('useStore', () => {
useStore.getState().addMask({ id: 'm1', frameId: 'f1', pathData: 'M 0 0 Z', label: 'mask', color: '#fff' });
useStore.getState().setSelectedMaskIds(['m1']);
useStore.getState().setMaskPreviewOpacity(35);
useStore.getState().setBrushSize(36);
useStore.getState().setEraserSize(44);
useStore.getState().updateMask('m1', { label: 'updated mask', saveStatus: 'dirty' });
useStore.getState().addAnnotation({ id: 'a1', frameId: 'f1', type: 'mask', points: [], label: 'ann', color: '#fff' });
useStore.getState().addTemplate({ id: 't1', name: 'Template', classes: [], rules: [] });
@@ -44,6 +48,8 @@ describe('useStore', () => {
expect(useStore.getState().currentFrameIndex).toBe(0);
expect(useStore.getState().selectedMaskIds).toEqual(['m1']);
expect(useStore.getState().maskPreviewOpacity).toBe(35);
expect(useStore.getState().brushSize).toBe(36);
expect(useStore.getState().eraserSize).toBe(44);
expect(useStore.getState().masks[0]).toEqual(expect.objectContaining({ label: 'updated mask', saveStatus: 'dirty' }));
expect(useStore.getState().annotations).toHaveLength(1);
expect(useStore.getState().templates[0].name).toBe('Template 2');

View File

@@ -24,6 +24,8 @@ export type AiModelId =
| 'sam2.1_hiera_large';
export const DEFAULT_AI_MODEL_ID: AiModelId = 'sam2.1_hiera_tiny';
export const DEFAULT_BRUSH_SIZE = 24;
export const DEFAULT_ERASER_SIZE = 28;
export const SAM2_MODEL_OPTIONS: Array<{ id: AiModelId; label: string; shortLabel: string }> = [
{ id: 'sam2.1_hiera_tiny', label: 'SAM 2.1 Tiny', shortLabel: 'tiny' },
@@ -64,6 +66,7 @@ export interface Mask {
classId?: string;
className?: string;
classZIndex?: number;
classMaskId?: number;
saveStatus?: 'draft' | 'saved' | 'dirty' | 'saving' | 'error';
saved?: boolean;
pathData: string;
@@ -92,6 +95,7 @@ export interface TemplateClass {
name: string;
color: string;
zIndex: number;
maskId?: number;
category?: string;
description?: string;
}
@@ -104,11 +108,20 @@ export interface TemplateRule {
operation: string;
}
export interface UserProfile {
id: number;
username: string;
role: string;
is_active?: number;
}
export interface AppState {
// Auth
isAuthenticated: boolean;
token: string | null;
login: (token: string) => void;
currentUser: UserProfile | null;
login: (token: string, user?: UserProfile | null) => void;
setCurrentUser: (user: UserProfile | null) => void;
logout: () => void;
// Projects
@@ -129,6 +142,8 @@ export interface AppState {
masks: Mask[];
selectedMaskIds: string[];
maskPreviewOpacity: number;
brushSize: number;
eraserSize: number;
maskHistory: Mask[][];
maskFuture: Mask[][];
setActiveModule: (module: string) => void;
@@ -142,6 +157,8 @@ export interface AppState {
setMasks: (masks: Mask[]) => void;
setSelectedMaskIds: (ids: string[]) => void;
setMaskPreviewOpacity: (opacity: number) => void;
setBrushSize: (size: number) => void;
setEraserSize: (size: number) => void;
clearMasks: () => void;
undoMasks: () => void;
redoMasks: () => void;
@@ -169,17 +186,20 @@ export interface AppState {
export const useStore = create<AppState>((set) => ({
// Auth
isAuthenticated: false,
token: null,
login: (token: string) => {
isAuthenticated: Boolean(localStorage.getItem('token')),
token: localStorage.getItem('token'),
currentUser: null,
login: (token: string, user: UserProfile | null = null) => {
localStorage.setItem('token', token);
set({ isAuthenticated: true, token });
set({ isAuthenticated: true, token, currentUser: user });
},
setCurrentUser: (currentUser: UserProfile | null) => set({ currentUser }),
logout: () => {
localStorage.removeItem('token');
set({
isAuthenticated: false,
token: null,
currentUser: null,
currentProject: null,
projects: [],
templates: [],
@@ -188,6 +208,8 @@ export const useStore = create<AppState>((set) => ({
masks: [],
selectedMaskIds: [],
maskPreviewOpacity: 50,
brushSize: DEFAULT_BRUSH_SIZE,
eraserSize: DEFAULT_ERASER_SIZE,
maskHistory: [],
maskFuture: [],
activeTemplateId: null,
@@ -218,6 +240,8 @@ export const useStore = create<AppState>((set) => ({
masks: [],
selectedMaskIds: [],
maskPreviewOpacity: 50,
brushSize: DEFAULT_BRUSH_SIZE,
eraserSize: DEFAULT_ERASER_SIZE,
maskHistory: [],
maskFuture: [],
setActiveModule: (activeModule: string) => set({ activeModule }),
@@ -254,6 +278,12 @@ export const useStore = create<AppState>((set) => ({
setMaskPreviewOpacity: (maskPreviewOpacity: number) => set({
maskPreviewOpacity: Math.min(Math.max(maskPreviewOpacity, 10), 100),
}),
setBrushSize: (brushSize: number) => set({
brushSize: Math.round(Math.min(Math.max(brushSize, 4), 96)),
}),
setEraserSize: (eraserSize: number) => set({
eraserSize: Math.round(Math.min(Math.max(eraserSize, 4), 128)),
}),
clearMasks: () =>
set((state) => ({
masks: [],

View File

@@ -1,9 +1,10 @@
import { DEFAULT_AI_MODEL_ID, useStore } from '../store/useStore';
import { DEFAULT_AI_MODEL_ID, DEFAULT_BRUSH_SIZE, DEFAULT_ERASER_SIZE, useStore } from '../store/useStore';
export function resetStore() {
useStore.setState({
isAuthenticated: false,
token: null,
currentUser: null,
projects: [],
currentProject: null,
activeModule: 'workspace',
@@ -15,6 +16,8 @@ export function resetStore() {
masks: [],
selectedMaskIds: [],
maskPreviewOpacity: 50,
brushSize: DEFAULT_BRUSH_SIZE,
eraserSize: DEFAULT_ERASER_SIZE,
maskHistory: [],
maskFuture: [],
templates: [],