feat: 完善 AI 分割与工作区标注闭环

功能增加:

- 将视频导入和生成帧拆成两个明确动作,项目库生成帧时选择 FPS,工作区不再自动触发拆帧。

- 为工作区新增调整多边形工具,支持选中 mask、拖动顶点、边中点插点、双击边界按位置插点,并保留多 polygon 子区域编辑。

- 打通 AI 页 SAM2/SAM3 结果到工作区的联动,生成 mask 后自动选中,可在右侧分类树换标签,并推送到工作区继续编辑。

- 增强 Dashboard WebSocket 连接状态与心跳,使用真实 onopen/onclose/onerror 状态驱动前端显示。

- 完善 SAM3 external worker 适配,支持 box prompt、semantic 请求级阈值和 video tracker 路径。

bugfix:

- 修复 SAM2 文本语义误走自动分割的问题,改为提示使用点提示或切换 SAM3。

- 修复 SAM2 多候选重叠显示的问题,点提示和 auto fallback 默认只采用最高分候选。

- 修复 SAM2 反向点看起来无效的问题,带负点时启用背景过滤,过滤为空时移除旧候选。

- 修复 SAM3 单个 2D mask 结果无法转 polygon、低阈值 semantic 返回被默认阈值吞掉的问题。

- 修复 AI 页 mask 未选中导致分类树无法修改 SAM2 结果标签的问题。

测试和文档:

- 补充 CanvasArea、AISegmentation、ProjectLibrary、VideoWorkspace、Dashboard、websocket 和 SAM engine/API 测试。

- 新增 backend/tests/test_sam2_engine.py,覆盖 SAM2 单候选请求和 auto fallback 行为。

- 更新 README、AGENTS 和 doc 需求/设计/接口/测试矩阵,按当前实现冻结功能状态。
This commit is contained in:
2026-05-01 21:50:17 +08:00
parent 5ab4602535
commit 8a9247075e
31 changed files with 920 additions and 216 deletions

View File

@@ -63,6 +63,129 @@ describe('AISegmentation', () => {
}));
});
it('does not run SAM2 text-only prompts as semantic segmentation', async () => {
render(<AISegmentation onSendToWorkspace={vi.fn()} />);
fireEvent.change(screen.getByPlaceholderText("例如:'分割出左侧车道上行驶的所有红色汽车'..."), {
target: { value: '胆囊' },
});
fireEvent.click(await screen.findByText('执行高精度语义分割'));
expect(apiMock.predictMask).not.toHaveBeenCalled();
expect(await screen.findByText('SAM2 不支持文本语义提示;请先放置正/反向点,或切换到 SAM3 使用文本语义。')).toBeInTheDocument();
});
it('keeps only the best SAM2 candidate when the backend returns overlapping alternatives', async () => {
apiMock.predictMask.mockResolvedValueOnce({
masks: [
{
id: 'sam2-best',
pathData: 'M 0 0 L 10 0 L 10 10 Z',
label: 'AI Mask',
color: '#06b6d4',
segmentation: [[0, 0, 10, 0, 10, 10]],
bbox: [0, 0, 10, 10],
area: 100,
},
{
id: 'sam2-alt',
pathData: 'M 1 1 L 11 1 L 11 11 Z',
label: 'AI Mask',
color: '#06b6d4',
segmentation: [[1, 1, 11, 1, 11, 11]],
bbox: [1, 1, 10, 10],
area: 100,
},
],
});
render(<AISegmentation onSendToWorkspace={vi.fn()} />);
fireEvent.click(screen.getByText('正向选点'));
fireEvent.click(screen.getByTestId('konva-stage'));
fireEvent.click(await screen.findByText('执行高精度语义分割'));
await waitFor(() => expect(useStore.getState().masks).toHaveLength(1));
expect(useStore.getState().masks[0].id).toBe('sam2-best');
expect(useStore.getState().selectedMaskIds).toEqual(['sam2-best']);
expect(await screen.findByText('SAM2 返回 2 个候选,已采用最高分区域。')).toBeInTheDocument();
});
it('lets a SAM2 result be selected and relabeled from the ontology panel', async () => {
useStore.setState({
templates: [
{
id: 'template-1',
name: '腹腔镜模板',
classes: [
{ id: 'class-1', name: '胆囊', color: '#ff0000', zIndex: 30 },
{ id: 'class-2', name: '肝脏', color: '#00ff00', zIndex: 20 },
],
rules: [],
},
],
});
apiMock.predictMask.mockResolvedValueOnce({
masks: [
{
id: 'sam2-mask',
pathData: 'M 10 10 L 40 10 L 40 40 Z',
label: 'AI Mask',
color: '#06b6d4',
segmentation: [[10, 10, 40, 10, 40, 40]],
bbox: [10, 10, 30, 30],
area: 900,
},
],
});
render(<AISegmentation onSendToWorkspace={vi.fn()} />);
fireEvent.click(screen.getByText('正向选点'));
fireEvent.click(screen.getByTestId('konva-stage'));
fireEvent.click(await screen.findByText('执行高精度语义分割'));
await waitFor(() => expect(useStore.getState().selectedMaskIds).toEqual(['sam2-mask']));
fireEvent.click(screen.getByText('肝脏'));
expect(useStore.getState().masks[0]).toEqual(expect.objectContaining({
templateId: 'template-1',
classId: 'class-2',
className: '肝脏',
classZIndex: 20,
label: '肝脏',
color: '#00ff00',
saveStatus: 'draft',
}));
});
it('keeps the generated SAM2 mask selected when sending it to the workspace editor', async () => {
const onSendToWorkspace = vi.fn();
apiMock.predictMask.mockResolvedValueOnce({
masks: [
{
id: 'sam2-mask',
pathData: 'M 10 10 L 40 10 L 40 40 Z',
label: 'AI Mask',
color: '#06b6d4',
segmentation: [[10, 10, 40, 10, 40, 40]],
bbox: [10, 10, 30, 30],
area: 900,
},
],
});
render(<AISegmentation onSendToWorkspace={onSendToWorkspace} />);
fireEvent.click(screen.getByText('正向选点'));
fireEvent.click(screen.getByTestId('konva-stage'));
fireEvent.click(await screen.findByText('执行高精度语义分割'));
await waitFor(() => expect(useStore.getState().selectedMaskIds).toEqual(['sam2-mask']));
fireEvent.click(screen.getByText('推送至工作区编辑'));
expect(useStore.getState().activeTool).toBe('edit_polygon');
expect(useStore.getState().selectedMaskIds).toEqual(['sam2-mask']);
expect(onSendToWorkspace).toHaveBeenCalled();
});
it('prompts for semantic text before running SAM3 inference', async () => {
apiMock.getAiModelStatus.mockResolvedValue({
selected_model: 'sam3',
@@ -106,7 +229,7 @@ describe('AISegmentation', () => {
points: undefined,
text: '胆囊',
})));
expect(await screen.findByText('模型没有返回可用区域,请换一个更具体的描述或调整提示。')).toBeInTheDocument();
expect(await screen.findByText('SAM3 已完成语义推理,但没有返回区域。请尝试英文目标描述,或换到包含该目标的帧。当前提示: 胆囊')).toBeInTheDocument();
});
it('runs SAM3 semantic text inference and assigns the active class to returned masks', async () => {

View File

@@ -17,6 +17,8 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
const masks = useStore((state) => state.masks);
const addMask = useStore((state) => state.addMask);
const clearMasks = useStore((state) => state.clearMasks);
const selectedMaskIds = useStore((state) => state.selectedMaskIds);
const setSelectedMaskIds = useStore((state) => state.setSelectedMaskIds);
const maskHistory = useStore((state) => state.maskHistory);
const maskFuture = useStore((state) => state.maskFuture);
const undoMasks = useStore((state) => state.undoMasks);
@@ -97,6 +99,10 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
setInferenceMessage('SAM3 当前使用文本语义提示,请先输入要分割的目标描述。');
return;
}
if (aiModel === 'sam2' && textPrompt && points.length === 0) {
setInferenceMessage('SAM2 不支持文本语义提示;请先放置正/反向点,或切换到 SAM3 使用文本语义。');
return;
}
if (points.length === 0 && !textPrompt) {
setInferenceMessage('请先放置正/反向提示点,或输入语义描述。');
return;
@@ -132,14 +138,22 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
},
});
if (result.masks.length === 0) {
setInferenceMessage('模型没有返回可用区域,请换一个更具体的描述或调整提示。');
const masksToApply = aiModel === 'sam2' ? result.masks.slice(0, 1) : result.masks;
if (masksToApply.length === 0) {
setInferenceMessage(aiModel === 'sam3'
? `SAM3 已完成语义推理,但没有返回区域。请尝试英文目标描述,或换到包含该目标的帧。当前提示: ${textPrompt}`
: '模型没有返回可用区域,请换一个更具体的描述或调整提示。');
} else {
setInferenceMessage(`已生成 ${result.masks.length} 个候选区域。`);
setInferenceMessage(aiModel === 'sam2' && result.masks.length > 1
? `SAM2 返回 ${result.masks.length} 个候选,已采用最高分区域。`
: `已生成 ${masksToApply.length} 个候选区域。`);
}
result.masks.forEach((m) => {
const generatedMaskIds: string[] = [];
masksToApply.forEach((m) => {
const label = activeClass?.name || m.label;
const color = activeClass?.color || m.color;
generatedMaskIds.push(m.id);
addMask({
id: m.id,
frameId: currentFrame.id,
@@ -157,6 +171,9 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
area: m.area,
});
});
if (generatedMaskIds.length > 0) {
setSelectedMaskIds(generatedMaskIds);
}
} catch (err) {
console.error('AI inference failed:', err);
const detail = (err as any)?.response?.data?.detail;
@@ -164,7 +181,7 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
} finally {
setIsInferencing(false);
}
}, [activeClass, activeTemplateId, addMask, aiModel, autoDeleteBg, cropMode, currentFrame?.height, currentFrame?.id, currentFrame?.width, image?.height, image?.naturalHeight, image?.naturalWidth, image?.width, points, semanticText]);
}, [activeClass, activeTemplateId, addMask, aiModel, autoDeleteBg, cropMode, currentFrame?.height, currentFrame?.id, currentFrame?.width, image?.height, image?.naturalHeight, image?.naturalWidth, image?.width, points, semanticText, setSelectedMaskIds]);
const handleStageClick = (e: any) => {
if (effectiveTool === 'move') return;
@@ -307,10 +324,13 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
</div>
)}
<button
onClick={onSendToWorkspace}
onClick={() => {
setActiveTool('edit_polygon');
onSendToWorkspace();
}}
className="w-full py-3.5 rounded-lg flex items-center justify-center gap-2 transition-all font-medium tracking-wide text-xs uppercase bg-white/5 hover:bg-white/10 text-gray-300 border border-white/5 hover:border-white/10"
>
<SendToBack size={16} /> 退
<SendToBack size={16} />
</button>
</div>
</aside>
@@ -376,12 +396,20 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
{/* AI Returned Masks */}
{frameMasks.map((mask) => (
<Group key={mask.id} opacity={0.45}>
<Group key={mask.id} opacity={selectedMaskIds.includes(mask.id) ? 0.72 : 0.45}>
<Path
data={mask.pathData}
fill={mask.color}
stroke={mask.color}
strokeWidth={1 / scale}
strokeWidth={(selectedMaskIds.includes(mask.id) ? 2.5 : 1) / scale}
onClick={(event: any) => {
event.cancelBubble = true;
setSelectedMaskIds([mask.id]);
}}
onTap={(event: any) => {
event.cancelBubble = true;
setSelectedMaskIds([mask.id]);
}}
/>
</Group>
))}

