feat: 完善工作区交互提示与后端属性分析

功能新增:
- 新增 POST /api/ai/analyze-mask 后端接口,基于 mask polygon、bbox、points 和 score 返回置信度来源、面积、拓扑锚点和后端分析提示。
- 前端新增 analyzeMask API 封装,并在本体检查面板读取选中 mask 的后端几何属性和重新提取拓扑锚点结果。
- 右侧语义分类树点击分类时,会给当前选中 mask 换标签、更新 class 元数据,并将选中 mask 移到前端渲染最上层,方便继续编辑。
- 分割工作区画布新增上下文操作提示,覆盖多边形 Enter 完成、Esc 取消、首节点闭合、拖拽图形、点区域、SAM 点/框提示、区域合并/去除选择顺序和多边形编辑。
- AI 智能分割画布新增正向点、反向点、边界框选和视口控制的上下文提示。
- 自动传播交互收敛为参考帧加起止帧范围加单个“自动传播”按钮,默认使用当前参考帧全部 mask 作为 seed。
- 时间轴改为用浅蓝色进度条区段标记自动传播生成的帧,而不是已编辑帧竖线提示。

Bugfix:
- AI 分割页无当前帧时移除外部演示背景图,改为明确空状态提示,避免误以为外部图片可参与真实推理。
- 工具栏魔法棒文案改为“打开 AI 智能分割”,避免误导为直接触发 SAM 推理。
- Canvas 底部当前图层信息改为显示真实选中 mask 标签和 annotation id,不再使用固定占位文本。
- 已保存标注回显时保留 mask metadata 中的传播来源、score 等字段,供时间轴和属性面板识别。
- 清理 server.ts 中遗留的 /api/login、/api/projects、/api/templates 内存 mock API,避免和 FastAPI 真实后端混淆。

测试:
- 补充 analyze-mask 后端测试,覆盖后端几何属性和锚点返回。
- 补充 api.analyzeMask 前端契约测试,覆盖 normalized polygon、bbox、points 和 extract_skeleton payload。
- 补充本体面板测试,覆盖后端属性读取、自定义分类写回后端模板、选中 mask 换标签和置顶显示。
- 补充 Canvas 测试,覆盖上下文提示、多边形完成提示、布尔选择顺序提示、当前图层真实显示和编辑优先级。
- 补充 AI 分割测试,覆盖无帧空状态和提示工具上下文提示。
- 更新 Konva 测试 mock,支持拖动过程、stroke/dash/fillRule 等渲染断言。

文档:
- 更新 README 和 AGENTS,说明 server.ts 不再保留业务 mock API。
- 更新 doc/02、doc/03、doc/04、doc/05、doc/07、doc/08、doc/09,记录后端属性分析、分类置顶显示、上下文提示、自动传播按钮、传播帧标记、测试覆盖和当前剩余限制。
This commit is contained in:
2026-05-02 02:10:37 +08:00
parent 4c21de02f8
commit b6a276cb8d
28 changed files with 796 additions and 231 deletions

View File

