Files
Pre_Seg_Server/src/components/AISegmentation.test.tsx
admin c8c59f7ede feat: 完善分割工作区传播与交互闭环
功能增加:新增后端传播任务执行器,支持异步自动传播、传播进度、结果统计、取消/重试状态同步。

功能增加:传播请求支持指定 SAM2.1 tiny/small/base+/large 权重,并记录 seed mask、source annotation 和传播范围。

功能增加:传播逻辑增加 seed 签名,未变化的 mask 二次传播会跳过,已变化的 mask 会先清理旧自动传播结果再重新生成,避免重复重叠。

功能增加:工作区增加传播范围二次选择、传播进度提示、人工/AI 标注帧红色标识、自动传播帧蓝色标识和当前帧双层边框。

功能增加:新增临时提示组件,让工具操作提示自动消失且不阻塞后续操作。

功能增加:补充项目删除、模板删除、任务失败详情、任务取消/重试等前后端联动状态。

功能增加:新增安装部署文档,补充当前需求冻结、设计冻结、接口契约、测试计划和 AGENTS/README 项目说明。

Bugfix:修复自动传播接口 404、传播后看不到任务进度、传播结果重复堆叠和已编辑帧提示不清晰的问题。

Bugfix:修复 AI 分割框选/点选交互、单候选 mask、删除选点、工作区保存与候选 mask 推送相关问题。

Bugfix:修复 Canvas 多边形顶点拖动告警、工具栏提示缺失、项目库 FPS 展示和若干 UI 文案/可用性问题。

测试:补充 AI 分割、Canvas、Dashboard、FrameTimeline、ProjectLibrary、TemplateRegistry、ToolsPalette、VideoWorkspace、API 和后端任务/AI/dashboard 测试。

验证:npm run lint;npm run test:run;python -m pytest backend/tests -q。
2026-05-02 05:17:18 +08:00

589 lines
22 KiB
TypeScript