View File

@@ -206,16 +206,58 @@ describe('CanvasArea', () => {
{ x: 300, y: 150, type: 'neg' },
],
box: { x1: 120, y1: 80, x2: 260, y2: 200 },
options: { auto_filter_background: true, min_score: 0.05 },
}));
expect(useStore.getState().masks).toHaveLength(1);
expect(useStore.getState().masks[0]).toEqual(expect.objectContaining({
id: 'mask-box',
segmentation: [[30, 30, 70, 30, 70, 70]],
points: [[150, 100]],
metadata: expect.objectContaining({ promptPointCount: 2 }),
metadata: expect.objectContaining({ promptPointCount: 2, promptNegativePointCount: 1 }),
}));
});
it('removes the SAM2 candidate when a negative point filters it out', async () => {
apiMock.predictMask
.mockResolvedValueOnce({
masks: [
{
id: 'mask-box',
pathData: 'M 10 10 L 90 10 L 90 90 Z',
label: 'AI Mask',
color: '#06b6d4',
segmentation: [[10, 10, 90, 10, 90, 90]],
bbox: [10, 10, 80, 80],
area: 6400,
},
],
})
.mockResolvedValueOnce({ masks: [] });
const { rerender } = render(<CanvasArea activeTool="box_select" frame={frame} />);
const stage = screen.getByTestId('konva-stage');
fireEvent.mouseDown(stage, { clientX: 120, clientY: 80 });
fireEvent.mouseMove(stage, { clientX: 260, clientY: 200 });
fireEvent.mouseUp(stage, { clientX: 260, clientY: 200 });
await waitFor(() => expect(useStore.getState().masks).toHaveLength(1));
rerender(<CanvasArea activeTool="point_neg" frame={frame} />);
fireEvent.click(stage, { clientX: 180, clientY: 120 });
await waitFor(() => expect(apiMock.predictMask).toHaveBeenNthCalledWith(2, {
imageId: 'frame-1',
imageWidth: 640,
imageHeight: 360,
model: 'sam2',
points: [{ x: 180, y: 120, type: 'neg' }],
box: { x1: 120, y1: 80, x2: 260, y2: 200 },
options: { auto_filter_background: true, min_score: 0.05 },
}));
await waitFor(() => expect(useStore.getState().masks).toHaveLength(0));
expect(await screen.findByText(/反向点已排除当前候选区域/)).toBeInTheDocument();
});
it('renders only masks that belong to the current frame', () => {
useStore.setState({
masks: [
@@ -250,6 +292,28 @@ describe('CanvasArea', () => {
await waitFor(() => expect(useStore.getState().selectedMaskIds).toEqual(['m1']));
});
it('keeps a mask selected when opening the workspace polygon editor from AI results', () => {
useStore.setState({
selectedMaskIds: ['m1'],
masks: [
{
id: 'm1',
frameId: 'frame-1',
pathData: 'M 0 0 L 10 0 L 10 10 Z',
label: 'A',
color: '#fff',
segmentation: [[0, 0, 10, 0, 10, 10]],
},
],
});
render(<CanvasArea activeTool="edit_polygon" frame={frame} />);
expect(useStore.getState().selectedMaskIds).toEqual(['m1']);
expect(screen.getAllByTestId('konva-circle')
.filter((element) => element.getAttribute('data-fill') === '#ffffff')).toHaveLength(3);
});
it('renders imported GT seed points for editable point regions', () => {
useStore.setState({
masks: [
@@ -415,6 +479,34 @@ describe('CanvasArea', () => {
}));
});
it('selects a polygon with the edit tool and inserts a vertex by double-clicking an edge', () => {
useStore.setState({
masks: [
{
id: 'draft-1',
frameId: 'frame-1',
pathData: 'M 10 10 L 90 10 L 90 40 Z',
label: 'Draft',
color: '#06b6d4',
saveStatus: 'draft',
segmentation: [[10, 10, 90, 10, 90, 40]],
bbox: [10, 10, 80, 30],
},
],
});
render(<CanvasArea activeTool="edit_polygon" frame={frame} />);
const path = screen.getByTestId('konva-path');
fireEvent.click(path);
fireEvent.doubleClick(path, { clientX: 50, clientY: 10 });
expect(useStore.getState().masks[0]).toEqual(expect.objectContaining({
segmentation: [[10, 10, 50, 10, 90, 10, 90, 40]],
pathData: 'M 10 10 L 50 10 L 90 10 L 90 40 Z',
saveStatus: 'draft',
}));
});
it('edits the selected polygon in a multi-polygon mask', () => {
useStore.setState({
masks: [

View File

@@ -19,6 +19,7 @@ type PromptBox = { x1: number; y1: number; x2: number; y2: number };
const DRAG_MANUAL_TOOLS = new Set(['create_rectangle', 'create_circle', 'create_line']);
const POLYGON_TOOL = 'create_polygon';
const EDIT_POLYGON_TOOL = 'edit_polygon';
const POINT_TOOL = 'create_point';
const BOOLEAN_TOOLS = new Set(['area_merge', 'area_remove']);
const POLYGON_CLOSE_RADIUS = 8;
@@ -95,6 +96,32 @@ function pointDistance(a: CanvasPoint, b: CanvasPoint): number {
return Math.hypot(a.x - b.x, a.y - b.y);
}
function distanceToSegmentSquared(point: CanvasPoint, start: CanvasPoint, end: CanvasPoint): number {
const dx = end.x - start.x;
const dy = end.y - start.y;
const lengthSquared = dx * dx + dy * dy;
if (lengthSquared === 0) {
return (point.x - start.x) ** 2 + (point.y - start.y) ** 2;
}
const t = clamp(((point.x - start.x) * dx + (point.y - start.y) * dy) / lengthSquared, 0, 1);
const projected = { x: start.x + t * dx, y: start.y + t * dy };
return (point.x - projected.x) ** 2 + (point.y - projected.y) ** 2;
}
function nearestPolygonEdgeIndex(points: CanvasPoint[], point: CanvasPoint): number {
return points.reduce((bestIndex, start, index) => {
const end = points[(index + 1) % points.length];
if (!end) return bestIndex;
const bestStart = points[bestIndex];
const bestEnd = points[(bestIndex + 1) % points.length];
const currentDistance = distanceToSegmentSquared(point, start, end);
const bestDistance = bestStart && bestEnd
? distanceToSegmentSquared(point, bestStart, bestEnd)
: Number.POSITIVE_INFINITY;
return currentDistance < bestDistance ? index : bestIndex;
}, 0);
}
function segmentationArea(segmentation?: number[][]): number {
return (segmentation || []).reduce((sum, polygon) => sum + polygonArea(flatPolygonToPoints(polygon)), 0);
}
@@ -210,10 +237,11 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
const [manualStart, setManualStart] = useState<CanvasPoint | null>(null);
const [manualCurrent, setManualCurrent] = useState<CanvasPoint | null>(null);
const [polygonPoints, setPolygonPoints] = useState<CanvasPoint[]>([]);
const [selectedMaskId, setSelectedMaskId] = useState<string | null>(null);
const [selectedMaskIds, setSelectedMaskIds] = useState<string[]>([]);
const [selectedMaskId, setSelectedMaskId] = useState<string | null>(() => useStore.getState().selectedMaskIds[0] || null);
const [selectedMaskIds, setSelectedMaskIds] = useState<string[]>(() => useStore.getState().selectedMaskIds);
const [selectedPolygonIndex, setSelectedPolygonIndex] = useState(0);
const [selectedVertexIndex, setSelectedVertexIndex] = useState<number | null>(null);
const previousFrameIdRef = useRef<string | undefined>(frame?.id);
const [isInferencing, setIsInferencing] = useState(false);
const [inferenceMessage, setInferenceMessage] = useState('');
@@ -253,6 +281,7 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
const draftMaskCount = frameMasks.filter((mask) => !mask.annotationId).length;
const dirtyMaskCount = frameMasks.filter((mask) => mask.saveStatus === 'dirty').length;
const isBooleanTool = BOOLEAN_TOOLS.has(effectiveTool);
const isPolygonEditTool = effectiveTool === 'move' || effectiveTool === EDIT_POLYGON_TOOL;
useEffect(() => {
const handleResize = () => {
@@ -273,11 +302,22 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
setManualStart(null);
setManualCurrent(null);
setPolygonPoints([]);
setSelectedVertexIndex(null);
if (!isPolygonEditTool && !isBooleanTool) {
setSelectedMaskId(null);
setSelectedMaskIds([]);
setSelectedPolygonIndex(0);
}
}, [effectiveTool, isBooleanTool, isPolygonEditTool]);
useEffect(() => {
if (previousFrameIdRef.current === frame?.id) return;
previousFrameIdRef.current = frame?.id;
setSelectedMaskId(null);
setSelectedMaskIds([]);
setSelectedPolygonIndex(0);
setSelectedVertexIndex(null);
}, [effectiveTool, frame?.id]);
}, [frame?.id]);
useEffect(() => {
setPoints([]);
@@ -420,6 +460,10 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
setIsInferencing(true);
setInferenceMessage('');
try {
const hasNegativePrompt = Boolean(promptPoints?.some((point) => point.type === 'neg'));
const existingCandidate = !options.resetCandidate && samCandidateMaskId
? masks.find((mask) => mask.id === samCandidateMaskId)
: null;
const result = await predictMask({
imageId: frame.id,
imageWidth,
@@ -429,13 +473,11 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
? promptPoints.map((p) => ({ x: p.x, y: p.y, type: p.type }))
: undefined,
box: promptBox,
...(hasNegativePrompt ? { options: { auto_filter_background: true, min_score: 0.05 } } : {}),
});
const [m] = result.masks;
if (m) {
const existingCandidate = !options.resetCandidate && samCandidateMaskId
? masks.find((mask) => mask.id === samCandidateMaskId)
: null;
const label = activeClass?.name || existingCandidate?.label || m.label;
const color = activeClass?.color || existingCandidate?.color || m.color;
const metadata = {
@@ -443,6 +485,7 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
source: aiModel === 'sam3' ? 'sam3_box' : 'sam2_interactive',
promptBox: promptBox || null,
promptPointCount: promptPoints?.length || 0,
promptNegativePointCount: promptPoints?.filter((point) => point.type === 'neg').length || 0,
};
const nextMask = {
frameId: frame.id,
@@ -476,7 +519,15 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
});
}
} else {
setInferenceMessage('模型没有返回可用区域,请调整点/框提示后重试。');
if (existingCandidate && hasNegativePrompt) {
setMasks(masks.filter((mask) => mask.id !== existingCandidate.id));
setSamCandidateMaskId(null);
setSelectedMaskId(null);
setSelectedMaskIds([]);
setInferenceMessage('反向点已排除当前候选区域,请重新框选或添加新的正向点。');
} else {
setInferenceMessage('模型没有返回可用区域,请调整点/框提示后重试。');
}
}
} catch (err) {
console.error('Inference failed:', err);
@@ -485,7 +536,7 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
} finally {
setIsInferencing(false);
}
}, [activeClass, activeTemplateId, addMask, aiModel, frame?.height, frame?.id, frame?.width, image?.height, image?.naturalHeight, image?.naturalWidth, image?.width, masks, samCandidateMaskId, updateMask]);
}, [activeClass, activeTemplateId, addMask, aiModel, frame?.height, frame?.id, frame?.width, image?.height, image?.naturalHeight, image?.naturalWidth, image?.width, masks, samCandidateMaskId, setMasks, updateMask]);
const handleApplyActiveClass = () => {
if (!frame?.id || !activeClass) return;
@@ -598,7 +649,7 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
};
const handleStageClick = (e: any) => {
if (effectiveTool === 'move') return;
if (isPolygonEditTool) return;
if (effectiveTool === 'box_select') return; // handled by mouseup
if (DRAG_MANUAL_TOOLS.has(effectiveTool)) return;
@@ -716,7 +767,7 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
window.addEventListener('keydown', handleKeyDown);
return () => window.removeEventListener('keydown', handleKeyDown);
}, [deleteMasksById, effectiveTool, finishPolygon, polygonPoints, redoMasks, selectedMask, selectedMaskIds, selectedPolygonIndex, selectedVertexIndex, undoMasks, updatePolygonMask]);
}, [deleteMasksById, effectiveTool, finishPolygon, isPolygonEditTool, polygonPoints, redoMasks, selectedMask, selectedMaskIds, selectedPolygonIndex, selectedVertexIndex, undoMasks, updatePolygonMask]);
const boxRect = React.useMemo(() => {
if (!boxStart || !boxCurrent) return null;
@@ -753,7 +804,7 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
};
const handleMaskSelect = (mask: Mask, event: any, polygonIndex = 0) => {
if (effectiveTool !== 'move' && !isBooleanTool) return;
if (!isPolygonEditTool && !isBooleanTool) return;
event.cancelBubble = true;
if (isBooleanTool) {
setSelectedMaskIds((current) => (
@@ -807,6 +858,25 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
updatePolygonMask(mask, nextPoints, selectedPolygonIndex);
};
const handlePathDoubleClick = (mask: Mask, event: any, polygonIndex = 0) => {
if (effectiveTool !== EDIT_POLYGON_TOOL) return;
event.cancelBubble = true;
const point = stagePoint(event);
const currentPoints = segmentationToPoints(mask.segmentation, polygonIndex);
if (!point || currentPoints.length < 3) return;
const edgeIndex = nearestPolygonEdgeIndex(currentPoints, point);
const nextPoints = [
...currentPoints.slice(0, edgeIndex + 1),
point,
...currentPoints.slice(edgeIndex + 1),
];
setSelectedMaskId(mask.id);
setSelectedMaskIds([mask.id]);
setSelectedPolygonIndex(polygonIndex);
setSelectedVertexIndex(edgeIndex + 1);
updatePolygonMask(mask, nextPoints, polygonIndex);
};
const handleBooleanOperation = async () => {
if (!frame || booleanSelectedMasks.length < 2) return;
const primary = booleanSelectedMasks[0];
@@ -918,6 +988,8 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
strokeWidth={(selectedMaskIds.includes(mask.id) ? 2 : 1) / scale}
onClick={(event: any) => handleMaskSelect(mask, event, polygonIndex)}
onTap={(event: any) => handleMaskSelect(mask, event, polygonIndex)}
onDblClick={(event: any) => handlePathDoubleClick(mask, event, polygonIndex)}
onDblTap={(event: any) => handlePathDoubleClick(mask, event, polygonIndex)}
/>
))}
</Group>
@@ -987,7 +1059,7 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
)))}
{/* Polygon edge insertion handles */}
{!isBooleanTool && selectedMask && selectedMaskPoints.map((point, index) => {
{isPolygonEditTool && selectedMask && selectedMaskPoints.map((point, index) => {
const next = selectedMaskPoints[(index + 1) % selectedMaskPoints.length];
if (!next) return null;
return (
@@ -1006,7 +1078,7 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
})}
{/* Polygon vertex editor */}
{!isBooleanTool && selectedMask && selectedMaskPoints.map((point, index) => (
{isPolygonEditTool && selectedMask && selectedMaskPoints.map((point, index) => (
<Circle
key={`${selectedMask.id}-vertex-${selectedPolygonIndex}-${index}`}
x={point.x}

View File

@@ -12,6 +12,7 @@ const apiMock = vi.hoisted(() => ({
const wsMock = vi.hoisted(() => {
const state = {
callback: undefined as undefined | ((data: any) => void),
statusCallback: undefined as undefined | ((status: any) => void),
connected: false,
};
return {
@@ -24,6 +25,11 @@ const wsMock = vi.hoisted(() => {
state.callback = cb;
return vi.fn();
}),
onStatus: vi.fn((cb: (status: any) => void) => {
state.statusCallback = cb;
cb(state.connected ? 'connected' : 'disconnected');
return vi.fn();
}),
},
};
});
@@ -45,6 +51,7 @@ describe('Dashboard', () => {
vi.clearAllMocks();
wsMock.state.connected = false;
wsMock.state.callback = undefined;
wsMock.state.statusCallback = undefined;
apiMock.getDashboardOverview.mockResolvedValue({
summary: {
project_count: 2,
@@ -109,6 +116,20 @@ describe('Dashboard', () => {
expect(screen.getByText('44%')).toBeInTheDocument();
});
it('updates the websocket badge from connection status callbacks', async () => {
render(<Dashboard />);
await waitFor(() => expect(wsMock.progressWS.onStatus).toHaveBeenCalled());
expect(screen.getByText('WebSocket 断开')).toBeInTheDocument();
act(() => {
wsMock.state.connected = true;
wsMock.state.statusCallback?.('connected');
});
expect(screen.getByText('WebSocket 已连接')).toBeInTheDocument();
});
it('adds activity logs for complete and status messages', async () => {
render(<Dashboard />);

View File

@@ -1,6 +1,6 @@
import React, { useState, useEffect } from 'react';
import { Activity, AlertTriangle, Clock, Folders, CheckCircle2, Info, Loader2, RotateCcw, XCircle } from 'lucide-react';
import { progressWS, type ProgressMessage } from '../lib/websocket';
import { progressWS, type ConnectionStatus, type ProgressMessage } from '../lib/websocket';
import { cn } from '../lib/utils';
import {
cancelTask,
@@ -178,6 +178,9 @@ export function Dashboard() {
]);
}
});
const unsubscribeStatus = progressWS.onStatus((status: ConnectionStatus) => {
if (mounted) setIsConnected(status === 'connected');
});
const checkConnection = setInterval(() => {
if (mounted) setIsConnected(progressWS.isConnected());
@@ -186,6 +189,7 @@ export function Dashboard() {
return () => {
mounted = false;
unsubscribe();
unsubscribeStatus();
clearInterval(checkConnection);
progressWS.disconnect();
};

View File

@@ -56,10 +56,9 @@ describe('ProjectLibrary', () => {
expect(useStore.getState().projects[0]).toEqual(expect.objectContaining({ id: 'p2' }));
});
it('imports video by creating a project, uploading media, parsing frames and refreshing projects', async () => {
it('imports video by creating a project and uploading media without parsing frames', async () => {
apiMock.createProject.mockResolvedValueOnce({ id: 'p3', name: 'clip.mp4', status: 'pending' });
apiMock.uploadMedia.mockResolvedValueOnce({ url: 'http://file', id: 'object' });
apiMock.parseMedia.mockResolvedValueOnce({ frames_extracted: 1 });
apiMock.getProjects.mockResolvedValue([]);
const { container } = render(<ProjectLibrary onProjectSelect={vi.fn()} />);
@@ -70,10 +69,24 @@ describe('ProjectLibrary', () => {
await waitFor(() => expect(apiMock.createProject).toHaveBeenCalledWith(expect.objectContaining({
name: 'clip.mp4',
parse_fps: 30,
})));
expect(apiMock.uploadMedia).toHaveBeenCalledWith(file, 'p3');
expect(apiMock.parseMedia).toHaveBeenCalledWith('p3');
expect(apiMock.parseMedia).not.toHaveBeenCalled();
});
it('generates frames from an imported video with the selected FPS', async () => {
apiMock.getProjects
.mockResolvedValueOnce([{ id: 'p4', name: 'clip.mp4', status: 'pending', frames: 0, video_path: 'uploads/clip.mp4', parse_fps: 30 }])
.mockResolvedValueOnce([{ id: 'p4', name: 'clip.mp4', status: 'parsing', frames: 0, video_path: 'uploads/clip.mp4', parse_fps: 12 }]);
apiMock.parseMedia.mockResolvedValueOnce({ id: 22, status: 'queued', progress: 0 });
const { container } = render(<ProjectLibrary onProjectSelect={vi.fn()} />);
fireEvent.click(await screen.findByRole('button', { name: '生成帧' }));
fireEvent.change(container.querySelector('input[type="range"]') as HTMLInputElement, { target: { value: '12' } });
fireEvent.click(screen.getByRole('button', { name: '开始生成帧' }));
await waitFor(() => expect(apiMock.parseMedia).toHaveBeenCalledWith('p4', { parseFps: 12 }));
});
it('imports only valid DICOM files and parses the returned project', async () => {

View File

@@ -1,5 +1,5 @@
import React, { useState, useEffect, useRef } from 'react';
import { UploadCloud, Film, Settings2, MoreHorizontal, Plus, Loader2, Activity } from 'lucide-react';
import { UploadCloud, Film, Settings2, MoreHorizontal, Plus, Loader2, Activity, Images } from 'lucide-react';
import { cn } from '../lib/utils';
import { useStore } from '../store/useStore';
import { getProjects, createProject, uploadMedia, parseMedia, uploadDicomBatch } from '../lib/api';
@@ -22,7 +22,10 @@ export function ProjectLibrary({ onProjectSelect }: ProjectLibraryProps) {
const [showImportMenu, setShowImportMenu] = useState(false);
const [showVideoConfig, setShowVideoConfig] = useState(false);
const [pendingFile, setPendingFile] = useState<File | null>(null);
const [parseFps, setParseFps] = useState(30);
const [frameProject, setFrameProject] = useState<Project | null>(null);
const [showFrameConfig, setShowFrameConfig] = useState(false);
const [frameParseFps, setFrameParseFps] = useState(30);
const [isGeneratingFrames, setIsGeneratingFrames] = useState(false);
const videoInputRef = useRef<HTMLInputElement>(null);
const dicomInputRef = useRef<HTMLInputElement>(null);
@@ -57,7 +60,6 @@ export function ProjectLibrary({ onProjectSelect }: ProjectLibraryProps) {
const handleVideoSelect = (file: File) => {
setPendingFile(file);
setParseFps(30);
setShowVideoConfig(true);
};
@@ -69,11 +71,9 @@ export function ProjectLibrary({ onProjectSelect }: ProjectLibraryProps) {
const newProject = await createProject({
name: pendingFile.name,
description: `导入于 ${new Date().toLocaleString()}`,
parse_fps: parseFps,
});
const result = await uploadMedia(pendingFile, String(newProject.id));
await parseMedia(String(newProject.id));
alert(`上传成功: ${pendingFile.name}\n已保存至: ${result.url}`);
alert(`视频导入成功: ${pendingFile.name}\n已保存至: ${result.url}\n需要生成帧时请在项目卡片点击“生成帧”。`);
const data = await getProjects();
setProjects(data);
} catch (err) {
@@ -86,6 +86,31 @@ export function ProjectLibrary({ onProjectSelect }: ProjectLibraryProps) {
}
};
const openFrameConfig = (project: Project, event: React.MouseEvent) => {
event.stopPropagation();
setFrameProject(project);
setFrameParseFps(Math.round(project.parse_fps || 30));
setShowFrameConfig(true);
};
const handleGenerateFrames = async () => {
if (!frameProject?.id) return;
setIsGeneratingFrames(true);
try {
const task = await parseMedia(frameProject.id, { parseFps: frameParseFps });
alert(`生成帧任务已入队 #${task.id}\n帧率: ${frameParseFps} FPS\n可在 Dashboard 查看进度。`);
const data = await getProjects();
setProjects(data);
setShowFrameConfig(false);
setFrameProject(null);
} catch (err) {
console.error('Frame generation failed:', err);
alert('生成帧失败,请检查后端服务或项目源文件');
} finally {
setIsGeneratingFrames(false);
}
};
const handleDicomUpload = async (files: FileList | null) => {
if (!files || files.length === 0) return;
const dcmFiles = Array.from(files).filter((f) => f.name.toLowerCase().endsWith('.dcm'));
@@ -209,7 +234,7 @@ export function ProjectLibrary({ onProjectSelect }: ProjectLibraryProps) {
)}
<div className="absolute top-2 right-2 flex gap-2">
<span className="backdrop-blur-md bg-black/40 text-gray-200 text-[10px] font-mono px-2 py-1 rounded border border-white/10 uppercase tracking-widest">
{proj.source_type === 'dicom' ? 'DICOM' : (proj.fps || '30FPS')}
{proj.source_type === 'dicom' ? 'DICOM' : (proj.video_path && (proj.frames ?? 0) === 0 ? '待生成帧' : (proj.fps || '30FPS'))}
</span>
<span className="backdrop-blur-md bg-black/40 text-gray-200 text-[10px] px-2 py-1 rounded border border-white/10 flex items-center gap-1 uppercase tracking-widest">
{proj.status === 'ready' ? (
@@ -235,6 +260,15 @@ export function ProjectLibrary({ onProjectSelect }: ProjectLibraryProps) {
<span className="flex items-center gap-1.5 text-cyan-400/80"><Activity size={12} /> {proj.original_fps.toFixed(1)}fps</span>
)}
</div>
{proj.video_path && (proj.frames ?? 0) === 0 && proj.status !== 'parsing' && (
<button
onClick={(event) => openFrameConfig(proj, event)}
className="mt-3 inline-flex items-center justify-center gap-2 rounded-md border border-cyan-500/30 bg-cyan-500/10 px-3 py-2 text-xs font-medium text-cyan-200 hover:bg-cyan-500/20 transition-colors"
>
<Images size={14} />
</button>
)}
</div>
</div>
))}
@@ -245,24 +279,10 @@ export function ProjectLibrary({ onProjectSelect }: ProjectLibraryProps) {
{showVideoConfig && pendingFile && (
<div className="fixed inset-0 z-50 flex items-center justify-center bg-black/60 backdrop-blur-sm">
<div className="bg-[#111] border border-white/10 rounded-2xl p-6 w-full max-w-md shadow-2xl">
<h2 className="text-lg font-semibold text-white mb-4"></h2>
<h2 className="text-lg font-semibold text-white mb-4"></h2>
<div className="space-y-4">
<div className="text-sm text-gray-400">: <span className="text-gray-200">{pendingFile.name}</span></div>
<div>
<label className="block text-xs font-medium text-gray-400 uppercase tracking-widest mb-2"> (FPS)</label>
<div className="flex items-center gap-3">
<input
type="range"
min="1"
max="60"
value={parseFps}
onChange={(e) => setParseFps(parseInt(e.target.value))}
className="flex-1 accent-cyan-500"
/>
<span className="text-sm font-mono text-cyan-400 w-12 text-right">{parseFps}</span>
</div>
<p className="text-[10px] text-gray-600 mt-1"></p>
</div>
<p className="text-xs leading-5 text-gray-500"> FPS</p>
</div>
<div className="flex justify-end gap-3 mt-6">
<button
@@ -282,6 +302,49 @@ export function ProjectLibrary({ onProjectSelect }: ProjectLibraryProps) {
</div>
)}
{/* Frame generation FPS config modal */}
{showFrameConfig && frameProject && (
<div className="fixed inset-0 z-50 flex items-center justify-center bg-black/60 backdrop-blur-sm">
<div className="bg-[#111] border border-white/10 rounded-2xl p-6 w-full max-w-md shadow-2xl">
<h2 className="text-lg font-semibold text-white mb-4"></h2>
<div className="space-y-4">
<div className="text-sm text-gray-400">: <span className="text-gray-200">{frameProject.name}</span></div>
<div>
<label className="block text-xs font-medium text-gray-400 uppercase tracking-widest mb-2"> (FPS)</label>
<div className="flex items-center gap-3">
<input
type="range"
min="1"
max="60"
value={frameParseFps}
onChange={(e) => setFrameParseFps(parseInt(e.target.value))}
className="flex-1 accent-cyan-500"
/>
<span className="text-sm font-mono text-cyan-400 w-12 text-right">{frameParseFps}</span>
</div>
<p className="text-[10px] text-gray-600 mt-1"></p>
</div>
</div>
<div className="flex justify-end gap-3 mt-6">
<button
onClick={() => { setShowFrameConfig(false); setFrameProject(null); }}
disabled={isGeneratingFrames}
className="px-4 py-2 rounded-lg text-sm text-gray-400 hover:text-white transition-colors disabled:opacity-50"
>
</button>
<button
onClick={handleGenerateFrames}
disabled={isGeneratingFrames}
className="px-4 py-2 rounded-lg text-sm font-medium bg-cyan-500 hover:bg-cyan-400 text-black transition-all disabled:opacity-60"
>
{isGeneratingFrames ? '入队中...' : '开始生成帧'}
</button>
</div>
</div>
</div>
)}
{/* New project modal */}
{showModal && (
<div className="fixed inset-0 z-50 flex items-center justify-center bg-black/60 backdrop-blur-sm">

View File

@@ -20,12 +20,14 @@ describe('ToolsPalette', () => {
);
fireEvent.click(screen.getByTitle('创建多边形 (P)'));
fireEvent.click(screen.getByTitle('调整多边形 (E)'));
fireEvent.click(screen.getByTitle('正向选点 (SAM)'));
fireEvent.click(screen.getByTitle('撤销操作 (Ctrl+Z)'));
fireEvent.click(screen.getByTitle('重做操作 (Ctrl+Shift+Z)'));
expect(setActiveTool).toHaveBeenNthCalledWith(1, 'create_polygon');
expect(setActiveTool).toHaveBeenNthCalledWith(2, 'point_pos');
expect(setActiveTool).toHaveBeenNthCalledWith(2, 'edit_polygon');
expect(setActiveTool).toHaveBeenNthCalledWith(3, 'point_pos');
expect(onUndo).toHaveBeenCalled();
expect(onRedo).toHaveBeenCalled();
});

View File

@@ -1,5 +1,5 @@
import React from 'react';
import { MousePointer2, Hexagon, Square, Circle, Minus, Combine, Scissors, Wand2, Undo, Redo, Crosshair, PlusCircle, MinusCircle, SquareDashed } from 'lucide-react';
import { MousePointer2, PencilLine, Hexagon, Square, Circle, Minus, Combine, Scissors, Wand2, Undo, Redo, Crosshair, PlusCircle, MinusCircle, SquareDashed } from 'lucide-react';
import { cn } from '../lib/utils';
interface ToolsPaletteProps {
@@ -23,6 +23,7 @@ export function ToolsPalette({
}: ToolsPaletteProps) {
const tools = [
{ id: 'move', icon: MousePointer2, label: '拖拽 / 选择 (V)' },
{ id: 'edit_polygon', icon: PencilLine, label: '调整多边形 (E)' },
{ id: 'create_polygon', icon: Hexagon, label: '创建多边形 (P)' },
{ id: 'create_rectangle', icon: Square, label: '创建矩形 (R)' },
{ id: 'create_circle', icon: Circle, label: '创建圆 (O)' },

View File

@@ -82,23 +82,16 @@ describe('VideoWorkspace', () => {
expect(apiMock.getProjectAnnotations).toHaveBeenCalledWith('1');
});
it('triggers parsing when a media project has no frames yet', async () => {
apiMock.getProjectFrames
.mockResolvedValueOnce([])
.mockResolvedValueOnce([
{ id: 11, project_id: 1, frame_index: 0, image_url: '/parsed.jpg', width: 320, height: 240 },
]);
apiMock.parseMedia.mockResolvedValueOnce({ id: 7, status: 'queued', progress: 0 });
apiMock.getTask.mockResolvedValueOnce({ id: 7, status: 'success', progress: 100, message: '解析完成' });
it('does not auto-generate frames when a media project has no frames yet', async () => {
apiMock.getProjectFrames.mockResolvedValueOnce([]);
render(<VideoWorkspace />);
await waitFor(() => expect(apiMock.parseMedia).toHaveBeenCalledWith('1'));
expect(apiMock.getTask).toHaveBeenCalledWith(7);
await waitFor(() => expect(useStore.getState().frames[0]).toEqual(expect.objectContaining({
id: '11',
url: '/parsed.jpg',
})));
await waitFor(() => expect(apiMock.getProjectFrames).toHaveBeenCalledWith('1'));
expect(apiMock.parseMedia).not.toHaveBeenCalled();
expect(apiMock.getTask).not.toHaveBeenCalled();
expect(useStore.getState().frames).toEqual([]);
expect(await screen.findByText('该项目已导入视频但尚未生成帧,请在项目库点击“生成帧”')).toBeInTheDocument();
});
it('hydrates saved annotations after loading frames', async () => {

View File

@@ -8,10 +8,8 @@ import {
exportMasks,
getProjectAnnotations,
getProjectFrames,
getTask,
getTemplates,
importGtMask,
parseMedia,
propagateMasks,
saveAnnotation,
updateAnnotation,
@@ -23,10 +21,6 @@ import { FrameTimeline } from './FrameTimeline';
import { ModelStatusBadge } from './ModelStatusBadge';
import type { Frame } from '../store/useStore';
function sleep(ms: number) {
return new Promise((resolve) => setTimeout(resolve, ms));
}
export function VideoWorkspace({ onNavigateToAI }: { onNavigateToAI?: () => void }) {
const gtMaskInputRef = React.useRef<HTMLInputElement>(null);
const activeTool = useStore((state) => state.activeTool);
@@ -72,64 +66,31 @@ export function VideoWorkspace({ onNavigateToAI }: { onNavigateToAI?: () => void
const data = await getProjectFrames(String(currentProject.id));
if (cancelled) return;
if (data.length === 0 && currentProject.video_path) {
// No frames yet but video exists -> queue parsing and poll the task.
try {
const task = await parseMedia(String(currentProject.id));
if (cancelled) return;
setStatusMessage(`解析任务已入队 #${task.id}`);
let completed = false;
for (let attempt = 0; attempt < 60; attempt += 1) {
const freshTask = await getTask(task.id);
if (cancelled) return;
setStatusMessage(freshTask.message || `解析进度 ${freshTask.progress}%`);
if (freshTask.status === 'success') {
completed = true;
break;
}
if (freshTask.status === 'failed') {
setStatusMessage(freshTask.error || '解析任务失败');
return;
}
await sleep(2000);
}
if (!completed) {
setStatusMessage('解析仍在后台运行,可稍后刷新工作区');
return;
}
const fresh = await getProjectFrames(String(currentProject.id));
if (cancelled) return;
const mappedFrames = fresh.map((f) => ({
id: String(f.id),
projectId: String(f.project_id),
index: f.frame_index,
url: f.image_url,
width: f.width ?? 0,
height: f.height ?? 0,
timestampMs: f.timestamp_ms ?? undefined,
sourceFrameNumber: f.source_frame_number ?? undefined,
}));
setFrames(mappedFrames);
setCurrentFrame(0);
await hydrateSavedAnnotations(String(currentProject.id), mappedFrames);
} catch (err) {
console.error('Parse failed:', err);
const mappedFrames = data.map((f) => ({
id: String(f.id),
projectId: String(f.project_id),
index: f.frame_index,
url: f.image_url,
width: f.width ?? 0,
height: f.height ?? 0,
timestampMs: f.timestamp_ms ?? undefined,
sourceFrameNumber: f.source_frame_number ?? undefined,
}));
setFrames(mappedFrames);
setCurrentFrame(0);
if (mappedFrames.length === 0) {
setMasks([]);
if (currentProject.status === 'parsing') {
setStatusMessage('生成帧任务正在后台运行,可在 Dashboard 查看进度');
} else if (currentProject.video_path) {
setStatusMessage('该项目已导入视频但尚未生成帧,请在项目库点击“生成帧”');
} else {
setStatusMessage('当前项目没有可显示帧');
}
} else {
const mappedFrames = data.map((f) => ({
id: String(f.id),
projectId: String(f.project_id),
index: f.frame_index,
url: f.image_url,
width: f.width ?? 0,
height: f.height ?? 0,
timestampMs: f.timestamp_ms ?? undefined,
sourceFrameNumber: f.source_frame_number ?? undefined,
}));
setFrames(mappedFrames);
setCurrentFrame(0);
await hydrateSavedAnnotations(String(currentProject.id), mappedFrames);
return;
}
setStatusMessage('');
await hydrateSavedAnnotations(String(currentProject.id), mappedFrames);
} catch (err) {
console.error('Failed to load frames:', err);
}

View File

@@ -2,12 +2,14 @@ import { afterEach, describe, expect, it, vi } from 'vitest';
describe('progress websocket client', () => {
afterEach(() => {
vi.useRealTimers();
vi.restoreAllMocks();
vi.resetModules();
vi.unstubAllGlobals();
});
it('connects using the configured URL and reports open state', async () => {
it('connects using the configured URL, reports open state, and sends heartbeat pings', async () => {
vi.useFakeTimers();
const instances: any[] = [];
class FakeWebSocket {
static CONNECTING = 0;
@@ -21,14 +23,26 @@ describe('progress websocket client', () => {
instances.push(this);
}
close = vi.fn();
send = vi.fn();
}
vi.stubGlobal('WebSocket', FakeWebSocket);
const { progressWS } = await import('./websocket');
const statusCallback = vi.fn();
const unsubscribeStatus = progressWS.onStatus(statusCallback);
progressWS.connect();
instances[0].onopen?.();
expect(instances[0].url).toContain('/ws/progress');
expect(progressWS.isConnected()).toBe(true);
expect(statusCallback).toHaveBeenCalledWith('connected');
expect(instances[0].send).toHaveBeenCalledWith('ping');
vi.advanceTimersByTime(15000);
expect(instances[0].send).toHaveBeenCalledTimes(2);
unsubscribeStatus();
progressWS.disconnect();
});
it('subscribes and unsubscribes progress callbacks', async () => {
@@ -43,4 +57,41 @@ describe('progress websocket client', () => {
expect(callback).toHaveBeenCalledTimes(1);
expect(callback).toHaveBeenCalledWith({ type: 'status', message: 'ok' });
});
it('notifies connection status changes and schedules reconnect on close', async () => {
vi.useFakeTimers();
const instances: any[] = [];
class FakeWebSocket {
static CONNECTING = 0;
static OPEN = 1;
readyState = FakeWebSocket.OPEN;
onopen?: () => void;
onmessage?: (event: MessageEvent) => void;
onclose?: () => void;
onerror?: () => void;
constructor(public url: string) {
instances.push(this);
}
close = vi.fn();
send = vi.fn();
}
vi.stubGlobal('WebSocket', FakeWebSocket);
const { progressWS } = await import('./websocket');
const statusCallback = vi.fn();
const unsubscribeStatus = progressWS.onStatus(statusCallback);
progressWS.connect();
instances[0].onopen?.();
instances[0].onclose?.();
expect(statusCallback).toHaveBeenCalledWith('disconnected');
expect(statusCallback).toHaveBeenCalledWith('reconnecting');
vi.advanceTimersByTime(3000);
expect(instances).toHaveLength(2);
unsubscribeStatus();
progressWS.disconnect();
});
});

View File

@@ -1,6 +1,8 @@
import { WS_PROGRESS_URL } from './config';
type ProgressCallback = (data: ProgressMessage) => void;
type ConnectionStatus = 'connecting' | 'connected' | 'reconnecting' | 'disconnected';
type StatusCallback = (status: ConnectionStatus) => void;
interface ProgressMessage {
type: 'progress' | 'status' | 'error' | 'complete' | 'cancelled';
@@ -20,9 +22,12 @@ class ProgressWebSocket {
private ws: WebSocket | null = null;
private url: string;
private callbacks: Set<ProgressCallback> = new Set();
private statusCallbacks: Set<StatusCallback> = new Set();
private reconnectTimer: ReturnType<typeof setTimeout> | null = null;
private heartbeatTimer: ReturnType<typeof setInterval> | null = null;
private reconnectInterval = 3000;
private maxReconnectInterval = 30000;
private heartbeatInterval = 15000;
private shouldReconnect = false;
private shouldCloseAfterOpen = false;
private currentInterval = 3000;
@@ -38,6 +43,7 @@ class ProgressWebSocket {
this.shouldReconnect = true;
this.shouldCloseAfterOpen = false;
this.notifyStatus('connecting');
try {
this.ws = new WebSocket(this.url);
@@ -50,6 +56,8 @@ class ProgressWebSocket {
return;
}
this.currentInterval = this.reconnectInterval;
this.startHeartbeat();
this.notifyStatus('connected');
console.log('[WebSocket] Connected to progress stream');
};
@@ -64,7 +72,9 @@ class ProgressWebSocket {
this.ws.onclose = () => {
console.log('[WebSocket] Connection closed');
this.stopHeartbeat();
this.ws = null;
this.notifyStatus('disconnected');
if (this.shouldReconnect) {
this.scheduleReconnect();
}
@@ -72,7 +82,9 @@ class ProgressWebSocket {
this.ws.onerror = () => {
// 静默处理错误,避免在 CONNECTING 状态时 close 触发浏览器报错
this.stopHeartbeat();
this.ws = null;
this.notifyStatus('disconnected');
if (this.shouldReconnect) {
this.scheduleReconnect();
}
@@ -85,6 +97,7 @@ class ProgressWebSocket {
disconnect() {
this.shouldReconnect = false;
this.stopHeartbeat();
if (this.reconnectTimer) {
clearTimeout(this.reconnectTimer);
this.reconnectTimer = null;
@@ -102,6 +115,7 @@ class ProgressWebSocket {
this.ws.close();
}
this.ws = null;
this.notifyStatus('disconnected');
}
onProgress(callback: ProgressCallback) {
@@ -111,21 +125,53 @@ class ProgressWebSocket {
};
}
onStatus(callback: StatusCallback) {
this.statusCallbacks.add(callback);
callback(this.isConnected() ? 'connected' : 'disconnected');
return () => {
this.statusCallbacks.delete(callback);
};
}
private scheduleReconnect() {
if (this.reconnectTimer) {
clearTimeout(this.reconnectTimer);
}
this.notifyStatus('reconnecting');
this.reconnectTimer = setTimeout(() => {
console.log(`[WebSocket] Reconnecting in ${this.currentInterval}ms...`);
console.log('[WebSocket] Reconnecting to progress stream...');
this.connect();
this.currentInterval = Math.min(this.currentInterval * 1.5, this.maxReconnectInterval);
}, this.currentInterval);
}
private startHeartbeat() {
this.stopHeartbeat();
this.sendHeartbeat();
this.heartbeatTimer = setInterval(() => this.sendHeartbeat(), this.heartbeatInterval);
}
private stopHeartbeat() {
if (this.heartbeatTimer) {
clearInterval(this.heartbeatTimer);
this.heartbeatTimer = null;
}
}
private sendHeartbeat() {
if (this.ws?.readyState === WebSocket.OPEN) {
this.ws.send('ping');
}
}
private notifyStatus(status: ConnectionStatus) {
this.statusCallbacks.forEach((cb) => cb(status));
}
isConnected(): boolean {
return this.ws !== null && this.ws.readyState === WebSocket.OPEN;
}
}
export const progressWS = new ProgressWebSocket();
export type { ProgressMessage };
export type { ConnectionStatus, ProgressMessage };

View File

@@ -102,6 +102,15 @@ vi.mock('react-konva', () => ({
props.onClick?.(konvaEvent);
if (konvaEvent.cancelBubble) event.stopPropagation();
}}
onDoubleClick={(event) => {
const point = {
x: event.clientX || 120,
y: event.clientY || 80,
};
const konvaEvent = { ...makeStageEvent(point.x, point.y), cancelBubble: false };
props.onDblClick?.(konvaEvent);
if (konvaEvent.cancelBubble) event.stopPropagation();
}}
/>
),
}));