feat: 完善视频传播、标注编辑和拆帧闭环

- 接入 SAM2 视频传播能力:新增 /api/ai/propagate,支持用当前帧 mask/polygon/bbox 作为 seed,通过 SAM2 video predictor 向前、向后或双向传播,并可保存为真实 annotation。
- 接入 SAM3 video tracker:通过独立 Python 3.12 external worker 调用 SAM3 video predictor/tracker,使用本地 checkpoint 与 bbox seed 执行视频级跟踪,并在模型状态中标记 video_track 能力。
- 完善 SAM 模型分发:sam_registry 按 model_id 明确区分 sam2 propagation 与 sam3 video_track,避免两个模型链路混用。
- 打通前端“传播片段”:VideoWorkspace 使用当前选中 mask 和当前 AI 模型调用后端传播接口,传播结果回写并刷新工作区已保存标注。
- 增强 SAM3 本地 checkpoint 配置:新增 sam3_checkpoint_path 配置和 .env.example 示例,状态检查改为基于本地 checkpoint/独立环境/模型包可用性。
- 完善视频拆帧参数:/api/media/parse 支持 parse_fps、max_frames、target_width,后端任务保存帧时间戳、源帧号和 frame_sequence 元数据。
- 增加运行时 schema 兼容处理:启动时为旧 frames 表补充 timestamp_ms 和 source_frame_number 列,避免旧库升级后缺字段。
- 强化 Canvas 标注编辑:补齐多边形闭合、点工具、顶点拖拽、边中点插入、Delete/Backspace 删除、区域合并和重叠去除等交互。
- 增强语义分类联动:选中 mask 后可通过右侧语义分类树更新标签、颜色和 class metadata,并同步到保存/导出链路。
- 增加关键帧时间轴体验:FrameTimeline 显示具体时间信息,并支持键盘左右方向键切换关键帧。
- 完善 AI 交互分割参数:前端保留正向点、反向点、框选和 interactive prompt 的调用状态,支持 SAM2 细化候选区域与 SAM3 bbox 入口。
- 扩展后端/前端 API 类型:新增 propagateMasks、传播请求/响应 schema,并补齐 annotation、导出、模型状态和任务接口的测试覆盖。
- 更新项目文档:同步 README、AGENTS、接口契约、需求冻结、设计冻结、前端元素审计、实施计划和测试计划,标明真实功能边界与剩余风险。
- 增加测试覆盖:补充 SAM2/SAM3 传播、SAM3 状态、媒体拆帧参数、Canvas 编辑、语义标签切换、时间轴、工作区传播和 API 合约测试。
- 加强仓库安全边界:将 sam3权重/ 加入 .gitignore,避免本地模型权重被误提交。

验证:npm run test:run;pytest backend/tests;npm run lint;npm run build;python -m py_compile;git diff --check。
This commit is contained in:
2026-05-01 20:27:33 +08:00
parent 689a9ba283
commit 5ab4602535
43 changed files with 2722 additions and 216 deletions

View File

@@ -1,4 +1,4 @@
import { fireEvent, render, screen } from '@testing-library/react';
import { fireEvent, render, screen, waitFor } from '@testing-library/react';
import { beforeEach, describe, expect, it, vi } from 'vitest';
import { resetStore } from '../test/storeTestUtils';
import { useStore } from '../store/useStore';
@@ -62,4 +62,112 @@ describe('AISegmentation', () => {
},
}));
});
it('prompts for semantic text before running SAM3 inference', async () => {
apiMock.getAiModelStatus.mockResolvedValue({
selected_model: 'sam3',
gpu: { available: true, device: 'cuda', name: 'RTX 4090', torch_available: true },
models: [
{ id: 'sam2', label: 'SAM 2', available: true, loaded: false, device: 'cuda', supports: ['point', 'box'], message: 'SAM 2 ready', package_available: true, checkpoint_exists: true, python_ok: true, torch_ok: true, cuda_required: false },
{ id: 'sam3', label: 'SAM 3', available: true, loaded: true, device: 'cuda', supports: ['semantic'], message: 'SAM 3 ready', package_available: true, checkpoint_exists: true, python_ok: true, torch_ok: true, cuda_required: true },
],
});
render(<AISegmentation onSendToWorkspace={vi.fn()} />);
const sam3Button = (await screen.findByText('SAM3')).closest('button')!;
fireEvent.click(sam3Button);
fireEvent.click(screen.getByText('执行高精度语义分割'));
expect(apiMock.predictMask).not.toHaveBeenCalled();
expect(await screen.findByText('SAM3 当前使用文本语义提示,请先输入要分割的目标描述。')).toBeInTheDocument();
});
it('shows feedback when SAM3 semantic inference returns no masks', async () => {
apiMock.getAiModelStatus.mockResolvedValue({
selected_model: 'sam3',
gpu: { available: true, device: 'cuda', name: 'RTX 4090', torch_available: true },
models: [
{ id: 'sam2', label: 'SAM 2', available: true, loaded: false, device: 'cuda', supports: ['point', 'box'], message: 'SAM 2 ready', package_available: true, checkpoint_exists: true, python_ok: true, torch_ok: true, cuda_required: false },
{ id: 'sam3', label: 'SAM 3', available: true, loaded: true, device: 'cuda', supports: ['semantic'], message: 'SAM 3 ready', package_available: true, checkpoint_exists: true, python_ok: true, torch_ok: true, cuda_required: true },
],
});
apiMock.predictMask.mockResolvedValueOnce({ masks: [] });
render(<AISegmentation onSendToWorkspace={vi.fn()} />);
const sam3Button = (await screen.findByText('SAM3')).closest('button')!;
fireEvent.click(sam3Button);
fireEvent.change(screen.getByPlaceholderText("例如:'分割出左侧车道上行驶的所有红色汽车'..."), {
target: { value: '胆囊' },
});
fireEvent.click(screen.getByText('执行高精度语义分割'));
await waitFor(() => expect(apiMock.predictMask).toHaveBeenCalledWith(expect.objectContaining({
model: 'sam3',
points: undefined,
text: '胆囊',
})));
expect(await screen.findByText('模型没有返回可用区域,请换一个更具体的描述或调整提示。')).toBeInTheDocument();
});
it('runs SAM3 semantic text inference and assigns the active class to returned masks', async () => {
apiMock.getAiModelStatus.mockResolvedValue({
selected_model: 'sam3',
gpu: { available: true, device: 'cuda', name: 'RTX 4090', torch_available: true },
models: [
{ id: 'sam2', label: 'SAM 2', available: true, loaded: false, device: 'cuda', supports: ['point', 'box'], message: 'SAM 2 ready', package_available: true, checkpoint_exists: true, python_ok: true, torch_ok: true, cuda_required: false },
{ id: 'sam3', label: 'SAM 3', available: true, loaded: true, device: 'cuda', supports: ['semantic'], message: 'SAM 3 ready', package_available: true, checkpoint_exists: true, python_ok: true, torch_ok: true, cuda_required: true },
],
});
apiMock.predictMask.mockResolvedValueOnce({
masks: [
{
id: 'semantic-1',
pathData: 'M 10 10 L 40 10 L 40 40 Z',
label: 'semantic result',
color: '#06b6d4',
segmentation: [[10, 10, 40, 10, 40, 40]],
bbox: [10, 10, 30, 30],
area: 900,
},
],
});
useStore.setState({
activeTemplateId: 'template-1',
activeClassId: 'class-1',
activeClass: { id: 'class-1', name: '胆囊', color: '#ff0000', zIndex: 30 },
});
render(<AISegmentation onSendToWorkspace={vi.fn()} />);
const sam3Button = (await screen.findByText('SAM3')).closest('button')!;
fireEvent.click(sam3Button);
fireEvent.change(screen.getByPlaceholderText("例如:'分割出左侧车道上行驶的所有红色汽车'..."), {
target: { value: '胆囊' },
});
fireEvent.click(screen.getByText('执行高精度语义分割'));
await waitFor(() => expect(apiMock.predictMask).toHaveBeenCalledWith(expect.objectContaining({
imageId: 'frame-1',
imageWidth: 640,
imageHeight: 360,
model: 'sam3',
points: undefined,
text: '胆囊',
options: {
crop_to_prompt: false,
auto_filter_background: true,
min_score: 0.05,
},
})));
expect(useStore.getState().masks[0]).toEqual(expect.objectContaining({
id: 'semantic-1',
frameId: 'frame-1',
templateId: 'template-1',
classId: 'class-1',
className: '胆囊',
classZIndex: 30,
label: '胆囊',
color: '#ff0000',
saveStatus: 'draft',
}));
});
});

View File

