feat: 完善视频传播、标注编辑和拆帧闭环
- 接入 SAM2 视频传播能力:新增 /api/ai/propagate,支持用当前帧 mask/polygon/bbox 作为 seed,通过 SAM2 video predictor 向前、向后或双向传播,并可保存为真实 annotation。 - 接入 SAM3 video tracker:通过独立 Python 3.12 external worker 调用 SAM3 video predictor/tracker,使用本地 checkpoint 与 bbox seed 执行视频级跟踪,并在模型状态中标记 video_track 能力。 - 完善 SAM 模型分发:sam_registry 按 model_id 明确区分 sam2 propagation 与 sam3 video_track,避免两个模型链路混用。 - 打通前端“传播片段”:VideoWorkspace 使用当前选中 mask 和当前 AI 模型调用后端传播接口,传播结果回写并刷新工作区已保存标注。 - 增强 SAM3 本地 checkpoint 配置:新增 sam3_checkpoint_path 配置和 .env.example 示例,状态检查改为基于本地 checkpoint/独立环境/模型包可用性。 - 完善视频拆帧参数:/api/media/parse 支持 parse_fps、max_frames、target_width,后端任务保存帧时间戳、源帧号和 frame_sequence 元数据。 - 增加运行时 schema 兼容处理:启动时为旧 frames 表补充 timestamp_ms 和 source_frame_number 列,避免旧库升级后缺字段。 - 强化 Canvas 标注编辑:补齐多边形闭合、点工具、顶点拖拽、边中点插入、Delete/Backspace 删除、区域合并和重叠去除等交互。 - 增强语义分类联动:选中 mask 后可通过右侧语义分类树更新标签、颜色和 class metadata,并同步到保存/导出链路。 - 增加关键帧时间轴体验:FrameTimeline 显示具体时间信息,并支持键盘左右方向键切换关键帧。 - 完善 AI 交互分割参数:前端保留正向点、反向点、框选和 interactive prompt 的调用状态,支持 SAM2 细化候选区域与 SAM3 bbox 入口。 - 扩展后端/前端 API 类型:新增 propagateMasks、传播请求/响应 schema,并补齐 annotation、导出、模型状态和任务接口的测试覆盖。 - 更新项目文档:同步 README、AGENTS、接口契约、需求冻结、设计冻结、前端元素审计、实施计划和测试计划,标明真实功能边界与剩余风险。 - 增加测试覆盖:补充 SAM2/SAM3 传播、SAM3 状态、媒体拆帧参数、Canvas 编辑、语义标签切换、时间轴、工作区传播和 API 合约测试。 - 加强仓库安全边界:将 sam3权重/ 加入 .gitignore,避免本地模型权重被误提交。 验证:npm run test:run;pytest backend/tests;npm run lint;npm run build;python -m py_compile;git diff --check。
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
import { fireEvent, render, screen } from '@testing-library/react';
|
||||
import { fireEvent, render, screen, waitFor } from '@testing-library/react';
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest';
|
||||
import { resetStore } from '../test/storeTestUtils';
|
||||
import { useStore } from '../store/useStore';
|
||||
@@ -62,4 +62,112 @@ describe('AISegmentation', () => {
|
||||
},
|
||||
}));
|
||||
});
|
||||
|
||||
it('prompts for semantic text before running SAM3 inference', async () => {
|
||||
apiMock.getAiModelStatus.mockResolvedValue({
|
||||
selected_model: 'sam3',
|
||||
gpu: { available: true, device: 'cuda', name: 'RTX 4090', torch_available: true },
|
||||
models: [
|
||||
{ id: 'sam2', label: 'SAM 2', available: true, loaded: false, device: 'cuda', supports: ['point', 'box'], message: 'SAM 2 ready', package_available: true, checkpoint_exists: true, python_ok: true, torch_ok: true, cuda_required: false },
|
||||
{ id: 'sam3', label: 'SAM 3', available: true, loaded: true, device: 'cuda', supports: ['semantic'], message: 'SAM 3 ready', package_available: true, checkpoint_exists: true, python_ok: true, torch_ok: true, cuda_required: true },
|
||||
],
|
||||
});
|
||||
|
||||
render(<AISegmentation onSendToWorkspace={vi.fn()} />);
|
||||
const sam3Button = (await screen.findByText('SAM3')).closest('button')!;
|
||||
fireEvent.click(sam3Button);
|
||||
fireEvent.click(screen.getByText('执行高精度语义分割'));
|
||||
|
||||
expect(apiMock.predictMask).not.toHaveBeenCalled();
|
||||
expect(await screen.findByText('SAM3 当前使用文本语义提示,请先输入要分割的目标描述。')).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('shows feedback when SAM3 semantic inference returns no masks', async () => {
|
||||
apiMock.getAiModelStatus.mockResolvedValue({
|
||||
selected_model: 'sam3',
|
||||
gpu: { available: true, device: 'cuda', name: 'RTX 4090', torch_available: true },
|
||||
models: [
|
||||
{ id: 'sam2', label: 'SAM 2', available: true, loaded: false, device: 'cuda', supports: ['point', 'box'], message: 'SAM 2 ready', package_available: true, checkpoint_exists: true, python_ok: true, torch_ok: true, cuda_required: false },
|
||||
{ id: 'sam3', label: 'SAM 3', available: true, loaded: true, device: 'cuda', supports: ['semantic'], message: 'SAM 3 ready', package_available: true, checkpoint_exists: true, python_ok: true, torch_ok: true, cuda_required: true },
|
||||
],
|
||||
});
|
||||
apiMock.predictMask.mockResolvedValueOnce({ masks: [] });
|
||||
|
||||
render(<AISegmentation onSendToWorkspace={vi.fn()} />);
|
||||
const sam3Button = (await screen.findByText('SAM3')).closest('button')!;
|
||||
fireEvent.click(sam3Button);
|
||||
fireEvent.change(screen.getByPlaceholderText("例如:'分割出左侧车道上行驶的所有红色汽车'..."), {
|
||||
target: { value: '胆囊' },
|
||||
});
|
||||
fireEvent.click(screen.getByText('执行高精度语义分割'));
|
||||
|
||||
await waitFor(() => expect(apiMock.predictMask).toHaveBeenCalledWith(expect.objectContaining({
|
||||
model: 'sam3',
|
||||
points: undefined,
|
||||
text: '胆囊',
|
||||
})));
|
||||
expect(await screen.findByText('模型没有返回可用区域,请换一个更具体的描述或调整提示。')).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('runs SAM3 semantic text inference and assigns the active class to returned masks', async () => {
|
||||
apiMock.getAiModelStatus.mockResolvedValue({
|
||||
selected_model: 'sam3',
|
||||
gpu: { available: true, device: 'cuda', name: 'RTX 4090', torch_available: true },
|
||||
models: [
|
||||
{ id: 'sam2', label: 'SAM 2', available: true, loaded: false, device: 'cuda', supports: ['point', 'box'], message: 'SAM 2 ready', package_available: true, checkpoint_exists: true, python_ok: true, torch_ok: true, cuda_required: false },
|
||||
{ id: 'sam3', label: 'SAM 3', available: true, loaded: true, device: 'cuda', supports: ['semantic'], message: 'SAM 3 ready', package_available: true, checkpoint_exists: true, python_ok: true, torch_ok: true, cuda_required: true },
|
||||
],
|
||||
});
|
||||
apiMock.predictMask.mockResolvedValueOnce({
|
||||
masks: [
|
||||
{
|
||||
id: 'semantic-1',
|
||||
pathData: 'M 10 10 L 40 10 L 40 40 Z',
|
||||
label: 'semantic result',
|
||||
color: '#06b6d4',
|
||||
segmentation: [[10, 10, 40, 10, 40, 40]],
|
||||
bbox: [10, 10, 30, 30],
|
||||
area: 900,
|
||||
},
|
||||
],
|
||||
});
|
||||
useStore.setState({
|
||||
activeTemplateId: 'template-1',
|
||||
activeClassId: 'class-1',
|
||||
activeClass: { id: 'class-1', name: '胆囊', color: '#ff0000', zIndex: 30 },
|
||||
});
|
||||
|
||||
render(<AISegmentation onSendToWorkspace={vi.fn()} />);
|
||||
const sam3Button = (await screen.findByText('SAM3')).closest('button')!;
|
||||
fireEvent.click(sam3Button);
|
||||
fireEvent.change(screen.getByPlaceholderText("例如:'分割出左侧车道上行驶的所有红色汽车'..."), {
|
||||
target: { value: '胆囊' },
|
||||
});
|
||||
fireEvent.click(screen.getByText('执行高精度语义分割'));
|
||||
|
||||
await waitFor(() => expect(apiMock.predictMask).toHaveBeenCalledWith(expect.objectContaining({
|
||||
imageId: 'frame-1',
|
||||
imageWidth: 640,
|
||||
imageHeight: 360,
|
||||
model: 'sam3',
|
||||
points: undefined,
|
||||
text: '胆囊',
|
||||
options: {
|
||||
crop_to_prompt: false,
|
||||
auto_filter_background: true,
|
||||
min_score: 0.05,
|
||||
},
|
||||
})));
|
||||
expect(useStore.getState().masks[0]).toEqual(expect.objectContaining({
|
||||
id: 'semantic-1',
|
||||
frameId: 'frame-1',
|
||||
templateId: 'template-1',
|
||||
classId: 'class-1',
|
||||
className: '胆囊',
|
||||
classZIndex: 30,
|
||||
label: '胆囊',
|
||||
color: '#ff0000',
|
||||
saveStatus: 'draft',
|
||||
}));
|
||||
});
|
||||
});
|
||||
|
||||
@@ -33,6 +33,7 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
|
||||
const [autoDeleteBg, setAutoDeleteBg] = useState(true);
|
||||
const [cropMode, setCropMode] = useState(false);
|
||||
const [isInferencing, setIsInferencing] = useState(false);
|
||||
const [inferenceMessage, setInferenceMessage] = useState('');
|
||||
|
||||
// Canvas state
|
||||
const [scale, setScale] = useState(1);
|
||||
@@ -91,9 +92,18 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
|
||||
};
|
||||
|
||||
const runInference = useCallback(async () => {
|
||||
if (points.length === 0 && !semanticText.trim()) return;
|
||||
const textPrompt = semanticText.trim();
|
||||
if (aiModel === 'sam3' && !textPrompt) {
|
||||
setInferenceMessage('SAM3 当前使用文本语义提示,请先输入要分割的目标描述。');
|
||||
return;
|
||||
}
|
||||
if (points.length === 0 && !textPrompt) {
|
||||
setInferenceMessage('请先放置正/反向提示点,或输入语义描述。');
|
||||
return;
|
||||
}
|
||||
if (!currentFrame?.id) {
|
||||
console.warn('AI inference skipped: no project frame is selected');
|
||||
setInferenceMessage('请先在项目工作区选择一帧图像。');
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -101,18 +111,20 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
|
||||
const imageHeight = currentFrame.height || image?.naturalHeight || image?.height || 0;
|
||||
if (imageWidth <= 0 || imageHeight <= 0) {
|
||||
console.warn('AI inference skipped: active frame dimensions are unavailable');
|
||||
setInferenceMessage('当前帧缺少宽高信息,无法推理。');
|
||||
return;
|
||||
}
|
||||
|
||||
setIsInferencing(true);
|
||||
setInferenceMessage('');
|
||||
try {
|
||||
const result = await predictMask({
|
||||
imageId: currentFrame.id,
|
||||
imageWidth,
|
||||
imageHeight,
|
||||
model: aiModel,
|
||||
points: points.map((p) => ({ x: p.x, y: p.y, type: p.type })),
|
||||
text: semanticText.trim() || undefined,
|
||||
points: aiModel === 'sam3' ? undefined : points.map((p) => ({ x: p.x, y: p.y, type: p.type })),
|
||||
text: textPrompt || undefined,
|
||||
options: {
|
||||
crop_to_prompt: cropMode,
|
||||
auto_filter_background: autoDeleteBg,
|
||||
@@ -120,6 +132,11 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
|
||||
},
|
||||
});
|
||||
|
||||
if (result.masks.length === 0) {
|
||||
setInferenceMessage('模型没有返回可用区域,请换一个更具体的描述或调整提示。');
|
||||
} else {
|
||||
setInferenceMessage(`已生成 ${result.masks.length} 个候选区域。`);
|
||||
}
|
||||
result.masks.forEach((m) => {
|
||||
const label = activeClass?.name || m.label;
|
||||
const color = activeClass?.color || m.color;
|
||||
@@ -142,6 +159,8 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
|
||||
});
|
||||
} catch (err) {
|
||||
console.error('AI inference failed:', err);
|
||||
const detail = (err as any)?.response?.data?.detail;
|
||||
setInferenceMessage(detail || 'AI 推理失败,请查看模型状态或后端日志。');
|
||||
} finally {
|
||||
setIsInferencing(false);
|
||||
}
|
||||
@@ -282,6 +301,11 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
|
||||
{isInferencing ? <Loader2 size={16} className="animate-spin" /> : <Sparkles size={16} />}
|
||||
{isInferencing ? '推理中...' : modelCanInfer ? '执行高精度语义分割' : '当前模型不可用'}
|
||||
</button>
|
||||
{inferenceMessage && (
|
||||
<div className="rounded border border-white/10 bg-white/5 px-3 py-2 text-[11px] leading-relaxed text-gray-300">
|
||||
{inferenceMessage}
|
||||
</div>
|
||||
)}
|
||||
<button
|
||||
onClick={onSendToWorkspace}
|
||||
className="w-full py-3.5 rounded-lg flex items-center justify-center gap-2 transition-all font-medium tracking-wide text-xs uppercase bg-white/5 hover:bg-white/10 text-gray-300 border border-white/5 hover:border-white/10"
|
||||
|
||||
@@ -65,6 +65,157 @@ describe('CanvasArea', () => {
|
||||
}));
|
||||
});
|
||||
|
||||
it('explains that SAM3 point prompts are not supported in the workspace', async () => {
|
||||
useStore.setState({ aiModel: 'sam3' });
|
||||
|
||||
render(<CanvasArea activeTool="point_pos" frame={frame} />);
|
||||
fireEvent.click(screen.getByTestId('konva-stage'));
|
||||
|
||||
expect(apiMock.predictMask).not.toHaveBeenCalled();
|
||||
expect(await screen.findByText(/SAM3 当前工作区只支持框选提示/)).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('calls SAM3 prediction with a box prompt from the workspace', async () => {
|
||||
useStore.setState({ aiModel: 'sam3' });
|
||||
apiMock.predictMask.mockResolvedValueOnce({
|
||||
masks: [
|
||||
{
|
||||
id: 'sam3-box-mask',
|
||||
pathData: 'M 20 20 L 80 20 L 80 80 Z',
|
||||
label: 'AI Mask',
|
||||
color: '#06b6d4',
|
||||
segmentation: [[20, 20, 80, 20, 80, 80]],
|
||||
bbox: [20, 20, 60, 60],
|
||||
area: 3600,
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
render(<CanvasArea activeTool="box_select" frame={frame} />);
|
||||
const stage = screen.getByTestId('konva-stage');
|
||||
fireEvent.mouseDown(stage, { clientX: 120, clientY: 80 });
|
||||
fireEvent.mouseMove(stage, { clientX: 260, clientY: 200 });
|
||||
fireEvent.mouseUp(stage, { clientX: 260, clientY: 200 });
|
||||
|
||||
await waitFor(() => expect(apiMock.predictMask).toHaveBeenCalledWith({
|
||||
imageId: 'frame-1',
|
||||
imageWidth: 640,
|
||||
imageHeight: 360,
|
||||
model: 'sam3',
|
||||
points: undefined,
|
||||
box: { x1: 120, y1: 80, x2: 260, y2: 200 },
|
||||
}));
|
||||
expect(useStore.getState().masks[0]).toEqual(expect.objectContaining({
|
||||
id: 'sam3-box-mask',
|
||||
metadata: expect.objectContaining({
|
||||
source: 'sam3_box',
|
||||
promptBox: { x1: 120, y1: 80, x2: 260, y2: 200 },
|
||||
}),
|
||||
}));
|
||||
});
|
||||
|
||||
it('refines one SAM2 candidate mask from an initial box with positive and negative points', async () => {
|
||||
apiMock.predictMask
|
||||
.mockResolvedValueOnce({
|
||||
masks: [
|
||||
{
|
||||
id: 'mask-box',
|
||||
pathData: 'M 10 10 L 90 10 L 90 90 Z',
|
||||
label: 'AI Mask',
|
||||
color: '#06b6d4',
|
||||
segmentation: [[10, 10, 90, 10, 90, 90]],
|
||||
bbox: [10, 10, 80, 80],
|
||||
area: 6400,
|
||||
},
|
||||
],
|
||||
})
|
||||
.mockResolvedValueOnce({
|
||||
masks: [
|
||||
{
|
||||
id: 'mask-refined-pos',
|
||||
pathData: 'M 20 20 L 80 20 L 80 80 Z',
|
||||
label: 'AI Mask',
|
||||
color: '#06b6d4',
|
||||
segmentation: [[20, 20, 80, 20, 80, 80]],
|
||||
bbox: [20, 20, 60, 60],
|
||||
area: 3600,
|
||||
},
|
||||
],
|
||||
})
|
||||
.mockResolvedValueOnce({
|
||||
masks: [
|
||||
{
|
||||
id: 'mask-refined-neg',
|
||||
pathData: 'M 30 30 L 70 30 L 70 70 Z',
|
||||
label: 'AI Mask',
|
||||
color: '#06b6d4',
|
||||
segmentation: [[30, 30, 70, 30, 70, 70]],
|
||||
bbox: [30, 30, 40, 40],
|
||||
area: 1600,
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
const { rerender } = render(<CanvasArea activeTool="box_select" frame={frame} />);
|
||||
const stage = screen.getByTestId('konva-stage');
|
||||
fireEvent.mouseDown(stage, { clientX: 120, clientY: 80 });
|
||||
fireEvent.mouseMove(stage, { clientX: 260, clientY: 200 });
|
||||
fireEvent.mouseUp(stage, { clientX: 260, clientY: 200 });
|
||||
|
||||
await waitFor(() => expect(apiMock.predictMask).toHaveBeenNthCalledWith(1, {
|
||||
imageId: 'frame-1',
|
||||
imageWidth: 640,
|
||||
imageHeight: 360,
|
||||
model: 'sam2',
|
||||
points: undefined,
|
||||
box: { x1: 120, y1: 80, x2: 260, y2: 200 },
|
||||
}));
|
||||
await waitFor(() => expect(useStore.getState().masks).toHaveLength(1));
|
||||
|
||||
rerender(<CanvasArea activeTool="point_pos" frame={frame} />);
|
||||
fireEvent.click(stage, { clientX: 150, clientY: 100 });
|
||||
|
||||
await waitFor(() => expect(apiMock.predictMask).toHaveBeenNthCalledWith(2, {
|
||||
imageId: 'frame-1',
|
||||
imageWidth: 640,
|
||||
imageHeight: 360,
|
||||
model: 'sam2',
|
||||
points: [{ x: 150, y: 100, type: 'pos' }],
|
||||
box: { x1: 120, y1: 80, x2: 260, y2: 200 },
|
||||
}));
|
||||
expect(useStore.getState().masks).toHaveLength(1);
|
||||
expect(useStore.getState().masks[0]).toEqual(expect.objectContaining({
|
||||
id: 'mask-box',
|
||||
segmentation: [[20, 20, 80, 20, 80, 80]],
|
||||
metadata: expect.objectContaining({
|
||||
source: 'sam2_interactive',
|
||||
promptPointCount: 1,
|
||||
}),
|
||||
}));
|
||||
|
||||
rerender(<CanvasArea activeTool="point_neg" frame={frame} />);
|
||||
fireEvent.click(stage, { clientX: 300, clientY: 150 });
|
||||
|
||||
await waitFor(() => expect(apiMock.predictMask).toHaveBeenNthCalledWith(3, {
|
||||
imageId: 'frame-1',
|
||||
imageWidth: 640,
|
||||
imageHeight: 360,
|
||||
model: 'sam2',
|
||||
points: [
|
||||
{ x: 150, y: 100, type: 'pos' },
|
||||
{ x: 300, y: 150, type: 'neg' },
|
||||
],
|
||||
box: { x1: 120, y1: 80, x2: 260, y2: 200 },
|
||||
}));
|
||||
expect(useStore.getState().masks).toHaveLength(1);
|
||||
expect(useStore.getState().masks[0]).toEqual(expect.objectContaining({
|
||||
id: 'mask-box',
|
||||
segmentation: [[30, 30, 70, 30, 70, 70]],
|
||||
points: [[150, 100]],
|
||||
metadata: expect.objectContaining({ promptPointCount: 2 }),
|
||||
}));
|
||||
});
|
||||
|
||||
it('renders only masks that belong to the current frame', () => {
|
||||
useStore.setState({
|
||||
masks: [
|
||||
@@ -79,6 +230,26 @@ describe('CanvasArea', () => {
|
||||
expect(screen.getByText('遮罩数: 1')).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('publishes the selected mask ids for the ontology panel', async () => {
|
||||
useStore.setState({
|
||||
masks: [
|
||||
{
|
||||
id: 'm1',
|
||||
frameId: 'frame-1',
|
||||
pathData: 'M 0 0 L 10 0 L 10 10 Z',
|
||||
label: 'A',
|
||||
color: '#fff',
|
||||
segmentation: [[0, 0, 10, 0, 10, 10]],
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
render(<CanvasArea activeTool="move" frame={frame} />);
|
||||
fireEvent.click(screen.getByTestId('konva-path'));
|
||||
|
||||
await waitFor(() => expect(useStore.getState().selectedMaskIds).toEqual(['m1']));
|
||||
});
|
||||
|
||||
it('renders imported GT seed points for editable point regions', () => {
|
||||
useStore.setState({
|
||||
masks: [
|
||||
@@ -164,6 +335,57 @@ describe('CanvasArea', () => {
|
||||
}));
|
||||
});
|
||||
|
||||
it('deletes the selected draft mask with Delete when no vertex is selected', () => {
|
||||
useStore.setState({
|
||||
masks: [
|
||||
{
|
||||
id: 'draft-1',
|
||||
frameId: 'frame-1',
|
||||
pathData: 'M 10 10 L 90 10 L 90 40 Z',
|
||||
label: 'Draft',
|
||||
color: '#06b6d4',
|
||||
saveStatus: 'draft',
|
||||
segmentation: [[10, 10, 90, 10, 90, 40]],
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
render(<CanvasArea activeTool="move" frame={frame} />);
|
||||
fireEvent.click(screen.getByTestId('konva-path'));
|
||||
fireEvent.keyDown(window, { key: 'Delete' });
|
||||
|
||||
expect(useStore.getState().masks).toEqual([]);
|
||||
expect(useStore.getState().maskHistory.at(-1)).toEqual([
|
||||
expect.objectContaining({ id: 'draft-1' }),
|
||||
]);
|
||||
});
|
||||
|
||||
it('deletes the selected saved mask locally and notifies the backend deletion callback', () => {
|
||||
const onDeleteMaskAnnotations = vi.fn();
|
||||
useStore.setState({
|
||||
masks: [
|
||||
{
|
||||
id: 'annotation-99',
|
||||
annotationId: '99',
|
||||
frameId: 'frame-1',
|
||||
pathData: 'M 10 10 L 90 10 L 90 40 Z',
|
||||
label: 'Saved',
|
||||
color: '#06b6d4',
|
||||
saveStatus: 'saved',
|
||||
saved: true,
|
||||
segmentation: [[10, 10, 90, 10, 90, 40]],
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
render(<CanvasArea activeTool="move" frame={frame} onDeleteMaskAnnotations={onDeleteMaskAnnotations} />);
|
||||
fireEvent.click(screen.getByTestId('konva-path'));
|
||||
fireEvent.keyDown(window, { key: 'Backspace' });
|
||||
|
||||
expect(useStore.getState().masks).toEqual([]);
|
||||
expect(onDeleteMaskAnnotations).toHaveBeenCalledWith(['99']);
|
||||
});
|
||||
|
||||
it('inserts a polygon vertex from an edge midpoint handle', () => {
|
||||
useStore.setState({
|
||||
masks: [
|
||||
@@ -248,9 +470,13 @@ describe('CanvasArea', () => {
|
||||
});
|
||||
|
||||
render(<CanvasArea activeTool="area_merge" frame={frame} />);
|
||||
expect(screen.getByText('已选 0')).toBeInTheDocument();
|
||||
const paths = screen.getAllByTestId('konva-path');
|
||||
fireEvent.click(paths[0]);
|
||||
expect(screen.getByText('已选 1')).toBeInTheDocument();
|
||||
expect(screen.queryAllByTestId('konva-circle')).toHaveLength(0);
|
||||
fireEvent.click(paths[1]);
|
||||
expect(screen.getByText('已选 2')).toBeInTheDocument();
|
||||
fireEvent.click(screen.getByRole('button', { name: '合并选中' }));
|
||||
|
||||
expect(useStore.getState().masks).toHaveLength(1);
|
||||
@@ -300,6 +526,45 @@ describe('CanvasArea', () => {
|
||||
expect(useStore.getState().masks[1].id).toBe('m2');
|
||||
});
|
||||
|
||||
it('renders inner overlap removal as a hole in the primary mask', () => {
|
||||
useStore.setState({
|
||||
masks: [
|
||||
{
|
||||
id: 'm1',
|
||||
frameId: 'frame-1',
|
||||
pathData: 'M 10 10 L 110 10 L 110 110 L 10 110 Z',
|
||||
label: 'A',
|
||||
color: '#06b6d4',
|
||||
segmentation: [[10, 10, 110, 10, 110, 110, 10, 110]],
|
||||
},
|
||||
{
|
||||
id: 'm2',
|
||||
frameId: 'frame-1',
|
||||
pathData: 'M 40 40 L 80 40 L 80 80 L 40 80 Z',
|
||||
label: 'B',
|
||||
color: '#ff0000',
|
||||
segmentation: [[40, 40, 80, 40, 80, 80, 40, 80]],
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
render(<CanvasArea activeTool="area_remove" frame={frame} />);
|
||||
const paths = screen.getAllByTestId('konva-path');
|
||||
fireEvent.click(paths[0]);
|
||||
fireEvent.click(paths[1]);
|
||||
fireEvent.click(screen.getByRole('button', { name: '从主区域去除' }));
|
||||
|
||||
const [primary] = useStore.getState().masks;
|
||||
expect(primary).toEqual(expect.objectContaining({
|
||||
id: 'm1',
|
||||
area: 8400,
|
||||
bbox: [10, 10, 100, 100],
|
||||
metadata: expect.objectContaining({ hasHoles: true }),
|
||||
}));
|
||||
expect(primary.segmentation).toHaveLength(2);
|
||||
expect(screen.getAllByTestId('konva-path')[0]).toHaveAttribute('data-fill-rule', 'evenodd');
|
||||
});
|
||||
|
||||
it('creates a manual rectangle mask that can be undone and redone', () => {
|
||||
useStore.setState({
|
||||
activeTemplateId: '2',
|
||||
@@ -329,6 +594,93 @@ describe('CanvasArea', () => {
|
||||
expect(useStore.getState().masks).toHaveLength(1);
|
||||
});
|
||||
|
||||
it('creates a manual circle mask from a drag gesture', () => {
|
||||
render(<CanvasArea activeTool="create_circle" frame={frame} />);
|
||||
const stage = screen.getByTestId('konva-stage');
|
||||
fireEvent.mouseDown(stage, { clientX: 120, clientY: 80 });
|
||||
fireEvent.mouseMove(stage, { clientX: 260, clientY: 200 });
|
||||
fireEvent.mouseUp(stage, { clientX: 260, clientY: 200 });
|
||||
|
||||
expect(useStore.getState().masks).toHaveLength(1);
|
||||
expect(useStore.getState().masks[0]).toEqual(expect.objectContaining({
|
||||
frameId: 'frame-1',
|
||||
label: '手工圆形',
|
||||
color: '#06b6d4',
|
||||
saveStatus: 'draft',
|
||||
bbox: [120, 80, 140, 120],
|
||||
metadata: expect.objectContaining({
|
||||
source: 'manual',
|
||||
shape: '圆形',
|
||||
}),
|
||||
}));
|
||||
expect(useStore.getState().masks[0].segmentation?.[0]).toHaveLength(64);
|
||||
});
|
||||
|
||||
it('creates a manual line region from a drag gesture', () => {
|
||||
render(<CanvasArea activeTool="create_line" frame={frame} />);
|
||||
const stage = screen.getByTestId('konva-stage');
|
||||
fireEvent.mouseDown(stage, { clientX: 120, clientY: 80 });
|
||||
fireEvent.mouseMove(stage, { clientX: 260, clientY: 200 });
|
||||
fireEvent.mouseUp(stage, { clientX: 260, clientY: 200 });
|
||||
|
||||
expect(useStore.getState().masks).toHaveLength(1);
|
||||
expect(useStore.getState().masks[0]).toEqual(expect.objectContaining({
|
||||
frameId: 'frame-1',
|
||||
label: '手工线段',
|
||||
color: '#06b6d4',
|
||||
saveStatus: 'draft',
|
||||
metadata: expect.objectContaining({
|
||||
source: 'manual',
|
||||
shape: '线段',
|
||||
}),
|
||||
}));
|
||||
expect(useStore.getState().masks[0].segmentation?.[0]).toHaveLength(8);
|
||||
expect(useStore.getState().masks[0].area).toBeGreaterThan(1000);
|
||||
});
|
||||
|
||||
it('creates an editable point region on click', () => {
|
||||
render(<CanvasArea activeTool="create_point" frame={frame} />);
|
||||
fireEvent.click(screen.getByTestId('konva-stage'), { clientX: 120, clientY: 80 });
|
||||
|
||||
expect(useStore.getState().masks).toHaveLength(1);
|
||||
expect(useStore.getState().masks[0]).toEqual(expect.objectContaining({
|
||||
frameId: 'frame-1',
|
||||
label: '手工点区域',
|
||||
color: '#06b6d4',
|
||||
saveStatus: 'draft',
|
||||
points: [[120, 80]],
|
||||
bbox: expect.arrayContaining([115, 75]),
|
||||
metadata: expect.objectContaining({
|
||||
source: 'manual',
|
||||
shape: '点区域',
|
||||
}),
|
||||
}));
|
||||
});
|
||||
|
||||
it('creates a point region when clicking over an existing mask', () => {
|
||||
useStore.setState({
|
||||
masks: [
|
||||
{
|
||||
id: 'm1',
|
||||
frameId: 'frame-1',
|
||||
pathData: 'M 10 10 L 200 10 L 200 200 Z',
|
||||
label: 'Existing',
|
||||
color: '#06b6d4',
|
||||
segmentation: [[10, 10, 200, 10, 200, 200]],
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
render(<CanvasArea activeTool="create_point" frame={frame} />);
|
||||
fireEvent.click(screen.getByTestId('konva-path'), { clientX: 120, clientY: 80 });
|
||||
|
||||
expect(useStore.getState().masks).toHaveLength(2);
|
||||
expect(useStore.getState().masks[1]).toEqual(expect.objectContaining({
|
||||
metadata: expect.objectContaining({ shape: '点区域' }),
|
||||
points: [[120, 80]],
|
||||
}));
|
||||
});
|
||||
|
||||
it('finalizes a clicked polygon with Enter', () => {
|
||||
render(<CanvasArea activeTool="create_polygon" frame={frame} />);
|
||||
const stage = screen.getByTestId('konva-stage');
|
||||
@@ -344,6 +696,29 @@ describe('CanvasArea', () => {
|
||||
}));
|
||||
});
|
||||
|
||||
it('closes a clicked polygon by clicking the first node again', () => {
|
||||
render(<CanvasArea activeTool="create_polygon" frame={frame} />);
|
||||
const stage = screen.getByTestId('konva-stage');
|
||||
fireEvent.click(stage, { clientX: 120, clientY: 80 });
|
||||
fireEvent.click(stage, { clientX: 220, clientY: 80 });
|
||||
fireEvent.click(stage, { clientX: 180, clientY: 160 });
|
||||
|
||||
const handles = screen.getAllByTestId('konva-circle');
|
||||
expect(handles[0]).toHaveAttribute('data-fill', '#facc15');
|
||||
fireEvent.click(handles[0]);
|
||||
|
||||
expect(useStore.getState().masks).toHaveLength(1);
|
||||
expect(useStore.getState().masks[0]).toEqual(expect.objectContaining({
|
||||
pathData: 'M 120 80 L 220 80 L 180 160 Z',
|
||||
segmentation: [[120, 80, 220, 80, 180, 160]],
|
||||
metadata: expect.objectContaining({
|
||||
source: 'manual',
|
||||
shape: '多边形',
|
||||
}),
|
||||
}));
|
||||
expect(screen.queryAllByTestId('konva-circle')).toHaveLength(0);
|
||||
});
|
||||
|
||||
it('applies the selected class to current-frame masks and marks saved masks dirty', () => {
|
||||
useStore.setState({
|
||||
activeTemplateId: '2',
|
||||
|
||||
@@ -14,11 +14,14 @@ interface CanvasAreaProps {
|
||||
}
|
||||
|
||||
type CanvasPoint = { x: number; y: number };
|
||||
type PromptPoint = CanvasPoint & { type: 'pos' | 'neg' };
|
||||
type PromptBox = { x1: number; y1: number; x2: number; y2: number };
|
||||
|
||||
const DRAG_MANUAL_TOOLS = new Set(['create_rectangle', 'create_circle', 'create_line']);
|
||||
const POLYGON_TOOL = 'create_polygon';
|
||||
const POINT_TOOL = 'create_point';
|
||||
const BOOLEAN_TOOLS = new Set(['area_merge', 'area_remove']);
|
||||
const POLYGON_CLOSE_RADIUS = 8;
|
||||
|
||||
function clamp(value: number, min: number, max: number): number {
|
||||
return Math.min(Math.max(value, min), max);
|
||||
@@ -88,6 +91,10 @@ function polygonArea(points: CanvasPoint[]): number {
|
||||
return Math.abs(sum) / 2;
|
||||
}
|
||||
|
||||
function pointDistance(a: CanvasPoint, b: CanvasPoint): number {
|
||||
return Math.hypot(a.x - b.x, a.y - b.y);
|
||||
}
|
||||
|
||||
function segmentationArea(segmentation?: number[][]): number {
|
||||
return (segmentation || []).reduce((sum, polygon) => sum + polygonArea(flatPolygonToPoints(polygon)), 0);
|
||||
}
|
||||
@@ -115,20 +122,35 @@ function maskToMultiPolygon(mask: Mask): MultiPolygon | null {
|
||||
return polygons.length > 0 ? polygons : null;
|
||||
}
|
||||
|
||||
function openRingPoints(ring: Pair[]): CanvasPoint[] {
|
||||
const openRing = ring.length > 1
|
||||
&& ring[0][0] === ring[ring.length - 1][0]
|
||||
&& ring[0][1] === ring[ring.length - 1][1]
|
||||
? ring.slice(0, -1)
|
||||
: ring;
|
||||
return openRing.map(([x, y]) => ({ x, y }));
|
||||
}
|
||||
|
||||
function multiPolygonToSegmentation(geometry: MultiPolygon): number[][] {
|
||||
return geometry
|
||||
.map((polygon) => polygon[0] || [])
|
||||
.map((ring) => {
|
||||
const openRing = ring.length > 1
|
||||
&& ring[0][0] === ring[ring.length - 1][0]
|
||||
&& ring[0][1] === ring[ring.length - 1][1]
|
||||
? ring.slice(0, -1)
|
||||
: ring;
|
||||
return openRing.flatMap(([x, y]) => [x, y]);
|
||||
})
|
||||
.flatMap((polygon) => polygon)
|
||||
.map((ring) => openRingPoints(ring).flatMap(({ x, y }) => [x, y]))
|
||||
.filter((polygon) => polygon.length >= 6);
|
||||
}
|
||||
|
||||
function multiPolygonArea(geometry: MultiPolygon): number {
|
||||
return geometry.reduce((sum, polygon) => {
|
||||
const [outerRing, ...holeRings] = polygon;
|
||||
const outerArea = outerRing ? polygonArea(openRingPoints(outerRing)) : 0;
|
||||
const holesArea = holeRings.reduce((holeSum, ring) => holeSum + polygonArea(openRingPoints(ring)), 0);
|
||||
return sum + Math.max(outerArea - holesArea, 0);
|
||||
}, 0);
|
||||
}
|
||||
|
||||
function multiPolygonHasHoles(geometry: MultiPolygon): boolean {
|
||||
return geometry.some((polygon) => polygon.length > 1);
|
||||
}
|
||||
|
||||
function rectanglePoints(start: CanvasPoint, end: CanvasPoint): CanvasPoint[] {
|
||||
const x1 = Math.min(start.x, end.x);
|
||||
const y1 = Math.min(start.y, end.y);
|
||||
@@ -179,10 +201,12 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
|
||||
const [stageSize, setStageSize] = useState({ width: 800, height: 600 });
|
||||
const [scale, setScale] = useState(1);
|
||||
const [position, setPosition] = useState({ x: 0, y: 0 });
|
||||
const [points, setPoints] = useState<{ x: number, y: number, type: 'pos'|'neg' }[]>([]);
|
||||
const [points, setPoints] = useState<PromptPoint[]>([]);
|
||||
const [cursorPos, setCursorPos] = useState({ x: 0, y: 0 });
|
||||
const [boxStart, setBoxStart] = useState<{ x: number, y: number } | null>(null);
|
||||
const [boxCurrent, setBoxCurrent] = useState<{ x: number, y: number } | null>(null);
|
||||
const [samPromptBox, setSamPromptBox] = useState<PromptBox | null>(null);
|
||||
const [samCandidateMaskId, setSamCandidateMaskId] = useState<string | null>(null);
|
||||
const [manualStart, setManualStart] = useState<CanvasPoint | null>(null);
|
||||
const [manualCurrent, setManualCurrent] = useState<CanvasPoint | null>(null);
|
||||
const [polygonPoints, setPolygonPoints] = useState<CanvasPoint[]>([]);
|
||||
@@ -191,12 +215,14 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
|
||||
const [selectedPolygonIndex, setSelectedPolygonIndex] = useState(0);
|
||||
const [selectedVertexIndex, setSelectedVertexIndex] = useState<number | null>(null);
|
||||
const [isInferencing, setIsInferencing] = useState(false);
|
||||
const [inferenceMessage, setInferenceMessage] = useState('');
|
||||
|
||||
const masks = useStore((state) => state.masks);
|
||||
const addMask = useStore((state) => state.addMask);
|
||||
const updateMask = useStore((state) => state.updateMask);
|
||||
const clearMasks = useStore((state) => state.clearMasks);
|
||||
const setMasks = useStore((state) => state.setMasks);
|
||||
const setGlobalSelectedMaskIds = useStore((state) => state.setSelectedMaskIds);
|
||||
const storeActiveTool = useStore((state) => state.activeTool);
|
||||
const aiModel = useStore((state) => state.aiModel);
|
||||
const activeTemplateId = useStore((state) => state.activeTemplateId);
|
||||
@@ -226,6 +252,7 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
|
||||
const savedMaskCount = frameMasks.filter((mask) => mask.saveStatus === 'saved' || mask.saved).length;
|
||||
const draftMaskCount = frameMasks.filter((mask) => !mask.annotationId).length;
|
||||
const dirtyMaskCount = frameMasks.filter((mask) => mask.saveStatus === 'dirty').length;
|
||||
const isBooleanTool = BOOLEAN_TOOLS.has(effectiveTool);
|
||||
|
||||
useEffect(() => {
|
||||
const handleResize = () => {
|
||||
@@ -252,6 +279,18 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
|
||||
setSelectedVertexIndex(null);
|
||||
}, [effectiveTool, frame?.id]);
|
||||
|
||||
useEffect(() => {
|
||||
setPoints([]);
|
||||
setSamPromptBox(null);
|
||||
setSamCandidateMaskId(null);
|
||||
}, [frame?.id]);
|
||||
|
||||
useEffect(() => {
|
||||
setGlobalSelectedMaskIds(selectedMaskIds);
|
||||
}, [selectedMaskIds, setGlobalSelectedMaskIds]);
|
||||
|
||||
useEffect(() => () => setGlobalSelectedMaskIds([]), [setGlobalSelectedMaskIds]);
|
||||
|
||||
useEffect(() => {
|
||||
if (selectedMaskId && !frameMasks.some((mask) => mask.id === selectedMaskId)) {
|
||||
setSelectedMaskId(null);
|
||||
@@ -324,6 +363,12 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
|
||||
addMask(mask);
|
||||
}, [activeClass, activeTemplateId, addMask, frame?.id]);
|
||||
|
||||
const finishPolygon = useCallback(() => {
|
||||
if (polygonPoints.length < 3) return;
|
||||
createManualMask('多边形', polygonPoints);
|
||||
setPolygonPoints([]);
|
||||
}, [createManualMask, polygonPoints]);
|
||||
|
||||
const handleMouseMove = (e: any) => {
|
||||
const stage = e.target.getStage();
|
||||
if (!stage) return;
|
||||
@@ -349,9 +394,18 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
|
||||
}
|
||||
};
|
||||
|
||||
const runInference = useCallback(async (promptPoints?: typeof points, promptBox?: { x1: number, y1: number, x2: number, y2: number }) => {
|
||||
const runInference = useCallback(async (
|
||||
promptPoints?: PromptPoint[],
|
||||
promptBox?: PromptBox,
|
||||
options: { resetCandidate?: boolean } = {},
|
||||
) => {
|
||||
if (!frame?.id) {
|
||||
console.warn('Inference skipped: no active frame');
|
||||
setInferenceMessage('请先选择一帧图像。');
|
||||
return;
|
||||
}
|
||||
if (aiModel === 'sam3' && (!promptBox || (promptPoints?.length ?? 0) > 0)) {
|
||||
setInferenceMessage('SAM3 当前工作区只支持框选提示;正/反点修正请切回 SAM2。');
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -359,31 +413,44 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
|
||||
const imageHeight = frame.height || image?.naturalHeight || image?.height || 0;
|
||||
if (imageWidth <= 0 || imageHeight <= 0) {
|
||||
console.warn('Inference skipped: active frame dimensions are unavailable');
|
||||
setInferenceMessage('当前帧缺少宽高信息,无法推理。');
|
||||
return;
|
||||
}
|
||||
|
||||
setIsInferencing(true);
|
||||
setInferenceMessage('');
|
||||
try {
|
||||
const result = await predictMask({
|
||||
imageId: frame.id,
|
||||
imageWidth,
|
||||
imageHeight,
|
||||
model: aiModel,
|
||||
points: promptPoints?.map((p) => ({ x: p.x, y: p.y, type: p.type })),
|
||||
points: promptPoints && promptPoints.length > 0
|
||||
? promptPoints.map((p) => ({ x: p.x, y: p.y, type: p.type }))
|
||||
: undefined,
|
||||
box: promptBox,
|
||||
});
|
||||
|
||||
result.masks.forEach((m) => {
|
||||
const label = activeClass?.name || m.label;
|
||||
const color = activeClass?.color || m.color;
|
||||
addMask({
|
||||
id: m.id,
|
||||
const [m] = result.masks;
|
||||
if (m) {
|
||||
const existingCandidate = !options.resetCandidate && samCandidateMaskId
|
||||
? masks.find((mask) => mask.id === samCandidateMaskId)
|
||||
: null;
|
||||
const label = activeClass?.name || existingCandidate?.label || m.label;
|
||||
const color = activeClass?.color || existingCandidate?.color || m.color;
|
||||
const metadata = {
|
||||
...(existingCandidate?.metadata || {}),
|
||||
source: aiModel === 'sam3' ? 'sam3_box' : 'sam2_interactive',
|
||||
promptBox: promptBox || null,
|
||||
promptPointCount: promptPoints?.length || 0,
|
||||
};
|
||||
const nextMask = {
|
||||
frameId: frame.id,
|
||||
templateId: activeTemplateId || undefined,
|
||||
classId: activeClass?.id,
|
||||
className: activeClass?.name,
|
||||
classZIndex: activeClass?.zIndex,
|
||||
saveStatus: 'draft',
|
||||
templateId: activeTemplateId || existingCandidate?.templateId || undefined,
|
||||
classId: activeClass?.id || existingCandidate?.classId,
|
||||
className: activeClass?.name || existingCandidate?.className,
|
||||
classZIndex: activeClass?.zIndex ?? existingCandidate?.classZIndex,
|
||||
saveStatus: existingCandidate?.annotationId ? 'dirty' as const : 'draft' as const,
|
||||
saved: false,
|
||||
pathData: m.pathData,
|
||||
label,
|
||||
@@ -392,14 +459,33 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
|
||||
points: promptPoints?.filter((p) => p.type === 'pos').map((p) => [p.x, p.y]),
|
||||
bbox: m.bbox,
|
||||
area: m.area,
|
||||
});
|
||||
});
|
||||
metadata,
|
||||
};
|
||||
if (existingCandidate) {
|
||||
updateMask(existingCandidate.id, nextMask);
|
||||
setSelectedMaskId(existingCandidate.id);
|
||||
setSelectedMaskIds([existingCandidate.id]);
|
||||
} else {
|
||||
const id = m.id;
|
||||
setSamCandidateMaskId(id);
|
||||
setSelectedMaskId(id);
|
||||
setSelectedMaskIds([id]);
|
||||
addMask({
|
||||
id,
|
||||
...nextMask,
|
||||
});
|
||||
}
|
||||
} else {
|
||||
setInferenceMessage('模型没有返回可用区域,请调整点/框提示后重试。');
|
||||
}
|
||||
} catch (err) {
|
||||
console.error('Inference failed:', err);
|
||||
const detail = (err as any)?.response?.data?.detail;
|
||||
setInferenceMessage(detail || 'AI 推理失败,请查看模型状态或后端日志。');
|
||||
} finally {
|
||||
setIsInferencing(false);
|
||||
}
|
||||
}, [activeClass, activeTemplateId, addMask, aiModel, frame?.height, frame?.id, frame?.width, image?.height, image?.naturalHeight, image?.naturalWidth, image?.width]);
|
||||
}, [activeClass, activeTemplateId, addMask, aiModel, frame?.height, frame?.id, frame?.width, image?.height, image?.naturalHeight, image?.naturalWidth, image?.width, masks, samCandidateMaskId, updateMask]);
|
||||
|
||||
const handleApplyActiveClass = () => {
|
||||
if (!frame?.id || !activeClass) return;
|
||||
@@ -427,6 +513,29 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
|
||||
clearMasks();
|
||||
};
|
||||
|
||||
const deleteMasksById = useCallback((maskIds: string[]) => {
|
||||
if (maskIds.length === 0) return;
|
||||
const idSet = new Set(maskIds);
|
||||
const deletingMasks = masks.filter((mask) => idSet.has(mask.id));
|
||||
if (deletingMasks.length === 0) return;
|
||||
setMasks(masks.filter((mask) => !idSet.has(mask.id)));
|
||||
const annotationIds = deletingMasks
|
||||
.map((mask) => mask.annotationId)
|
||||
.filter((annotationId): annotationId is string => Boolean(annotationId));
|
||||
if (annotationIds.length > 0) {
|
||||
void onDeleteMaskAnnotations?.(annotationIds);
|
||||
}
|
||||
if (samCandidateMaskId && idSet.has(samCandidateMaskId)) {
|
||||
setSamCandidateMaskId(null);
|
||||
setSamPromptBox(null);
|
||||
setPoints([]);
|
||||
}
|
||||
setSelectedMaskId(null);
|
||||
setSelectedMaskIds([]);
|
||||
setSelectedPolygonIndex(0);
|
||||
setSelectedVertexIndex(null);
|
||||
}, [masks, onDeleteMaskAnnotations, samCandidateMaskId, setMasks]);
|
||||
|
||||
const handleStageMouseDown = (e: any) => {
|
||||
if (DRAG_MANUAL_TOOLS.has(effectiveTool)) {
|
||||
const pos = stagePoint(e);
|
||||
@@ -476,7 +585,11 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
|
||||
const y2 = Math.max(boxStart.y, boxCurrent.y);
|
||||
|
||||
if (Math.abs(x2 - x1) > 5 && Math.abs(y2 - y1) > 5) {
|
||||
runInference(undefined, { x1, y1, x2, y2 });
|
||||
const nextBox = { x1, y1, x2, y2 };
|
||||
setPoints([]);
|
||||
setSamPromptBox(nextBox);
|
||||
setSamCandidateMaskId(null);
|
||||
runInference([], nextBox, { resetCandidate: true });
|
||||
}
|
||||
|
||||
setBoxStart(null);
|
||||
@@ -500,6 +613,11 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
|
||||
if (effectiveTool === POLYGON_TOOL) {
|
||||
const pos = stagePoint(e);
|
||||
if (pos) {
|
||||
const closeRadius = POLYGON_CLOSE_RADIUS / Math.max(scale, 0.1);
|
||||
if (polygonPoints.length >= 3 && pointDistance(pos, polygonPoints[0]) <= closeRadius) {
|
||||
finishPolygon();
|
||||
return;
|
||||
}
|
||||
setPolygonPoints((current) => [...current, pos]);
|
||||
}
|
||||
return;
|
||||
@@ -514,8 +632,7 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
|
||||
{ x: pos.x, y: pos.y, type: (effectiveTool === 'point_pos' ? 'pos' : 'neg') as 'pos' | 'neg' },
|
||||
];
|
||||
setPoints(newPoints);
|
||||
// Auto-trigger inference after point selection
|
||||
runInference(newPoints);
|
||||
runInference(newPoints, samPromptBox || undefined);
|
||||
}
|
||||
}
|
||||
};
|
||||
@@ -535,14 +652,22 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
|
||||
});
|
||||
}, [updateMask]);
|
||||
|
||||
const updateMaskFromSegmentation = useCallback((mask: Mask, segmentation: number[][]): Mask => {
|
||||
const updateMaskFromSegmentation = useCallback((
|
||||
mask: Mask,
|
||||
segmentation: number[][],
|
||||
options: { area?: number; hasHoles?: boolean } = {},
|
||||
): Mask => {
|
||||
const bbox = segmentationBbox(segmentation);
|
||||
const metadata = { ...(mask.metadata || {}) };
|
||||
if (options.hasHoles === true) metadata.hasHoles = true;
|
||||
if (options.hasHoles === false) delete metadata.hasHoles;
|
||||
return {
|
||||
...mask,
|
||||
pathData: segmentationPath(segmentation),
|
||||
segmentation,
|
||||
bbox,
|
||||
area: segmentationArea(segmentation),
|
||||
area: options.area ?? segmentationArea(segmentation),
|
||||
metadata,
|
||||
saveStatus: mask.annotationId ? 'dirty' : 'draft',
|
||||
saved: mask.annotationId ? false : mask.saved,
|
||||
};
|
||||
@@ -572,11 +697,16 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
|
||||
}
|
||||
return;
|
||||
}
|
||||
if ((event.key === 'Delete' || event.key === 'Backspace') && selectedMask) {
|
||||
event.preventDefault();
|
||||
const ids = selectedMaskIds.length > 0 ? selectedMaskIds : [selectedMask.id];
|
||||
deleteMasksById(ids);
|
||||
return;
|
||||
}
|
||||
if (effectiveTool !== POLYGON_TOOL) return;
|
||||
if (event.key === 'Enter' && polygonPoints.length >= 3) {
|
||||
event.preventDefault();
|
||||
createManualMask('多边形', polygonPoints);
|
||||
setPolygonPoints([]);
|
||||
finishPolygon();
|
||||
}
|
||||
if (event.key === 'Escape') {
|
||||
event.preventDefault();
|
||||
@@ -586,7 +716,7 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
|
||||
|
||||
window.addEventListener('keydown', handleKeyDown);
|
||||
return () => window.removeEventListener('keydown', handleKeyDown);
|
||||
}, [createManualMask, effectiveTool, polygonPoints, redoMasks, selectedMask, selectedPolygonIndex, selectedVertexIndex, undoMasks, updatePolygonMask]);
|
||||
}, [deleteMasksById, effectiveTool, finishPolygon, polygonPoints, redoMasks, selectedMask, selectedMaskIds, selectedPolygonIndex, selectedVertexIndex, undoMasks, updatePolygonMask]);
|
||||
|
||||
const boxRect = React.useMemo(() => {
|
||||
if (!boxStart || !boxCurrent) return null;
|
||||
@@ -623,8 +753,9 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
|
||||
};
|
||||
|
||||
const handleMaskSelect = (mask: Mask, event: any, polygonIndex = 0) => {
|
||||
if (effectiveTool !== 'move' && !isBooleanTool) return;
|
||||
event.cancelBubble = true;
|
||||
if (BOOLEAN_TOOLS.has(effectiveTool)) {
|
||||
if (isBooleanTool) {
|
||||
setSelectedMaskIds((current) => (
|
||||
current.includes(mask.id)
|
||||
? current.filter((id) => id !== mask.id)
|
||||
@@ -703,7 +834,10 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
|
||||
return;
|
||||
}
|
||||
|
||||
const nextPrimary = updateMaskFromSegmentation(primary, resultSegmentation);
|
||||
const nextPrimary = updateMaskFromSegmentation(primary, resultSegmentation, {
|
||||
area: multiPolygonArea(resultGeometry),
|
||||
hasHoles: multiPolygonHasHoles(resultGeometry),
|
||||
});
|
||||
const secondaryIds = effectiveTool === 'area_merge'
|
||||
? new Set(booleanSelectedMasks.slice(1).map((mask) => mask.id))
|
||||
: new Set<string>();
|
||||
@@ -731,6 +865,11 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
|
||||
<span className="text-xs text-cyan-400 font-mono">AI 推理中...</span>
|
||||
</div>
|
||||
)}
|
||||
{!isInferencing && inferenceMessage && (
|
||||
<div className="absolute top-4 right-4 z-20 max-w-xs bg-[#111] border border-white/10 px-3 py-2 rounded-lg shadow-xl text-xs leading-relaxed text-gray-300">
|
||||
{inferenceMessage}
|
||||
</div>
|
||||
)}
|
||||
|
||||
<Stage
|
||||
width={stageSize.width}
|
||||
@@ -758,21 +897,32 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
|
||||
)}
|
||||
|
||||
{/* AI Returned Masks */}
|
||||
{frameMasks.map((mask) => (
|
||||
<Group key={mask.id} opacity={selectedMaskIds.includes(mask.id) ? 0.65 : 0.5}>
|
||||
{(mask.segmentation && mask.segmentation.length > 0 ? mask.segmentation : [undefined]).map((_, polygonIndex) => (
|
||||
<Path
|
||||
key={`${mask.id}-polygon-${polygonIndex}`}
|
||||
data={mask.segmentation ? segmentationPolygonPath(mask.segmentation, polygonIndex) : mask.pathData}
|
||||
fill={mask.color}
|
||||
stroke={mask.color}
|
||||
strokeWidth={(selectedMaskIds.includes(mask.id) ? 2 : 1) / scale}
|
||||
onClick={(event: any) => handleMaskSelect(mask, event, polygonIndex)}
|
||||
onTap={(event: any) => handleMaskSelect(mask, event, polygonIndex)}
|
||||
/>
|
||||
))}
|
||||
</Group>
|
||||
))}
|
||||
{frameMasks.map((mask) => {
|
||||
const hasHoles = Boolean(mask.metadata?.hasHoles);
|
||||
const paths = hasHoles
|
||||
? [{ data: segmentationPath(mask.segmentation), polygonIndex: 0, fillRule: 'evenodd' }]
|
||||
: (mask.segmentation && mask.segmentation.length > 0 ? mask.segmentation : [undefined]).map((_, polygonIndex) => ({
|
||||
data: mask.segmentation ? segmentationPolygonPath(mask.segmentation, polygonIndex) : mask.pathData,
|
||||
polygonIndex,
|
||||
fillRule: undefined,
|
||||
}));
|
||||
return (
|
||||
<Group key={mask.id} opacity={selectedMaskIds.includes(mask.id) ? 0.65 : 0.5}>
|
||||
{paths.map(({ data, polygonIndex, fillRule }) => (
|
||||
<Path
|
||||
key={`${mask.id}-polygon-${polygonIndex}`}
|
||||
data={data}
|
||||
fill={mask.color}
|
||||
fillRule={fillRule}
|
||||
stroke={mask.color}
|
||||
strokeWidth={(selectedMaskIds.includes(mask.id) ? 2 : 1) / scale}
|
||||
onClick={(event: any) => handleMaskSelect(mask, event, polygonIndex)}
|
||||
onTap={(event: any) => handleMaskSelect(mask, event, polygonIndex)}
|
||||
/>
|
||||
))}
|
||||
</Group>
|
||||
);
|
||||
})}
|
||||
|
||||
{/* Box selection preview */}
|
||||
{boxRect && effectiveTool === 'box_select' && (
|
||||
@@ -804,10 +954,20 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
|
||||
key={`poly-point-${index}`}
|
||||
x={point.x}
|
||||
y={point.y}
|
||||
radius={4 / scale}
|
||||
fill="#22d3ee"
|
||||
stroke="#ffffff"
|
||||
radius={(index === 0 && polygonPoints.length >= 3 ? 6 : 4) / scale}
|
||||
fill={index === 0 && polygonPoints.length >= 3 ? '#facc15' : '#22d3ee'}
|
||||
stroke={index === 0 && polygonPoints.length >= 3 ? '#fef3c7' : '#ffffff'}
|
||||
strokeWidth={1 / scale}
|
||||
onClick={(event: any) => {
|
||||
if (index !== 0 || polygonPoints.length < 3) return;
|
||||
event.cancelBubble = true;
|
||||
finishPolygon();
|
||||
}}
|
||||
onTap={(event: any) => {
|
||||
if (index !== 0 || polygonPoints.length < 3) return;
|
||||
event.cancelBubble = true;
|
||||
finishPolygon();
|
||||
}}
|
||||
/>
|
||||
))}
|
||||
|
||||
@@ -827,7 +987,7 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
|
||||
)))}
|
||||
|
||||
{/* Polygon edge insertion handles */}
|
||||
{selectedMask && selectedMaskPoints.map((point, index) => {
|
||||
{!isBooleanTool && selectedMask && selectedMaskPoints.map((point, index) => {
|
||||
const next = selectedMaskPoints[(index + 1) % selectedMaskPoints.length];
|
||||
if (!next) return null;
|
||||
return (
|
||||
@@ -846,7 +1006,7 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
|
||||
})}
|
||||
|
||||
{/* Polygon vertex editor */}
|
||||
{selectedMask && selectedMaskPoints.map((point, index) => (
|
||||
{!isBooleanTool && selectedMask && selectedMaskPoints.map((point, index) => (
|
||||
<Circle
|
||||
key={`${selectedMask.id}-vertex-${selectedPolygonIndex}-${index}`}
|
||||
x={point.x}
|
||||
@@ -900,13 +1060,19 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
|
||||
|
||||
{frameMasks.length > 0 && (
|
||||
<div className="absolute bottom-4 right-4 flex gap-2">
|
||||
{BOOLEAN_TOOLS.has(effectiveTool) && booleanSelectedMasks.length >= 2 && (
|
||||
<button
|
||||
onClick={handleBooleanOperation}
|
||||
className="text-xs bg-emerald-500/10 hover:bg-emerald-500/20 text-emerald-300 border border-emerald-500/20 px-3 py-1.5 rounded transition-colors"
|
||||
>
|
||||
{effectiveTool === 'area_merge' ? '合并选中' : '从主区域去除'}
|
||||
</button>
|
||||
{isBooleanTool && (
|
||||
<div className="flex items-center gap-2">
|
||||
<span className="text-xs bg-white/5 text-gray-300 border border-white/10 px-2.5 py-1.5 rounded">
|
||||
已选 {booleanSelectedMasks.length}
|
||||
</span>
|
||||
<button
|
||||
onClick={handleBooleanOperation}
|
||||
disabled={booleanSelectedMasks.length < 2}
|
||||
className="text-xs bg-emerald-500/10 hover:bg-emerald-500/20 text-emerald-300 border border-emerald-500/20 px-3 py-1.5 rounded transition-colors disabled:opacity-40 disabled:cursor-not-allowed disabled:hover:bg-emerald-500/10"
|
||||
>
|
||||
{effectiveTool === 'area_merge' ? '合并选中' : '从主区域去除'}
|
||||
</button>
|
||||
</div>
|
||||
)}
|
||||
{activeClass && (
|
||||
<button
|
||||
|
||||
@@ -34,6 +34,65 @@ describe('FrameTimeline', () => {
|
||||
expect(useStore.getState().currentFrameIndex).toBe(2);
|
||||
});
|
||||
|
||||
it('shows current and total timeline time based on project fps', () => {
|
||||
useStore.setState({
|
||||
currentProject: { id: 'p1', name: 'P', status: 'ready', parse_fps: 10 },
|
||||
currentFrameIndex: 1,
|
||||
frames: [
|
||||
{ id: 'f1', projectId: 'p1', index: 0, url: '/1.jpg', width: 640, height: 360 },
|
||||
{ id: 'f2', projectId: 'p1', index: 1, url: '/2.jpg', width: 640, height: 360 },
|
||||
{ id: 'f3', projectId: 'p1', index: 2, url: '/3.jpg', width: 640, height: 360 },
|
||||
],
|
||||
});
|
||||
|
||||
render(<FrameTimeline />);
|
||||
|
||||
expect(screen.getAllByText('00:00.10').length).toBeGreaterThan(0);
|
||||
expect(screen.getAllByText('00:00.20').length).toBeGreaterThan(0);
|
||||
});
|
||||
|
||||
it('changes frames with left and right arrow keys without leaving bounds', () => {
|
||||
useStore.setState({
|
||||
currentFrameIndex: 1,
|
||||
frames: [
|
||||
{ id: 'f1', projectId: 'p1', index: 0, url: '/1.jpg', width: 640, height: 360 },
|
||||
{ id: 'f2', projectId: 'p1', index: 1, url: '/2.jpg', width: 640, height: 360 },
|
||||
{ id: 'f3', projectId: 'p1', index: 2, url: '/3.jpg', width: 640, height: 360 },
|
||||
],
|
||||
});
|
||||
|
||||
render(<FrameTimeline />);
|
||||
fireEvent.keyDown(window, { key: 'ArrowRight' });
|
||||
expect(useStore.getState().currentFrameIndex).toBe(2);
|
||||
|
||||
fireEvent.keyDown(window, { key: 'ArrowRight' });
|
||||
expect(useStore.getState().currentFrameIndex).toBe(2);
|
||||
|
||||
fireEvent.keyDown(window, { key: 'ArrowLeft' });
|
||||
expect(useStore.getState().currentFrameIndex).toBe(1);
|
||||
});
|
||||
|
||||
it('does not change frames while typing in editable fields', () => {
|
||||
useStore.setState({
|
||||
currentFrameIndex: 1,
|
||||
frames: [
|
||||
{ id: 'f1', projectId: 'p1', index: 0, url: '/1.jpg', width: 640, height: 360 },
|
||||
{ id: 'f2', projectId: 'p1', index: 1, url: '/2.jpg', width: 640, height: 360 },
|
||||
{ id: 'f3', projectId: 'p1', index: 2, url: '/3.jpg', width: 640, height: 360 },
|
||||
],
|
||||
});
|
||||
|
||||
render(
|
||||
<>
|
||||
<input aria-label="annotation-name" />
|
||||
<FrameTimeline />
|
||||
</>,
|
||||
);
|
||||
fireEvent.keyDown(screen.getByLabelText('annotation-name'), { key: 'ArrowRight' });
|
||||
|
||||
expect(useStore.getState().currentFrameIndex).toBe(1);
|
||||
});
|
||||
|
||||
it('plays forward using the project parse fps and stops at the end', () => {
|
||||
vi.useFakeTimers();
|
||||
useStore.setState({
|
||||
|
||||
@@ -16,6 +16,20 @@ export function FrameTimeline() {
|
||||
const fps = currentProject?.parse_fps || currentProject?.original_fps || 12;
|
||||
return Math.min(Math.max(fps, 1), 30);
|
||||
}, [currentProject?.original_fps, currentProject?.parse_fps]);
|
||||
const timeBaseFps = useMemo(() => {
|
||||
const fps = currentProject?.parse_fps || currentProject?.original_fps || 12;
|
||||
return Math.max(fps, 1);
|
||||
}, [currentProject?.original_fps, currentProject?.parse_fps]);
|
||||
const currentSeconds = totalFrames > 0 ? currentFrameIndex / timeBaseFps : 0;
|
||||
const totalSeconds = totalFrames > 0 ? Math.max(totalFrames - 1, 0) / timeBaseFps : 0;
|
||||
|
||||
const formatTime = (seconds: number) => {
|
||||
const safeSeconds = Math.max(0, seconds);
|
||||
const minutes = Math.floor(safeSeconds / 60);
|
||||
const wholeSeconds = Math.floor(safeSeconds % 60);
|
||||
const centiseconds = Math.floor((safeSeconds % 1) * 100);
|
||||
return `${minutes.toString().padStart(2, '0')}:${wholeSeconds.toString().padStart(2, '0')}.${centiseconds.toString().padStart(2, '0')}`;
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
if (!isPlaying || totalFrames <= 1) return;
|
||||
@@ -38,6 +52,30 @@ export function FrameTimeline() {
|
||||
}
|
||||
}, [totalFrames]);
|
||||
|
||||
useEffect(() => {
|
||||
const isEditableTarget = (target: EventTarget | null) => {
|
||||
if (!(target instanceof HTMLElement)) return false;
|
||||
const tagName = target.tagName.toLowerCase();
|
||||
return target.isContentEditable || ['input', 'textarea', 'select'].includes(tagName);
|
||||
};
|
||||
|
||||
const handleKeyDown = (event: KeyboardEvent) => {
|
||||
if (isEditableTarget(event.target) || totalFrames <= 1) return;
|
||||
if (event.key !== 'ArrowLeft' && event.key !== 'ArrowRight') return;
|
||||
|
||||
event.preventDefault();
|
||||
setIsPlaying(false);
|
||||
const direction = event.key === 'ArrowRight' ? 1 : -1;
|
||||
const nextIndex = Math.min(Math.max(currentFrameIndex + direction, 0), totalFrames - 1);
|
||||
if (nextIndex !== currentFrameIndex) {
|
||||
setCurrentFrame(nextIndex);
|
||||
}
|
||||
};
|
||||
|
||||
window.addEventListener('keydown', handleKeyDown);
|
||||
return () => window.removeEventListener('keydown', handleKeyDown);
|
||||
}, [currentFrameIndex, setCurrentFrame, totalFrames]);
|
||||
|
||||
// show frames around current frame
|
||||
const frameWindow = 20;
|
||||
const displayIndices = totalFrames > 0
|
||||
@@ -47,6 +85,12 @@ export function FrameTimeline() {
|
||||
return (
|
||||
<div className="h-32 bg-[#111] border-t border-white/5 flex flex-col shrink-0 z-20">
|
||||
<div className="h-4 bg-[#0d0d0d] flex items-center group relative">
|
||||
<div className="absolute left-3 -top-5 text-[10px] font-mono text-gray-500 pointer-events-none">
|
||||
{formatTime(currentSeconds)}
|
||||
</div>
|
||||
<div className="absolute right-3 -top-5 text-[10px] font-mono text-gray-500 pointer-events-none">
|
||||
{formatTime(totalSeconds)}
|
||||
</div>
|
||||
<input
|
||||
type="range"
|
||||
min="1"
|
||||
@@ -65,6 +109,12 @@ export function FrameTimeline() {
|
||||
className="w-3 h-3 bg-white rounded-full absolute top-1/2 -translate-y-1/2 -ml-1.5 shadow-sm transform scale-0 group-hover:scale-100 transition-transform shadow-cyan-500/50"
|
||||
style={{ left: `${totalFrames > 0 ? (currentFrame / totalFrames) * 100 : 0}%` }}
|
||||
/>
|
||||
<div
|
||||
className="absolute -top-7 -translate-x-1/2 rounded bg-black/80 border border-white/10 px-2 py-0.5 text-[10px] font-mono text-cyan-300 opacity-0 group-hover:opacity-100 transition-opacity pointer-events-none"
|
||||
style={{ left: `${totalFrames > 0 ? (currentFrame / totalFrames) * 100 : 0}%` }}
|
||||
>
|
||||
{formatTime(currentSeconds)}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -129,6 +179,9 @@ export function FrameTimeline() {
|
||||
|
||||
<div className="w-48 text-right shrink-0">
|
||||
<div className="text-2xl font-mono text-white">{currentFrame}<span className="text-xs text-gray-500"> / {totalFrames}</span></div>
|
||||
<div className="text-xs font-mono text-cyan-300 mt-1">
|
||||
{formatTime(currentSeconds)} <span className="text-gray-600">/</span> {formatTime(totalSeconds)}
|
||||
</div>
|
||||
<div className="text-[10px] text-gray-500 uppercase tracking-widest mt-1">底层时序视频图层截帧导航轴</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -45,6 +45,41 @@ describe('OntologyInspector', () => {
|
||||
}));
|
||||
});
|
||||
|
||||
it('applies the selected class to currently selected masks', () => {
|
||||
useStore.setState({
|
||||
selectedMaskIds: ['m1'],
|
||||
masks: [
|
||||
{
|
||||
id: 'm1',
|
||||
annotationId: '99',
|
||||
frameId: 'frame-1',
|
||||
pathData: 'M 0 0 Z',
|
||||
label: '旧标签',
|
||||
color: '#06b6d4',
|
||||
saveStatus: 'saved',
|
||||
saved: true,
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
render(<OntologyInspector />);
|
||||
fireEvent.click(screen.getByText('肝脏'));
|
||||
|
||||
expect(useStore.getState().activeClassId).toBe('c2');
|
||||
expect(useStore.getState().masks[0]).toEqual(expect.objectContaining({
|
||||
templateId: 't1',
|
||||
classId: 'c2',
|
||||
className: '肝脏',
|
||||
classZIndex: 10,
|
||||
label: '肝脏',
|
||||
color: '#00ff00',
|
||||
saveStatus: 'dirty',
|
||||
saved: false,
|
||||
}));
|
||||
expect(screen.getByText('当前选中区域:')).toBeInTheDocument();
|
||||
expect(screen.getByText('1')).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('adds custom classes locally without backend persistence', () => {
|
||||
const { container } = render(<OntologyInspector />);
|
||||
const customSection = screen.getByText('自定义分类').parentElement!;
|
||||
|
||||
@@ -10,6 +10,9 @@ export function OntologyInspector() {
|
||||
const activeTemplateId = useStore((state) => state.activeTemplateId);
|
||||
const activeClassId = useStore((state) => state.activeClassId);
|
||||
const activeClass = useStore((state) => state.activeClass);
|
||||
const masks = useStore((state) => state.masks);
|
||||
const selectedMaskIds = useStore((state) => state.selectedMaskIds);
|
||||
const setMasks = useStore((state) => state.setMasks);
|
||||
const setActiveTemplateId = useStore((state) => state.setActiveTemplateId);
|
||||
const setActiveClass = useStore((state) => state.setActiveClass);
|
||||
|
||||
@@ -28,6 +31,25 @@ export function OntologyInspector() {
|
||||
setActiveTemplateId(activeTemplate.id);
|
||||
}
|
||||
setActiveClass(templateClass);
|
||||
const selectedIdSet = new Set(selectedMaskIds);
|
||||
const hasSelectedMasks = masks.some((mask) => selectedIdSet.has(mask.id));
|
||||
if (!hasSelectedMasks) return;
|
||||
|
||||
const templateId = activeTemplate?.id || activeTemplateId || undefined;
|
||||
setMasks(masks.map((mask) => {
|
||||
if (!selectedIdSet.has(mask.id)) return mask;
|
||||
return {
|
||||
...mask,
|
||||
templateId: templateId || mask.templateId,
|
||||
classId: templateClass.id,
|
||||
className: templateClass.name,
|
||||
classZIndex: templateClass.zIndex,
|
||||
label: templateClass.name,
|
||||
color: templateClass.color,
|
||||
saveStatus: mask.annotationId ? 'dirty' : 'draft',
|
||||
saved: mask.annotationId ? false : mask.saved,
|
||||
};
|
||||
}));
|
||||
};
|
||||
|
||||
const handleAddCustom = () => {
|
||||
@@ -164,6 +186,10 @@ export function OntologyInspector() {
|
||||
</span>
|
||||
</div>
|
||||
<div className="space-y-3">
|
||||
<div className="flex items-center justify-between">
|
||||
<span className="text-[10px] text-gray-500 uppercase">当前选中区域:</span>
|
||||
<span className="text-xs font-mono text-gray-300">{selectedMaskIds.length}</span>
|
||||
</div>
|
||||
<div className="space-y-1">
|
||||
<label className="text-[10px] text-gray-500 uppercase">感知算法置信度</label>
|
||||
<div className="h-1.5 w-full bg-white/10 rounded-full overflow-hidden">
|
||||
|
||||
@@ -82,4 +82,65 @@ describe('TemplateRegistry', () => {
|
||||
|
||||
expect(screen.getByText('分类A')).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('edits an existing template through the backend and store', async () => {
|
||||
apiMock.getTemplates.mockResolvedValueOnce([
|
||||
{
|
||||
id: 't1',
|
||||
name: '旧模板',
|
||||
description: 'old desc',
|
||||
classes: [{ id: 'c1', name: '胆囊', color: '#ff0000', zIndex: 10, category: '器官' }],
|
||||
rules: [],
|
||||
color: '#06b6d4',
|
||||
z_index: 3,
|
||||
},
|
||||
]);
|
||||
apiMock.updateTemplate.mockResolvedValueOnce({
|
||||
id: 't1',
|
||||
name: '新模板',
|
||||
description: 'new desc',
|
||||
classes: [{ id: 'c1', name: '胆囊', color: '#ff0000', zIndex: 10, category: '器官' }],
|
||||
rules: [],
|
||||
});
|
||||
|
||||
render(<TemplateRegistry />);
|
||||
fireEvent.click(await screen.findByRole('button', { name: /修改库视图结构/ }));
|
||||
fireEvent.change(screen.getAllByRole('textbox')[0], { target: { value: '新模板' } });
|
||||
fireEvent.change(screen.getAllByRole('textbox')[1], { target: { value: 'new desc' } });
|
||||
fireEvent.click(screen.getByRole('button', { name: '保存' }));
|
||||
|
||||
await waitFor(() => expect(apiMock.updateTemplate).toHaveBeenCalledWith('t1', expect.objectContaining({
|
||||
name: '新模板',
|
||||
description: 'new desc',
|
||||
classes: [expect.objectContaining({ id: 'c1', name: '胆囊' })],
|
||||
rules: [],
|
||||
color: '#06b6d4',
|
||||
z_index: 3,
|
||||
})));
|
||||
expect(useStore.getState().templates[0]).toEqual(expect.objectContaining({
|
||||
id: 't1',
|
||||
name: '新模板',
|
||||
}));
|
||||
});
|
||||
|
||||
it('deletes an existing template after confirmation', async () => {
|
||||
apiMock.getTemplates.mockResolvedValueOnce([
|
||||
{
|
||||
id: 't1',
|
||||
name: '待删除模板',
|
||||
description: 'desc',
|
||||
classes: [],
|
||||
rules: [],
|
||||
},
|
||||
]);
|
||||
apiMock.deleteTemplate.mockResolvedValueOnce(undefined);
|
||||
const { container } = render(<TemplateRegistry />);
|
||||
|
||||
await screen.findAllByText('待删除模板');
|
||||
const buttons = Array.from(container.querySelectorAll('button'));
|
||||
fireEvent.click(buttons[2]);
|
||||
|
||||
await waitFor(() => expect(apiMock.deleteTemplate).toHaveBeenCalledWith('t1'));
|
||||
expect(useStore.getState().templates).toEqual([]);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -7,6 +7,7 @@ import { VideoWorkspace } from './VideoWorkspace';
|
||||
const apiMock = vi.hoisted(() => ({
|
||||
getProjectFrames: vi.fn(),
|
||||
parseMedia: vi.fn(),
|
||||
propagateMasks: vi.fn(),
|
||||
getTask: vi.fn(),
|
||||
getTemplates: vi.fn(),
|
||||
getProjectAnnotations: vi.fn(),
|
||||
@@ -24,6 +25,7 @@ const apiMock = vi.hoisted(() => ({
|
||||
vi.mock('../lib/api', () => ({
|
||||
getProjectFrames: apiMock.getProjectFrames,
|
||||
parseMedia: apiMock.parseMedia,
|
||||
propagateMasks: apiMock.propagateMasks,
|
||||
getTask: apiMock.getTask,
|
||||
getTemplates: apiMock.getTemplates,
|
||||
getProjectAnnotations: apiMock.getProjectAnnotations,
|
||||
@@ -47,6 +49,14 @@ describe('VideoWorkspace', () => {
|
||||
apiMock.getProjectAnnotations.mockResolvedValue([]);
|
||||
apiMock.annotationToMask.mockReturnValue(null);
|
||||
apiMock.getTask.mockResolvedValue({ id: 1, status: 'success', progress: 100, message: '解析完成' });
|
||||
apiMock.propagateMasks.mockResolvedValue({
|
||||
model: 'sam2',
|
||||
direction: 'forward',
|
||||
source_frame_id: 10,
|
||||
processed_frame_count: 3,
|
||||
created_annotation_count: 2,
|
||||
annotations: [],
|
||||
});
|
||||
apiMock.getAiModelStatus.mockResolvedValue({
|
||||
selected_model: 'sam2',
|
||||
gpu: { available: false, device: 'cpu', name: null, torch_available: true },
|
||||
@@ -320,4 +330,64 @@ describe('VideoWorkspace', () => {
|
||||
expect.objectContaining({ id: 'annotation-88', label: 'GT Mask' }),
|
||||
]));
|
||||
});
|
||||
|
||||
it('propagates the selected current-frame mask through the backend video tracker', async () => {
|
||||
apiMock.getProjectFrames.mockResolvedValueOnce([
|
||||
{ id: 10, project_id: 1, frame_index: 0, image_url: '/frame.jpg', width: 640, height: 360 },
|
||||
{ id: 11, project_id: 1, frame_index: 1, image_url: '/frame-1.jpg', width: 640, height: 360 },
|
||||
]);
|
||||
apiMock.buildAnnotationPayload.mockReturnValueOnce({
|
||||
project_id: 1,
|
||||
frame_id: 10,
|
||||
template_id: 2,
|
||||
mask_data: {
|
||||
polygons: [[[0.1, 0.1], [0.3, 0.1], [0.3, 0.3]]],
|
||||
label: '胆囊',
|
||||
color: '#ff0000',
|
||||
class: { id: 'c1', name: '胆囊', color: '#ff0000', zIndex: 20 },
|
||||
},
|
||||
bbox: [0.1, 0.1, 0.2, 0.2],
|
||||
});
|
||||
|
||||
render(<VideoWorkspace />);
|
||||
await waitFor(() => expect(useStore.getState().frames).toHaveLength(2));
|
||||
act(() => {
|
||||
useStore.setState({
|
||||
aiModel: 'sam2',
|
||||
activeTemplateId: '2',
|
||||
selectedMaskIds: ['mask-1'],
|
||||
masks: [{
|
||||
id: 'mask-1',
|
||||
frameId: '10',
|
||||
pathData: 'M 0 0 Z',
|
||||
label: '胆囊',
|
||||
color: '#ff0000',
|
||||
segmentation: [[64, 36, 192, 36, 192, 108]],
|
||||
bbox: [64, 36, 128, 72],
|
||||
}],
|
||||
});
|
||||
});
|
||||
|
||||
fireEvent.click(screen.getByRole('button', { name: '传播片段' }));
|
||||
|
||||
await waitFor(() => expect(apiMock.propagateMasks).toHaveBeenCalledWith({
|
||||
project_id: 1,
|
||||
frame_id: 10,
|
||||
model: 'sam2',
|
||||
direction: 'forward',
|
||||
max_frames: 30,
|
||||
include_source: false,
|
||||
save_annotations: true,
|
||||
seed: {
|
||||
polygons: [[[0.1, 0.1], [0.3, 0.1], [0.3, 0.3]]],
|
||||
bbox: [0.1, 0.1, 0.2, 0.2],
|
||||
points: undefined,
|
||||
label: '胆囊',
|
||||
color: '#ff0000',
|
||||
class_metadata: { id: 'c1', name: '胆囊', color: '#ff0000', zIndex: 20 },
|
||||
template_id: 2,
|
||||
},
|
||||
}));
|
||||
await waitFor(() => expect(screen.getByText('已传播并保存 2 个区域')).toBeInTheDocument());
|
||||
});
|
||||
});
|
||||
|
||||
@@ -12,6 +12,7 @@ import {
|
||||
getTemplates,
|
||||
importGtMask,
|
||||
parseMedia,
|
||||
propagateMasks,
|
||||
saveAnnotation,
|
||||
updateAnnotation,
|
||||
} from '../lib/api';
|
||||
@@ -37,6 +38,8 @@ export function VideoWorkspace({ onNavigateToAI }: { onNavigateToAI?: () => void
|
||||
const maskHistory = useStore((state) => state.maskHistory);
|
||||
const maskFuture = useStore((state) => state.maskFuture);
|
||||
const activeTemplateId = useStore((state) => state.activeTemplateId);
|
||||
const aiModel = useStore((state) => state.aiModel);
|
||||
const selectedMaskIds = useStore((state) => state.selectedMaskIds);
|
||||
const setFrames = useStore((state) => state.setFrames);
|
||||
const setCurrentFrame = useStore((state) => state.setCurrentFrame);
|
||||
const setMasks = useStore((state) => state.setMasks);
|
||||
@@ -45,6 +48,7 @@ export function VideoWorkspace({ onNavigateToAI }: { onNavigateToAI?: () => void
|
||||
const [isSaving, setIsSaving] = useState(false);
|
||||
const [isExporting, setIsExporting] = useState(false);
|
||||
const [isImportingGt, setIsImportingGt] = useState(false);
|
||||
const [isPropagating, setIsPropagating] = useState(false);
|
||||
const [statusMessage, setStatusMessage] = useState('');
|
||||
|
||||
const hydrateSavedAnnotations = useCallback(async (projectId: string, projectFrames: Frame[]) => {
|
||||
@@ -102,6 +106,8 @@ export function VideoWorkspace({ onNavigateToAI }: { onNavigateToAI?: () => void
|
||||
url: f.image_url,
|
||||
width: f.width ?? 0,
|
||||
height: f.height ?? 0,
|
||||
timestampMs: f.timestamp_ms ?? undefined,
|
||||
sourceFrameNumber: f.source_frame_number ?? undefined,
|
||||
}));
|
||||
setFrames(mappedFrames);
|
||||
setCurrentFrame(0);
|
||||
@@ -117,6 +123,8 @@ export function VideoWorkspace({ onNavigateToAI }: { onNavigateToAI?: () => void
|
||||
url: f.image_url,
|
||||
width: f.width ?? 0,
|
||||
height: f.height ?? 0,
|
||||
timestampMs: f.timestamp_ms ?? undefined,
|
||||
sourceFrameNumber: f.source_frame_number ?? undefined,
|
||||
}));
|
||||
setFrames(mappedFrames);
|
||||
setCurrentFrame(0);
|
||||
@@ -314,6 +322,55 @@ export function VideoWorkspace({ onNavigateToAI }: { onNavigateToAI?: () => void
|
||||
}
|
||||
};
|
||||
|
||||
const handlePropagateSegment = async () => {
|
||||
if (!currentProject?.id || !currentFrame?.id) return;
|
||||
const currentFrameMasks = masks.filter((mask) => mask.frameId === currentFrame.id);
|
||||
const selectedMask = selectedMaskIds
|
||||
.map((id) => currentFrameMasks.find((mask) => mask.id === id))
|
||||
.find((mask): mask is NonNullable<typeof mask> => Boolean(mask));
|
||||
const seedMask = selectedMask || currentFrameMasks[0];
|
||||
if (!seedMask) {
|
||||
setStatusMessage('请先选择或创建一个当前帧区域');
|
||||
return;
|
||||
}
|
||||
|
||||
const seedPayload = buildAnnotationPayload(currentProject.id, seedMask, currentFrame, activeTemplateId);
|
||||
if (!seedPayload?.mask_data?.polygons?.length && !seedPayload?.bbox) {
|
||||
setStatusMessage('当前区域缺少可传播的 polygon 或 bbox');
|
||||
return;
|
||||
}
|
||||
|
||||
setIsPropagating(true);
|
||||
setStatusMessage(`${aiModel.toUpperCase()} 正在传播当前区域...`);
|
||||
try {
|
||||
const result = await propagateMasks({
|
||||
project_id: Number(currentProject.id),
|
||||
frame_id: Number(currentFrame.id),
|
||||
model: aiModel,
|
||||
direction: 'forward',
|
||||
max_frames: 30,
|
||||
include_source: false,
|
||||
save_annotations: true,
|
||||
seed: {
|
||||
polygons: seedPayload.mask_data?.polygons,
|
||||
bbox: seedPayload.bbox,
|
||||
points: seedPayload.points,
|
||||
label: seedPayload.mask_data?.label,
|
||||
color: seedPayload.mask_data?.color,
|
||||
class_metadata: seedPayload.mask_data?.class,
|
||||
template_id: seedPayload.template_id,
|
||||
},
|
||||
});
|
||||
await hydrateSavedAnnotations(currentProject.id, frames);
|
||||
setStatusMessage(`已传播并保存 ${result.created_annotation_count} 个区域`);
|
||||
} catch (err) {
|
||||
console.error('Propagation failed:', err);
|
||||
setStatusMessage('传播失败,请检查模型状态或后端日志');
|
||||
} finally {
|
||||
setIsPropagating(false);
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="w-full h-full flex flex-col bg-[#0a0a0a]">
|
||||
{/* Top Header / Status bar */}
|
||||
@@ -339,28 +396,35 @@ export function VideoWorkspace({ onNavigateToAI }: { onNavigateToAI?: () => void
|
||||
/>
|
||||
<button
|
||||
onClick={() => gtMaskInputRef.current?.click()}
|
||||
disabled={!currentProject?.id || !currentFrame?.id || isImportingGt || isSaving || isExporting}
|
||||
disabled={!currentProject?.id || !currentFrame?.id || isImportingGt || isSaving || isExporting || isPropagating}
|
||||
className="px-4 py-1.5 bg-white/5 hover:bg-white/10 border border-white/10 rounded-md text-xs transition-colors text-white disabled:opacity-40 disabled:cursor-not-allowed"
|
||||
>
|
||||
{isImportingGt ? '导入中...' : '导入 GT Mask'}
|
||||
</button>
|
||||
<button
|
||||
onClick={handlePropagateSegment}
|
||||
disabled={!currentProject?.id || !currentFrame?.id || isSaving || isExporting || isImportingGt || isPropagating}
|
||||
className="px-4 py-1.5 bg-white/5 hover:bg-white/10 border border-white/10 rounded-md text-xs transition-colors text-white disabled:opacity-40 disabled:cursor-not-allowed"
|
||||
>
|
||||
{isPropagating ? '传播中...' : '传播片段'}
|
||||
</button>
|
||||
<button
|
||||
onClick={handleExportMasks}
|
||||
disabled={!currentProject?.id || isExporting || isSaving}
|
||||
disabled={!currentProject?.id || isExporting || isSaving || isPropagating}
|
||||
className="px-4 py-1.5 bg-white/5 hover:bg-white/10 border border-white/10 rounded-md text-xs transition-colors text-white disabled:opacity-40 disabled:cursor-not-allowed"
|
||||
>
|
||||
{isExporting ? '导出中...' : '导出 PNG Mask ZIP'}
|
||||
</button>
|
||||
<button
|
||||
onClick={handleExport}
|
||||
disabled={!currentProject?.id || isExporting || isSaving}
|
||||
disabled={!currentProject?.id || isExporting || isSaving || isPropagating}
|
||||
className="px-4 py-1.5 bg-white/5 hover:bg-white/10 border border-white/10 rounded-md text-xs transition-colors text-white disabled:opacity-40 disabled:cursor-not-allowed"
|
||||
>
|
||||
{isExporting ? '导出中...' : '导出 JSON 标注集'}
|
||||
</button>
|
||||
<button
|
||||
onClick={handleSave}
|
||||
disabled={!currentProject?.id || isSaving || isExporting}
|
||||
disabled={!currentProject?.id || isSaving || isExporting || isPropagating}
|
||||
className="px-4 py-1.5 bg-cyan-600 hover:bg-cyan-500 text-white text-xs font-medium rounded-md transition-shadow shadow-lg shadow-cyan-900/20 disabled:opacity-40 disabled:cursor-not-allowed"
|
||||
>
|
||||
{isSaving ? '保存中...' : '结构化归档保存'}
|
||||
|
||||
@@ -159,9 +159,9 @@ describe('api client contracts', () => {
|
||||
axiosMock.client.post.mockResolvedValueOnce({ data: { ...task, status: 'cancelled', progress: 100 } });
|
||||
axiosMock.client.post.mockResolvedValueOnce({ data: { ...task, id: 13, status: 'queued', progress: 0 } });
|
||||
|
||||
await expect(parseMedia('9')).resolves.toEqual(task);
|
||||
await expect(parseMedia('9', { parseFps: 15, maxFrames: 120, targetWidth: 960 })).resolves.toEqual(task);
|
||||
expect(axiosMock.client.post).toHaveBeenCalledWith('/api/media/parse', null, {
|
||||
params: { project_id: '9' },
|
||||
params: { project_id: '9', parse_fps: 15, max_frames: 120, target_width: 960 },
|
||||
});
|
||||
|
||||
await expect(getTask(12)).resolves.toEqual(expect.objectContaining({ status: 'success', progress: 100 }));
|
||||
@@ -175,7 +175,7 @@ describe('api client contracts', () => {
|
||||
});
|
||||
|
||||
it('lists, saves, updates, and deletes annotations with the backend annotation contract', async () => {
|
||||
const { deleteAnnotation, getProjectAnnotations, saveAnnotation, updateAnnotation } = await import('./api');
|
||||
const { deleteAnnotation, getProjectAnnotations, propagateMasks, saveAnnotation, updateAnnotation } = await import('./api');
|
||||
const saved = {
|
||||
id: 1,
|
||||
project_id: 9,
|
||||
@@ -221,6 +221,43 @@ describe('api client contracts', () => {
|
||||
axiosMock.client.delete.mockResolvedValueOnce({ data: null });
|
||||
await expect(deleteAnnotation('1')).resolves.toBeUndefined();
|
||||
expect(axiosMock.client.delete).toHaveBeenCalledWith('/api/ai/annotations/1');
|
||||
|
||||
axiosMock.client.post.mockResolvedValueOnce({
|
||||
data: {
|
||||
model: 'sam2',
|
||||
direction: 'forward',
|
||||
source_frame_id: 5,
|
||||
processed_frame_count: 3,
|
||||
created_annotation_count: 2,
|
||||
annotations: [saved],
|
||||
},
|
||||
});
|
||||
await expect(propagateMasks({
|
||||
project_id: 9,
|
||||
frame_id: 5,
|
||||
model: 'sam2',
|
||||
seed: {
|
||||
polygons: [[[0, 0], [1, 0], [1, 1]]],
|
||||
label: 'mask',
|
||||
color: '#06b6d4',
|
||||
},
|
||||
direction: 'forward',
|
||||
max_frames: 30,
|
||||
})).resolves.toEqual(expect.objectContaining({ created_annotation_count: 2 }));
|
||||
expect(axiosMock.client.post).toHaveBeenCalledWith('/api/ai/propagate', {
|
||||
project_id: 9,
|
||||
frame_id: 5,
|
||||
model: 'sam2',
|
||||
seed: {
|
||||
polygons: [[[0, 0], [1, 0], [1, 1]]],
|
||||
label: 'mask',
|
||||
color: '#06b6d4',
|
||||
},
|
||||
direction: 'forward',
|
||||
max_frames: 30,
|
||||
}, {
|
||||
timeout: 600000,
|
||||
});
|
||||
});
|
||||
|
||||
it('imports GT masks through multipart form data', async () => {
|
||||
@@ -377,6 +414,33 @@ describe('api client contracts', () => {
|
||||
});
|
||||
});
|
||||
|
||||
it('normalizes combined box and point prompts for interactive SAM2 refinement', async () => {
|
||||
const { predictMask } = await import('./api');
|
||||
axiosMock.client.post.mockResolvedValueOnce({ data: { polygons: [], scores: [] } });
|
||||
|
||||
await predictMask({
|
||||
imageId: '5',
|
||||
imageWidth: 640,
|
||||
imageHeight: 320,
|
||||
box: { x1: 64, y1: 32, x2: 320, y2: 160 },
|
||||
points: [
|
||||
{ x: 128, y: 64, type: 'pos' },
|
||||
{ x: 256, y: 128, type: 'neg' },
|
||||
],
|
||||
});
|
||||
|
||||
expect(axiosMock.client.post).toHaveBeenCalledWith('/api/ai/predict', {
|
||||
image_id: 5,
|
||||
prompt_type: 'interactive',
|
||||
prompt_data: {
|
||||
box: [0.1, 0.1, 0.5, 0.5],
|
||||
points: [[0.2, 0.2], [0.4, 0.4]],
|
||||
labels: [1, 0],
|
||||
},
|
||||
model: 'sam2',
|
||||
});
|
||||
});
|
||||
|
||||
it('uses semantic prompt type for text-only AI prediction', async () => {
|
||||
const { predictMask } = await import('./api');
|
||||
axiosMock.client.post.mockResolvedValueOnce({ data: { polygons: [], scores: [] } });
|
||||
|
||||
@@ -153,6 +153,8 @@ export async function getProjectFrames(projectId: string): Promise<Array<{
|
||||
image_url: string;
|
||||
width: number | null;
|
||||
height: number | null;
|
||||
timestamp_ms?: number | null;
|
||||
source_frame_number?: number | null;
|
||||
}>> {
|
||||
const response = await apiClient.get(`/api/projects/${projectId}/frames`);
|
||||
return response.data;
|
||||
@@ -185,9 +187,18 @@ export interface ProcessingTask {
|
||||
updated_at: string;
|
||||
}
|
||||
|
||||
export async function parseMedia(projectId: string): Promise<ProcessingTask> {
|
||||
export async function parseMedia(projectId: string, options: {
|
||||
parseFps?: number;
|
||||
maxFrames?: number;
|
||||
targetWidth?: number;
|
||||
} = {}): Promise<ProcessingTask> {
|
||||
const response = await apiClient.post('/api/media/parse', null, {
|
||||
params: { project_id: projectId },
|
||||
params: {
|
||||
project_id: projectId,
|
||||
...(options.parseFps ? { parse_fps: options.parseFps } : {}),
|
||||
...(options.maxFrames ? { max_frames: options.maxFrames } : {}),
|
||||
...(options.targetWidth ? { target_width: options.targetWidth } : {}),
|
||||
},
|
||||
});
|
||||
return response.data;
|
||||
}
|
||||
@@ -312,6 +323,40 @@ export interface SaveAnnotationPayload {
|
||||
|
||||
export type UpdateAnnotationPayload = Omit<SaveAnnotationPayload, 'project_id' | 'frame_id'>;
|
||||
|
||||
export interface PropagateMasksPayload {
|
||||
project_id: number;
|
||||
frame_id: number;
|
||||
model?: AiModelId;
|
||||
seed: {
|
||||
polygons?: number[][][];
|
||||
bbox?: number[];
|
||||
points?: number[][];
|
||||
label?: string;
|
||||
color?: string;
|
||||
class_metadata?: {
|
||||
id?: string;
|
||||
name?: string;
|
||||
color?: string;
|
||||
zIndex?: number;
|
||||
category?: string;
|
||||
};
|
||||
template_id?: number;
|
||||
};
|
||||
direction?: 'forward' | 'backward' | 'both';
|
||||
max_frames?: number;
|
||||
include_source?: boolean;
|
||||
save_annotations?: boolean;
|
||||
}
|
||||
|
||||
export interface PropagateMasksResult {
|
||||
model: AiModelId;
|
||||
direction: string;
|
||||
source_frame_id: number;
|
||||
processed_frame_count: number;
|
||||
created_annotation_count: number;
|
||||
annotations: SavedAnnotation[];
|
||||
}
|
||||
|
||||
export interface DashboardTask {
|
||||
id: string;
|
||||
task_id?: number;
|
||||
@@ -474,10 +519,22 @@ export function annotationToMask(annotation: SavedAnnotation, frame: Frame): Mas
|
||||
}
|
||||
|
||||
export async function predictMask(payload: PredictMaskPayload): Promise<PredictMaskResult> {
|
||||
let prompt_type: 'point' | 'box' | 'semantic';
|
||||
let prompt_type: 'point' | 'box' | 'semantic' | 'interactive';
|
||||
let prompt_data: unknown;
|
||||
|
||||
if (payload.box) {
|
||||
if (payload.box && payload.points && payload.points.length > 0) {
|
||||
prompt_type = 'interactive';
|
||||
prompt_data = {
|
||||
box: [
|
||||
clamp01(payload.box.x1 / Math.max(payload.imageWidth, 1)),
|
||||
clamp01(payload.box.y1 / Math.max(payload.imageHeight, 1)),
|
||||
clamp01(payload.box.x2 / Math.max(payload.imageWidth, 1)),
|
||||
clamp01(payload.box.y2 / Math.max(payload.imageHeight, 1)),
|
||||
],
|
||||
points: payload.points.map((point) => normalizePoint(point, payload.imageWidth, payload.imageHeight)),
|
||||
labels: payload.points.map((point) => (point.type === 'neg' ? 0 : 1)),
|
||||
};
|
||||
} else if (payload.box) {
|
||||
prompt_type = 'box';
|
||||
prompt_data = [
|
||||
clamp01(payload.box.x1 / Math.max(payload.imageWidth, 1)),
|
||||
@@ -540,6 +597,13 @@ export async function getProjectAnnotations(projectId: string, frameId?: string)
|
||||
return response.data;
|
||||
}
|
||||
|
||||
export async function propagateMasks(payload: PropagateMasksPayload): Promise<PropagateMasksResult> {
|
||||
const response = await apiClient.post('/api/ai/propagate', payload, {
|
||||
timeout: 600000,
|
||||
});
|
||||
return response.data;
|
||||
}
|
||||
|
||||
export async function saveAnnotation(payload: SaveAnnotationPayload): Promise<SavedAnnotation> {
|
||||
const response = await apiClient.post('/api/ai/annotate', payload);
|
||||
return response.data;
|
||||
|
||||
@@ -30,6 +30,7 @@ describe('useStore', () => {
|
||||
useStore.getState().setFrames([{ id: 'f1', projectId: '1', index: 0, url: '/f1.jpg', width: 640, height: 360 }]);
|
||||
useStore.getState().setCurrentFrame(0);
|
||||
useStore.getState().addMask({ id: 'm1', frameId: 'f1', pathData: 'M 0 0 Z', label: 'mask', color: '#fff' });
|
||||
useStore.getState().setSelectedMaskIds(['m1']);
|
||||
useStore.getState().updateMask('m1', { label: 'updated mask', saveStatus: 'dirty' });
|
||||
useStore.getState().addAnnotation({ id: 'a1', frameId: 'f1', type: 'mask', points: [], label: 'ann', color: '#fff' });
|
||||
useStore.getState().addTemplate({ id: 't1', name: 'Template', classes: [], rules: [] });
|
||||
@@ -40,6 +41,7 @@ describe('useStore', () => {
|
||||
expect(useStore.getState().currentProject?.id).toBe('1');
|
||||
expect(useStore.getState().frames).toHaveLength(1);
|
||||
expect(useStore.getState().currentFrameIndex).toBe(0);
|
||||
expect(useStore.getState().selectedMaskIds).toEqual(['m1']);
|
||||
expect(useStore.getState().masks[0]).toEqual(expect.objectContaining({ label: 'updated mask', saveStatus: 'dirty' }));
|
||||
expect(useStore.getState().annotations).toHaveLength(1);
|
||||
expect(useStore.getState().templates[0].name).toBe('Template 2');
|
||||
@@ -51,6 +53,7 @@ describe('useStore', () => {
|
||||
|
||||
expect(useStore.getState().annotations).toEqual([]);
|
||||
expect(useStore.getState().masks).toEqual([]);
|
||||
expect(useStore.getState().selectedMaskIds).toEqual([]);
|
||||
expect(useStore.getState().templates).toEqual([]);
|
||||
});
|
||||
|
||||
|
||||
@@ -27,6 +27,8 @@ export interface Frame {
|
||||
width: number;
|
||||
height: number;
|
||||
timestamp?: string;
|
||||
timestampMs?: number;
|
||||
sourceFrameNumber?: number;
|
||||
}
|
||||
|
||||
export interface Annotation {
|
||||
@@ -112,6 +114,7 @@ export interface AppState {
|
||||
currentFrameIndex: number;
|
||||
annotations: Annotation[];
|
||||
masks: Mask[];
|
||||
selectedMaskIds: string[];
|
||||
maskHistory: Mask[][];
|
||||
maskFuture: Mask[][];
|
||||
setActiveModule: (module: string) => void;
|
||||
@@ -123,6 +126,7 @@ export interface AppState {
|
||||
addMask: (mask: Mask) => void;
|
||||
updateMask: (id: string, updates: Partial<Mask>) => void;
|
||||
setMasks: (masks: Mask[]) => void;
|
||||
setSelectedMaskIds: (ids: string[]) => void;
|
||||
clearMasks: () => void;
|
||||
undoMasks: () => void;
|
||||
redoMasks: () => void;
|
||||
@@ -167,6 +171,7 @@ export const useStore = create<AppState>((set) => ({
|
||||
frames: [],
|
||||
annotations: [],
|
||||
masks: [],
|
||||
selectedMaskIds: [],
|
||||
maskHistory: [],
|
||||
maskFuture: [],
|
||||
activeTemplateId: null,
|
||||
@@ -195,6 +200,7 @@ export const useStore = create<AppState>((set) => ({
|
||||
currentFrameIndex: 0,
|
||||
annotations: [],
|
||||
masks: [],
|
||||
selectedMaskIds: [],
|
||||
maskHistory: [],
|
||||
maskFuture: [],
|
||||
setActiveModule: (activeModule: string) => set({ activeModule }),
|
||||
@@ -227,9 +233,11 @@ export const useStore = create<AppState>((set) => ({
|
||||
maskFuture: [],
|
||||
};
|
||||
}),
|
||||
setSelectedMaskIds: (selectedMaskIds: string[]) => set({ selectedMaskIds }),
|
||||
clearMasks: () =>
|
||||
set((state) => ({
|
||||
masks: [],
|
||||
selectedMaskIds: [],
|
||||
maskHistory: [...state.maskHistory, state.masks],
|
||||
maskFuture: [],
|
||||
})),
|
||||
|
||||
@@ -71,7 +71,11 @@ vi.mock('react-konva', () => ({
|
||||
data-fill={props.fill}
|
||||
data-x={props.x}
|
||||
data-y={props.y}
|
||||
onClick={() => props.onClick?.({ cancelBubble: false })}
|
||||
onClick={(event) => {
|
||||
const konvaEvent = { cancelBubble: false };
|
||||
props.onClick?.(konvaEvent);
|
||||
if (konvaEvent.cancelBubble) event.stopPropagation();
|
||||
}}
|
||||
onMouseUp={(event: React.MouseEvent<HTMLSpanElement>) => props.onDragEnd?.({
|
||||
target: {
|
||||
x: () => event.clientX || props.x || 0,
|
||||
@@ -92,7 +96,12 @@ vi.mock('react-konva', () => ({
|
||||
data-testid="konva-path"
|
||||
data-path={props.data}
|
||||
data-fill={props.fill}
|
||||
onClick={() => props.onClick?.({ cancelBubble: false })}
|
||||
data-fill-rule={props.fillRule}
|
||||
onClick={(event) => {
|
||||
const konvaEvent = { cancelBubble: false };
|
||||
props.onClick?.(konvaEvent);
|
||||
if (konvaEvent.cancelBubble) event.stopPropagation();
|
||||
}}
|
||||
/>
|
||||
),
|
||||
}));
|
||||
|
||||
@@ -13,6 +13,7 @@ export function resetStore() {
|
||||
currentFrameIndex: 0,
|
||||
annotations: [],
|
||||
masks: [],
|
||||
selectedMaskIds: [],
|
||||
maskHistory: [],
|
||||
maskFuture: [],
|
||||
templates: [],
|
||||
|
||||
Reference in New Issue
Block a user