@@ -42,6 +42,30 @@ describe('AISegmentation', () => {
expect(apiMock.getAiModelStatus).toHaveBeenCalledWith('sam2.1_hiera_tiny');
});
it('does not render the legacy upload-replace-background mock button', () => {
render(<AISegmentation onSendToWorkspace={vi.fn()} />);
expect(screen.queryByText('上传替换底图')).not.toBeInTheDocument();
});
it('shows an empty state instead of a demo image when no project frame is selected', () => {
useStore.setState({ frames: [] });
render(<AISegmentation onSendToWorkspace={vi.fn()} />);
expect(screen.getByText('请先在项目库选择项目并生成帧')).toBeInTheDocument();
});
it('shows contextual guidance for prompt tools', () => {
render(<AISegmentation onSendToWorkspace={vi.fn()} />);
fireEvent.click(screen.getByText('正向选点'));
expect(screen.getByText(/点击目标内部添加正向点/)).toBeInTheDocument();
fireEvent.click(screen.getByText('边界框选'));
expect(screen.getByText(/按住并拖拽建立框选区域/)).toBeInTheDocument();
});
it('passes enabled inference parameters to the backend', async () => {
apiMock.predictMask.mockResolvedValueOnce({ masks: [] });
render(<AISegmentation onSendToWorkspace={vi.fn()} />);

View File

@@ -1,5 +1,5 @@
import React, { useState, useCallback, useEffect } from 'react';
import { Target, PlusCircle, MinusCircle, SquareDashed, Sparkles, SendToBack, Image as ImageIcon, Undo, Redo, Loader2, XCircle, Trash2 } from 'lucide-react';
import { Target, PlusCircle, MinusCircle, SquareDashed, Sparkles, SendToBack, Undo, Redo, Loader2, XCircle, Trash2 } from 'lucide-react';
import { cn } from '../lib/utils';
import { Stage, Layer, Image as KonvaImage, Circle, Path, Group, Rect } from 'react-konva';
import useImage from 'use-image';
@@ -13,6 +13,7 @@ interface AISegmentationProps {
type PromptPoint = { x: number; y: number; type: 'pos' | 'neg' };
type PromptBox = { x1: number; y1: number; x2: number; y2: number };
type ToolHint = { title: string; body: string };
export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
const storeActiveTool = useStore((state) => state.activeTool);
@@ -49,8 +50,7 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
const [boxCurrent, setBoxCurrent] = useState<{ x: number; y: number } | null>(null);
const [cursorPos, setCursorPos] = useState({ x: 0, y: 0 });
const currentFrame = frames[currentFrameIndex] || null;
const previewUrl = currentFrame?.url || 'https://images.unsplash.com/photo-1549317661-bd32c8ce0be2?q=80&w=2070&auto=format&fit=crop';
const [image] = useImage(previewUrl);
const [image] = useImage(currentFrame?.url || '');
const aiMaskIdSet = new Set(aiMaskIds);
const frameMasks = currentFrame
? masks.filter((mask) => mask.frameId === currentFrame.id && aiMaskIdSet.has(mask.id))
@@ -59,6 +59,33 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
const modelCanInfer = selectedModelStatus?.available ?? true;
const effectiveTool = storeActiveTool;
const toolHint = React.useMemo<ToolHint | null>(() => {
if (!currentFrame) return null;
if (effectiveTool === 'point_pos') {
return {
title: '正向选点',
body: '点击目标内部添加正向点;点击已有提示点可删除。完成提示后点击“执行高精度语义分割”。',
};
}
if (effectiveTool === 'point_neg') {
return {
title: '反向选点',
body: '点击不应包含的区域添加反向点;可和框选/正向点一起使用来细化结果。',
};
}
if (effectiveTool === 'box_select') {
return {
title: promptBox ? '边界框已建立' : '边界框选',
body: promptBox
? '当前框会随推理一起发送;也可以继续添加正向/反向点细化。重新拖拽会替换框。'
: '按住并拖拽建立框选区域,松开后保留框,再点击“执行高精度语义分割”。',
};
}
if (effectiveTool === 'move') {
return { title: '视口控制', body: '拖拽移动画布,滚轮缩放;切回正向/反向点或框选后继续放置提示。' };
}
return null;
}, [currentFrame, effectiveTool, promptBox]);
const boxRect = React.useMemo(() => {
const activeBox = boxStart && boxCurrent
@@ -505,9 +532,6 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
<Redo size={14} />
</button>
<div className="w-px h-4 bg-white/10 mx-1"></div>
<button className="flex items-center gap-2 text-xs text-gray-400 hover:text-white transition-colors bg-white/5 hover:bg-white/10 px-3 py-1.5 rounded-md border border-white/5">
<ImageIcon size={14} />
</button>
<button
className="flex items-center gap-2 text-xs text-gray-400 hover:text-white transition-colors bg-white/5 hover:bg-white/10 px-3 py-1.5 rounded-md border border-white/5 disabled:opacity-30 disabled:hover:bg-white/5 disabled:hover:text-gray-400 disabled:cursor-not-allowed"
onClick={removeLastPromptPoint}
@@ -534,6 +558,17 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
<div className="flex-1 relative p-8">
<div className="w-full h-full relative border border-white/5 rounded shadow-2xl bg-[#1e1e1e] overflow-hidden cursor-crosshair">
{!currentFrame && (
<div className="absolute inset-0 z-20 flex items-center justify-center bg-[#151515] text-xs text-gray-500">
</div>
)}
{toolHint && (
<div className="absolute top-4 left-4 z-20 max-w-sm rounded-lg border border-cyan-400/20 bg-[#0d0d0d]/95 px-3 py-2 shadow-xl pointer-events-none">
<div className="text-[10px] font-semibold uppercase tracking-widest text-cyan-300">{toolHint.title}</div>
<div className="mt-1 text-xs leading-relaxed text-gray-300">{toolHint.body}</div>
</div>
)}
<Stage
width={window.innerWidth - 320 - 64}
height={window.innerHeight - 64 - 64}

View File

@@ -297,9 +297,10 @@ describe('CanvasArea', () => {
masks: [
{
id: 'm1',
annotationId: '42',
frameId: 'frame-1',
pathData: 'M 0 0 L 10 0 L 10 10 Z',
label: 'A',
label: '胆囊',
color: '#fff',
segmentation: [[0, 0, 10, 0, 10, 10]],
},
@@ -310,6 +311,7 @@ describe('CanvasArea', () => {
fireEvent.click(screen.getByTestId('konva-path'));
await waitFor(() => expect(useStore.getState().selectedMaskIds).toEqual(['m1']));
expect(screen.getByText('当前图层: 胆囊 #42')).toBeInTheDocument();
});
it('keeps a mask selected when opening the workspace polygon editor from AI results', () => {
@@ -389,6 +391,37 @@ describe('CanvasArea', () => {
}));
});
it('moves a polygon vertex directly while dragging without a prior vertex click', () => {
useStore.setState({
selectedMaskIds: ['draft-1'],
masks: [
{
id: 'draft-1',
frameId: 'frame-1',
pathData: 'M 10 10 L 90 10 L 90 40 Z',
label: 'Draft',
color: '#06b6d4',
saveStatus: 'draft',
segmentation: [[10, 10, 90, 10, 90, 40]],
bbox: [10, 10, 80, 30],
},
],
});
render(<CanvasArea activeTool="edit_polygon" frame={frame} />);
const handles = screen.getAllByTestId('konva-circle')
.filter((element) => element.getAttribute('data-fill') === '#ffffff');
fireEvent.mouseDown(handles[0]);
fireEvent.mouseMove(handles[0], { clientX: 25, clientY: 35 });
expect(useStore.getState().masks[0]).toEqual(expect.objectContaining({
pathData: 'M 25 35 L 90 10 L 90 40 Z',
segmentation: [[25, 35, 90, 10, 90, 40]],
saveStatus: 'draft',
}));
});
it('deletes a selected polygon vertex without dropping below three points', () => {
useStore.setState({
masks: [
@@ -626,6 +659,11 @@ describe('CanvasArea', () => {
const paths = screen.getAllByTestId('konva-path');
fireEvent.click(paths[0]);
fireEvent.click(paths[1]);
const selectedPaths = screen.getAllByTestId('konva-path');
expect(selectedPaths[0]).toHaveAttribute('data-stroke', '#facc15');
expect(selectedPaths[0]).toHaveAttribute('data-dash', '');
expect(selectedPaths[1]).toHaveAttribute('data-stroke', '#fb7185');
expect(selectedPaths[1]).toHaveAttribute('data-dash', '6,4');
fireEvent.click(screen.getByRole('button', { name: '从主区域去除' }));
expect(useStore.getState().masks).toHaveLength(2);
@@ -796,9 +834,11 @@ describe('CanvasArea', () => {
it('finalizes a clicked polygon with Enter', () => {
render(<CanvasArea activeTool="create_polygon" frame={frame} />);
const stage = screen.getByTestId('konva-stage');
expect(screen.getByText(/点击画布添加顶点/)).toBeInTheDocument();
fireEvent.click(stage, { clientX: 120, clientY: 80 });
fireEvent.click(stage, { clientX: 220, clientY: 80 });
fireEvent.click(stage, { clientX: 180, clientY: 160 });
expect(screen.getByText(/点击黄色首节点或按 Enter 闭合完成/)).toBeInTheDocument();
fireEvent.keyDown(window, { key: 'Enter' });
expect(useStore.getState().masks).toHaveLength(1);
@@ -831,6 +871,35 @@ describe('CanvasArea', () => {
expect(screen.queryAllByTestId('konva-circle')).toHaveLength(0);
});
it('shows contextual guidance for boolean selection ordering', () => {
useStore.setState({
masks: [
{
id: 'm1',
frameId: 'frame-1',
pathData: 'M 10 10 L 90 10 L 90 50 Z',
label: 'A',
color: '#06b6d4',
segmentation: [[10, 10, 90, 10, 90, 50]],
},
{
id: 'm2',
frameId: 'frame-1',
pathData: 'M 50 30 L 120 30 L 120 80 Z',
label: 'B',
color: '#ff0000',
segmentation: [[50, 30, 120, 30, 120, 80]],
},
],
});
render(<CanvasArea activeTool="area_remove" frame={frame} />);
expect(screen.getByText(/先点击要保留的主区域/)).toBeInTheDocument();
fireEvent.click(screen.getAllByTestId('konva-path')[0]);
expect(screen.getByText(/第一个是保留主区域/)).toBeInTheDocument();
});
it('applies the selected class to current-frame masks and marks saved masks dirty', () => {
useStore.setState({
activeTemplateId: '2',

View File

@@ -16,6 +16,7 @@ interface CanvasAreaProps {
type CanvasPoint = { x: number; y: number };
type PromptPoint = CanvasPoint & { type: 'pos' | 'neg' };
type PromptBox = { x1: number; y1: number; x2: number; y2: number };
type ToolHint = { title: string; body: string };
const DRAG_MANUAL_TOOLS = new Set(['create_rectangle', 'create_circle', 'create_line']);
const POLYGON_TOOL = 'create_polygon';
@@ -282,6 +283,81 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
const dirtyMaskCount = frameMasks.filter((mask) => mask.saveStatus === 'dirty').length;
const isBooleanTool = BOOLEAN_TOOLS.has(effectiveTool);
const isPolygonEditTool = effectiveTool === 'move' || effectiveTool === EDIT_POLYGON_TOOL;
const currentLayerLabel = selectedMask
? `${selectedMask.className || selectedMask.label}${selectedMask.annotationId ? ` #${selectedMask.annotationId}` : ' (未保存)'}`
: '未选择';
const toolHint = React.useMemo<ToolHint | null>(() => {
if (!frame) return null;
if (effectiveTool === POLYGON_TOOL) {
if (polygonPoints.length === 0) {
return {
title: '创建多边形',
body: '点击画布添加顶点;至少 3 个点后,点击首节点或按 Enter 完成,按 Esc 取消。',
};
}
if (polygonPoints.length < 3) {
return {
title: `创建多边形 · 已放置 ${polygonPoints.length}`,
body: '继续点击添加顶点;满 3 个点后才能闭合,按 Esc 可取消当前多边形。',
};
}
return {
title: `创建多边形 · 已放置 ${polygonPoints.length}`,
body: '点击黄色首节点或按 Enter 闭合完成;按 Esc 放弃当前多边形。',
};
}
if (effectiveTool === 'create_rectangle') {
return { title: '创建矩形', body: '按住并拖拽框出区域,松开鼠标后生成 mask切换工具可放弃当前操作。' };
}
if (effectiveTool === 'create_circle') {
return { title: '创建圆形', body: '按住并拖拽确定外接范围,松开鼠标后生成椭圆 mask。' };
}
if (effectiveTool === 'create_line') {
return { title: '创建线段', body: '按住并拖拽画出线段,松开后生成有宽度的线状 mask。' };
}
if (effectiveTool === POINT_TOOL) {
return { title: '创建点区域', body: '点击画布创建一个小型点区域;也可以在已有 mask 上继续落点。' };
}
if (effectiveTool === 'box_select') {
return {
title: samPromptBox ? '边界框已建立' : '边界框选',
body: samPromptBox
? '继续添加正向/反向点可细化同一个候选区域;重新拖拽会替换当前框。'
: '按住并拖拽建立框选区域,松开后会触发 SAM 推理。',
};
}
if (effectiveTool === 'point_pos') {
return { title: '正向选点', body: '点击目标内部添加正向点并触发细化;点击已有提示点可删除并重新推理。' };
}
if (effectiveTool === 'point_neg') {
return { title: '反向选点', body: '点击不应包含的区域添加反向点;点击已有提示点可删除并重新推理。' };
}
if (effectiveTool === 'area_merge') {
return {
title: '区域合并',
body: booleanSelectedMasks.length > 0
? `已选 ${booleanSelectedMasks.length} 个区域;第一个选中的是主区域,点击“合并选中”完成。`
: '依次点击多个 mask第一个选中的区域会作为合并后的主区域。',
};
}
if (effectiveTool === 'area_remove') {
return {
title: '重叠区域去除',
body: booleanSelectedMasks.length > 0
? `已选 ${booleanSelectedMasks.length} 个区域;第一个是保留主区域,后续区域会被扣除。`
: '先点击要保留的主区域,再点击要扣除的干涉区域。',
};
}
if (effectiveTool === EDIT_POLYGON_TOOL || (effectiveTool === 'move' && selectedMask)) {
return {
title: selectedMask ? '调整多边形' : '调整多边形',
body: selectedMask
? '可直接拖动白色顶点;点击青色边中点或双击边线新增顶点;选中顶点/区域后按 Delete 删除。'
: '点击一个 mask 后,可拖动顶点、点击边中点新增顶点,或按 Delete 删除选中区域。',
};
}
return null;
}, [booleanSelectedMasks.length, effectiveTool, frame, polygonPoints.length, samPromptBox, selectedMask]);
useEffect(() => {
const handleResize = () => {
@@ -860,7 +936,14 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
setSelectedVertexIndex(null);
};
const handleVertexDragEnd = (mask: Mask, vertexIndex: number, event: any) => {
const handleVertexDragStart = (mask: Mask, vertexIndex: number, event?: any) => {
if (event) event.cancelBubble = true;
setSelectedMaskId(mask.id);
setSelectedMaskIds([mask.id]);
setSelectedVertexIndex(vertexIndex);
};
const handleVertexDrag = (mask: Mask, vertexIndex: number, event: any) => {
const imageWidth = frame?.width || image?.naturalWidth || image?.width || stageSize.width;
const imageHeight = frame?.height || image?.naturalHeight || image?.height || stageSize.height;
const currentPoints = segmentationToPoints(mask.segmentation, selectedPolygonIndex);
@@ -874,6 +957,7 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
: point
));
setSelectedMaskId(mask.id);
setSelectedMaskIds([mask.id]);
setSelectedVertexIndex(vertexIndex);
updatePolygonMask(mask, nextPoints, selectedPolygonIndex);
};
@@ -977,6 +1061,12 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
{inferenceMessage}
</div>
)}
{toolHint && (
<div className="absolute top-4 left-4 z-20 max-w-sm rounded-lg border border-cyan-400/20 bg-[#0d0d0d]/95 px-3 py-2 shadow-xl pointer-events-none">
<div className="text-[10px] font-semibold uppercase tracking-widest text-cyan-300">{toolHint.title}</div>
<div className="mt-1 text-xs leading-relaxed text-gray-300">{toolHint.body}</div>
</div>
)}
<Stage
width={stageSize.width}
@@ -1005,6 +1095,16 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
{/* AI Returned Masks */}
{frameMasks.map((mask) => {
const selectedIndex = selectedMaskIds.indexOf(mask.id);
const isMaskSelected = selectedIndex >= 0;
const isBooleanPrimary = isBooleanTool && selectedIndex === 0;
const isBooleanSecondary = isBooleanTool && selectedIndex > 0;
const strokeColor = isBooleanPrimary
? '#facc15'
: isBooleanSecondary
? '#fb7185'
: mask.color;
const strokeDash = isBooleanSecondary ? [6 / scale, 4 / scale] : undefined;
const hasHoles = Boolean(mask.metadata?.hasHoles);
const paths = hasHoles
? [{ data: segmentationPath(mask.segmentation), polygonIndex: 0, fillRule: 'evenodd' }]
@@ -1014,15 +1114,16 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
fillRule: undefined,
}));
return (
<Group key={mask.id} opacity={selectedMaskIds.includes(mask.id) ? 0.65 : 0.5}>
<Group key={mask.id} opacity={isMaskSelected ? 0.65 : 0.5}>
{paths.map(({ data, polygonIndex, fillRule }) => (
<Path
key={`${mask.id}-polygon-${polygonIndex}`}
data={data}
fill={mask.color}
fillRule={fillRule}
stroke={mask.color}
strokeWidth={(selectedMaskIds.includes(mask.id) ? 2 : 1) / scale}
stroke={strokeColor}
strokeWidth={(isMaskSelected ? 2 : 1) / scale}
dash={strokeDash}
onClick={(event: any) => handleMaskSelect(mask, event, polygonIndex)}
onTap={(event: any) => handleMaskSelect(mask, event, polygonIndex)}
onDblClick={(event: any) => handlePathDoubleClick(mask, event, polygonIndex)}
@@ -1125,6 +1226,9 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
stroke={selectedMask.color}
strokeWidth={2 / scale}
draggable
onMouseDown={(event: any) => handleVertexDragStart(selectedMask, index, event)}
onTouchStart={(event: any) => handleVertexDragStart(selectedMask, index, event)}
onDragStart={(event: any) => handleVertexDragStart(selectedMask, index, event)}
onClick={(event: any) => {
event.cancelBubble = true;
setSelectedVertexIndex(index);
@@ -1133,7 +1237,8 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
event.cancelBubble = true;
setSelectedVertexIndex(index);
}}
onDragEnd={(event: any) => handleVertexDragEnd(selectedMask, index, event)}
onDragMove={(event: any) => handleVertexDrag(selectedMask, index, event)}
onDragEnd={(event: any) => handleVertexDrag(selectedMask, index, event)}
/>
))}
@@ -1163,7 +1268,7 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
<div className="absolute bottom-4 left-4 flex gap-4 text-[10px] font-mono text-gray-500 pointer-events-none">
<span>: {cursorPos.x.toFixed(2)}, {cursorPos.y.toFixed(2)}</span>
<span>当前图层树: OBJECT_VEHICLE_01</span>
<span>: {currentLayerLabel}</span>
<span>: {(scale * 100).toFixed(0)}%</span>
<span>: {frameMasks.length}</span>
<span>: {savedMaskCount}</span>

View File

@@ -51,7 +51,7 @@ describe('FrameTimeline', () => {
expect(screen.getAllByText('00:00.20').length).toBeGreaterThan(0);
});
it('overlays edited frame markers as amber vertical lines on the time progress bar', () => {
it('marks propagated frames as light-blue progress bar segments', () => {
useStore.setState({
currentFrameIndex: 1,
frames: [
@@ -61,20 +61,26 @@ describe('FrameTimeline', () => {
],
masks: [
{ id: 'm1', frameId: 'f2', pathData: 'M 0 0 Z', label: 'Draft', color: '#06b6d4' },
{ id: 'm2', frameId: 'f3', annotationId: '9', pathData: 'M 0 0 Z', label: 'Saved', color: '#22c55e' },
{
id: 'm2',
frameId: 'f3',
annotationId: '9',
pathData: 'M 0 0 Z',
label: 'Saved',
color: '#22c55e',
metadata: { source: 'sam2.1_hiera_tiny_propagation' },
},
{ id: 'outside', frameId: 'other-frame', pathData: 'M 0 0 Z', label: 'Other', color: '#fff' },
],
});
render(<FrameTimeline />);
expect(screen.getByText('已编辑 2 帧')).toBeInTheDocument();
expect(screen.getByText('自动传播 1 帧')).toBeInTheDocument();
expect(screen.queryByTestId('current-frame-line')).not.toBeInTheDocument();
expect(screen.getByLabelText('跳转到已编辑帧 2').className).toContain('before:bg-amber-300');
expect(screen.getByLabelText('跳转到已编辑帧 3').className).toContain('before:h-5');
expect(screen.getByLabelText('跳转到已编辑帧 3').className).not.toContain('h-2 w-2');
fireEvent.click(screen.getByLabelText('跳转到已编辑帧 3'));
expect(useStore.getState().currentFrameIndex).toBe(2);
expect(screen.getAllByTestId('propagated-frame-segment')).toHaveLength(1);
expect(screen.getByTestId('propagated-frame-segment').className).toContain('bg-sky-200');
expect(screen.queryByLabelText('跳转到已编辑帧 3')).not.toBeInTheDocument();
});
it('changes frames with left and right arrow keys without leaving bounds', () => {

View File

@@ -23,16 +23,20 @@ export function FrameTimeline() {
}, [currentProject?.original_fps, currentProject?.parse_fps]);
const currentSeconds = totalFrames > 0 ? currentFrameIndex / timeBaseFps : 0;
const totalSeconds = totalFrames > 0 ? Math.max(totalFrames - 1, 0) / timeBaseFps : 0;
const editedFrameMarkers = useMemo(() => {
const propagatedFrameMarkers = useMemo(() => {
const frameIds = new Set(frames.map((frame) => frame.id));
const editedIds = new Set(
const propagatedIds = new Set(
masks
.filter((mask) => frameIds.has(mask.frameId))
.filter((mask) => {
const source = typeof mask.metadata?.source === 'string' ? mask.metadata.source : '';
return source.includes('_propagation') || mask.metadata?.propagated_from_frame_id !== undefined;
})
.map((mask) => mask.frameId),
);
return frames
.map((frame, index) => ({ frame, index }))
.filter(({ frame }) => editedIds.has(frame.id));
.filter(({ frame }) => propagatedIds.has(frame.id));
}, [frames, masks]);
const formatTime = (seconds: number) => {
@@ -117,21 +121,16 @@ export function FrameTimeline() {
className="h-full bg-cyan-500 absolute left-0"
style={{ width: `${totalFrames > 0 ? (currentFrame / totalFrames) * 100 : 0}%` }}
/>
{editedFrameMarkers.map(({ frame, index }) => {
const left = totalFrames > 0 ? ((index + 1) / totalFrames) * 100 : 0;
{propagatedFrameMarkers.map(({ frame, index }) => {
const left = totalFrames > 0 ? (index / totalFrames) * 100 : 0;
const width = totalFrames > 0 ? 100 / totalFrames : 0;
return (
<button
<div
key={frame.id}
type="button"
aria-label={`跳转到已编辑${index + 1}`}
title={`已编辑帧 ${index + 1}`}
onClick={() => setCurrentFrame(index)}
className={cn(
"absolute left-0 top-1/2 z-30 w-3 -translate-x-1/2 -translate-y-1/2 cursor-pointer rounded-sm transition-all",
"before:absolute before:left-1/2 before:top-1/2 before:w-px before:-translate-x-1/2 before:-translate-y-1/2 before:rounded-full before:content-['']",
"before:h-5 before:bg-amber-300 before:shadow-[0_0_8px_rgba(251,191,36,0.5)] hover:before:h-7 hover:before:bg-amber-100"
)}
style={{ left: `${left}%` }}
data-testid="propagated-frame-segment"
title={`自动传播${index + 1}`}
className="absolute inset-y-0 z-10 bg-sky-200/80 shadow-[0_0_10px_rgba(186,230,253,0.55)]"
style={{ left: `${left}%`, width: `${width}%` }}
/>
);
})}
@@ -143,7 +142,7 @@ export function FrameTimeline() {
</div>
</div>
<div className="absolute bottom-0 right-3 text-[9px] font-mono text-gray-500 pointer-events-none">
{editedFrameMarkers.length}
{propagatedFrameMarkers.length}
</div>
</div>

View File

@@ -1,12 +1,33 @@
import { fireEvent, render, screen, within } from '@testing-library/react';
import { beforeEach, describe, expect, it } from 'vitest';
import { beforeEach, describe, expect, it, vi } from 'vitest';
import { resetStore } from '../test/storeTestUtils';
import { useStore } from '../store/useStore';
import { OntologyInspector } from './OntologyInspector';
const apiMock = vi.hoisted(() => ({
analyzeMask: vi.fn(),
updateTemplate: vi.fn(),
}));
vi.mock('../lib/api', () => ({
analyzeMask: apiMock.analyzeMask,
updateTemplate: apiMock.updateTemplate,
}));
describe('OntologyInspector', () => {
beforeEach(() => {
resetStore();
vi.clearAllMocks();
apiMock.analyzeMask.mockResolvedValue({
confidence: 0.82,
confidence_source: 'model_score',
topology_anchor_count: 4,
topology_anchors: [],
area: 0.1,
bbox: [0, 0, 0.1, 0.1],
source: 'sam2.1_hiera_tiny',
message: '已读取后端几何属性',
});
useStore.setState({
templates: [
{
@@ -49,6 +70,14 @@ describe('OntologyInspector', () => {
useStore.setState({
selectedMaskIds: ['m1'],
masks: [
{
id: 'm2',
frameId: 'frame-1',
pathData: 'M 10 10 Z',
label: '未选区域',
color: '#ffffff',
saveStatus: 'draft',
},
{
id: 'm1',
annotationId: '99',
@@ -66,7 +95,8 @@ describe('OntologyInspector', () => {
fireEvent.click(screen.getByText('肝脏'));
expect(useStore.getState().activeClassId).toBe('c2');
expect(useStore.getState().masks[0]).toEqual(expect.objectContaining({
expect(useStore.getState().masks.map((mask) => mask.id)).toEqual(['m2', 'm1']);
expect(useStore.getState().masks[1]).toEqual(expect.objectContaining({
templateId: 't1',
classId: 'c2',
className: '肝脏',
@@ -80,16 +110,59 @@ describe('OntologyInspector', () => {
expect(screen.getByText('1')).toBeInTheDocument();
});
it('adds custom classes locally without backend persistence', () => {
const { container } = render(<OntologyInspector />);
it('persists custom classes to the active backend template', async () => {
apiMock.updateTemplate.mockResolvedValueOnce({
id: 't1',
name: '腹腔镜模板',
classes: [
{ id: 'c1', name: '胆囊', color: '#ff0000', zIndex: 20, category: '器官' },
{ id: 'c2', name: '肝脏', color: '#00ff00', zIndex: 10, category: '器官' },
{ id: 'custom-1', name: '新局部分类', color: '#06b6d4', zIndex: 30, category: '自定义' },
],
rules: [],
});
render(<OntologyInspector />);
fireEvent.change(screen.getByRole('combobox'), { target: { value: 't1' } });
const customSection = screen.getByText('自定义分类').parentElement!;
fireEvent.click(within(customSection).getByRole('button'));
fireEvent.change(screen.getByPlaceholderText('分类名称'), { target: { value: '新局部分类' } });
fireEvent.keyDown(screen.getByPlaceholderText('分类名称'), { key: 'Enter' });
expect(screen.getAllByText('新局部分类')).toHaveLength(2);
expect(await screen.findByText('自定义分类已保存到后端模板')).toBeInTheDocument();
expect(apiMock.updateTemplate).toHaveBeenCalledWith('t1', expect.objectContaining({
classes: expect.arrayContaining([expect.objectContaining({ name: '新局部分类', category: '自定义' })]),
}));
expect(useStore.getState().activeClass).toEqual(expect.objectContaining({ name: '新局部分类' }));
expect(useStore.getState().templates[0].classes).toHaveLength(2);
expect(container).toHaveTextContent('2 个分类来自模板 + 1 个自定义');
expect(useStore.getState().templates[0].classes).toHaveLength(3);
});
it('loads selected mask properties from the backend analyzer', async () => {
useStore.setState({
frames: [{ id: 'frame-1', projectId: 'p1', index: 0, url: '/1.jpg', width: 100, height: 100 }],
selectedMaskIds: ['m1'],
masks: [
{
id: 'm1',
frameId: 'frame-1',
pathData: 'M 0 0 Z',
label: '胆囊',
color: '#ff0000',
segmentation: [[10, 10, 20, 10, 20, 20]],
metadata: { source: 'sam2.1_hiera_tiny', score: 0.82 },
},
],
});
render(<OntologyInspector />);
expect(await screen.findByText('0.8200')).toBeInTheDocument();
expect(screen.getByText('4 节点')).toBeInTheDocument();
fireEvent.click(screen.getByRole('button', { name: '重新提取拓扑锚点' }));
expect(apiMock.analyzeMask).toHaveBeenLastCalledWith(
expect.objectContaining({ id: 'm1' }),
expect.objectContaining({ id: 'frame-1' }),
{ extractSkeleton: true },
);
});
});

View File

@@ -1,30 +1,39 @@
import React, { useState } from 'react';
import { Layers, ChevronDown, Tag, Eye, Plus, X } from 'lucide-react';
import { Layers, ChevronDown, Tag, Eye, Plus, X, Loader2 } from 'lucide-react';
import { useStore } from '../store/useStore';
import type { TemplateClass } from '../store/useStore';
import { cn } from '../lib/utils';
import { getActiveTemplate } from '../lib/templateSelection';
import { analyzeMask, updateTemplate, type MaskAnalysisResult } from '../lib/api';
export function OntologyInspector() {
const templates = useStore((state) => state.templates);
const activeTemplateId = useStore((state) => state.activeTemplateId);
const activeClassId = useStore((state) => state.activeClassId);
const activeClass = useStore((state) => state.activeClass);
const frames = useStore((state) => state.frames);
const currentFrameIndex = useStore((state) => state.currentFrameIndex);
const masks = useStore((state) => state.masks);
const selectedMaskIds = useStore((state) => state.selectedMaskIds);
const setMasks = useStore((state) => state.setMasks);
const updateTemplateStore = useStore((state) => state.updateTemplate);
const setActiveTemplateId = useStore((state) => state.setActiveTemplateId);
const setActiveClass = useStore((state) => state.setActiveClass);
// Project-level custom classes (in addition to template classes)
const [customClasses, setCustomClasses] = useState<TemplateClass[]>([]);
const [showAddForm, setShowAddForm] = useState(false);
const [newClassName, setNewClassName] = useState('');
const [newClassColor, setNewClassColor] = useState('#06b6d4');
const [isSavingClass, setIsSavingClass] = useState(false);
const [classSaveMessage, setClassSaveMessage] = useState('');
const [maskAnalysis, setMaskAnalysis] = useState<MaskAnalysisResult | null>(null);
const [isAnalyzingMask, setIsAnalyzingMask] = useState(false);
const [analysisMessage, setAnalysisMessage] = useState('');
const activeTemplate = getActiveTemplate(templates, activeTemplateId);
const templateClasses = activeTemplate?.classes || [];
const allClasses = [...templateClasses, ...customClasses].sort((a, b) => b.zIndex - a.zIndex);
const allClasses = [...templateClasses].sort((a, b) => b.zIndex - a.zIndex);
const selectedMask = masks.find((mask) => selectedMaskIds.includes(mask.id)) || null;
const currentFrame = frames[currentFrameIndex] || null;
const handleSelectClass = (templateClass: TemplateClass) => {
if (activeTemplate && !activeTemplateId) {
@@ -36,7 +45,7 @@ export function OntologyInspector() {
if (!hasSelectedMasks) return;
const templateId = activeTemplate?.id || activeTemplateId || undefined;
setMasks(masks.map((mask) => {
const updatedMasks = masks.map((mask) => {
if (!selectedIdSet.has(mask.id)) return mask;
return {
...mask,
@@ -46,15 +55,53 @@ export function OntologyInspector() {
classZIndex: templateClass.zIndex,
label: templateClass.name,
color: templateClass.color,
saveStatus: mask.annotationId ? 'dirty' : 'draft',
saveStatus: mask.annotationId ? 'dirty' as const : 'draft' as const,
saved: mask.annotationId ? false : mask.saved,
};
}));
});
const selectedMasksOnTop = selectedMaskIds
.map((id) => updatedMasks.find((mask) => mask.id === id))
.filter((mask): mask is (typeof updatedMasks)[number] => Boolean(mask));
setMasks([
...updatedMasks.filter((mask) => !selectedIdSet.has(mask.id)),
...selectedMasksOnTop,
]);
};
const handleAddCustom = () => {
const refreshMaskAnalysis = async (extractSkeleton = false) => {
if (!selectedMask || !currentFrame) {
setMaskAnalysis(null);
setAnalysisMessage(selectedMask ? '当前帧信息不可用,无法读取后端属性' : '请选择一个 mask 查看后端属性');
return;
}
setIsAnalyzingMask(true);
setAnalysisMessage('');
try {
const result = await analyzeMask(selectedMask, currentFrame, { extractSkeleton });
setMaskAnalysis(result);
setAnalysisMessage(result.message);
} catch (err) {
console.error('Mask analysis failed:', err);
setMaskAnalysis(null);
setAnalysisMessage('后端属性读取失败');
} finally {
setIsAnalyzingMask(false);
}
};
React.useEffect(() => {
void refreshMaskAnalysis(false);
// selectedMask is intentionally tracked by id and geometry fields to avoid
// re-running analysis for unrelated store changes.
}, [selectedMask?.id, selectedMask?.segmentation, selectedMask?.points, currentFrame?.id]);
const handleAddCustom = async () => {
if (!newClassName.trim()) return;
const maxZ = allClasses.length > 0 ? Math.max(...allClasses.map((c) => c.zIndex)) : 0;
if (!activeTemplate) {
setClassSaveMessage('请先选择一个模板');
return;
}
const maxZ = templateClasses.length > 0 ? Math.max(...templateClasses.map((c) => c.zIndex)) : 0;
const newClass: TemplateClass = {
id: `custom-${Date.now()}`,
name: newClassName.trim(),
@@ -62,10 +109,27 @@ export function OntologyInspector() {
zIndex: maxZ + 10,
category: '自定义',
};
setCustomClasses([...customClasses, newClass]);
handleSelectClass(newClass);
setNewClassName('');
setShowAddForm(false);
setIsSavingClass(true);
setClassSaveMessage('');
try {
const updated = await updateTemplate(activeTemplate.id, {
name: activeTemplate.name,
description: activeTemplate.description,
classes: [...templateClasses, newClass],
rules: activeTemplate.rules || [],
});
updateTemplateStore(updated);
setActiveTemplateId(updated.id);
handleSelectClass(newClass);
setNewClassName('');
setShowAddForm(false);
setClassSaveMessage('自定义分类已保存到后端模板');
} catch (err) {
console.error('Save custom class failed:', err);
setClassSaveMessage('自定义分类保存失败');
} finally {
setIsSavingClass(false);
}
};
return (
@@ -98,7 +162,6 @@ export function OntologyInspector() {
{activeTemplate && (
<div className="mt-2 text-[10px] text-gray-600">
{activeTemplate.classes?.length ?? 0}
{customClasses.length > 0 && ` + ${customClasses.length} 个自定义`}
</div>
)}
</div>
@@ -165,7 +228,7 @@ export function OntologyInspector() {
onKeyDown={(e) => e.key === 'Enter' && handleAddCustom()}
/>
<button onClick={handleAddCustom} className="text-cyan-400 hover:text-cyan-300">
<Plus size={14} />
{isSavingClass ? <Loader2 size={14} className="animate-spin" /> : <Plus size={14} />}
</button>
<button onClick={() => setShowAddForm(false)} className="text-gray-500 hover:text-gray-300">
<X size={14} />
@@ -173,6 +236,9 @@ export function OntologyInspector() {
</div>
</div>
)}
{classSaveMessage && (
<div className="mt-2 text-[10px] text-gray-500">{classSaveMessage}</div>
)}
</div>
{/* Current Active Object Properties */}
@@ -191,18 +257,30 @@ export function OntologyInspector() {
<span className="text-xs font-mono text-gray-300">{selectedMaskIds.length}</span>
</div>
<div className="space-y-1">
<label className="text-[10px] text-gray-500 uppercase"></label>
<label className="text-[10px] text-gray-500 uppercase"></label>
<div className="h-1.5 w-full bg-white/10 rounded-full overflow-hidden">
<div className="h-full bg-green-500 w-[94%]" />
<div
className="h-full bg-green-500"
style={{ width: `${Math.round((maskAnalysis?.confidence ?? 0) * 100)}%` }}
/>
</div>
<div className="text-[10px] font-mono text-green-500 text-right">
{maskAnalysis?.confidence != null ? maskAnalysis.confidence.toFixed(4) : '无模型分数'}
</div>
<div className="text-[10px] font-mono text-green-500 text-right">0.9412</div>
</div>
<div className="flex items-center justify-between">
<span className="text-[10px] text-gray-500 uppercase">:</span>
<span className="text-xs font-mono text-gray-300">12 </span>
<span className="text-[10px] text-gray-500 uppercase">:</span>
<span className="text-xs font-mono text-gray-300">{maskAnalysis?.topology_anchor_count ?? 0} </span>
</div>
<button className="w-full mt-2 bg-white/5 hover:bg-white/10 border border-white/10 text-xs text-gray-300 py-1.5 rounded transition-colors">
{analysisMessage && (
<div className="text-[10px] leading-relaxed text-gray-500">{analysisMessage}</div>
)}
<button
onClick={() => refreshMaskAnalysis(true)}
disabled={!selectedMask || isAnalyzingMask}
className="w-full mt-2 bg-white/5 hover:bg-white/10 border border-white/10 text-xs text-gray-300 py-1.5 rounded transition-colors disabled:opacity-40 disabled:cursor-not-allowed"
>
{isAnalyzingMask ? '提取中...' : '重新提取拓扑锚点'}
</button>
</div>
</div>

View File

@@ -37,7 +37,7 @@ describe('ToolsPalette', () => {
const onTriggerAI = vi.fn();
render(<ToolsPalette activeTool="move" setActiveTool={setActiveTool} onTriggerAI={onTriggerAI} />);
fireEvent.click(screen.getByTitle('触发 SAM 推理 (Enter)'));
fireEvent.click(screen.getByTitle('打开 AI 智能分割'));
expect(setActiveTool).toHaveBeenCalledWith('sam_trigger');
expect(onTriggerAI).toHaveBeenCalled();

View File

@@ -91,7 +91,7 @@ export function ToolsPalette({
setActiveTool('sam_trigger');
if (onTriggerAI) onTriggerAI();
}}
title="触发 SAM 推理 (Enter)"
title="打开 AI 智能分割"
className={cn(
"w-10 h-10 rounded-lg flex items-center justify-center transition-all",
activeTool === 'sam_trigger'

View File

@@ -380,7 +380,7 @@ describe('VideoWorkspace', () => {
]));
});
it('propagates the selected current-frame mask through the configured frame range', async () => {
it('auto-propagates reference-frame masks through the configured frame range', async () => {
apiMock.getProjectFrames.mockResolvedValueOnce([
{ id: 10, project_id: 1, frame_index: 0, image_url: '/frame.jpg', width: 640, height: 360 },
{ id: 11, project_id: 1, frame_index: 1, image_url: '/frame-1.jpg', width: 640, height: 360 },
@@ -404,7 +404,6 @@ describe('VideoWorkspace', () => {
useStore.setState({
aiModel: 'sam2.1_hiera_tiny',
activeTemplateId: '2',
selectedMaskIds: ['mask-1'],
masks: [{
id: 'mask-1',
frameId: '10',
@@ -417,7 +416,7 @@ describe('VideoWorkspace', () => {
});
});
fireEvent.click(screen.getByRole('button', { name: '按范围传播' }));
fireEvent.click(screen.getByRole('button', { name: '自动传播' }));
await waitFor(() => expect(apiMock.propagateMasks).toHaveBeenCalledWith({
project_id: 1,
@@ -437,10 +436,10 @@ describe('VideoWorkspace', () => {
template_id: 2,
},
}));
await waitFor(() => expect(screen.getByText('已传播 1 个 seed,处理 3 帧次,保存 2 个区域')).toBeInTheDocument());
await waitFor(() => expect(screen.getByText('已自动传播 1 个参考 mask,处理 3 帧次,保存 2 个区域')).toBeInTheDocument());
});
it('propagates all current-frame masks to all reachable frames in both directions', async () => {
it('auto-propagates all reference-frame masks in both directions inside the selected range', async () => {
apiMock.getProjectFrames.mockResolvedValueOnce([
{ id: 10, project_id: 1, frame_index: 0, image_url: '/frame-0.jpg', width: 640, height: 360 },
{ id: 11, project_id: 1, frame_index: 1, image_url: '/frame-1.jpg', width: 640, height: 360 },
@@ -502,8 +501,9 @@ describe('VideoWorkspace', () => {
});
});
fireEvent.change(screen.getByLabelText('传播对象'), { target: { value: 'all' } });
fireEvent.click(screen.getByRole('button', { name: '传播全部可达' }));
fireEvent.change(screen.getByLabelText('传播起始帧'), { target: { value: '1' } });
fireEvent.change(screen.getByLabelText('传播结束帧'), { target: { value: '3' } });
fireEvent.click(screen.getByRole('button', { name: '自动传播' }));
await waitFor(() => expect(apiMock.propagateMasks).toHaveBeenCalledTimes(4));
expect(apiMock.propagateMasks).toHaveBeenNthCalledWith(1, expect.objectContaining({
@@ -526,6 +526,6 @@ describe('VideoWorkspace', () => {
max_frames: 2,
seed: expect.objectContaining({ label: '肝脏' }),
}));
await waitFor(() => expect(screen.getByText('已传播 2 个 seed,处理 8 帧次,保存 4 个区域')).toBeInTheDocument());
await waitFor(() => expect(screen.getByText('已自动传播 2 个参考 mask,处理 8 帧次,保存 4 个区域')).toBeInTheDocument());
});
});

View File

@@ -21,7 +21,6 @@ import { FrameTimeline } from './FrameTimeline';
import { ModelStatusBadge } from './ModelStatusBadge';
import type { Frame, Mask } from '../store/useStore';
type PropagationTarget = 'selected' | 'all';
type PropagationDirection = 'forward' | 'backward';
export function VideoWorkspace({ onNavigateToAI }: { onNavigateToAI?: () => void }) {
@@ -52,7 +51,6 @@ export function VideoWorkspace({ onNavigateToAI }: { onNavigateToAI?: () => void
const [isImportingGt, setIsImportingGt] = useState(false);
const [isPropagating, setIsPropagating] = useState(false);
const [statusMessage, setStatusMessage] = useState('');
const [propagationTarget, setPropagationTarget] = useState<PropagationTarget>('selected');
const [propagationStartFrame, setPropagationStartFrame] = useState(1);
const [propagationEndFrame, setPropagationEndFrame] = useState(1);
@@ -354,20 +352,16 @@ export function VideoWorkspace({ onNavigateToAI }: { onNavigateToAI?: () => void
};
}, [activeTemplateId, currentFrame, currentProject?.id]);
const handlePropagateSegment = async (rangeOverride?: { startFrameNumber: number; endFrameNumber: number }) => {
const handleAutoPropagate = async () => {
if (!currentProject?.id || !currentFrame?.id) return;
const currentFrameMasks = masks.filter((mask) => mask.frameId === currentFrame.id);
const selectedMasks = selectedMaskIds
.map((id) => currentFrameMasks.find((mask) => mask.id === id))
.filter((mask): mask is Mask => Boolean(mask));
const seedMasks = propagationTarget === 'all' ? currentFrameMasks : selectedMasks;
const seedMasks = masks.filter((mask) => mask.frameId === currentFrame.id);
if (seedMasks.length === 0) {
setStatusMessage(propagationTarget === 'all' ? '当前帧没有可传播区域' : '请先选择一个或多个当前帧区域');
setStatusMessage('请先在当前参考帧创建或保存至少一个 mask');
return;
}
const startFrameNumber = clampFrameNumber(rangeOverride?.startFrameNumber ?? propagationStartFrame);
const endFrameNumber = clampFrameNumber(rangeOverride?.endFrameNumber ?? propagationEndFrame);
const startFrameNumber = clampFrameNumber(propagationStartFrame);
const endFrameNumber = clampFrameNumber(propagationEndFrame);
const rangeStartIndex = Math.min(startFrameNumber, endFrameNumber) - 1;
const rangeEndIndex = Math.max(startFrameNumber, endFrameNumber) - 1;
const propagationDirections: Array<{ direction: PropagationDirection; maxFrames: number }> = [];
@@ -397,7 +391,7 @@ export function VideoWorkspace({ onNavigateToAI }: { onNavigateToAI?: () => void
}
setIsPropagating(true);
setStatusMessage(`${aiModel.toUpperCase()} 正在传播 ${seeds.length}区域到第 ${rangeStartIndex + 1}-${rangeEndIndex + 1} 帧...`);
setStatusMessage(`${aiModel.toUpperCase()} 正在以第 ${currentFrameNumber} 帧为参考,自动传播 ${seeds.length} mask 到第 ${rangeStartIndex + 1}-${rangeEndIndex + 1} 帧...`);
try {
let createdCount = 0;
let processedCount = 0;
@@ -418,7 +412,7 @@ export function VideoWorkspace({ onNavigateToAI }: { onNavigateToAI?: () => void
}
}
await hydrateSavedAnnotations(currentProject.id, frames);
setStatusMessage(`已传播 ${seeds.length} seed,处理 ${processedCount} 帧次,保存 ${createdCount} 个区域`);
setStatusMessage(`自动传播 ${seeds.length}参考 mask,处理 ${processedCount} 帧次,保存 ${createdCount} 个区域`);
} catch (err) {
console.error('Propagation failed:', err);
setStatusMessage('传播失败,请检查模型状态或后端日志');
@@ -427,16 +421,6 @@ export function VideoWorkspace({ onNavigateToAI }: { onNavigateToAI?: () => void
}
};
const handlePropagateAllReachable = () => {
if (totalFrames <= 1) {
setStatusMessage('当前项目没有可传播的前后帧');
return;
}
setPropagationStartFrame(1);
setPropagationEndFrame(totalFrames);
void handlePropagateSegment({ startFrameNumber: 1, endFrameNumber: totalFrames });
};
return (
<div className="w-full h-full flex flex-col bg-[#0a0a0a]">
{/* Top Header / Status bar */}
@@ -468,16 +452,7 @@ export function VideoWorkspace({ onNavigateToAI }: { onNavigateToAI?: () => void
{isImportingGt ? '导入中...' : '导入 GT Mask'}
</button>
<div className="flex items-center gap-1 rounded-md border border-white/10 bg-white/[0.03] px-2 py-1">
<select
aria-label="传播对象"
value={propagationTarget}
onChange={(event) => setPropagationTarget(event.target.value as PropagationTarget)}
disabled={isPropagating || isSaving || isExporting || isImportingGt}
className="h-6 bg-transparent text-[10px] text-gray-300 outline-none disabled:opacity-40"
>
<option value="selected"></option>
<option value="all"></option>
</select>
<span className="text-[10px] text-gray-500 whitespace-nowrap"> {currentFrameNumber || 0}</span>
<span className="text-[10px] text-gray-600"></span>
<input
aria-label="传播起始帧"
@@ -502,18 +477,11 @@ export function VideoWorkspace({ onNavigateToAI }: { onNavigateToAI?: () => void
/>
</div>
<button
onClick={() => handlePropagateSegment()}
onClick={handleAutoPropagate}
disabled={!currentProject?.id || !currentFrame?.id || isSaving || isExporting || isImportingGt || isPropagating}
className="px-4 py-1.5 bg-white/5 hover:bg-white/10 border border-white/10 rounded-md text-xs transition-colors text-white disabled:opacity-40 disabled:cursor-not-allowed"
>
{isPropagating ? '传播中...' : '按范围传播'}
</button>
<button
onClick={handlePropagateAllReachable}
disabled={!currentProject?.id || !currentFrame?.id || totalFrames <= 1 || isSaving || isExporting || isImportingGt || isPropagating}
className="px-4 py-1.5 bg-white/5 hover:bg-white/10 border border-white/10 rounded-md text-xs transition-colors text-white disabled:opacity-40 disabled:cursor-not-allowed"
>
{isPropagating ? '传播中...' : '自动传播'}
</button>
<button
onClick={handleExportMasks}

View File

@@ -327,6 +327,8 @@ describe('api client contracts', () => {
label: '旧标签',
color: '#06b6d4',
class: { id: 'c1', name: '胆囊', color: '#ff0000', zIndex: 20 },
source: 'sam2.1_hiera_tiny_propagation',
propagated_from_frame_id: 4,
},
points: [[0.5, 0.5]],
bbox: null,
@@ -347,6 +349,10 @@ describe('api client contracts', () => {
pathData: 'M 10 10 L 90 10 L 90 40 Z',
points: [[50, 25]],
bbox: [10, 10, 80, 30],
metadata: {
source: 'sam2.1_hiera_tiny_propagation',
propagated_from_frame_id: 4,
},
}));
});
@@ -423,6 +429,48 @@ describe('api client contracts', () => {
});
});
it('sends normalized mask geometry to the backend analyzer', async () => {
const { analyzeMask } = await import('./api');
axiosMock.client.post.mockResolvedValueOnce({
data: {
confidence: 0.87,
confidence_source: 'model_score',
topology_anchor_count: 3,
topology_anchors: [],
area: 0.12,
bbox: [0.1, 0.2, 0.8, 0.6],
source: 'sam2.1_hiera_tiny',
message: 'ok',
},
});
const result = await analyzeMask({
id: 'm1',
frameId: '5',
pathData: 'M 10 10 L 90 10 L 90 40 Z',
label: '胆囊',
color: '#ff0000',
segmentation: [[10, 10, 90, 10, 90, 40]],
bbox: [10, 10, 80, 30],
metadata: { source: 'sam2.1_hiera_tiny', score: 0.87 },
}, { id: '5', projectId: '9', index: 0, url: '/frame.jpg', width: 100, height: 50 }, { extractSkeleton: true });
expect(axiosMock.client.post).toHaveBeenCalledWith('/api/ai/analyze-mask', {
frame_id: 5,
mask_data: {
polygons: [[[0.1, 0.2], [0.9, 0.2], [0.9, 0.8]]],
label: '胆囊',
color: '#ff0000',
source: 'sam2.1_hiera_tiny',
score: 0.87,
},
points: undefined,
bbox: [0.1, 0.2, 0.8, 0.6],
extract_skeleton: true,
});
expect(result.confidence).toBe(0.87);
});
it('normalizes combined box and point prompts for interactive SAM2 refinement', async () => {
const { predictMask } = await import('./api');
axiosMock.client.post.mockResolvedValueOnce({ data: { polygons: [], scores: [] } });

View File

@@ -294,6 +294,11 @@ export interface SavedAnnotation {
zIndex?: number;
category?: string;
};
source?: string;
propagated_from_frame_id?: number;
propagated_from_frame_index?: number;
score?: number | null;
[key: string]: unknown;
} | null;
points: number[][] | null;
bbox: number[] | null;
@@ -357,6 +362,17 @@ export interface PropagateMasksResult {
annotations: SavedAnnotation[];
}
export interface MaskAnalysisResult {
confidence: number | null;
confidence_source: string;
topology_anchor_count: number;
topology_anchors: number[][];
area: number;
bbox?: number[] | null;
source?: string | null;
message: string;
}
export interface DashboardTask {
id: string;
task_id?: number;
@@ -498,6 +514,8 @@ export function annotationToMask(annotation: SavedAnnotation, frame: Frame): Mas
if (!firstPolygon || firstPolygon.length === 0) return null;
const bbox = polygonToBbox(firstPolygon, frame.width, frame.height);
const classMetadata = annotation.mask_data?.class;
const { polygons: _polygons, label: _label, color: _color, class: _classMetadata, ...metadata } = annotation.mask_data || {};
const hasMetadata = Object.keys(metadata).length > 0;
return {
id: `annotation-${annotation.id}`,
annotationId: String(annotation.id),
@@ -515,9 +533,39 @@ export function annotationToMask(annotation: SavedAnnotation, frame: Frame): Mas
points: annotation.points?.map(([x, y]) => [x * frame.width, y * frame.height]),
bbox,
area: bbox[2] * bbox[3],
metadata: hasMetadata ? metadata : undefined,
};
}
export async function analyzeMask(mask: Mask, frame: Frame, options: { extractSkeleton?: boolean } = {}): Promise<MaskAnalysisResult> {
const polygons = pixelSegmentationToNormalizedPolygons(mask.segmentation, frame.width, frame.height);
const metadata = mask.metadata || {};
const response = await apiClient.post('/api/ai/analyze-mask', {
frame_id: Number(frame.id),
mask_data: {
polygons,
label: mask.label,
color: mask.color,
...(typeof metadata.source === 'string' ? { source: metadata.source } : {}),
...(typeof metadata.score === 'number' ? { score: metadata.score } : {}),
},
points: mask.points?.map(([x, y]) => [
clamp01(x / Math.max(frame.width, 1)),
clamp01(y / Math.max(frame.height, 1)),
]),
bbox: mask.bbox
? [
clamp01(mask.bbox[0] / Math.max(frame.width, 1)),
clamp01(mask.bbox[1] / Math.max(frame.height, 1)),
clamp01(mask.bbox[2] / Math.max(frame.width, 1)),
clamp01(mask.bbox[3] / Math.max(frame.height, 1)),
]
: undefined,
extract_skeleton: options.extractSkeleton ?? false,
});
return response.data;
}
export async function predictMask(payload: PredictMaskPayload): Promise<PredictMaskResult> {
let prompt_type: 'point' | 'box' | 'semantic' | 'interactive';
let prompt_data: unknown;

View File

@@ -80,6 +80,22 @@ vi.mock('react-konva', () => ({
props.onClick?.(konvaEvent);
if (konvaEvent.cancelBubble) event.stopPropagation();
}}
onMouseDown={(event) => {
const point = {
x: event.clientX || props.x || 120,
y: event.clientY || props.y || 80,
};
const konvaEvent = { ...makeStageEvent(point.x, point.y), cancelBubble: false };
props.onMouseDown?.(konvaEvent);
props.onDragStart?.(konvaEvent);
if (konvaEvent.cancelBubble) event.stopPropagation();
}}
onMouseMove={(event) => props.onDragMove?.({
target: {
x: () => event.clientX || props.x || 0,
y: () => event.clientY || props.y || 0,
},
})}
onMouseUp={(event: React.MouseEvent<HTMLSpanElement>) => props.onDragEnd?.({
target: {
x: () => event.clientX || props.x || 0,
@@ -100,6 +116,9 @@ vi.mock('react-konva', () => ({
data-testid="konva-path"
data-path={props.data}
data-fill={props.fill}
data-stroke={props.stroke}
data-stroke-width={props.strokeWidth}
data-dash={props.dash?.join(',') || ''}
data-fill-rule={props.fillRule}
onClick={(event) => {
const point = {