@@ -33,6 +33,7 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
const [autoDeleteBg, setAutoDeleteBg] = useState(true);
const [cropMode, setCropMode] = useState(false);
const [isInferencing, setIsInferencing] = useState(false);
const [inferenceMessage, setInferenceMessage] = useState('');
// Canvas state
const [scale, setScale] = useState(1);
@@ -91,9 +92,18 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
};
const runInference = useCallback(async () => {
if (points.length === 0 && !semanticText.trim()) return;
const textPrompt = semanticText.trim();
if (aiModel === 'sam3' && !textPrompt) {
setInferenceMessage('SAM3 当前使用文本语义提示,请先输入要分割的目标描述。');
return;
}
if (points.length === 0 && !textPrompt) {
setInferenceMessage('请先放置正/反向提示点,或输入语义描述。');
return;
}
if (!currentFrame?.id) {
console.warn('AI inference skipped: no project frame is selected');
setInferenceMessage('请先在项目工作区选择一帧图像。');
return;
}
@@ -101,18 +111,20 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
const imageHeight = currentFrame.height || image?.naturalHeight || image?.height || 0;
if (imageWidth <= 0 || imageHeight <= 0) {
console.warn('AI inference skipped: active frame dimensions are unavailable');
setInferenceMessage('当前帧缺少宽高信息,无法推理。');
return;
}
setIsInferencing(true);
setInferenceMessage('');
try {
const result = await predictMask({
imageId: currentFrame.id,
imageWidth,
imageHeight,
model: aiModel,
points: points.map((p) => ({ x: p.x, y: p.y, type: p.type })),
text: semanticText.trim() || undefined,
points: aiModel === 'sam3' ? undefined : points.map((p) => ({ x: p.x, y: p.y, type: p.type })),
text: textPrompt || undefined,
options: {
crop_to_prompt: cropMode,
auto_filter_background: autoDeleteBg,
@@ -120,6 +132,11 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
},
});
if (result.masks.length === 0) {
setInferenceMessage('模型没有返回可用区域,请换一个更具体的描述或调整提示。');
} else {
setInferenceMessage(`已生成 ${result.masks.length} 个候选区域。`);
}
result.masks.forEach((m) => {
const label = activeClass?.name || m.label;
const color = activeClass?.color || m.color;
@@ -142,6 +159,8 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
});
} catch (err) {
console.error('AI inference failed:', err);
const detail = (err as any)?.response?.data?.detail;
setInferenceMessage(detail || 'AI 推理失败,请查看模型状态或后端日志。');
} finally {
setIsInferencing(false);
}
@@ -282,6 +301,11 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
{isInferencing ? <Loader2 size={16} className="animate-spin" /> : <Sparkles size={16} />}
{isInferencing ? '推理中...' : modelCanInfer ? '执行高精度语义分割' : '当前模型不可用'}
</button>
{inferenceMessage && (
<div className="rounded border border-white/10 bg-white/5 px-3 py-2 text-[11px] leading-relaxed text-gray-300">
{inferenceMessage}
</div>
)}
<button
onClick={onSendToWorkspace}
className="w-full py-3.5 rounded-lg flex items-center justify-center gap-2 transition-all font-medium tracking-wide text-xs uppercase bg-white/5 hover:bg-white/10 text-gray-300 border border-white/5 hover:border-white/10"

View File

@@ -65,6 +65,157 @@ describe('CanvasArea', () => {
}));
});
it('explains that SAM3 point prompts are not supported in the workspace', async () => {
useStore.setState({ aiModel: 'sam3' });
render(<CanvasArea activeTool="point_pos" frame={frame} />);
fireEvent.click(screen.getByTestId('konva-stage'));
expect(apiMock.predictMask).not.toHaveBeenCalled();
expect(await screen.findByText(/SAM3 当前工作区只支持框选提示/)).toBeInTheDocument();
});
it('calls SAM3 prediction with a box prompt from the workspace', async () => {
useStore.setState({ aiModel: 'sam3' });
apiMock.predictMask.mockResolvedValueOnce({
masks: [
{
id: 'sam3-box-mask',
pathData: 'M 20 20 L 80 20 L 80 80 Z',
label: 'AI Mask',
color: '#06b6d4',
segmentation: [[20, 20, 80, 20, 80, 80]],
bbox: [20, 20, 60, 60],
area: 3600,
},
],
});
render(<CanvasArea activeTool="box_select" frame={frame} />);
const stage = screen.getByTestId('konva-stage');
fireEvent.mouseDown(stage, { clientX: 120, clientY: 80 });
fireEvent.mouseMove(stage, { clientX: 260, clientY: 200 });
fireEvent.mouseUp(stage, { clientX: 260, clientY: 200 });
await waitFor(() => expect(apiMock.predictMask).toHaveBeenCalledWith({
imageId: 'frame-1',
imageWidth: 640,
imageHeight: 360,
model: 'sam3',
points: undefined,
box: { x1: 120, y1: 80, x2: 260, y2: 200 },
}));
expect(useStore.getState().masks[0]).toEqual(expect.objectContaining({
id: 'sam3-box-mask',
metadata: expect.objectContaining({
source: 'sam3_box',
promptBox: { x1: 120, y1: 80, x2: 260, y2: 200 },
}),
}));
});
it('refines one SAM2 candidate mask from an initial box with positive and negative points', async () => {
apiMock.predictMask
.mockResolvedValueOnce({
masks: [
{
id: 'mask-box',
pathData: 'M 10 10 L 90 10 L 90 90 Z',
label: 'AI Mask',
color: '#06b6d4',
segmentation: [[10, 10, 90, 10, 90, 90]],
bbox: [10, 10, 80, 80],
area: 6400,
},
],
})
.mockResolvedValueOnce({
masks: [
{
id: 'mask-refined-pos',
pathData: 'M 20 20 L 80 20 L 80 80 Z',
label: 'AI Mask',
color: '#06b6d4',
segmentation: [[20, 20, 80, 20, 80, 80]],
bbox: [20, 20, 60, 60],
area: 3600,
},
],
})
.mockResolvedValueOnce({
masks: [
{
id: 'mask-refined-neg',
pathData: 'M 30 30 L 70 30 L 70 70 Z',
label: 'AI Mask',
color: '#06b6d4',
segmentation: [[30, 30, 70, 30, 70, 70]],
bbox: [30, 30, 40, 40],
area: 1600,
},
],
});
const { rerender } = render(<CanvasArea activeTool="box_select" frame={frame} />);
const stage = screen.getByTestId('konva-stage');
fireEvent.mouseDown(stage, { clientX: 120, clientY: 80 });
fireEvent.mouseMove(stage, { clientX: 260, clientY: 200 });
fireEvent.mouseUp(stage, { clientX: 260, clientY: 200 });
await waitFor(() => expect(apiMock.predictMask).toHaveBeenNthCalledWith(1, {
imageId: 'frame-1',
imageWidth: 640,
imageHeight: 360,
model: 'sam2',
points: undefined,
box: { x1: 120, y1: 80, x2: 260, y2: 200 },
}));
await waitFor(() => expect(useStore.getState().masks).toHaveLength(1));
rerender(<CanvasArea activeTool="point_pos" frame={frame} />);
fireEvent.click(stage, { clientX: 150, clientY: 100 });
await waitFor(() => expect(apiMock.predictMask).toHaveBeenNthCalledWith(2, {
imageId: 'frame-1',
imageWidth: 640,
imageHeight: 360,
model: 'sam2',
points: [{ x: 150, y: 100, type: 'pos' }],
box: { x1: 120, y1: 80, x2: 260, y2: 200 },
}));
expect(useStore.getState().masks).toHaveLength(1);
expect(useStore.getState().masks[0]).toEqual(expect.objectContaining({
id: 'mask-box',
segmentation: [[20, 20, 80, 20, 80, 80]],
metadata: expect.objectContaining({
source: 'sam2_interactive',
promptPointCount: 1,
}),
}));
rerender(<CanvasArea activeTool="point_neg" frame={frame} />);
fireEvent.click(stage, { clientX: 300, clientY: 150 });
await waitFor(() => expect(apiMock.predictMask).toHaveBeenNthCalledWith(3, {
imageId: 'frame-1',
imageWidth: 640,
imageHeight: 360,
model: 'sam2',
points: [
{ x: 150, y: 100, type: 'pos' },
{ x: 300, y: 150, type: 'neg' },
],
box: { x1: 120, y1: 80, x2: 260, y2: 200 },
}));
expect(useStore.getState().masks).toHaveLength(1);
expect(useStore.getState().masks[0]).toEqual(expect.objectContaining({
id: 'mask-box',
segmentation: [[30, 30, 70, 30, 70, 70]],
points: [[150, 100]],
metadata: expect.objectContaining({ promptPointCount: 2 }),
}));
});
it('renders only masks that belong to the current frame', () => {
useStore.setState({
masks: [
@@ -79,6 +230,26 @@ describe('CanvasArea', () => {
expect(screen.getByText('遮罩数: 1')).toBeInTheDocument();
});
it('publishes the selected mask ids for the ontology panel', async () => {
useStore.setState({
masks: [
{
id: 'm1',
frameId: 'frame-1',
pathData: 'M 0 0 L 10 0 L 10 10 Z',
label: 'A',
color: '#fff',
segmentation: [[0, 0, 10, 0, 10, 10]],
},
],
});
render(<CanvasArea activeTool="move" frame={frame} />);
fireEvent.click(screen.getByTestId('konva-path'));
await waitFor(() => expect(useStore.getState().selectedMaskIds).toEqual(['m1']));
});
it('renders imported GT seed points for editable point regions', () => {
useStore.setState({
masks: [
@@ -164,6 +335,57 @@ describe('CanvasArea', () => {
}));
});
it('deletes the selected draft mask with Delete when no vertex is selected', () => {
useStore.setState({
masks: [
{
id: 'draft-1',
frameId: 'frame-1',
pathData: 'M 10 10 L 90 10 L 90 40 Z',
label: 'Draft',
color: '#06b6d4',
saveStatus: 'draft',
segmentation: [[10, 10, 90, 10, 90, 40]],
},
],
});
render(<CanvasArea activeTool="move" frame={frame} />);
fireEvent.click(screen.getByTestId('konva-path'));
fireEvent.keyDown(window, { key: 'Delete' });
expect(useStore.getState().masks).toEqual([]);
expect(useStore.getState().maskHistory.at(-1)).toEqual([
expect.objectContaining({ id: 'draft-1' }),
]);
});
it('deletes the selected saved mask locally and notifies the backend deletion callback', () => {
const onDeleteMaskAnnotations = vi.fn();
useStore.setState({
masks: [
{
id: 'annotation-99',
annotationId: '99',
frameId: 'frame-1',
pathData: 'M 10 10 L 90 10 L 90 40 Z',
label: 'Saved',
color: '#06b6d4',
saveStatus: 'saved',
saved: true,
segmentation: [[10, 10, 90, 10, 90, 40]],
},
],
});
render(<CanvasArea activeTool="move" frame={frame} onDeleteMaskAnnotations={onDeleteMaskAnnotations} />);
fireEvent.click(screen.getByTestId('konva-path'));
fireEvent.keyDown(window, { key: 'Backspace' });
expect(useStore.getState().masks).toEqual([]);
expect(onDeleteMaskAnnotations).toHaveBeenCalledWith(['99']);
});
it('inserts a polygon vertex from an edge midpoint handle', () => {
useStore.setState({
masks: [
@@ -248,9 +470,13 @@ describe('CanvasArea', () => {
});
render(<CanvasArea activeTool="area_merge" frame={frame} />);
expect(screen.getByText('已选 0')).toBeInTheDocument();
const paths = screen.getAllByTestId('konva-path');
fireEvent.click(paths[0]);
expect(screen.getByText('已选 1')).toBeInTheDocument();
expect(screen.queryAllByTestId('konva-circle')).toHaveLength(0);
fireEvent.click(paths[1]);
expect(screen.getByText('已选 2')).toBeInTheDocument();
fireEvent.click(screen.getByRole('button', { name: '合并选中' }));
expect(useStore.getState().masks).toHaveLength(1);
@@ -300,6 +526,45 @@ describe('CanvasArea', () => {
expect(useStore.getState().masks[1].id).toBe('m2');
});
it('renders inner overlap removal as a hole in the primary mask', () => {
useStore.setState({
masks: [
{
id: 'm1',
frameId: 'frame-1',
pathData: 'M 10 10 L 110 10 L 110 110 L 10 110 Z',
label: 'A',
color: '#06b6d4',
segmentation: [[10, 10, 110, 10, 110, 110, 10, 110]],
},
{
id: 'm2',
frameId: 'frame-1',
pathData: 'M 40 40 L 80 40 L 80 80 L 40 80 Z',
label: 'B',
color: '#ff0000',
segmentation: [[40, 40, 80, 40, 80, 80, 40, 80]],
},
],
});
render(<CanvasArea activeTool="area_remove" frame={frame} />);
const paths = screen.getAllByTestId('konva-path');
fireEvent.click(paths[0]);
fireEvent.click(paths[1]);
fireEvent.click(screen.getByRole('button', { name: '从主区域去除' }));
const [primary] = useStore.getState().masks;
expect(primary).toEqual(expect.objectContaining({
id: 'm1',
area: 8400,
bbox: [10, 10, 100, 100],
metadata: expect.objectContaining({ hasHoles: true }),
}));
expect(primary.segmentation).toHaveLength(2);
expect(screen.getAllByTestId('konva-path')[0]).toHaveAttribute('data-fill-rule', 'evenodd');
});
it('creates a manual rectangle mask that can be undone and redone', () => {
useStore.setState({
activeTemplateId: '2',
@@ -329,6 +594,93 @@ describe('CanvasArea', () => {
expect(useStore.getState().masks).toHaveLength(1);
});
it('creates a manual circle mask from a drag gesture', () => {
render(<CanvasArea activeTool="create_circle" frame={frame} />);
const stage = screen.getByTestId('konva-stage');
fireEvent.mouseDown(stage, { clientX: 120, clientY: 80 });
fireEvent.mouseMove(stage, { clientX: 260, clientY: 200 });
fireEvent.mouseUp(stage, { clientX: 260, clientY: 200 });
expect(useStore.getState().masks).toHaveLength(1);
expect(useStore.getState().masks[0]).toEqual(expect.objectContaining({
frameId: 'frame-1',
label: '手工圆形',
color: '#06b6d4',
saveStatus: 'draft',
bbox: [120, 80, 140, 120],
metadata: expect.objectContaining({
source: 'manual',
shape: '圆形',
}),
}));
expect(useStore.getState().masks[0].segmentation?.[0]).toHaveLength(64);
});
it('creates a manual line region from a drag gesture', () => {
render(<CanvasArea activeTool="create_line" frame={frame} />);
const stage = screen.getByTestId('konva-stage');
fireEvent.mouseDown(stage, { clientX: 120, clientY: 80 });
fireEvent.mouseMove(stage, { clientX: 260, clientY: 200 });
fireEvent.mouseUp(stage, { clientX: 260, clientY: 200 });
expect(useStore.getState().masks).toHaveLength(1);
expect(useStore.getState().masks[0]).toEqual(expect.objectContaining({
frameId: 'frame-1',
label: '手工线段',
color: '#06b6d4',
saveStatus: 'draft',
metadata: expect.objectContaining({
source: 'manual',
shape: '线段',
}),
}));
expect(useStore.getState().masks[0].segmentation?.[0]).toHaveLength(8);
expect(useStore.getState().masks[0].area).toBeGreaterThan(1000);
});
it('creates an editable point region on click', () => {
render(<CanvasArea activeTool="create_point" frame={frame} />);
fireEvent.click(screen.getByTestId('konva-stage'), { clientX: 120, clientY: 80 });
expect(useStore.getState().masks).toHaveLength(1);
expect(useStore.getState().masks[0]).toEqual(expect.objectContaining({
frameId: 'frame-1',
label: '手工点区域',
color: '#06b6d4',
saveStatus: 'draft',
points: [[120, 80]],
bbox: expect.arrayContaining([115, 75]),
metadata: expect.objectContaining({
source: 'manual',
shape: '点区域',
}),
}));
});
it('creates a point region when clicking over an existing mask', () => {
useStore.setState({
masks: [
{
id: 'm1',
frameId: 'frame-1',
pathData: 'M 10 10 L 200 10 L 200 200 Z',
label: 'Existing',
color: '#06b6d4',
segmentation: [[10, 10, 200, 10, 200, 200]],
},
],
});
render(<CanvasArea activeTool="create_point" frame={frame} />);
fireEvent.click(screen.getByTestId('konva-path'), { clientX: 120, clientY: 80 });
expect(useStore.getState().masks).toHaveLength(2);
expect(useStore.getState().masks[1]).toEqual(expect.objectContaining({
metadata: expect.objectContaining({ shape: '点区域' }),
points: [[120, 80]],
}));
});
it('finalizes a clicked polygon with Enter', () => {
render(<CanvasArea activeTool="create_polygon" frame={frame} />);
const stage = screen.getByTestId('konva-stage');
@@ -344,6 +696,29 @@ describe('CanvasArea', () => {
}));
});
it('closes a clicked polygon by clicking the first node again', () => {
render(<CanvasArea activeTool="create_polygon" frame={frame} />);
const stage = screen.getByTestId('konva-stage');
fireEvent.click(stage, { clientX: 120, clientY: 80 });
fireEvent.click(stage, { clientX: 220, clientY: 80 });
fireEvent.click(stage, { clientX: 180, clientY: 160 });
const handles = screen.getAllByTestId('konva-circle');
expect(handles[0]).toHaveAttribute('data-fill', '#facc15');
fireEvent.click(handles[0]);
expect(useStore.getState().masks).toHaveLength(1);
expect(useStore.getState().masks[0]).toEqual(expect.objectContaining({
pathData: 'M 120 80 L 220 80 L 180 160 Z',
segmentation: [[120, 80, 220, 80, 180, 160]],
metadata: expect.objectContaining({
source: 'manual',
shape: '多边形',
}),
}));
expect(screen.queryAllByTestId('konva-circle')).toHaveLength(0);
});
it('applies the selected class to current-frame masks and marks saved masks dirty', () => {
useStore.setState({
activeTemplateId: '2',

View File

@@ -14,11 +14,14 @@ interface CanvasAreaProps {
}
type CanvasPoint = { x: number; y: number };
type PromptPoint = CanvasPoint & { type: 'pos' | 'neg' };
type PromptBox = { x1: number; y1: number; x2: number; y2: number };
const DRAG_MANUAL_TOOLS = new Set(['create_rectangle', 'create_circle', 'create_line']);
const POLYGON_TOOL = 'create_polygon';
const POINT_TOOL = 'create_point';
const BOOLEAN_TOOLS = new Set(['area_merge', 'area_remove']);
const POLYGON_CLOSE_RADIUS = 8;
function clamp(value: number, min: number, max: number): number {
return Math.min(Math.max(value, min), max);
@@ -88,6 +91,10 @@ function polygonArea(points: CanvasPoint[]): number {
return Math.abs(sum) / 2;
}
function pointDistance(a: CanvasPoint, b: CanvasPoint): number {
return Math.hypot(a.x - b.x, a.y - b.y);
}
function segmentationArea(segmentation?: number[][]): number {
return (segmentation || []).reduce((sum, polygon) => sum + polygonArea(flatPolygonToPoints(polygon)), 0);
}
@@ -115,20 +122,35 @@ function maskToMultiPolygon(mask: Mask): MultiPolygon | null {
return polygons.length > 0 ? polygons : null;
}
function openRingPoints(ring: Pair[]): CanvasPoint[] {
const openRing = ring.length > 1
&& ring[0][0] === ring[ring.length - 1][0]
&& ring[0][1] === ring[ring.length - 1][1]
? ring.slice(0, -1)
: ring;
return openRing.map(([x, y]) => ({ x, y }));
}
function multiPolygonToSegmentation(geometry: MultiPolygon): number[][] {
return geometry
.map((polygon) => polygon[0] || [])
.map((ring) => {
const openRing = ring.length > 1
&& ring[0][0] === ring[ring.length - 1][0]
&& ring[0][1] === ring[ring.length - 1][1]
? ring.slice(0, -1)
: ring;
return openRing.flatMap(([x, y]) => [x, y]);
})
.flatMap((polygon) => polygon)
.map((ring) => openRingPoints(ring).flatMap(({ x, y }) => [x, y]))
.filter((polygon) => polygon.length >= 6);
}
function multiPolygonArea(geometry: MultiPolygon): number {
return geometry.reduce((sum, polygon) => {
const [outerRing, ...holeRings] = polygon;
const outerArea = outerRing ? polygonArea(openRingPoints(outerRing)) : 0;
const holesArea = holeRings.reduce((holeSum, ring) => holeSum + polygonArea(openRingPoints(ring)), 0);
return sum + Math.max(outerArea - holesArea, 0);
}, 0);
}
function multiPolygonHasHoles(geometry: MultiPolygon): boolean {
return geometry.some((polygon) => polygon.length > 1);
}
function rectanglePoints(start: CanvasPoint, end: CanvasPoint): CanvasPoint[] {
const x1 = Math.min(start.x, end.x);
const y1 = Math.min(start.y, end.y);
@@ -179,10 +201,12 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
const [stageSize, setStageSize] = useState({ width: 800, height: 600 });
const [scale, setScale] = useState(1);
const [position, setPosition] = useState({ x: 0, y: 0 });
const [points, setPoints] = useState<{ x: number, y: number, type: 'pos'|'neg' }[]>([]);
const [points, setPoints] = useState<PromptPoint[]>([]);
const [cursorPos, setCursorPos] = useState({ x: 0, y: 0 });
const [boxStart, setBoxStart] = useState<{ x: number, y: number } | null>(null);
const [boxCurrent, setBoxCurrent] = useState<{ x: number, y: number } | null>(null);
const [samPromptBox, setSamPromptBox] = useState<PromptBox | null>(null);
const [samCandidateMaskId, setSamCandidateMaskId] = useState<string | null>(null);
const [manualStart, setManualStart] = useState<CanvasPoint | null>(null);
const [manualCurrent, setManualCurrent] = useState<CanvasPoint | null>(null);
const [polygonPoints, setPolygonPoints] = useState<CanvasPoint[]>([]);
@@ -191,12 +215,14 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
const [selectedPolygonIndex, setSelectedPolygonIndex] = useState(0);
const [selectedVertexIndex, setSelectedVertexIndex] = useState<number | null>(null);
const [isInferencing, setIsInferencing] = useState(false);
const [inferenceMessage, setInferenceMessage] = useState('');
const masks = useStore((state) => state.masks);
const addMask = useStore((state) => state.addMask);
const updateMask = useStore((state) => state.updateMask);
const clearMasks = useStore((state) => state.clearMasks);
const setMasks = useStore((state) => state.setMasks);
const setGlobalSelectedMaskIds = useStore((state) => state.setSelectedMaskIds);
const storeActiveTool = useStore((state) => state.activeTool);
const aiModel = useStore((state) => state.aiModel);
const activeTemplateId = useStore((state) => state.activeTemplateId);
@@ -226,6 +252,7 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
const savedMaskCount = frameMasks.filter((mask) => mask.saveStatus === 'saved' || mask.saved).length;
const draftMaskCount = frameMasks.filter((mask) => !mask.annotationId).length;
const dirtyMaskCount = frameMasks.filter((mask) => mask.saveStatus === 'dirty').length;
const isBooleanTool = BOOLEAN_TOOLS.has(effectiveTool);
useEffect(() => {
const handleResize = () => {
@@ -252,6 +279,18 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
setSelectedVertexIndex(null);
}, [effectiveTool, frame?.id]);
useEffect(() => {
setPoints([]);
setSamPromptBox(null);
setSamCandidateMaskId(null);
}, [frame?.id]);
useEffect(() => {
setGlobalSelectedMaskIds(selectedMaskIds);
}, [selectedMaskIds, setGlobalSelectedMaskIds]);
useEffect(() => () => setGlobalSelectedMaskIds([]), [setGlobalSelectedMaskIds]);
useEffect(() => {
if (selectedMaskId && !frameMasks.some((mask) => mask.id === selectedMaskId)) {
setSelectedMaskId(null);
@@ -324,6 +363,12 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
addMask(mask);
}, [activeClass, activeTemplateId, addMask, frame?.id]);
const finishPolygon = useCallback(() => {
if (polygonPoints.length < 3) return;
createManualMask('多边形', polygonPoints);
setPolygonPoints([]);
}, [createManualMask, polygonPoints]);
const handleMouseMove = (e: any) => {
const stage = e.target.getStage();
if (!stage) return;
@@ -349,9 +394,18 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
}
};
const runInference = useCallback(async (promptPoints?: typeof points, promptBox?: { x1: number, y1: number, x2: number, y2: number }) => {
const runInference = useCallback(async (
promptPoints?: PromptPoint[],
promptBox?: PromptBox,
options: { resetCandidate?: boolean } = {},
) => {
if (!frame?.id) {
console.warn('Inference skipped: no active frame');
setInferenceMessage('请先选择一帧图像。');
return;
}
if (aiModel === 'sam3' && (!promptBox || (promptPoints?.length ?? 0) > 0)) {
setInferenceMessage('SAM3 当前工作区只支持框选提示;正/反点修正请切回 SAM2。');
return;
}
@@ -359,31 +413,44 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
const imageHeight = frame.height || image?.naturalHeight || image?.height || 0;
if (imageWidth <= 0 || imageHeight <= 0) {
console.warn('Inference skipped: active frame dimensions are unavailable');
setInferenceMessage('当前帧缺少宽高信息,无法推理。');
return;
}
setIsInferencing(true);
setInferenceMessage('');
try {
const result = await predictMask({
imageId: frame.id,
imageWidth,
imageHeight,
model: aiModel,
points: promptPoints?.map((p) => ({ x: p.x, y: p.y, type: p.type })),
points: promptPoints && promptPoints.length > 0
? promptPoints.map((p) => ({ x: p.x, y: p.y, type: p.type }))
: undefined,
box: promptBox,
});
result.masks.forEach((m) => {
const label = activeClass?.name || m.label;
const color = activeClass?.color || m.color;
addMask({
id: m.id,
const [m] = result.masks;
if (m) {
const existingCandidate = !options.resetCandidate && samCandidateMaskId
? masks.find((mask) => mask.id === samCandidateMaskId)
: null;
const label = activeClass?.name || existingCandidate?.label || m.label;
const color = activeClass?.color || existingCandidate?.color || m.color;
const metadata = {
...(existingCandidate?.metadata || {}),
source: aiModel === 'sam3' ? 'sam3_box' : 'sam2_interactive',
promptBox: promptBox || null,
promptPointCount: promptPoints?.length || 0,
};
const nextMask = {
frameId: frame.id,
templateId: activeTemplateId || undefined,
classId: activeClass?.id,
className: activeClass?.name,
classZIndex: activeClass?.zIndex,
saveStatus: 'draft',
templateId: activeTemplateId || existingCandidate?.templateId || undefined,
classId: activeClass?.id || existingCandidate?.classId,
className: activeClass?.name || existingCandidate?.className,
classZIndex: activeClass?.zIndex ?? existingCandidate?.classZIndex,
saveStatus: existingCandidate?.annotationId ? 'dirty' as const : 'draft' as const,
saved: false,
pathData: m.pathData,
label,
@@ -392,14 +459,33 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
points: promptPoints?.filter((p) => p.type === 'pos').map((p) => [p.x, p.y]),
bbox: m.bbox,
area: m.area,
});
});
metadata,
};
if (existingCandidate) {
updateMask(existingCandidate.id, nextMask);
setSelectedMaskId(existingCandidate.id);
setSelectedMaskIds([existingCandidate.id]);
} else {
const id = m.id;
setSamCandidateMaskId(id);
setSelectedMaskId(id);
setSelectedMaskIds([id]);
addMask({
id,
...nextMask,
});
}
} else {
setInferenceMessage('模型没有返回可用区域,请调整点/框提示后重试。');
}
} catch (err) {
console.error('Inference failed:', err);
const detail = (err as any)?.response?.data?.detail;
setInferenceMessage(detail || 'AI 推理失败,请查看模型状态或后端日志。');
} finally {
setIsInferencing(false);
}
}, [activeClass, activeTemplateId, addMask, aiModel, frame?.height, frame?.id, frame?.width, image?.height, image?.naturalHeight, image?.naturalWidth, image?.width]);
}, [activeClass, activeTemplateId, addMask, aiModel, frame?.height, frame?.id, frame?.width, image?.height, image?.naturalHeight, image?.naturalWidth, image?.width, masks, samCandidateMaskId, updateMask]);
const handleApplyActiveClass = () => {
if (!frame?.id || !activeClass) return;
@@ -427,6 +513,29 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
clearMasks();
};
const deleteMasksById = useCallback((maskIds: string[]) => {
if (maskIds.length === 0) return;
const idSet = new Set(maskIds);
const deletingMasks = masks.filter((mask) => idSet.has(mask.id));
if (deletingMasks.length === 0) return;
setMasks(masks.filter((mask) => !idSet.has(mask.id)));
const annotationIds = deletingMasks
.map((mask) => mask.annotationId)
.filter((annotationId): annotationId is string => Boolean(annotationId));
if (annotationIds.length > 0) {
void onDeleteMaskAnnotations?.(annotationIds);
}
if (samCandidateMaskId && idSet.has(samCandidateMaskId)) {
setSamCandidateMaskId(null);
setSamPromptBox(null);
setPoints([]);
}
setSelectedMaskId(null);
setSelectedMaskIds([]);
setSelectedPolygonIndex(0);
setSelectedVertexIndex(null);
}, [masks, onDeleteMaskAnnotations, samCandidateMaskId, setMasks]);
const handleStageMouseDown = (e: any) => {
if (DRAG_MANUAL_TOOLS.has(effectiveTool)) {
const pos = stagePoint(e);
@@ -476,7 +585,11 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
const y2 = Math.max(boxStart.y, boxCurrent.y);
if (Math.abs(x2 - x1) > 5 && Math.abs(y2 - y1) > 5) {
runInference(undefined, { x1, y1, x2, y2 });
const nextBox = { x1, y1, x2, y2 };
setPoints([]);
setSamPromptBox(nextBox);
setSamCandidateMaskId(null);
runInference([], nextBox, { resetCandidate: true });
}
setBoxStart(null);
@@ -500,6 +613,11 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
if (effectiveTool === POLYGON_TOOL) {
const pos = stagePoint(e);
if (pos) {
const closeRadius = POLYGON_CLOSE_RADIUS / Math.max(scale, 0.1);
if (polygonPoints.length >= 3 && pointDistance(pos, polygonPoints[0]) <= closeRadius) {
finishPolygon();
return;
}
setPolygonPoints((current) => [...current, pos]);
}
return;
@@ -514,8 +632,7 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
{ x: pos.x, y: pos.y, type: (effectiveTool === 'point_pos' ? 'pos' : 'neg') as 'pos' | 'neg' },
];
setPoints(newPoints);
// Auto-trigger inference after point selection
runInference(newPoints);
runInference(newPoints, samPromptBox || undefined);
}
}
};
@@ -535,14 +652,22 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
});
}, [updateMask]);
const updateMaskFromSegmentation = useCallback((mask: Mask, segmentation: number[][]): Mask => {
const updateMaskFromSegmentation = useCallback((
mask: Mask,
segmentation: number[][],
options: { area?: number; hasHoles?: boolean } = {},
): Mask => {
const bbox = segmentationBbox(segmentation);
const metadata = { ...(mask.metadata || {}) };
if (options.hasHoles === true) metadata.hasHoles = true;
if (options.hasHoles === false) delete metadata.hasHoles;
return {
...mask,
pathData: segmentationPath(segmentation),
segmentation,
bbox,
area: segmentationArea(segmentation),
area: options.area ?? segmentationArea(segmentation),
metadata,
saveStatus: mask.annotationId ? 'dirty' : 'draft',
saved: mask.annotationId ? false : mask.saved,
};
@@ -572,11 +697,16 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
}
return;
}
if ((event.key === 'Delete' || event.key === 'Backspace') && selectedMask) {
event.preventDefault();
const ids = selectedMaskIds.length > 0 ? selectedMaskIds : [selectedMask.id];
deleteMasksById(ids);
return;
}
if (effectiveTool !== POLYGON_TOOL) return;
if (event.key === 'Enter' && polygonPoints.length >= 3) {
event.preventDefault();
createManualMask('多边形', polygonPoints);
setPolygonPoints([]);
finishPolygon();
}
if (event.key === 'Escape') {
event.preventDefault();
@@ -586,7 +716,7 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
window.addEventListener('keydown', handleKeyDown);
return () => window.removeEventListener('keydown', handleKeyDown);
}, [createManualMask, effectiveTool, polygonPoints, redoMasks, selectedMask, selectedPolygonIndex, selectedVertexIndex, undoMasks, updatePolygonMask]);
}, [deleteMasksById, effectiveTool, finishPolygon, polygonPoints, redoMasks, selectedMask, selectedMaskIds, selectedPolygonIndex, selectedVertexIndex, undoMasks, updatePolygonMask]);
const boxRect = React.useMemo(() => {
if (!boxStart || !boxCurrent) return null;
@@ -623,8 +753,9 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
};
const handleMaskSelect = (mask: Mask, event: any, polygonIndex = 0) => {
if (effectiveTool !== 'move' && !isBooleanTool) return;
event.cancelBubble = true;
if (BOOLEAN_TOOLS.has(effectiveTool)) {
if (isBooleanTool) {
setSelectedMaskIds((current) => (
current.includes(mask.id)
? current.filter((id) => id !== mask.id)
@@ -703,7 +834,10 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
return;
}
const nextPrimary = updateMaskFromSegmentation(primary, resultSegmentation);
const nextPrimary = updateMaskFromSegmentation(primary, resultSegmentation, {
area: multiPolygonArea(resultGeometry),
hasHoles: multiPolygonHasHoles(resultGeometry),
});
const secondaryIds = effectiveTool === 'area_merge'
? new Set(booleanSelectedMasks.slice(1).map((mask) => mask.id))
: new Set<string>();
@@ -731,6 +865,11 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
<span className="text-xs text-cyan-400 font-mono">AI ...</span>
</div>
)}
{!isInferencing && inferenceMessage && (
<div className="absolute top-4 right-4 z-20 max-w-xs bg-[#111] border border-white/10 px-3 py-2 rounded-lg shadow-xl text-xs leading-relaxed text-gray-300">
{inferenceMessage}
</div>
)}
<Stage
width={stageSize.width}
@@ -758,21 +897,32 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
)}
{/* AI Returned Masks */}
{frameMasks.map((mask) => (
<Group key={mask.id} opacity={selectedMaskIds.includes(mask.id) ? 0.65 : 0.5}>
{(mask.segmentation && mask.segmentation.length > 0 ? mask.segmentation : [undefined]).map((_, polygonIndex) => (
<Path
key={`${mask.id}-polygon-${polygonIndex}`}
data={mask.segmentation ? segmentationPolygonPath(mask.segmentation, polygonIndex) : mask.pathData}
fill={mask.color}
stroke={mask.color}
strokeWidth={(selectedMaskIds.includes(mask.id) ? 2 : 1) / scale}
onClick={(event: any) => handleMaskSelect(mask, event, polygonIndex)}
onTap={(event: any) => handleMaskSelect(mask, event, polygonIndex)}
/>
))}
</Group>
))}
{frameMasks.map((mask) => {
const hasHoles = Boolean(mask.metadata?.hasHoles);
const paths = hasHoles
? [{ data: segmentationPath(mask.segmentation), polygonIndex: 0, fillRule: 'evenodd' }]
: (mask.segmentation && mask.segmentation.length > 0 ? mask.segmentation : [undefined]).map((_, polygonIndex) => ({
data: mask.segmentation ? segmentationPolygonPath(mask.segmentation, polygonIndex) : mask.pathData,
polygonIndex,
fillRule: undefined,
}));
return (
<Group key={mask.id} opacity={selectedMaskIds.includes(mask.id) ? 0.65 : 0.5}>
{paths.map(({ data, polygonIndex, fillRule }) => (
<Path
key={`${mask.id}-polygon-${polygonIndex}`}
data={data}
fill={mask.color}
fillRule={fillRule}
stroke={mask.color}
strokeWidth={(selectedMaskIds.includes(mask.id) ? 2 : 1) / scale}
onClick={(event: any) => handleMaskSelect(mask, event, polygonIndex)}
onTap={(event: any) => handleMaskSelect(mask, event, polygonIndex)}
/>
))}
</Group>
);
})}
{/* Box selection preview */}
{boxRect && effectiveTool === 'box_select' && (
@@ -804,10 +954,20 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
key={`poly-point-${index}`}
x={point.x}
y={point.y}
radius={4 / scale}
fill="#22d3ee"
stroke="#ffffff"
radius={(index === 0 && polygonPoints.length >= 3 ? 6 : 4) / scale}
fill={index === 0 && polygonPoints.length >= 3 ? '#facc15' : '#22d3ee'}
stroke={index === 0 && polygonPoints.length >= 3 ? '#fef3c7' : '#ffffff'}
strokeWidth={1 / scale}
onClick={(event: any) => {
if (index !== 0 || polygonPoints.length < 3) return;
event.cancelBubble = true;
finishPolygon();
}}
onTap={(event: any) => {
if (index !== 0 || polygonPoints.length < 3) return;
event.cancelBubble = true;
finishPolygon();
}}
/>
))}
@@ -827,7 +987,7 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
)))}
{/* Polygon edge insertion handles */}
{selectedMask && selectedMaskPoints.map((point, index) => {
{!isBooleanTool && selectedMask && selectedMaskPoints.map((point, index) => {
const next = selectedMaskPoints[(index + 1) % selectedMaskPoints.length];
if (!next) return null;
return (
@@ -846,7 +1006,7 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
})}
{/* Polygon vertex editor */}
{selectedMask && selectedMaskPoints.map((point, index) => (
{!isBooleanTool && selectedMask && selectedMaskPoints.map((point, index) => (
<Circle
key={`${selectedMask.id}-vertex-${selectedPolygonIndex}-${index}`}
x={point.x}
@@ -900,13 +1060,19 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
{frameMasks.length > 0 && (
<div className="absolute bottom-4 right-4 flex gap-2">
{BOOLEAN_TOOLS.has(effectiveTool) && booleanSelectedMasks.length >= 2 && (
<button
onClick={handleBooleanOperation}
className="text-xs bg-emerald-500/10 hover:bg-emerald-500/20 text-emerald-300 border border-emerald-500/20 px-3 py-1.5 rounded transition-colors"
>
{effectiveTool === 'area_merge' ? '合并选中' : '从主区域去除'}
</button>
{isBooleanTool && (
<div className="flex items-center gap-2">
<span className="text-xs bg-white/5 text-gray-300 border border-white/10 px-2.5 py-1.5 rounded">
{booleanSelectedMasks.length}
</span>
<button
onClick={handleBooleanOperation}
disabled={booleanSelectedMasks.length < 2}
className="text-xs bg-emerald-500/10 hover:bg-emerald-500/20 text-emerald-300 border border-emerald-500/20 px-3 py-1.5 rounded transition-colors disabled:opacity-40 disabled:cursor-not-allowed disabled:hover:bg-emerald-500/10"
>
{effectiveTool === 'area_merge' ? '合并选中' : '从主区域去除'}
</button>
</div>
)}
{activeClass && (
<button

View File

@@ -34,6 +34,65 @@ describe('FrameTimeline', () => {
expect(useStore.getState().currentFrameIndex).toBe(2);
});
it('shows current and total timeline time based on project fps', () => {
useStore.setState({
currentProject: { id: 'p1', name: 'P', status: 'ready', parse_fps: 10 },
currentFrameIndex: 1,
frames: [
{ id: 'f1', projectId: 'p1', index: 0, url: '/1.jpg', width: 640, height: 360 },
{ id: 'f2', projectId: 'p1', index: 1, url: '/2.jpg', width: 640, height: 360 },
{ id: 'f3', projectId: 'p1', index: 2, url: '/3.jpg', width: 640, height: 360 },
],
});
render(<FrameTimeline />);
expect(screen.getAllByText('00:00.10').length).toBeGreaterThan(0);
expect(screen.getAllByText('00:00.20').length).toBeGreaterThan(0);
});
it('changes frames with left and right arrow keys without leaving bounds', () => {
useStore.setState({
currentFrameIndex: 1,
frames: [
{ id: 'f1', projectId: 'p1', index: 0, url: '/1.jpg', width: 640, height: 360 },
{ id: 'f2', projectId: 'p1', index: 1, url: '/2.jpg', width: 640, height: 360 },
{ id: 'f3', projectId: 'p1', index: 2, url: '/3.jpg', width: 640, height: 360 },
],
});
render(<FrameTimeline />);
fireEvent.keyDown(window, { key: 'ArrowRight' });
expect(useStore.getState().currentFrameIndex).toBe(2);
fireEvent.keyDown(window, { key: 'ArrowRight' });
expect(useStore.getState().currentFrameIndex).toBe(2);
fireEvent.keyDown(window, { key: 'ArrowLeft' });
expect(useStore.getState().currentFrameIndex).toBe(1);
});
it('does not change frames while typing in editable fields', () => {
useStore.setState({
currentFrameIndex: 1,
frames: [
{ id: 'f1', projectId: 'p1', index: 0, url: '/1.jpg', width: 640, height: 360 },
{ id: 'f2', projectId: 'p1', index: 1, url: '/2.jpg', width: 640, height: 360 },
{ id: 'f3', projectId: 'p1', index: 2, url: '/3.jpg', width: 640, height: 360 },
],
});
render(
<>
<input aria-label="annotation-name" />
<FrameTimeline />
</>,
);
fireEvent.keyDown(screen.getByLabelText('annotation-name'), { key: 'ArrowRight' });
expect(useStore.getState().currentFrameIndex).toBe(1);
});
it('plays forward using the project parse fps and stops at the end', () => {
vi.useFakeTimers();
useStore.setState({

View File

@@ -16,6 +16,20 @@ export function FrameTimeline() {
const fps = currentProject?.parse_fps || currentProject?.original_fps || 12;
return Math.min(Math.max(fps, 1), 30);
}, [currentProject?.original_fps, currentProject?.parse_fps]);
const timeBaseFps = useMemo(() => {
const fps = currentProject?.parse_fps || currentProject?.original_fps || 12;
return Math.max(fps, 1);
}, [currentProject?.original_fps, currentProject?.parse_fps]);
const currentSeconds = totalFrames > 0 ? currentFrameIndex / timeBaseFps : 0;
const totalSeconds = totalFrames > 0 ? Math.max(totalFrames - 1, 0) / timeBaseFps : 0;
const formatTime = (seconds: number) => {
const safeSeconds = Math.max(0, seconds);
const minutes = Math.floor(safeSeconds / 60);
const wholeSeconds = Math.floor(safeSeconds % 60);
const centiseconds = Math.floor((safeSeconds % 1) * 100);
return `${minutes.toString().padStart(2, '0')}:${wholeSeconds.toString().padStart(2, '0')}.${centiseconds.toString().padStart(2, '0')}`;
};
useEffect(() => {
if (!isPlaying || totalFrames <= 1) return;
@@ -38,6 +52,30 @@ export function FrameTimeline() {
}
}, [totalFrames]);
useEffect(() => {
const isEditableTarget = (target: EventTarget | null) => {
if (!(target instanceof HTMLElement)) return false;
const tagName = target.tagName.toLowerCase();
return target.isContentEditable || ['input', 'textarea', 'select'].includes(tagName);
};
const handleKeyDown = (event: KeyboardEvent) => {
if (isEditableTarget(event.target) || totalFrames <= 1) return;
if (event.key !== 'ArrowLeft' && event.key !== 'ArrowRight') return;
event.preventDefault();
setIsPlaying(false);
const direction = event.key === 'ArrowRight' ? 1 : -1;
const nextIndex = Math.min(Math.max(currentFrameIndex + direction, 0), totalFrames - 1);
if (nextIndex !== currentFrameIndex) {
setCurrentFrame(nextIndex);
}
};
window.addEventListener('keydown', handleKeyDown);
return () => window.removeEventListener('keydown', handleKeyDown);
}, [currentFrameIndex, setCurrentFrame, totalFrames]);
// show frames around current frame
const frameWindow = 20;
const displayIndices = totalFrames > 0
@@ -47,6 +85,12 @@ export function FrameTimeline() {
return (
<div className="h-32 bg-[#111] border-t border-white/5 flex flex-col shrink-0 z-20">
<div className="h-4 bg-[#0d0d0d] flex items-center group relative">
<div className="absolute left-3 -top-5 text-[10px] font-mono text-gray-500 pointer-events-none">
{formatTime(currentSeconds)}
</div>
<div className="absolute right-3 -top-5 text-[10px] font-mono text-gray-500 pointer-events-none">
{formatTime(totalSeconds)}
</div>
<input
type="range"
min="1"
@@ -65,6 +109,12 @@ export function FrameTimeline() {
className="w-3 h-3 bg-white rounded-full absolute top-1/2 -translate-y-1/2 -ml-1.5 shadow-sm transform scale-0 group-hover:scale-100 transition-transform shadow-cyan-500/50"
style={{ left: `${totalFrames > 0 ? (currentFrame / totalFrames) * 100 : 0}%` }}
/>
<div
className="absolute -top-7 -translate-x-1/2 rounded bg-black/80 border border-white/10 px-2 py-0.5 text-[10px] font-mono text-cyan-300 opacity-0 group-hover:opacity-100 transition-opacity pointer-events-none"
style={{ left: `${totalFrames > 0 ? (currentFrame / totalFrames) * 100 : 0}%` }}
>
{formatTime(currentSeconds)}
</div>
</div>
</div>
@@ -129,6 +179,9 @@ export function FrameTimeline() {
<div className="w-48 text-right shrink-0">
<div className="text-2xl font-mono text-white">{currentFrame}<span className="text-xs text-gray-500"> / {totalFrames}</span></div>
<div className="text-xs font-mono text-cyan-300 mt-1">
{formatTime(currentSeconds)} <span className="text-gray-600">/</span> {formatTime(totalSeconds)}
</div>
<div className="text-[10px] text-gray-500 uppercase tracking-widest mt-1"></div>
</div>
</div>

View File

@@ -45,6 +45,41 @@ describe('OntologyInspector', () => {
}));
});
it('applies the selected class to currently selected masks', () => {
useStore.setState({
selectedMaskIds: ['m1'],
masks: [
{
id: 'm1',
annotationId: '99',
frameId: 'frame-1',
pathData: 'M 0 0 Z',
label: '旧标签',
color: '#06b6d4',
saveStatus: 'saved',
saved: true,
},
],
});
render(<OntologyInspector />);
fireEvent.click(screen.getByText('肝脏'));
expect(useStore.getState().activeClassId).toBe('c2');
expect(useStore.getState().masks[0]).toEqual(expect.objectContaining({
templateId: 't1',
classId: 'c2',
className: '肝脏',
classZIndex: 10,
label: '肝脏',
color: '#00ff00',
saveStatus: 'dirty',
saved: false,
}));
expect(screen.getByText('当前选中区域:')).toBeInTheDocument();
expect(screen.getByText('1')).toBeInTheDocument();
});
it('adds custom classes locally without backend persistence', () => {
const { container } = render(<OntologyInspector />);
const customSection = screen.getByText('自定义分类').parentElement!;

View File

@@ -10,6 +10,9 @@ export function OntologyInspector() {
const activeTemplateId = useStore((state) => state.activeTemplateId);
const activeClassId = useStore((state) => state.activeClassId);
const activeClass = useStore((state) => state.activeClass);
const masks = useStore((state) => state.masks);
const selectedMaskIds = useStore((state) => state.selectedMaskIds);
const setMasks = useStore((state) => state.setMasks);
const setActiveTemplateId = useStore((state) => state.setActiveTemplateId);
const setActiveClass = useStore((state) => state.setActiveClass);
@@ -28,6 +31,25 @@ export function OntologyInspector() {
setActiveTemplateId(activeTemplate.id);
}
setActiveClass(templateClass);
const selectedIdSet = new Set(selectedMaskIds);
const hasSelectedMasks = masks.some((mask) => selectedIdSet.has(mask.id));
if (!hasSelectedMasks) return;
const templateId = activeTemplate?.id || activeTemplateId || undefined;
setMasks(masks.map((mask) => {
if (!selectedIdSet.has(mask.id)) return mask;
return {
...mask,
templateId: templateId || mask.templateId,
classId: templateClass.id,
className: templateClass.name,
classZIndex: templateClass.zIndex,
label: templateClass.name,
color: templateClass.color,
saveStatus: mask.annotationId ? 'dirty' : 'draft',
saved: mask.annotationId ? false : mask.saved,
};
}));
};
const handleAddCustom = () => {
@@ -164,6 +186,10 @@ export function OntologyInspector() {
</span>
</div>
<div className="space-y-3">
<div className="flex items-center justify-between">
<span className="text-[10px] text-gray-500 uppercase">:</span>
<span className="text-xs font-mono text-gray-300">{selectedMaskIds.length}</span>
</div>
<div className="space-y-1">
<label className="text-[10px] text-gray-500 uppercase"></label>
<div className="h-1.5 w-full bg-white/10 rounded-full overflow-hidden">

View File

@@ -82,4 +82,65 @@ describe('TemplateRegistry', () => {
expect(screen.getByText('分类A')).toBeInTheDocument();
});
it('edits an existing template through the backend and store', async () => {
apiMock.getTemplates.mockResolvedValueOnce([
{
id: 't1',
name: '旧模板',
description: 'old desc',
classes: [{ id: 'c1', name: '胆囊', color: '#ff0000', zIndex: 10, category: '器官' }],
rules: [],
color: '#06b6d4',
z_index: 3,
},
]);
apiMock.updateTemplate.mockResolvedValueOnce({
id: 't1',
name: '新模板',
description: 'new desc',
classes: [{ id: 'c1', name: '胆囊', color: '#ff0000', zIndex: 10, category: '器官' }],
rules: [],
});
render(<TemplateRegistry />);
fireEvent.click(await screen.findByRole('button', { name: /修改库视图结构/ }));
fireEvent.change(screen.getAllByRole('textbox')[0], { target: { value: '新模板' } });
fireEvent.change(screen.getAllByRole('textbox')[1], { target: { value: 'new desc' } });
fireEvent.click(screen.getByRole('button', { name: '保存' }));
await waitFor(() => expect(apiMock.updateTemplate).toHaveBeenCalledWith('t1', expect.objectContaining({
name: '新模板',
description: 'new desc',
classes: [expect.objectContaining({ id: 'c1', name: '胆囊' })],
rules: [],
color: '#06b6d4',
z_index: 3,
})));
expect(useStore.getState().templates[0]).toEqual(expect.objectContaining({
id: 't1',
name: '新模板',
}));
});
it('deletes an existing template after confirmation', async () => {
apiMock.getTemplates.mockResolvedValueOnce([
{
id: 't1',
name: '待删除模板',
description: 'desc',
classes: [],
rules: [],
},
]);
apiMock.deleteTemplate.mockResolvedValueOnce(undefined);
const { container } = render(<TemplateRegistry />);
await screen.findAllByText('待删除模板');
const buttons = Array.from(container.querySelectorAll('button'));
fireEvent.click(buttons[2]);
await waitFor(() => expect(apiMock.deleteTemplate).toHaveBeenCalledWith('t1'));
expect(useStore.getState().templates).toEqual([]);
});
});

View File

@@ -7,6 +7,7 @@ import { VideoWorkspace } from './VideoWorkspace';
const apiMock = vi.hoisted(() => ({
getProjectFrames: vi.fn(),
parseMedia: vi.fn(),
propagateMasks: vi.fn(),
getTask: vi.fn(),
getTemplates: vi.fn(),
getProjectAnnotations: vi.fn(),
@@ -24,6 +25,7 @@ const apiMock = vi.hoisted(() => ({
vi.mock('../lib/api', () => ({
getProjectFrames: apiMock.getProjectFrames,
parseMedia: apiMock.parseMedia,
propagateMasks: apiMock.propagateMasks,
getTask: apiMock.getTask,
getTemplates: apiMock.getTemplates,
getProjectAnnotations: apiMock.getProjectAnnotations,
@@ -47,6 +49,14 @@ describe('VideoWorkspace', () => {
apiMock.getProjectAnnotations.mockResolvedValue([]);
apiMock.annotationToMask.mockReturnValue(null);
apiMock.getTask.mockResolvedValue({ id: 1, status: 'success', progress: 100, message: '解析完成' });
apiMock.propagateMasks.mockResolvedValue({
model: 'sam2',
direction: 'forward',
source_frame_id: 10,
processed_frame_count: 3,
created_annotation_count: 2,
annotations: [],
});
apiMock.getAiModelStatus.mockResolvedValue({
selected_model: 'sam2',
gpu: { available: false, device: 'cpu', name: null, torch_available: true },
@@ -320,4 +330,64 @@ describe('VideoWorkspace', () => {
expect.objectContaining({ id: 'annotation-88', label: 'GT Mask' }),
]));
});
it('propagates the selected current-frame mask through the backend video tracker', async () => {
apiMock.getProjectFrames.mockResolvedValueOnce([
{ id: 10, project_id: 1, frame_index: 0, image_url: '/frame.jpg', width: 640, height: 360 },
{ id: 11, project_id: 1, frame_index: 1, image_url: '/frame-1.jpg', width: 640, height: 360 },
]);
apiMock.buildAnnotationPayload.mockReturnValueOnce({
project_id: 1,
frame_id: 10,
template_id: 2,
mask_data: {
polygons: [[[0.1, 0.1], [0.3, 0.1], [0.3, 0.3]]],
label: '胆囊',
color: '#ff0000',
class: { id: 'c1', name: '胆囊', color: '#ff0000', zIndex: 20 },
},
bbox: [0.1, 0.1, 0.2, 0.2],
});
render(<VideoWorkspace />);
await waitFor(() => expect(useStore.getState().frames).toHaveLength(2));
act(() => {
useStore.setState({
aiModel: 'sam2',
activeTemplateId: '2',
selectedMaskIds: ['mask-1'],
masks: [{
id: 'mask-1',
frameId: '10',
pathData: 'M 0 0 Z',
label: '胆囊',
color: '#ff0000',
segmentation: [[64, 36, 192, 36, 192, 108]],
bbox: [64, 36, 128, 72],
}],
});
});
fireEvent.click(screen.getByRole('button', { name: '传播片段' }));
await waitFor(() => expect(apiMock.propagateMasks).toHaveBeenCalledWith({
project_id: 1,
frame_id: 10,
model: 'sam2',
direction: 'forward',
max_frames: 30,
include_source: false,
save_annotations: true,
seed: {
polygons: [[[0.1, 0.1], [0.3, 0.1], [0.3, 0.3]]],
bbox: [0.1, 0.1, 0.2, 0.2],
points: undefined,
label: '胆囊',
color: '#ff0000',
class_metadata: { id: 'c1', name: '胆囊', color: '#ff0000', zIndex: 20 },
template_id: 2,
},
}));
await waitFor(() => expect(screen.getByText('已传播并保存 2 个区域')).toBeInTheDocument());
});
});

View File

@@ -12,6 +12,7 @@ import {
getTemplates,
importGtMask,
parseMedia,
propagateMasks,
saveAnnotation,
updateAnnotation,
} from '../lib/api';
@@ -37,6 +38,8 @@ export function VideoWorkspace({ onNavigateToAI }: { onNavigateToAI?: () => void
const maskHistory = useStore((state) => state.maskHistory);
const maskFuture = useStore((state) => state.maskFuture);
const activeTemplateId = useStore((state) => state.activeTemplateId);
const aiModel = useStore((state) => state.aiModel);
const selectedMaskIds = useStore((state) => state.selectedMaskIds);
const setFrames = useStore((state) => state.setFrames);
const setCurrentFrame = useStore((state) => state.setCurrentFrame);
const setMasks = useStore((state) => state.setMasks);
@@ -45,6 +48,7 @@ export function VideoWorkspace({ onNavigateToAI }: { onNavigateToAI?: () => void
const [isSaving, setIsSaving] = useState(false);
const [isExporting, setIsExporting] = useState(false);
const [isImportingGt, setIsImportingGt] = useState(false);
const [isPropagating, setIsPropagating] = useState(false);
const [statusMessage, setStatusMessage] = useState('');
const hydrateSavedAnnotations = useCallback(async (projectId: string, projectFrames: Frame[]) => {
@@ -102,6 +106,8 @@ export function VideoWorkspace({ onNavigateToAI }: { onNavigateToAI?: () => void
url: f.image_url,
width: f.width ?? 0,
height: f.height ?? 0,
timestampMs: f.timestamp_ms ?? undefined,
sourceFrameNumber: f.source_frame_number ?? undefined,
}));
setFrames(mappedFrames);
setCurrentFrame(0);
@@ -117,6 +123,8 @@ export function VideoWorkspace({ onNavigateToAI }: { onNavigateToAI?: () => void
url: f.image_url,
width: f.width ?? 0,
height: f.height ?? 0,
timestampMs: f.timestamp_ms ?? undefined,
sourceFrameNumber: f.source_frame_number ?? undefined,
}));
setFrames(mappedFrames);
setCurrentFrame(0);
@@ -314,6 +322,55 @@ export function VideoWorkspace({ onNavigateToAI }: { onNavigateToAI?: () => void
}
};
const handlePropagateSegment = async () => {
if (!currentProject?.id || !currentFrame?.id) return;
const currentFrameMasks = masks.filter((mask) => mask.frameId === currentFrame.id);
const selectedMask = selectedMaskIds
.map((id) => currentFrameMasks.find((mask) => mask.id === id))
.find((mask): mask is NonNullable<typeof mask> => Boolean(mask));
const seedMask = selectedMask || currentFrameMasks[0];
if (!seedMask) {
setStatusMessage('请先选择或创建一个当前帧区域');
return;
}
const seedPayload = buildAnnotationPayload(currentProject.id, seedMask, currentFrame, activeTemplateId);
if (!seedPayload?.mask_data?.polygons?.length && !seedPayload?.bbox) {
setStatusMessage('当前区域缺少可传播的 polygon 或 bbox');
return;
}
setIsPropagating(true);
setStatusMessage(`${aiModel.toUpperCase()} 正在传播当前区域...`);
try {
const result = await propagateMasks({
project_id: Number(currentProject.id),
frame_id: Number(currentFrame.id),
model: aiModel,
direction: 'forward',
max_frames: 30,
include_source: false,
save_annotations: true,
seed: {
polygons: seedPayload.mask_data?.polygons,
bbox: seedPayload.bbox,
points: seedPayload.points,
label: seedPayload.mask_data?.label,
color: seedPayload.mask_data?.color,
class_metadata: seedPayload.mask_data?.class,
template_id: seedPayload.template_id,
},
});
await hydrateSavedAnnotations(currentProject.id, frames);
setStatusMessage(`已传播并保存 ${result.created_annotation_count} 个区域`);
} catch (err) {
console.error('Propagation failed:', err);
setStatusMessage('传播失败,请检查模型状态或后端日志');
} finally {
setIsPropagating(false);
}
};
return (
<div className="w-full h-full flex flex-col bg-[#0a0a0a]">
{/* Top Header / Status bar */}
@@ -339,28 +396,35 @@ export function VideoWorkspace({ onNavigateToAI }: { onNavigateToAI?: () => void
/>
<button
onClick={() => gtMaskInputRef.current?.click()}
disabled={!currentProject?.id || !currentFrame?.id || isImportingGt || isSaving || isExporting}
disabled={!currentProject?.id || !currentFrame?.id || isImportingGt || isSaving || isExporting || isPropagating}
className="px-4 py-1.5 bg-white/5 hover:bg-white/10 border border-white/10 rounded-md text-xs transition-colors text-white disabled:opacity-40 disabled:cursor-not-allowed"
>
{isImportingGt ? '导入中...' : '导入 GT Mask'}
</button>
<button
onClick={handlePropagateSegment}
disabled={!currentProject?.id || !currentFrame?.id || isSaving || isExporting || isImportingGt || isPropagating}
className="px-4 py-1.5 bg-white/5 hover:bg-white/10 border border-white/10 rounded-md text-xs transition-colors text-white disabled:opacity-40 disabled:cursor-not-allowed"
>
{isPropagating ? '传播中...' : '传播片段'}
</button>
<button
onClick={handleExportMasks}
disabled={!currentProject?.id || isExporting || isSaving}
disabled={!currentProject?.id || isExporting || isSaving || isPropagating}
className="px-4 py-1.5 bg-white/5 hover:bg-white/10 border border-white/10 rounded-md text-xs transition-colors text-white disabled:opacity-40 disabled:cursor-not-allowed"
>
{isExporting ? '导出中...' : '导出 PNG Mask ZIP'}
</button>
<button
onClick={handleExport}
disabled={!currentProject?.id || isExporting || isSaving}
disabled={!currentProject?.id || isExporting || isSaving || isPropagating}
className="px-4 py-1.5 bg-white/5 hover:bg-white/10 border border-white/10 rounded-md text-xs transition-colors text-white disabled:opacity-40 disabled:cursor-not-allowed"
>
{isExporting ? '导出中...' : '导出 JSON 标注集'}
</button>
<button
onClick={handleSave}
disabled={!currentProject?.id || isSaving || isExporting}
disabled={!currentProject?.id || isSaving || isExporting || isPropagating}
className="px-4 py-1.5 bg-cyan-600 hover:bg-cyan-500 text-white text-xs font-medium rounded-md transition-shadow shadow-lg shadow-cyan-900/20 disabled:opacity-40 disabled:cursor-not-allowed"
>
{isSaving ? '保存中...' : '结构化归档保存'}

View File

@@ -159,9 +159,9 @@ describe('api client contracts', () => {
axiosMock.client.post.mockResolvedValueOnce({ data: { ...task, status: 'cancelled', progress: 100 } });
axiosMock.client.post.mockResolvedValueOnce({ data: { ...task, id: 13, status: 'queued', progress: 0 } });
await expect(parseMedia('9')).resolves.toEqual(task);
await expect(parseMedia('9', { parseFps: 15, maxFrames: 120, targetWidth: 960 })).resolves.toEqual(task);
expect(axiosMock.client.post).toHaveBeenCalledWith('/api/media/parse', null, {
params: { project_id: '9' },
params: { project_id: '9', parse_fps: 15, max_frames: 120, target_width: 960 },
});
await expect(getTask(12)).resolves.toEqual(expect.objectContaining({ status: 'success', progress: 100 }));
@@ -175,7 +175,7 @@ describe('api client contracts', () => {
});
it('lists, saves, updates, and deletes annotations with the backend annotation contract', async () => {
const { deleteAnnotation, getProjectAnnotations, saveAnnotation, updateAnnotation } = await import('./api');
const { deleteAnnotation, getProjectAnnotations, propagateMasks, saveAnnotation, updateAnnotation } = await import('./api');
const saved = {
id: 1,
project_id: 9,
@@ -221,6 +221,43 @@ describe('api client contracts', () => {
axiosMock.client.delete.mockResolvedValueOnce({ data: null });
await expect(deleteAnnotation('1')).resolves.toBeUndefined();
expect(axiosMock.client.delete).toHaveBeenCalledWith('/api/ai/annotations/1');
axiosMock.client.post.mockResolvedValueOnce({
data: {
model: 'sam2',
direction: 'forward',
source_frame_id: 5,
processed_frame_count: 3,
created_annotation_count: 2,
annotations: [saved],
},
});
await expect(propagateMasks({
project_id: 9,
frame_id: 5,
model: 'sam2',
seed: {
polygons: [[[0, 0], [1, 0], [1, 1]]],
label: 'mask',
color: '#06b6d4',
},
direction: 'forward',
max_frames: 30,
})).resolves.toEqual(expect.objectContaining({ created_annotation_count: 2 }));
expect(axiosMock.client.post).toHaveBeenCalledWith('/api/ai/propagate', {
project_id: 9,
frame_id: 5,
model: 'sam2',
seed: {
polygons: [[[0, 0], [1, 0], [1, 1]]],
label: 'mask',
color: '#06b6d4',
},
direction: 'forward',
max_frames: 30,
}, {
timeout: 600000,
});
});
it('imports GT masks through multipart form data', async () => {
@@ -377,6 +414,33 @@ describe('api client contracts', () => {
});
});
it('normalizes combined box and point prompts for interactive SAM2 refinement', async () => {
const { predictMask } = await import('./api');
axiosMock.client.post.mockResolvedValueOnce({ data: { polygons: [], scores: [] } });
await predictMask({
imageId: '5',
imageWidth: 640,
imageHeight: 320,
box: { x1: 64, y1: 32, x2: 320, y2: 160 },
points: [
{ x: 128, y: 64, type: 'pos' },
{ x: 256, y: 128, type: 'neg' },
],
});
expect(axiosMock.client.post).toHaveBeenCalledWith('/api/ai/predict', {
image_id: 5,
prompt_type: 'interactive',
prompt_data: {
box: [0.1, 0.1, 0.5, 0.5],
points: [[0.2, 0.2], [0.4, 0.4]],
labels: [1, 0],
},
model: 'sam2',
});
});
it('uses semantic prompt type for text-only AI prediction', async () => {
const { predictMask } = await import('./api');
axiosMock.client.post.mockResolvedValueOnce({ data: { polygons: [], scores: [] } });

View File

@@ -153,6 +153,8 @@ export async function getProjectFrames(projectId: string): Promise<Array<{
image_url: string;
width: number | null;
height: number | null;
timestamp_ms?: number | null;
source_frame_number?: number | null;
}>> {
const response = await apiClient.get(`/api/projects/${projectId}/frames`);
return response.data;
@@ -185,9 +187,18 @@ export interface ProcessingTask {
updated_at: string;
}
export async function parseMedia(projectId: string): Promise<ProcessingTask> {
export async function parseMedia(projectId: string, options: {
parseFps?: number;
maxFrames?: number;
targetWidth?: number;
} = {}): Promise<ProcessingTask> {
const response = await apiClient.post('/api/media/parse', null, {
params: { project_id: projectId },
params: {
project_id: projectId,
...(options.parseFps ? { parse_fps: options.parseFps } : {}),
...(options.maxFrames ? { max_frames: options.maxFrames } : {}),
...(options.targetWidth ? { target_width: options.targetWidth } : {}),
},
});
return response.data;
}
@@ -312,6 +323,40 @@ export interface SaveAnnotationPayload {
export type UpdateAnnotationPayload = Omit<SaveAnnotationPayload, 'project_id' | 'frame_id'>;
export interface PropagateMasksPayload {
project_id: number;
frame_id: number;
model?: AiModelId;
seed: {
polygons?: number[][][];
bbox?: number[];
points?: number[][];
label?: string;
color?: string;
class_metadata?: {
id?: string;
name?: string;
color?: string;
zIndex?: number;
category?: string;
};
template_id?: number;
};
direction?: 'forward' | 'backward' | 'both';
max_frames?: number;
include_source?: boolean;
save_annotations?: boolean;
}
export interface PropagateMasksResult {
model: AiModelId;
direction: string;
source_frame_id: number;
processed_frame_count: number;
created_annotation_count: number;
annotations: SavedAnnotation[];
}
export interface DashboardTask {
id: string;
task_id?: number;
@@ -474,10 +519,22 @@ export function annotationToMask(annotation: SavedAnnotation, frame: Frame): Mas
}
export async function predictMask(payload: PredictMaskPayload): Promise<PredictMaskResult> {
let prompt_type: 'point' | 'box' | 'semantic';
let prompt_type: 'point' | 'box' | 'semantic' | 'interactive';
let prompt_data: unknown;
if (payload.box) {
if (payload.box && payload.points && payload.points.length > 0) {
prompt_type = 'interactive';
prompt_data = {
box: [
clamp01(payload.box.x1 / Math.max(payload.imageWidth, 1)),
clamp01(payload.box.y1 / Math.max(payload.imageHeight, 1)),
clamp01(payload.box.x2 / Math.max(payload.imageWidth, 1)),
clamp01(payload.box.y2 / Math.max(payload.imageHeight, 1)),
],
points: payload.points.map((point) => normalizePoint(point, payload.imageWidth, payload.imageHeight)),
labels: payload.points.map((point) => (point.type === 'neg' ? 0 : 1)),
};
} else if (payload.box) {
prompt_type = 'box';
prompt_data = [
clamp01(payload.box.x1 / Math.max(payload.imageWidth, 1)),
@@ -540,6 +597,13 @@ export async function getProjectAnnotations(projectId: string, frameId?: string)
return response.data;
}
export async function propagateMasks(payload: PropagateMasksPayload): Promise<PropagateMasksResult> {
const response = await apiClient.post('/api/ai/propagate', payload, {
timeout: 600000,
});
return response.data;
}
export async function saveAnnotation(payload: SaveAnnotationPayload): Promise<SavedAnnotation> {
const response = await apiClient.post('/api/ai/annotate', payload);
return response.data;

View File

@@ -30,6 +30,7 @@ describe('useStore', () => {
useStore.getState().setFrames([{ id: 'f1', projectId: '1', index: 0, url: '/f1.jpg', width: 640, height: 360 }]);
useStore.getState().setCurrentFrame(0);
useStore.getState().addMask({ id: 'm1', frameId: 'f1', pathData: 'M 0 0 Z', label: 'mask', color: '#fff' });
useStore.getState().setSelectedMaskIds(['m1']);
useStore.getState().updateMask('m1', { label: 'updated mask', saveStatus: 'dirty' });
useStore.getState().addAnnotation({ id: 'a1', frameId: 'f1', type: 'mask', points: [], label: 'ann', color: '#fff' });
useStore.getState().addTemplate({ id: 't1', name: 'Template', classes: [], rules: [] });
@@ -40,6 +41,7 @@ describe('useStore', () => {
expect(useStore.getState().currentProject?.id).toBe('1');
expect(useStore.getState().frames).toHaveLength(1);
expect(useStore.getState().currentFrameIndex).toBe(0);
expect(useStore.getState().selectedMaskIds).toEqual(['m1']);
expect(useStore.getState().masks[0]).toEqual(expect.objectContaining({ label: 'updated mask', saveStatus: 'dirty' }));
expect(useStore.getState().annotations).toHaveLength(1);
expect(useStore.getState().templates[0].name).toBe('Template 2');
@@ -51,6 +53,7 @@ describe('useStore', () => {
expect(useStore.getState().annotations).toEqual([]);
expect(useStore.getState().masks).toEqual([]);
expect(useStore.getState().selectedMaskIds).toEqual([]);
expect(useStore.getState().templates).toEqual([]);
});

View File

@@ -27,6 +27,8 @@ export interface Frame {
width: number;
height: number;
timestamp?: string;
timestampMs?: number;
sourceFrameNumber?: number;
}
export interface Annotation {
@@ -112,6 +114,7 @@ export interface AppState {
currentFrameIndex: number;
annotations: Annotation[];
masks: Mask[];
selectedMaskIds: string[];
maskHistory: Mask[][];
maskFuture: Mask[][];
setActiveModule: (module: string) => void;
@@ -123,6 +126,7 @@ export interface AppState {
addMask: (mask: Mask) => void;
updateMask: (id: string, updates: Partial<Mask>) => void;
setMasks: (masks: Mask[]) => void;
setSelectedMaskIds: (ids: string[]) => void;
clearMasks: () => void;
undoMasks: () => void;
redoMasks: () => void;
@@ -167,6 +171,7 @@ export const useStore = create<AppState>((set) => ({
frames: [],
annotations: [],
masks: [],
selectedMaskIds: [],
maskHistory: [],
maskFuture: [],
activeTemplateId: null,
@@ -195,6 +200,7 @@ export const useStore = create<AppState>((set) => ({
currentFrameIndex: 0,
annotations: [],
masks: [],
selectedMaskIds: [],
maskHistory: [],
maskFuture: [],
setActiveModule: (activeModule: string) => set({ activeModule }),
@@ -227,9 +233,11 @@ export const useStore = create<AppState>((set) => ({
maskFuture: [],
};
}),
setSelectedMaskIds: (selectedMaskIds: string[]) => set({ selectedMaskIds }),
clearMasks: () =>
set((state) => ({
masks: [],
selectedMaskIds: [],
maskHistory: [...state.maskHistory, state.masks],
maskFuture: [],
})),

View File

@@ -71,7 +71,11 @@ vi.mock('react-konva', () => ({
data-fill={props.fill}
data-x={props.x}
data-y={props.y}
onClick={() => props.onClick?.({ cancelBubble: false })}
onClick={(event) => {
const konvaEvent = { cancelBubble: false };
props.onClick?.(konvaEvent);
if (konvaEvent.cancelBubble) event.stopPropagation();
}}
onMouseUp={(event: React.MouseEvent<HTMLSpanElement>) => props.onDragEnd?.({
target: {
x: () => event.clientX || props.x || 0,
@@ -92,7 +96,12 @@ vi.mock('react-konva', () => ({
data-testid="konva-path"
data-path={props.data}
data-fill={props.fill}
onClick={() => props.onClick?.({ cancelBubble: false })}
data-fill-rule={props.fillRule}
onClick={(event) => {
const konvaEvent = { cancelBubble: false };
props.onClick?.(konvaEvent);
if (konvaEvent.cancelBubble) event.stopPropagation();
}}
/>
),
}));

View File

@@ -13,6 +13,7 @@ export function resetStore() {
currentFrameIndex: 0,
annotations: [],
masks: [],
selectedMaskIds: [],
maskHistory: [],
maskFuture: [],
templates: [],