import { fireEvent, render, screen, waitFor } from '@testing-library/react';
import { beforeEach, describe, expect, it, vi } from 'vitest';
import { resetStore } from '../test/storeTestUtils';
import { useStore } from '../store/useStore';
import { AISegmentation } from './AISegmentation';
const apiMock = vi.hoisted(() => ({
getAiModelStatus: vi.fn(),
predictMask: vi.fn(),
}));
vi.mock('../lib/api', () => ({
getAiModelStatus: apiMock.getAiModelStatus,
predictMask: apiMock.predictMask,
}));
describe('AISegmentation', () => {
beforeEach(() => {
resetStore();
vi.clearAllMocks();
useStore.setState({
frames: [{ id: 'frame-1', projectId: 'project-1', index: 0, url: '/frame.jpg', width: 640, height: 360 }],
});
apiMock.getAiModelStatus.mockResolvedValue({
selected_model: 'sam2.1_hiera_tiny',
gpu: { available: true, device: 'cuda', name: 'RTX 4090', torch_available: true },
models: [
{ id: 'sam2.1_hiera_tiny', label: 'SAM 2.1 Tiny', available: true, loaded: false, device: 'cuda', supports: ['point', 'box'], message: 'SAM 2.1 Tiny ready', package_available: true, checkpoint_exists: true, python_ok: true, torch_ok: true, cuda_required: false },
],
});
});
it('shows the SAM2.1 variant selector without exposing SAM3', async () => {
render(<AISegmentation onSendToWorkspace={vi.fn()} />);
expect(await screen.findByText('SAM 2.1 Tiny')).toBeInTheDocument();
expect(screen.getByText('tiny')).toBeInTheDocument();
expect(screen.getByText('small')).toBeInTheDocument();
expect(screen.getByText('base+')).toBeInTheDocument();
expect(screen.getByText('large')).toBeInTheDocument();
expect(screen.queryByText('SAM3')).not.toBeInTheDocument();
expect(apiMock.getAiModelStatus).toHaveBeenCalledWith('sam2.1_hiera_tiny');
});
it('does not render the legacy upload-replace-background mock button', () => {
render(<AISegmentation onSendToWorkspace={vi.fn()} />);
expect(screen.queryByText('上传替换底图')).not.toBeInTheDocument();
});
it('shows an empty state instead of a demo image when no project frame is selected', () => {
useStore.setState({ frames: [] });
render(<AISegmentation onSendToWorkspace={vi.fn()} />);
expect(screen.getByText('请先在项目库选择项目并生成帧')).toBeInTheDocument();
});
it('shows contextual guidance for prompt tools', () => {
render(<AISegmentation onSendToWorkspace={vi.fn()} />);
fireEvent.click(screen.getByText('正向选点'));
expect(screen.getByText(/点击目标内部添加正向点/)).toBeInTheDocument();
fireEvent.click(screen.getByText('边界框选'));
expect(screen.getByText(/按住并拖拽建立框选区域/)).toBeInTheDocument();
});
it('passes enabled inference parameters to the backend', async () => {
apiMock.predictMask.mockResolvedValueOnce({ masks: [] });
render(<AISegmentation onSendToWorkspace={vi.fn()} />);
expect(screen.getByText('局部专注模式(自动裁剪无锚区域)')).toBeInTheDocument();
expect(screen.getByText('严格除杂模式(自动清理干涉点)')).toBeInTheDocument();
fireEvent.click(screen.getByText('正向选点'));
fireEvent.click(screen.getByTestId('konva-stage'));
fireEvent.click(await screen.findByText('执行高精度语义分割'));
expect(apiMock.predictMask).toHaveBeenCalledWith(expect.objectContaining({
imageId: 'frame-1',
imageWidth: 640,
imageHeight: 360,
model: 'sam2.1_hiera_tiny',
points: [{ x: 120, y: 80, type: 'pos' }],
options: {
crop_to_prompt: false,
auto_filter_background: true,
min_score: 0.05,
},
}));
});
it('sends the selected SAM2.1 variant to prediction', async () => {
apiMock.predictMask.mockResolvedValueOnce({ masks: [] });
render(<AISegmentation onSendToWorkspace={vi.fn()} />);
fireEvent.click(await screen.findByText('small'));
fireEvent.click(screen.getByText('正向选点'));
fireEvent.click(screen.getByTestId('konva-stage'));
fireEvent.click(await screen.findByText('执行高精度语义分割'));
expect(apiMock.getAiModelStatus).toHaveBeenCalledWith('sam2.1_hiera_small');
expect(apiMock.predictMask).toHaveBeenCalledWith(expect.objectContaining({
model: 'sam2.1_hiera_small',
}));
});
it('does not render masks that were created in the workspace', async () => {
useStore.setState({
masks: [
{
id: 'workspace-mask',
frameId: 'frame-1',
pathData: 'M 0 0 L 10 0 L 10 10 Z',
label: 'Manual Mask',
color: '#ff0000',
segmentation: [[0, 0, 10, 0, 10, 10]],
metadata: { source: 'manual' },
},
],
selectedMaskIds: ['workspace-mask'],
});
render(<AISegmentation onSendToWorkspace={vi.fn()} />);
expect(screen.queryAllByTestId('konva-path')).toHaveLength(0);
await waitFor(() => expect(useStore.getState().selectedMaskIds).toEqual([]));
});
it('requires point prompts before running SAM2 inference', async () => {
render(<AISegmentation onSendToWorkspace={vi.fn()} />);
fireEvent.click(await screen.findByText('执行高精度语义分割'));
expect(apiMock.predictMask).not.toHaveBeenCalled();
expect(await screen.findByText('请先放置正/反向提示点或框选区域。')).toBeInTheDocument();
});
it('uses a dragged box prompt for AI page inference without adding a point on click', async () => {
apiMock.predictMask.mockResolvedValueOnce({ masks: [] });
render(<AISegmentation onSendToWorkspace={vi.fn()} />);
fireEvent.click(screen.getByText('边界框选'));
const stage = screen.getByTestId('konva-stage');
fireEvent.mouseDown(stage, { clientX: 120, clientY: 80 });
fireEvent.mouseMove(stage, { clientX: 260, clientY: 200 });
fireEvent.mouseUp(stage, { clientX: 260, clientY: 200 });
expect(screen.getByTestId('konva-rect')).toHaveAttribute('data-width', '140');
expect(await screen.findByText('已框选区域,可执行分割,或继续添加正/反向点细化。')).toBeInTheDocument();
expect(screen.queryAllByTestId('konva-circle')).toHaveLength(0);
expect(apiMock.predictMask).not.toHaveBeenCalled();
fireEvent.click(await screen.findByText('执行高精度语义分割'));
expect(apiMock.predictMask).toHaveBeenCalledWith(expect.objectContaining({
imageId: 'frame-1',
imageWidth: 640,
imageHeight: 360,
model: 'sam2.1_hiera_tiny',
points: undefined,
box: { x1: 120, y1: 80, x2: 260, y2: 200 },
options: {
crop_to_prompt: false,
auto_filter_background: true,
min_score: 0.05,
},
}));
});
it('handles stage drag end for move-tool canvas panning', () => {
render(<AISegmentation onSendToWorkspace={vi.fn()} />);
expect(screen.getByTestId('konva-stage')).toHaveAttribute('data-has-drag-end', 'true');
});
it('centers the active frame with a large default fit inside the AI canvas', async () => {
Object.defineProperty(HTMLElement.prototype, 'clientWidth', { configurable: true, get: () => 1000 });
Object.defineProperty(HTMLElement.prototype, 'clientHeight', { configurable: true, get: () => 700 });
render(<AISegmentation onSendToWorkspace={vi.fn()} />);
await waitFor(() => {
const stage = screen.getByTestId('konva-stage');
expect(Number(stage.getAttribute('data-scale-x'))).toBeCloseTo(1.34375, 4);
expect(Number(stage.getAttribute('data-x'))).toBeCloseTo(70, 0);
expect(Number(stage.getAttribute('data-y'))).toBeCloseTo(108, 0);
});
});
it('combines the AI page box prompt with later positive and negative refinement points', async () => {
apiMock.predictMask.mockResolvedValueOnce({ masks: [] });
render(<AISegmentation onSendToWorkspace={vi.fn()} />);
fireEvent.click(screen.getByText('边界框选'));
const stage = screen.getByTestId('konva-stage');
fireEvent.mouseDown(stage, { clientX: 100, clientY: 60 });
fireEvent.mouseMove(stage, { clientX: 300, clientY: 180 });
fireEvent.mouseUp(stage, { clientX: 300, clientY: 180 });
fireEvent.click(screen.getByText('正向选点'));
fireEvent.click(stage, { clientX: 160, clientY: 100 });
fireEvent.click(screen.getByText('反向选点'));
fireEvent.click(stage, { clientX: 260, clientY: 150 });
fireEvent.click(await screen.findByText('执行高精度语义分割'));
expect(apiMock.predictMask).toHaveBeenCalledWith(expect.objectContaining({
points: [
{ x: 160, y: 100, type: 'pos' },
{ x: 260, y: 150, type: 'neg' },
],
box: { x1: 100, y1: 60, x2: 300, y2: 180 },
}));
});
it('replaces the previous AI page candidate when running the same box prompt again', async () => {
useStore.setState({
masks: [
{
id: 'workspace-mask',
frameId: 'frame-1',
pathData: 'M 0 0 L 10 0 L 10 10 Z',
label: 'Manual Mask',
color: '#ff0000',
segmentation: [[0, 0, 10, 0, 10, 10]],
metadata: { source: 'manual' },
},
],
});
apiMock.predictMask
.mockResolvedValueOnce({
masks: [
{
id: 'sam2-first',
pathData: 'M 10 10 L 40 10 L 40 40 Z',
label: 'AI Mask',
color: '#06b6d4',
segmentation: [[10, 10, 40, 10, 40, 40]],
bbox: [10, 10, 30, 30],
area: 900,
},
],
})
.mockResolvedValueOnce({
masks: [
{
id: 'sam2-second',
pathData: 'M 20 20 L 50 20 L 50 50 Z',
label: 'AI Mask',
color: '#06b6d4',
segmentation: [[20, 20, 50, 20, 50, 50]],
bbox: [20, 20, 30, 30],
area: 900,
},
],
});
render(<AISegmentation onSendToWorkspace={vi.fn()} />);
fireEvent.click(screen.getByText('边界框选'));
const stage = screen.getByTestId('konva-stage');
fireEvent.mouseDown(stage, { clientX: 120, clientY: 80 });
fireEvent.mouseMove(stage, { clientX: 260, clientY: 200 });
fireEvent.mouseUp(stage, { clientX: 260, clientY: 200 });
fireEvent.click(await screen.findByText('执行高精度语义分割'));
await waitFor(() => expect(useStore.getState().masks.map((mask) => mask.id)).toEqual(['workspace-mask', 'sam2-first']));
expect(useStore.getState().selectedMaskIds).toEqual(['sam2-first']);
fireEvent.click(screen.getByText('执行高精度语义分割'));
await waitFor(() => expect(useStore.getState().masks.map((mask) => mask.id)).toEqual(['workspace-mask', 'sam2-second']));
expect(useStore.getState().selectedMaskIds).toEqual(['sam2-second']);
expect(screen.getAllByTestId('konva-path')).toHaveLength(1);
});
it('deletes prompt points individually and can remove the latest point', async () => {
apiMock.predictMask.mockResolvedValueOnce({ masks: [] });
render(<AISegmentation onSendToWorkspace={vi.fn()} />);
fireEvent.click(screen.getByText('正向选点'));
fireEvent.click(screen.getByTestId('konva-stage'), { clientX: 120, clientY: 80 });
fireEvent.click(screen.getByText('反向选点'));
fireEvent.click(screen.getByTestId('konva-stage'), { clientX: 220, clientY: 140 });
await waitFor(() => expect(screen.getAllByTestId('konva-circle')).toHaveLength(4));
fireEvent.click(screen.getAllByTestId('konva-circle')[0]);
await waitFor(() => expect(screen.getAllByTestId('konva-circle')).toHaveLength(2));
fireEvent.click(await screen.findByText('执行高精度语义分割'));
expect(apiMock.predictMask).toHaveBeenCalledWith(expect.objectContaining({
points: [{ x: 220, y: 140, type: 'neg' }],
}));
fireEvent.click(screen.getByLabelText('删除最近锚点'));
await waitFor(() => expect(screen.queryAllByTestId('konva-circle')).toHaveLength(0));
});
it('keeps only the best SAM2 candidate when the backend returns overlapping alternatives', async () => {
apiMock.predictMask.mockResolvedValueOnce({
masks: [
{
id: 'sam2-best',
pathData: 'M 0 0 L 10 0 L 10 10 Z',
label: 'AI Mask',
color: '#06b6d4',
segmentation: [[0, 0, 10, 0, 10, 10]],
bbox: [0, 0, 10, 10],
area: 100,
},
{
id: 'sam2-alt',
pathData: 'M 1 1 L 11 1 L 11 11 Z',
label: 'AI Mask',
color: '#06b6d4',
segmentation: [[1, 1, 11, 1, 11, 11]],
bbox: [1, 1, 10, 10],
area: 100,
},
],
});
render(<AISegmentation onSendToWorkspace={vi.fn()} />);
fireEvent.click(screen.getByText('正向选点'));
fireEvent.click(screen.getByTestId('konva-stage'));
fireEvent.click(await screen.findByText('执行高精度语义分割'));
await waitFor(() => expect(useStore.getState().masks).toHaveLength(1));
expect(useStore.getState().masks[0].id).toBe('sam2-best');
expect(useStore.getState().masks[0].metadata).toEqual(expect.objectContaining({ source: 'ai_segmentation' }));
expect(useStore.getState().selectedMaskIds).toEqual(['sam2-best']);
expect(await screen.findByText('SAM 2.1 Tiny 返回 2 个候选,已采用最高分区域。')).toBeInTheDocument();
});
it('adjusts the AI mask preview opacity without changing mask data', async () => {
apiMock.predictMask.mockResolvedValueOnce({
masks: [
{
id: 'sam2-mask',
pathData: 'M 10 10 L 40 10 L 40 40 Z',
label: 'AI Mask',
color: '#06b6d4',
segmentation: [[10, 10, 40, 10, 40, 40]],
bbox: [10, 10, 30, 30],
area: 900,
},
],
});
render(<AISegmentation onSendToWorkspace={vi.fn()} />);
fireEvent.click(screen.getByText('正向选点'));
fireEvent.click(screen.getByTestId('konva-stage'));
fireEvent.click(await screen.findByText('执行高精度语义分割'));
await waitFor(() => expect(screen.getByTestId('konva-path')).toBeInTheDocument());
const maskGroup = () => screen.getAllByTestId('konva-group').find((group) => group.getAttribute('data-opacity'));
expect(maskGroup()).toHaveAttribute('data-opacity', '0.72');
fireEvent.change(screen.getByLabelText('遮罩清晰度'), { target: { value: '35' } });
expect(maskGroup()).toHaveAttribute('data-opacity', '0.35');
expect(useStore.getState().masks[0].segmentation).toEqual([[10, 10, 40, 10, 40, 40]]);
});
it('lets positive and negative prompt points be added on top of an AI mask', async () => {
apiMock.predictMask
.mockResolvedValueOnce({
masks: [
{
id: 'sam2-mask',
pathData: 'M 10 10 L 40 10 L 40 40 Z',
label: 'AI Mask',
color: '#06b6d4',
segmentation: [[10, 10, 40, 10, 40, 40]],
bbox: [10, 10, 30, 30],
area: 900,
},
],
})
.mockResolvedValueOnce({ masks: [] });
render(<AISegmentation onSendToWorkspace={vi.fn()} />);
fireEvent.click(screen.getByText('正向选点'));
fireEvent.click(screen.getByTestId('konva-stage'), { clientX: 120, clientY: 80 });
fireEvent.click(await screen.findByText('执行高精度语义分割'));
await waitFor(() => expect(screen.getByTestId('konva-path')).toBeInTheDocument());
fireEvent.click(screen.getByText('反向选点'));
fireEvent.click(screen.getByTestId('konva-path'), { clientX: 220, clientY: 140 });
await waitFor(() => expect(screen.getAllByTestId('konva-circle')).toHaveLength(4));
fireEvent.click(screen.getByText('执行高精度语义分割'));
expect(apiMock.predictMask).toHaveBeenLastCalledWith(expect.objectContaining({
points: [
{ x: 120, y: 80, type: 'pos' },
{ x: 220, y: 140, type: 'neg' },
],
}));
expect(useStore.getState().selectedMaskIds).toEqual(['sam2-mask']);
});
it('clears only AI page candidates and keeps workspace masks in the store', async () => {
useStore.setState({
masks: [
{
id: 'workspace-mask',
frameId: 'frame-1',
pathData: 'M 0 0 L 10 0 L 10 10 Z',
label: 'Manual Mask',
color: '#ff0000',
segmentation: [[0, 0, 10, 0, 10, 10]],
metadata: { source: 'manual' },
},
],
});
apiMock.predictMask.mockResolvedValueOnce({
masks: [
{
id: 'sam2-mask',
pathData: 'M 10 10 L 40 10 L 40 40 Z',
label: 'AI Mask',
color: '#06b6d4',
segmentation: [[10, 10, 40, 10, 40, 40]],
bbox: [10, 10, 30, 30],
area: 900,
},
],
});
render(<AISegmentation onSendToWorkspace={vi.fn()} />);
fireEvent.click(screen.getByText('正向选点'));
fireEvent.click(screen.getByTestId('konva-stage'));
await waitFor(() => expect(screen.getAllByTestId('konva-circle')).toHaveLength(2));
fireEvent.click(await screen.findByText('执行高精度语义分割'));
await waitFor(() => expect(useStore.getState().masks.map((mask) => mask.id)).toEqual(['workspace-mask', 'sam2-mask']));
fireEvent.click(screen.getByText('清空全体锚点'));
expect(useStore.getState().masks.map((mask) => mask.id)).toEqual(['workspace-mask']);
expect(useStore.getState().selectedMaskIds).toEqual([]);
});
it('deletes only the selected AI candidate and preserves workspace masks', async () => {
useStore.setState({
masks: [
{
id: 'workspace-mask',
frameId: 'frame-1',
pathData: 'M 0 0 L 10 0 L 10 10 Z',
label: 'Manual Mask',
color: '#ff0000',
segmentation: [[0, 0, 10, 0, 10, 10]],
metadata: { source: 'manual' },
},
],
});
apiMock.predictMask.mockResolvedValueOnce({
masks: [
{
id: 'sam2-mask',
pathData: 'M 10 10 L 40 10 L 40 40 Z',
label: 'AI Mask',
color: '#06b6d4',
segmentation: [[10, 10, 40, 10, 40, 40]],
bbox: [10, 10, 30, 30],
area: 900,
},
],
});
render(<AISegmentation onSendToWorkspace={vi.fn()} />);
fireEvent.click(screen.getByText('正向选点'));
fireEvent.click(screen.getByTestId('konva-stage'));
fireEvent.click(await screen.findByText('执行高精度语义分割'));
await waitFor(() => expect(useStore.getState().selectedMaskIds).toEqual(['sam2-mask']));
fireEvent.click(screen.getByLabelText('删除选中候选'));
await waitFor(() => expect(useStore.getState().masks.map((mask) => mask.id)).toEqual(['workspace-mask']));
expect(useStore.getState().selectedMaskIds).toEqual([]);
});
it('lets Delete remove the selected AI candidate after a mask click selects it', async () => {
apiMock.predictMask.mockResolvedValueOnce({
masks: [
{
id: 'sam2-mask',
pathData: 'M 10 10 L 40 10 L 40 40 Z',
label: 'AI Mask',
color: '#06b6d4',
segmentation: [[10, 10, 40, 10, 40, 40]],
bbox: [10, 10, 30, 30],
area: 900,
},
],
});
render(<AISegmentation onSendToWorkspace={vi.fn()} />);
fireEvent.click(screen.getByText('正向选点'));
fireEvent.click(screen.getByTestId('konva-stage'));
fireEvent.click(await screen.findByText('执行高精度语义分割'));
await waitFor(() => expect(screen.getByTestId('konva-path')).toBeInTheDocument());
fireEvent.click(screen.getByText('视口控制'));
fireEvent.click(screen.getByTestId('konva-path'));
fireEvent.keyDown(window, { key: 'Delete' });
await waitFor(() => expect(useStore.getState().masks).toEqual([]));
});
it('lets a SAM2 result be selected and relabeled from the ontology panel', async () => {
useStore.setState({
templates: [
{
id: 'template-1',
name: '腹腔镜模板',
classes: [
{ id: 'class-1', name: '胆囊', color: '#ff0000', zIndex: 30 },
{ id: 'class-2', name: '肝脏', color: '#00ff00', zIndex: 20 },
],
rules: [],
},
],
});
apiMock.predictMask.mockResolvedValueOnce({
masks: [
{
id: 'sam2-mask',
pathData: 'M 10 10 L 40 10 L 40 40 Z',
label: 'AI Mask',
color: '#06b6d4',
segmentation: [[10, 10, 40, 10, 40, 40]],
bbox: [10, 10, 30, 30],
area: 900,
},
],
});
render(<AISegmentation onSendToWorkspace={vi.fn()} />);
fireEvent.click(screen.getByText('正向选点'));
fireEvent.click(screen.getByTestId('konva-stage'));
fireEvent.click(await screen.findByText('执行高精度语义分割'));
await waitFor(() => expect(useStore.getState().selectedMaskIds).toEqual(['sam2-mask']));
fireEvent.click(screen.getByText('肝脏'));
expect(useStore.getState().masks[0]).toEqual(expect.objectContaining({
templateId: 'template-1',
classId: 'class-2',
className: '肝脏',
classZIndex: 20,
label: '肝脏',
color: '#00ff00',
saveStatus: 'draft',
}));
});
it('keeps the generated SAM2 mask selected when sending it to the workspace editor', async () => {
const onSendToWorkspace = vi.fn();
apiMock.predictMask.mockResolvedValueOnce({
masks: [
{
id: 'sam2-mask',
pathData: 'M 10 10 L 40 10 L 40 40 Z',
label: 'AI Mask',
color: '#06b6d4',
segmentation: [[10, 10, 40, 10, 40, 40]],
bbox: [10, 10, 30, 30],
area: 900,
},
],
});
render(<AISegmentation onSendToWorkspace={onSendToWorkspace} />);
fireEvent.click(screen.getByText('正向选点'));
fireEvent.click(screen.getByTestId('konva-stage'));
fireEvent.click(await screen.findByText('执行高精度语义分割'));
await waitFor(() => expect(useStore.getState().selectedMaskIds).toEqual(['sam2-mask']));
fireEvent.click(screen.getByText('推送至工作区编辑'));
expect(useStore.getState().activeTool).toBe('edit_polygon');
expect(useStore.getState().selectedMaskIds).toEqual(['sam2-mask']);
expect(onSendToWorkspace).toHaveBeenCalled();
});
});