feat: 完善 AI 分割与工作区标注闭环
功能增加: - 将视频导入和生成帧拆成两个明确动作,项目库生成帧时选择 FPS,工作区不再自动触发拆帧。 - 为工作区新增调整多边形工具,支持选中 mask、拖动顶点、边中点插点、双击边界按位置插点,并保留多 polygon 子区域编辑。 - 打通 AI 页 SAM2/SAM3 结果到工作区的联动,生成 mask 后自动选中,可在右侧分类树换标签,并推送到工作区继续编辑。 - 增强 Dashboard WebSocket 连接状态与心跳,使用真实 onopen/onclose/onerror 状态驱动前端显示。 - 完善 SAM3 external worker 适配,支持 box prompt、semantic 请求级阈值和 video tracker 路径。 bugfix: - 修复 SAM2 文本语义误走自动分割的问题,改为提示使用点提示或切换 SAM3。 - 修复 SAM2 多候选重叠显示的问题,点提示和 auto fallback 默认只采用最高分候选。 - 修复 SAM2 反向点看起来无效的问题,带负点时启用背景过滤,过滤为空时移除旧候选。 - 修复 SAM3 单个 2D mask 结果无法转 polygon、低阈值 semantic 返回被默认阈值吞掉的问题。 - 修复 AI 页 mask 未选中导致分类树无法修改 SAM2 结果标签的问题。 测试和文档: - 补充 CanvasArea、AISegmentation、ProjectLibrary、VideoWorkspace、Dashboard、websocket 和 SAM engine/API 测试。 - 新增 backend/tests/test_sam2_engine.py,覆盖 SAM2 单候选请求和 auto fallback 行为。 - 更新 README、AGENTS 和 doc 需求/设计/接口/测试矩阵,按当前实现冻结功能状态。
This commit is contained in:
@@ -63,6 +63,129 @@ describe('AISegmentation', () => {
|
||||
}));
|
||||
});
|
||||
|
||||
it('does not run SAM2 text-only prompts as semantic segmentation', async () => {
|
||||
render(<AISegmentation onSendToWorkspace={vi.fn()} />);
|
||||
|
||||
fireEvent.change(screen.getByPlaceholderText("例如:'分割出左侧车道上行驶的所有红色汽车'..."), {
|
||||
target: { value: '胆囊' },
|
||||
});
|
||||
fireEvent.click(await screen.findByText('执行高精度语义分割'));
|
||||
|
||||
expect(apiMock.predictMask).not.toHaveBeenCalled();
|
||||
expect(await screen.findByText('SAM2 不支持文本语义提示;请先放置正/反向点,或切换到 SAM3 使用文本语义。')).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('keeps only the best SAM2 candidate when the backend returns overlapping alternatives', async () => {
|
||||
apiMock.predictMask.mockResolvedValueOnce({
|
||||
masks: [
|
||||
{
|
||||
id: 'sam2-best',
|
||||
pathData: 'M 0 0 L 10 0 L 10 10 Z',
|
||||
label: 'AI Mask',
|
||||
color: '#06b6d4',
|
||||
segmentation: [[0, 0, 10, 0, 10, 10]],
|
||||
bbox: [0, 0, 10, 10],
|
||||
area: 100,
|
||||
},
|
||||
{
|
||||
id: 'sam2-alt',
|
||||
pathData: 'M 1 1 L 11 1 L 11 11 Z',
|
||||
label: 'AI Mask',
|
||||
color: '#06b6d4',
|
||||
segmentation: [[1, 1, 11, 1, 11, 11]],
|
||||
bbox: [1, 1, 10, 10],
|
||||
area: 100,
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
render(<AISegmentation onSendToWorkspace={vi.fn()} />);
|
||||
fireEvent.click(screen.getByText('正向选点'));
|
||||
fireEvent.click(screen.getByTestId('konva-stage'));
|
||||
fireEvent.click(await screen.findByText('执行高精度语义分割'));
|
||||
|
||||
await waitFor(() => expect(useStore.getState().masks).toHaveLength(1));
|
||||
expect(useStore.getState().masks[0].id).toBe('sam2-best');
|
||||
expect(useStore.getState().selectedMaskIds).toEqual(['sam2-best']);
|
||||
expect(await screen.findByText('SAM2 返回 2 个候选,已采用最高分区域。')).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('lets a SAM2 result be selected and relabeled from the ontology panel', async () => {
|
||||
useStore.setState({
|
||||
templates: [
|
||||
{
|
||||
id: 'template-1',
|
||||
name: '腹腔镜模板',
|
||||
classes: [
|
||||
{ id: 'class-1', name: '胆囊', color: '#ff0000', zIndex: 30 },
|
||||
{ id: 'class-2', name: '肝脏', color: '#00ff00', zIndex: 20 },
|
||||
],
|
||||
rules: [],
|
||||
},
|
||||
],
|
||||
});
|
||||
apiMock.predictMask.mockResolvedValueOnce({
|
||||
masks: [
|
||||
{
|
||||
id: 'sam2-mask',
|
||||
pathData: 'M 10 10 L 40 10 L 40 40 Z',
|
||||
label: 'AI Mask',
|
||||
color: '#06b6d4',
|
||||
segmentation: [[10, 10, 40, 10, 40, 40]],
|
||||
bbox: [10, 10, 30, 30],
|
||||
area: 900,
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
render(<AISegmentation onSendToWorkspace={vi.fn()} />);
|
||||
fireEvent.click(screen.getByText('正向选点'));
|
||||
fireEvent.click(screen.getByTestId('konva-stage'));
|
||||
fireEvent.click(await screen.findByText('执行高精度语义分割'));
|
||||
|
||||
await waitFor(() => expect(useStore.getState().selectedMaskIds).toEqual(['sam2-mask']));
|
||||
fireEvent.click(screen.getByText('肝脏'));
|
||||
|
||||
expect(useStore.getState().masks[0]).toEqual(expect.objectContaining({
|
||||
templateId: 'template-1',
|
||||
classId: 'class-2',
|
||||
className: '肝脏',
|
||||
classZIndex: 20,
|
||||
label: '肝脏',
|
||||
color: '#00ff00',
|
||||
saveStatus: 'draft',
|
||||
}));
|
||||
});
|
||||
|
||||
it('keeps the generated SAM2 mask selected when sending it to the workspace editor', async () => {
|
||||
const onSendToWorkspace = vi.fn();
|
||||
apiMock.predictMask.mockResolvedValueOnce({
|
||||
masks: [
|
||||
{
|
||||
id: 'sam2-mask',
|
||||
pathData: 'M 10 10 L 40 10 L 40 40 Z',
|
||||
label: 'AI Mask',
|
||||
color: '#06b6d4',
|
||||
segmentation: [[10, 10, 40, 10, 40, 40]],
|
||||
bbox: [10, 10, 30, 30],
|
||||
area: 900,
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
render(<AISegmentation onSendToWorkspace={onSendToWorkspace} />);
|
||||
fireEvent.click(screen.getByText('正向选点'));
|
||||
fireEvent.click(screen.getByTestId('konva-stage'));
|
||||
fireEvent.click(await screen.findByText('执行高精度语义分割'));
|
||||
await waitFor(() => expect(useStore.getState().selectedMaskIds).toEqual(['sam2-mask']));
|
||||
|
||||
fireEvent.click(screen.getByText('推送至工作区编辑'));
|
||||
|
||||
expect(useStore.getState().activeTool).toBe('edit_polygon');
|
||||
expect(useStore.getState().selectedMaskIds).toEqual(['sam2-mask']);
|
||||
expect(onSendToWorkspace).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('prompts for semantic text before running SAM3 inference', async () => {
|
||||
apiMock.getAiModelStatus.mockResolvedValue({
|
||||
selected_model: 'sam3',
|
||||
@@ -106,7 +229,7 @@ describe('AISegmentation', () => {
|
||||
points: undefined,
|
||||
text: '胆囊',
|
||||
})));
|
||||
expect(await screen.findByText('模型没有返回可用区域,请换一个更具体的描述或调整提示。')).toBeInTheDocument();
|
||||
expect(await screen.findByText('SAM3 已完成语义推理,但没有返回区域。请尝试英文目标描述,或换到包含该目标的帧。当前提示: 胆囊')).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('runs SAM3 semantic text inference and assigns the active class to returned masks', async () => {
|
||||
|
||||
@@ -17,6 +17,8 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
|
||||
const masks = useStore((state) => state.masks);
|
||||
const addMask = useStore((state) => state.addMask);
|
||||
const clearMasks = useStore((state) => state.clearMasks);
|
||||
const selectedMaskIds = useStore((state) => state.selectedMaskIds);
|
||||
const setSelectedMaskIds = useStore((state) => state.setSelectedMaskIds);
|
||||
const maskHistory = useStore((state) => state.maskHistory);
|
||||
const maskFuture = useStore((state) => state.maskFuture);
|
||||
const undoMasks = useStore((state) => state.undoMasks);
|
||||
@@ -97,6 +99,10 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
|
||||
setInferenceMessage('SAM3 当前使用文本语义提示,请先输入要分割的目标描述。');
|
||||
return;
|
||||
}
|
||||
if (aiModel === 'sam2' && textPrompt && points.length === 0) {
|
||||
setInferenceMessage('SAM2 不支持文本语义提示;请先放置正/反向点,或切换到 SAM3 使用文本语义。');
|
||||
return;
|
||||
}
|
||||
if (points.length === 0 && !textPrompt) {
|
||||
setInferenceMessage('请先放置正/反向提示点,或输入语义描述。');
|
||||
return;
|
||||
@@ -132,14 +138,22 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
|
||||
},
|
||||
});
|
||||
|
||||
if (result.masks.length === 0) {
|
||||
setInferenceMessage('模型没有返回可用区域,请换一个更具体的描述或调整提示。');
|
||||
const masksToApply = aiModel === 'sam2' ? result.masks.slice(0, 1) : result.masks;
|
||||
|
||||
if (masksToApply.length === 0) {
|
||||
setInferenceMessage(aiModel === 'sam3'
|
||||
? `SAM3 已完成语义推理,但没有返回区域。请尝试英文目标描述,或换到包含该目标的帧。当前提示: ${textPrompt}`
|
||||
: '模型没有返回可用区域,请换一个更具体的描述或调整提示。');
|
||||
} else {
|
||||
setInferenceMessage(`已生成 ${result.masks.length} 个候选区域。`);
|
||||
setInferenceMessage(aiModel === 'sam2' && result.masks.length > 1
|
||||
? `SAM2 返回 ${result.masks.length} 个候选,已采用最高分区域。`
|
||||
: `已生成 ${masksToApply.length} 个候选区域。`);
|
||||
}
|
||||
result.masks.forEach((m) => {
|
||||
const generatedMaskIds: string[] = [];
|
||||
masksToApply.forEach((m) => {
|
||||
const label = activeClass?.name || m.label;
|
||||
const color = activeClass?.color || m.color;
|
||||
generatedMaskIds.push(m.id);
|
||||
addMask({
|
||||
id: m.id,
|
||||
frameId: currentFrame.id,
|
||||
@@ -157,6 +171,9 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
|
||||
area: m.area,
|
||||
});
|
||||
});
|
||||
if (generatedMaskIds.length > 0) {
|
||||
setSelectedMaskIds(generatedMaskIds);
|
||||
}
|
||||
} catch (err) {
|
||||
console.error('AI inference failed:', err);
|
||||
const detail = (err as any)?.response?.data?.detail;
|
||||
@@ -164,7 +181,7 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
|
||||
} finally {
|
||||
setIsInferencing(false);
|
||||
}
|
||||
}, [activeClass, activeTemplateId, addMask, aiModel, autoDeleteBg, cropMode, currentFrame?.height, currentFrame?.id, currentFrame?.width, image?.height, image?.naturalHeight, image?.naturalWidth, image?.width, points, semanticText]);
|
||||
}, [activeClass, activeTemplateId, addMask, aiModel, autoDeleteBg, cropMode, currentFrame?.height, currentFrame?.id, currentFrame?.width, image?.height, image?.naturalHeight, image?.naturalWidth, image?.width, points, semanticText, setSelectedMaskIds]);
|
||||
|
||||
const handleStageClick = (e: any) => {
|
||||
if (effectiveTool === 'move') return;
|
||||
@@ -307,10 +324,13 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
|
||||
</div>
|
||||
)}
|
||||
<button
|
||||
onClick={onSendToWorkspace}
|
||||
onClick={() => {
|
||||
setActiveTool('edit_polygon');
|
||||
onSendToWorkspace();
|
||||
}}
|
||||
className="w-full py-3.5 rounded-lg flex items-center justify-center gap-2 transition-all font-medium tracking-wide text-xs uppercase bg-white/5 hover:bg-white/10 text-gray-300 border border-white/5 hover:border-white/10"
|
||||
>
|
||||
<SendToBack size={16} /> 退档推送至工作区重组
|
||||
<SendToBack size={16} /> 推送至工作区编辑
|
||||
</button>
|
||||
</div>
|
||||
</aside>
|
||||
@@ -376,12 +396,20 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
|
||||
|
||||
{/* AI Returned Masks */}
|
||||
{frameMasks.map((mask) => (
|
||||
<Group key={mask.id} opacity={0.45}>
|
||||
<Group key={mask.id} opacity={selectedMaskIds.includes(mask.id) ? 0.72 : 0.45}>
|
||||
<Path
|
||||
data={mask.pathData}
|
||||
fill={mask.color}
|
||||
stroke={mask.color}
|
||||
strokeWidth={1 / scale}
|
||||
strokeWidth={(selectedMaskIds.includes(mask.id) ? 2.5 : 1) / scale}
|
||||
onClick={(event: any) => {
|
||||
event.cancelBubble = true;
|
||||
setSelectedMaskIds([mask.id]);
|
||||
}}
|
||||
onTap={(event: any) => {
|
||||
event.cancelBubble = true;
|
||||
setSelectedMaskIds([mask.id]);
|
||||
}}
|
||||
/>
|
||||
</Group>
|
||||
))}
|
||||
|
||||
@@ -206,16 +206,58 @@ describe('CanvasArea', () => {
|
||||
{ x: 300, y: 150, type: 'neg' },
|
||||
],
|
||||
box: { x1: 120, y1: 80, x2: 260, y2: 200 },
|
||||
options: { auto_filter_background: true, min_score: 0.05 },
|
||||
}));
|
||||
expect(useStore.getState().masks).toHaveLength(1);
|
||||
expect(useStore.getState().masks[0]).toEqual(expect.objectContaining({
|
||||
id: 'mask-box',
|
||||
segmentation: [[30, 30, 70, 30, 70, 70]],
|
||||
points: [[150, 100]],
|
||||
metadata: expect.objectContaining({ promptPointCount: 2 }),
|
||||
metadata: expect.objectContaining({ promptPointCount: 2, promptNegativePointCount: 1 }),
|
||||
}));
|
||||
});
|
||||
|
||||
it('removes the SAM2 candidate when a negative point filters it out', async () => {
|
||||
apiMock.predictMask
|
||||
.mockResolvedValueOnce({
|
||||
masks: [
|
||||
{
|
||||
id: 'mask-box',
|
||||
pathData: 'M 10 10 L 90 10 L 90 90 Z',
|
||||
label: 'AI Mask',
|
||||
color: '#06b6d4',
|
||||
segmentation: [[10, 10, 90, 10, 90, 90]],
|
||||
bbox: [10, 10, 80, 80],
|
||||
area: 6400,
|
||||
},
|
||||
],
|
||||
})
|
||||
.mockResolvedValueOnce({ masks: [] });
|
||||
|
||||
const { rerender } = render(<CanvasArea activeTool="box_select" frame={frame} />);
|
||||
const stage = screen.getByTestId('konva-stage');
|
||||
fireEvent.mouseDown(stage, { clientX: 120, clientY: 80 });
|
||||
fireEvent.mouseMove(stage, { clientX: 260, clientY: 200 });
|
||||
fireEvent.mouseUp(stage, { clientX: 260, clientY: 200 });
|
||||
|
||||
await waitFor(() => expect(useStore.getState().masks).toHaveLength(1));
|
||||
|
||||
rerender(<CanvasArea activeTool="point_neg" frame={frame} />);
|
||||
fireEvent.click(stage, { clientX: 180, clientY: 120 });
|
||||
|
||||
await waitFor(() => expect(apiMock.predictMask).toHaveBeenNthCalledWith(2, {
|
||||
imageId: 'frame-1',
|
||||
imageWidth: 640,
|
||||
imageHeight: 360,
|
||||
model: 'sam2',
|
||||
points: [{ x: 180, y: 120, type: 'neg' }],
|
||||
box: { x1: 120, y1: 80, x2: 260, y2: 200 },
|
||||
options: { auto_filter_background: true, min_score: 0.05 },
|
||||
}));
|
||||
await waitFor(() => expect(useStore.getState().masks).toHaveLength(0));
|
||||
expect(await screen.findByText(/反向点已排除当前候选区域/)).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('renders only masks that belong to the current frame', () => {
|
||||
useStore.setState({
|
||||
masks: [
|
||||
@@ -250,6 +292,28 @@ describe('CanvasArea', () => {
|
||||
await waitFor(() => expect(useStore.getState().selectedMaskIds).toEqual(['m1']));
|
||||
});
|
||||
|
||||
it('keeps a mask selected when opening the workspace polygon editor from AI results', () => {
|
||||
useStore.setState({
|
||||
selectedMaskIds: ['m1'],
|
||||
masks: [
|
||||
{
|
||||
id: 'm1',
|
||||
frameId: 'frame-1',
|
||||
pathData: 'M 0 0 L 10 0 L 10 10 Z',
|
||||
label: 'A',
|
||||
color: '#fff',
|
||||
segmentation: [[0, 0, 10, 0, 10, 10]],
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
render(<CanvasArea activeTool="edit_polygon" frame={frame} />);
|
||||
|
||||
expect(useStore.getState().selectedMaskIds).toEqual(['m1']);
|
||||
expect(screen.getAllByTestId('konva-circle')
|
||||
.filter((element) => element.getAttribute('data-fill') === '#ffffff')).toHaveLength(3);
|
||||
});
|
||||
|
||||
it('renders imported GT seed points for editable point regions', () => {
|
||||
useStore.setState({
|
||||
masks: [
|
||||
@@ -415,6 +479,34 @@ describe('CanvasArea', () => {
|
||||
}));
|
||||
});
|
||||
|
||||
it('selects a polygon with the edit tool and inserts a vertex by double-clicking an edge', () => {
|
||||
useStore.setState({
|
||||
masks: [
|
||||
{
|
||||
id: 'draft-1',
|
||||
frameId: 'frame-1',
|
||||
pathData: 'M 10 10 L 90 10 L 90 40 Z',
|
||||
label: 'Draft',
|
||||
color: '#06b6d4',
|
||||
saveStatus: 'draft',
|
||||
segmentation: [[10, 10, 90, 10, 90, 40]],
|
||||
bbox: [10, 10, 80, 30],
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
render(<CanvasArea activeTool="edit_polygon" frame={frame} />);
|
||||
const path = screen.getByTestId('konva-path');
|
||||
fireEvent.click(path);
|
||||
fireEvent.doubleClick(path, { clientX: 50, clientY: 10 });
|
||||
|
||||
expect(useStore.getState().masks[0]).toEqual(expect.objectContaining({
|
||||
segmentation: [[10, 10, 50, 10, 90, 10, 90, 40]],
|
||||
pathData: 'M 10 10 L 50 10 L 90 10 L 90 40 Z',
|
||||
saveStatus: 'draft',
|
||||
}));
|
||||
});
|
||||
|
||||
it('edits the selected polygon in a multi-polygon mask', () => {
|
||||
useStore.setState({
|
||||
masks: [
|
||||
|
||||
@@ -19,6 +19,7 @@ type PromptBox = { x1: number; y1: number; x2: number; y2: number };
|
||||
|
||||
const DRAG_MANUAL_TOOLS = new Set(['create_rectangle', 'create_circle', 'create_line']);
|
||||
const POLYGON_TOOL = 'create_polygon';
|
||||
const EDIT_POLYGON_TOOL = 'edit_polygon';
|
||||
const POINT_TOOL = 'create_point';
|
||||
const BOOLEAN_TOOLS = new Set(['area_merge', 'area_remove']);
|
||||
const POLYGON_CLOSE_RADIUS = 8;
|
||||
@@ -95,6 +96,32 @@ function pointDistance(a: CanvasPoint, b: CanvasPoint): number {
|
||||
return Math.hypot(a.x - b.x, a.y - b.y);
|
||||
}
|
||||
|
||||
function distanceToSegmentSquared(point: CanvasPoint, start: CanvasPoint, end: CanvasPoint): number {
|
||||
const dx = end.x - start.x;
|
||||
const dy = end.y - start.y;
|
||||
const lengthSquared = dx * dx + dy * dy;
|
||||
if (lengthSquared === 0) {
|
||||
return (point.x - start.x) ** 2 + (point.y - start.y) ** 2;
|
||||
}
|
||||
const t = clamp(((point.x - start.x) * dx + (point.y - start.y) * dy) / lengthSquared, 0, 1);
|
||||
const projected = { x: start.x + t * dx, y: start.y + t * dy };
|
||||
return (point.x - projected.x) ** 2 + (point.y - projected.y) ** 2;
|
||||
}
|
||||
|
||||
function nearestPolygonEdgeIndex(points: CanvasPoint[], point: CanvasPoint): number {
|
||||
return points.reduce((bestIndex, start, index) => {
|
||||
const end = points[(index + 1) % points.length];
|
||||
if (!end) return bestIndex;
|
||||
const bestStart = points[bestIndex];
|
||||
const bestEnd = points[(bestIndex + 1) % points.length];
|
||||
const currentDistance = distanceToSegmentSquared(point, start, end);
|
||||
const bestDistance = bestStart && bestEnd
|
||||
? distanceToSegmentSquared(point, bestStart, bestEnd)
|
||||
: Number.POSITIVE_INFINITY;
|
||||
return currentDistance < bestDistance ? index : bestIndex;
|
||||
}, 0);
|
||||
}
|
||||
|
||||
function segmentationArea(segmentation?: number[][]): number {
|
||||
return (segmentation || []).reduce((sum, polygon) => sum + polygonArea(flatPolygonToPoints(polygon)), 0);
|
||||
}
|
||||
@@ -210,10 +237,11 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
|
||||
const [manualStart, setManualStart] = useState<CanvasPoint | null>(null);
|
||||
const [manualCurrent, setManualCurrent] = useState<CanvasPoint | null>(null);
|
||||
const [polygonPoints, setPolygonPoints] = useState<CanvasPoint[]>([]);
|
||||
const [selectedMaskId, setSelectedMaskId] = useState<string | null>(null);
|
||||
const [selectedMaskIds, setSelectedMaskIds] = useState<string[]>([]);
|
||||
const [selectedMaskId, setSelectedMaskId] = useState<string | null>(() => useStore.getState().selectedMaskIds[0] || null);
|
||||
const [selectedMaskIds, setSelectedMaskIds] = useState<string[]>(() => useStore.getState().selectedMaskIds);
|
||||
const [selectedPolygonIndex, setSelectedPolygonIndex] = useState(0);
|
||||
const [selectedVertexIndex, setSelectedVertexIndex] = useState<number | null>(null);
|
||||
const previousFrameIdRef = useRef<string | undefined>(frame?.id);
|
||||
const [isInferencing, setIsInferencing] = useState(false);
|
||||
const [inferenceMessage, setInferenceMessage] = useState('');
|
||||
|
||||
@@ -253,6 +281,7 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
|
||||
const draftMaskCount = frameMasks.filter((mask) => !mask.annotationId).length;
|
||||
const dirtyMaskCount = frameMasks.filter((mask) => mask.saveStatus === 'dirty').length;
|
||||
const isBooleanTool = BOOLEAN_TOOLS.has(effectiveTool);
|
||||
const isPolygonEditTool = effectiveTool === 'move' || effectiveTool === EDIT_POLYGON_TOOL;
|
||||
|
||||
useEffect(() => {
|
||||
const handleResize = () => {
|
||||
@@ -273,11 +302,22 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
|
||||
setManualStart(null);
|
||||
setManualCurrent(null);
|
||||
setPolygonPoints([]);
|
||||
setSelectedVertexIndex(null);
|
||||
if (!isPolygonEditTool && !isBooleanTool) {
|
||||
setSelectedMaskId(null);
|
||||
setSelectedMaskIds([]);
|
||||
setSelectedPolygonIndex(0);
|
||||
}
|
||||
}, [effectiveTool, isBooleanTool, isPolygonEditTool]);
|
||||
|
||||
useEffect(() => {
|
||||
if (previousFrameIdRef.current === frame?.id) return;
|
||||
previousFrameIdRef.current = frame?.id;
|
||||
setSelectedMaskId(null);
|
||||
setSelectedMaskIds([]);
|
||||
setSelectedPolygonIndex(0);
|
||||
setSelectedVertexIndex(null);
|
||||
}, [effectiveTool, frame?.id]);
|
||||
}, [frame?.id]);
|
||||
|
||||
useEffect(() => {
|
||||
setPoints([]);
|
||||
@@ -420,6 +460,10 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
|
||||
setIsInferencing(true);
|
||||
setInferenceMessage('');
|
||||
try {
|
||||
const hasNegativePrompt = Boolean(promptPoints?.some((point) => point.type === 'neg'));
|
||||
const existingCandidate = !options.resetCandidate && samCandidateMaskId
|
||||
? masks.find((mask) => mask.id === samCandidateMaskId)
|
||||
: null;
|
||||
const result = await predictMask({
|
||||
imageId: frame.id,
|
||||
imageWidth,
|
||||
@@ -429,13 +473,11 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
|
||||
? promptPoints.map((p) => ({ x: p.x, y: p.y, type: p.type }))
|
||||
: undefined,
|
||||
box: promptBox,
|
||||
...(hasNegativePrompt ? { options: { auto_filter_background: true, min_score: 0.05 } } : {}),
|
||||
});
|
||||
|
||||
const [m] = result.masks;
|
||||
if (m) {
|
||||
const existingCandidate = !options.resetCandidate && samCandidateMaskId
|
||||
? masks.find((mask) => mask.id === samCandidateMaskId)
|
||||
: null;
|
||||
const label = activeClass?.name || existingCandidate?.label || m.label;
|
||||
const color = activeClass?.color || existingCandidate?.color || m.color;
|
||||
const metadata = {
|
||||
@@ -443,6 +485,7 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
|
||||
source: aiModel === 'sam3' ? 'sam3_box' : 'sam2_interactive',
|
||||
promptBox: promptBox || null,
|
||||
promptPointCount: promptPoints?.length || 0,
|
||||
promptNegativePointCount: promptPoints?.filter((point) => point.type === 'neg').length || 0,
|
||||
};
|
||||
const nextMask = {
|
||||
frameId: frame.id,
|
||||
@@ -476,7 +519,15 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
|
||||
});
|
||||
}
|
||||
} else {
|
||||
setInferenceMessage('模型没有返回可用区域,请调整点/框提示后重试。');
|
||||
if (existingCandidate && hasNegativePrompt) {
|
||||
setMasks(masks.filter((mask) => mask.id !== existingCandidate.id));
|
||||
setSamCandidateMaskId(null);
|
||||
setSelectedMaskId(null);
|
||||
setSelectedMaskIds([]);
|
||||
setInferenceMessage('反向点已排除当前候选区域,请重新框选或添加新的正向点。');
|
||||
} else {
|
||||
setInferenceMessage('模型没有返回可用区域,请调整点/框提示后重试。');
|
||||
}
|
||||
}
|
||||
} catch (err) {
|
||||
console.error('Inference failed:', err);
|
||||
@@ -485,7 +536,7 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
|
||||
} finally {
|
||||
setIsInferencing(false);
|
||||
}
|
||||
}, [activeClass, activeTemplateId, addMask, aiModel, frame?.height, frame?.id, frame?.width, image?.height, image?.naturalHeight, image?.naturalWidth, image?.width, masks, samCandidateMaskId, updateMask]);
|
||||
}, [activeClass, activeTemplateId, addMask, aiModel, frame?.height, frame?.id, frame?.width, image?.height, image?.naturalHeight, image?.naturalWidth, image?.width, masks, samCandidateMaskId, setMasks, updateMask]);
|
||||
|
||||
const handleApplyActiveClass = () => {
|
||||
if (!frame?.id || !activeClass) return;
|
||||
@@ -598,7 +649,7 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
|
||||
};
|
||||
|
||||
const handleStageClick = (e: any) => {
|
||||
if (effectiveTool === 'move') return;
|
||||
if (isPolygonEditTool) return;
|
||||
if (effectiveTool === 'box_select') return; // handled by mouseup
|
||||
if (DRAG_MANUAL_TOOLS.has(effectiveTool)) return;
|
||||
|
||||
@@ -716,7 +767,7 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
|
||||
|
||||
window.addEventListener('keydown', handleKeyDown);
|
||||
return () => window.removeEventListener('keydown', handleKeyDown);
|
||||
}, [deleteMasksById, effectiveTool, finishPolygon, polygonPoints, redoMasks, selectedMask, selectedMaskIds, selectedPolygonIndex, selectedVertexIndex, undoMasks, updatePolygonMask]);
|
||||
}, [deleteMasksById, effectiveTool, finishPolygon, isPolygonEditTool, polygonPoints, redoMasks, selectedMask, selectedMaskIds, selectedPolygonIndex, selectedVertexIndex, undoMasks, updatePolygonMask]);
|
||||
|
||||
const boxRect = React.useMemo(() => {
|
||||
if (!boxStart || !boxCurrent) return null;
|
||||
@@ -753,7 +804,7 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
|
||||
};
|
||||
|
||||
const handleMaskSelect = (mask: Mask, event: any, polygonIndex = 0) => {
|
||||
if (effectiveTool !== 'move' && !isBooleanTool) return;
|
||||
if (!isPolygonEditTool && !isBooleanTool) return;
|
||||
event.cancelBubble = true;
|
||||
if (isBooleanTool) {
|
||||
setSelectedMaskIds((current) => (
|
||||
@@ -807,6 +858,25 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
|
||||
updatePolygonMask(mask, nextPoints, selectedPolygonIndex);
|
||||
};
|
||||
|
||||
const handlePathDoubleClick = (mask: Mask, event: any, polygonIndex = 0) => {
|
||||
if (effectiveTool !== EDIT_POLYGON_TOOL) return;
|
||||
event.cancelBubble = true;
|
||||
const point = stagePoint(event);
|
||||
const currentPoints = segmentationToPoints(mask.segmentation, polygonIndex);
|
||||
if (!point || currentPoints.length < 3) return;
|
||||
const edgeIndex = nearestPolygonEdgeIndex(currentPoints, point);
|
||||
const nextPoints = [
|
||||
...currentPoints.slice(0, edgeIndex + 1),
|
||||
point,
|
||||
...currentPoints.slice(edgeIndex + 1),
|
||||
];
|
||||
setSelectedMaskId(mask.id);
|
||||
setSelectedMaskIds([mask.id]);
|
||||
setSelectedPolygonIndex(polygonIndex);
|
||||
setSelectedVertexIndex(edgeIndex + 1);
|
||||
updatePolygonMask(mask, nextPoints, polygonIndex);
|
||||
};
|
||||
|
||||
const handleBooleanOperation = async () => {
|
||||
if (!frame || booleanSelectedMasks.length < 2) return;
|
||||
const primary = booleanSelectedMasks[0];
|
||||
@@ -918,6 +988,8 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
|
||||
strokeWidth={(selectedMaskIds.includes(mask.id) ? 2 : 1) / scale}
|
||||
onClick={(event: any) => handleMaskSelect(mask, event, polygonIndex)}
|
||||
onTap={(event: any) => handleMaskSelect(mask, event, polygonIndex)}
|
||||
onDblClick={(event: any) => handlePathDoubleClick(mask, event, polygonIndex)}
|
||||
onDblTap={(event: any) => handlePathDoubleClick(mask, event, polygonIndex)}
|
||||
/>
|
||||
))}
|
||||
</Group>
|
||||
@@ -987,7 +1059,7 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
|
||||
)))}
|
||||
|
||||
{/* Polygon edge insertion handles */}
|
||||
{!isBooleanTool && selectedMask && selectedMaskPoints.map((point, index) => {
|
||||
{isPolygonEditTool && selectedMask && selectedMaskPoints.map((point, index) => {
|
||||
const next = selectedMaskPoints[(index + 1) % selectedMaskPoints.length];
|
||||
if (!next) return null;
|
||||
return (
|
||||
@@ -1006,7 +1078,7 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
|
||||
})}
|
||||
|
||||
{/* Polygon vertex editor */}
|
||||
{!isBooleanTool && selectedMask && selectedMaskPoints.map((point, index) => (
|
||||
{isPolygonEditTool && selectedMask && selectedMaskPoints.map((point, index) => (
|
||||
<Circle
|
||||
key={`${selectedMask.id}-vertex-${selectedPolygonIndex}-${index}`}
|
||||
x={point.x}
|
||||
|
||||
@@ -12,6 +12,7 @@ const apiMock = vi.hoisted(() => ({
|
||||
const wsMock = vi.hoisted(() => {
|
||||
const state = {
|
||||
callback: undefined as undefined | ((data: any) => void),
|
||||
statusCallback: undefined as undefined | ((status: any) => void),
|
||||
connected: false,
|
||||
};
|
||||
return {
|
||||
@@ -24,6 +25,11 @@ const wsMock = vi.hoisted(() => {
|
||||
state.callback = cb;
|
||||
return vi.fn();
|
||||
}),
|
||||
onStatus: vi.fn((cb: (status: any) => void) => {
|
||||
state.statusCallback = cb;
|
||||
cb(state.connected ? 'connected' : 'disconnected');
|
||||
return vi.fn();
|
||||
}),
|
||||
},
|
||||
};
|
||||
});
|
||||
@@ -45,6 +51,7 @@ describe('Dashboard', () => {
|
||||
vi.clearAllMocks();
|
||||
wsMock.state.connected = false;
|
||||
wsMock.state.callback = undefined;
|
||||
wsMock.state.statusCallback = undefined;
|
||||
apiMock.getDashboardOverview.mockResolvedValue({
|
||||
summary: {
|
||||
project_count: 2,
|
||||
@@ -109,6 +116,20 @@ describe('Dashboard', () => {
|
||||
expect(screen.getByText('44%')).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('updates the websocket badge from connection status callbacks', async () => {
|
||||
render(<Dashboard />);
|
||||
|
||||
await waitFor(() => expect(wsMock.progressWS.onStatus).toHaveBeenCalled());
|
||||
expect(screen.getByText('WebSocket 断开')).toBeInTheDocument();
|
||||
|
||||
act(() => {
|
||||
wsMock.state.connected = true;
|
||||
wsMock.state.statusCallback?.('connected');
|
||||
});
|
||||
|
||||
expect(screen.getByText('WebSocket 已连接')).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('adds activity logs for complete and status messages', async () => {
|
||||
render(<Dashboard />);
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import React, { useState, useEffect } from 'react';
|
||||
import { Activity, AlertTriangle, Clock, Folders, CheckCircle2, Info, Loader2, RotateCcw, XCircle } from 'lucide-react';
|
||||
import { progressWS, type ProgressMessage } from '../lib/websocket';
|
||||
import { progressWS, type ConnectionStatus, type ProgressMessage } from '../lib/websocket';
|
||||
import { cn } from '../lib/utils';
|
||||
import {
|
||||
cancelTask,
|
||||
@@ -178,6 +178,9 @@ export function Dashboard() {
|
||||
]);
|
||||
}
|
||||
});
|
||||
const unsubscribeStatus = progressWS.onStatus((status: ConnectionStatus) => {
|
||||
if (mounted) setIsConnected(status === 'connected');
|
||||
});
|
||||
|
||||
const checkConnection = setInterval(() => {
|
||||
if (mounted) setIsConnected(progressWS.isConnected());
|
||||
@@ -186,6 +189,7 @@ export function Dashboard() {
|
||||
return () => {
|
||||
mounted = false;
|
||||
unsubscribe();
|
||||
unsubscribeStatus();
|
||||
clearInterval(checkConnection);
|
||||
progressWS.disconnect();
|
||||
};
|
||||
|
||||
@@ -56,10 +56,9 @@ describe('ProjectLibrary', () => {
|
||||
expect(useStore.getState().projects[0]).toEqual(expect.objectContaining({ id: 'p2' }));
|
||||
});
|
||||
|
||||
it('imports video by creating a project, uploading media, parsing frames and refreshing projects', async () => {
|
||||
it('imports video by creating a project and uploading media without parsing frames', async () => {
|
||||
apiMock.createProject.mockResolvedValueOnce({ id: 'p3', name: 'clip.mp4', status: 'pending' });
|
||||
apiMock.uploadMedia.mockResolvedValueOnce({ url: 'http://file', id: 'object' });
|
||||
apiMock.parseMedia.mockResolvedValueOnce({ frames_extracted: 1 });
|
||||
apiMock.getProjects.mockResolvedValue([]);
|
||||
|
||||
const { container } = render(<ProjectLibrary onProjectSelect={vi.fn()} />);
|
||||
@@ -70,10 +69,24 @@ describe('ProjectLibrary', () => {
|
||||
|
||||
await waitFor(() => expect(apiMock.createProject).toHaveBeenCalledWith(expect.objectContaining({
|
||||
name: 'clip.mp4',
|
||||
parse_fps: 30,
|
||||
})));
|
||||
expect(apiMock.uploadMedia).toHaveBeenCalledWith(file, 'p3');
|
||||
expect(apiMock.parseMedia).toHaveBeenCalledWith('p3');
|
||||
expect(apiMock.parseMedia).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('generates frames from an imported video with the selected FPS', async () => {
|
||||
apiMock.getProjects
|
||||
.mockResolvedValueOnce([{ id: 'p4', name: 'clip.mp4', status: 'pending', frames: 0, video_path: 'uploads/clip.mp4', parse_fps: 30 }])
|
||||
.mockResolvedValueOnce([{ id: 'p4', name: 'clip.mp4', status: 'parsing', frames: 0, video_path: 'uploads/clip.mp4', parse_fps: 12 }]);
|
||||
apiMock.parseMedia.mockResolvedValueOnce({ id: 22, status: 'queued', progress: 0 });
|
||||
|
||||
const { container } = render(<ProjectLibrary onProjectSelect={vi.fn()} />);
|
||||
|
||||
fireEvent.click(await screen.findByRole('button', { name: '生成帧' }));
|
||||
fireEvent.change(container.querySelector('input[type="range"]') as HTMLInputElement, { target: { value: '12' } });
|
||||
fireEvent.click(screen.getByRole('button', { name: '开始生成帧' }));
|
||||
|
||||
await waitFor(() => expect(apiMock.parseMedia).toHaveBeenCalledWith('p4', { parseFps: 12 }));
|
||||
});
|
||||
|
||||
it('imports only valid DICOM files and parses the returned project', async () => {
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import React, { useState, useEffect, useRef } from 'react';
|
||||
import { UploadCloud, Film, Settings2, MoreHorizontal, Plus, Loader2, Activity } from 'lucide-react';
|
||||
import { UploadCloud, Film, Settings2, MoreHorizontal, Plus, Loader2, Activity, Images } from 'lucide-react';
|
||||
import { cn } from '../lib/utils';
|
||||
import { useStore } from '../store/useStore';
|
||||
import { getProjects, createProject, uploadMedia, parseMedia, uploadDicomBatch } from '../lib/api';
|
||||
@@ -22,7 +22,10 @@ export function ProjectLibrary({ onProjectSelect }: ProjectLibraryProps) {
|
||||
const [showImportMenu, setShowImportMenu] = useState(false);
|
||||
const [showVideoConfig, setShowVideoConfig] = useState(false);
|
||||
const [pendingFile, setPendingFile] = useState<File | null>(null);
|
||||
const [parseFps, setParseFps] = useState(30);
|
||||
const [frameProject, setFrameProject] = useState<Project | null>(null);
|
||||
const [showFrameConfig, setShowFrameConfig] = useState(false);
|
||||
const [frameParseFps, setFrameParseFps] = useState(30);
|
||||
const [isGeneratingFrames, setIsGeneratingFrames] = useState(false);
|
||||
const videoInputRef = useRef<HTMLInputElement>(null);
|
||||
const dicomInputRef = useRef<HTMLInputElement>(null);
|
||||
|
||||
@@ -57,7 +60,6 @@ export function ProjectLibrary({ onProjectSelect }: ProjectLibraryProps) {
|
||||
|
||||
const handleVideoSelect = (file: File) => {
|
||||
setPendingFile(file);
|
||||
setParseFps(30);
|
||||
setShowVideoConfig(true);
|
||||
};
|
||||
|
||||
@@ -69,11 +71,9 @@ export function ProjectLibrary({ onProjectSelect }: ProjectLibraryProps) {
|
||||
const newProject = await createProject({
|
||||
name: pendingFile.name,
|
||||
description: `导入于 ${new Date().toLocaleString()}`,
|
||||
parse_fps: parseFps,
|
||||
});
|
||||
const result = await uploadMedia(pendingFile, String(newProject.id));
|
||||
await parseMedia(String(newProject.id));
|
||||
alert(`上传成功: ${pendingFile.name}\n已保存至: ${result.url}`);
|
||||
alert(`视频导入成功: ${pendingFile.name}\n已保存至: ${result.url}\n需要生成帧时,请在项目卡片点击“生成帧”。`);
|
||||
const data = await getProjects();
|
||||
setProjects(data);
|
||||
} catch (err) {
|
||||
@@ -86,6 +86,31 @@ export function ProjectLibrary({ onProjectSelect }: ProjectLibraryProps) {
|
||||
}
|
||||
};
|
||||
|
||||
const openFrameConfig = (project: Project, event: React.MouseEvent) => {
|
||||
event.stopPropagation();
|
||||
setFrameProject(project);
|
||||
setFrameParseFps(Math.round(project.parse_fps || 30));
|
||||
setShowFrameConfig(true);
|
||||
};
|
||||
|
||||
const handleGenerateFrames = async () => {
|
||||
if (!frameProject?.id) return;
|
||||
setIsGeneratingFrames(true);
|
||||
try {
|
||||
const task = await parseMedia(frameProject.id, { parseFps: frameParseFps });
|
||||
alert(`生成帧任务已入队 #${task.id}\n帧率: ${frameParseFps} FPS\n可在 Dashboard 查看进度。`);
|
||||
const data = await getProjects();
|
||||
setProjects(data);
|
||||
setShowFrameConfig(false);
|
||||
setFrameProject(null);
|
||||
} catch (err) {
|
||||
console.error('Frame generation failed:', err);
|
||||
alert('生成帧失败,请检查后端服务或项目源文件');
|
||||
} finally {
|
||||
setIsGeneratingFrames(false);
|
||||
}
|
||||
};
|
||||
|
||||
const handleDicomUpload = async (files: FileList | null) => {
|
||||
if (!files || files.length === 0) return;
|
||||
const dcmFiles = Array.from(files).filter((f) => f.name.toLowerCase().endsWith('.dcm'));
|
||||
@@ -209,7 +234,7 @@ export function ProjectLibrary({ onProjectSelect }: ProjectLibraryProps) {
|
||||
)}
|
||||
<div className="absolute top-2 right-2 flex gap-2">
|
||||
<span className="backdrop-blur-md bg-black/40 text-gray-200 text-[10px] font-mono px-2 py-1 rounded border border-white/10 uppercase tracking-widest">
|
||||
{proj.source_type === 'dicom' ? 'DICOM' : (proj.fps || '30FPS')}
|
||||
{proj.source_type === 'dicom' ? 'DICOM' : (proj.video_path && (proj.frames ?? 0) === 0 ? '待生成帧' : (proj.fps || '30FPS'))}
|
||||
</span>
|
||||
<span className="backdrop-blur-md bg-black/40 text-gray-200 text-[10px] px-2 py-1 rounded border border-white/10 flex items-center gap-1 uppercase tracking-widest">
|
||||
{proj.status === 'ready' ? (
|
||||
@@ -235,6 +260,15 @@ export function ProjectLibrary({ onProjectSelect }: ProjectLibraryProps) {
|
||||
<span className="flex items-center gap-1.5 text-cyan-400/80"><Activity size={12} /> 原 {proj.original_fps.toFixed(1)}fps</span>
|
||||
)}
|
||||
</div>
|
||||
{proj.video_path && (proj.frames ?? 0) === 0 && proj.status !== 'parsing' && (
|
||||
<button
|
||||
onClick={(event) => openFrameConfig(proj, event)}
|
||||
className="mt-3 inline-flex items-center justify-center gap-2 rounded-md border border-cyan-500/30 bg-cyan-500/10 px-3 py-2 text-xs font-medium text-cyan-200 hover:bg-cyan-500/20 transition-colors"
|
||||
>
|
||||
<Images size={14} />
|
||||
生成帧
|
||||
</button>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
))}
|
||||
@@ -245,24 +279,10 @@ export function ProjectLibrary({ onProjectSelect }: ProjectLibraryProps) {
|
||||
{showVideoConfig && pendingFile && (
|
||||
<div className="fixed inset-0 z-50 flex items-center justify-center bg-black/60 backdrop-blur-sm">
|
||||
<div className="bg-[#111] border border-white/10 rounded-2xl p-6 w-full max-w-md shadow-2xl">
|
||||
<h2 className="text-lg font-semibold text-white mb-4">导入视频配置</h2>
|
||||
<h2 className="text-lg font-semibold text-white mb-4">导入视频文件</h2>
|
||||
<div className="space-y-4">
|
||||
<div className="text-sm text-gray-400">文件: <span className="text-gray-200">{pendingFile.name}</span></div>
|
||||
<div>
|
||||
<label className="block text-xs font-medium text-gray-400 uppercase tracking-widest mb-2">解析帧率 (FPS)</label>
|
||||
<div className="flex items-center gap-3">
|
||||
<input
|
||||
type="range"
|
||||
min="1"
|
||||
max="60"
|
||||
value={parseFps}
|
||||
onChange={(e) => setParseFps(parseInt(e.target.value))}
|
||||
className="flex-1 accent-cyan-500"
|
||||
/>
|
||||
<span className="text-sm font-mono text-cyan-400 w-12 text-right">{parseFps}</span>
|
||||
</div>
|
||||
<p className="text-[10px] text-gray-600 mt-1">帧率越低,提取的帧越少,处理速度越快</p>
|
||||
</div>
|
||||
<p className="text-xs leading-5 text-gray-500">此步骤只上传源视频并创建项目,不会立即拆帧。拆帧时再选择目标 FPS。</p>
|
||||
</div>
|
||||
<div className="flex justify-end gap-3 mt-6">
|
||||
<button
|
||||
@@ -282,6 +302,49 @@ export function ProjectLibrary({ onProjectSelect }: ProjectLibraryProps) {
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Frame generation FPS config modal */}
|
||||
{showFrameConfig && frameProject && (
|
||||
<div className="fixed inset-0 z-50 flex items-center justify-center bg-black/60 backdrop-blur-sm">
|
||||
<div className="bg-[#111] border border-white/10 rounded-2xl p-6 w-full max-w-md shadow-2xl">
|
||||
<h2 className="text-lg font-semibold text-white mb-4">生成帧</h2>
|
||||
<div className="space-y-4">
|
||||
<div className="text-sm text-gray-400">项目: <span className="text-gray-200">{frameProject.name}</span></div>
|
||||
<div>
|
||||
<label className="block text-xs font-medium text-gray-400 uppercase tracking-widest mb-2">生成帧率 (FPS)</label>
|
||||
<div className="flex items-center gap-3">
|
||||
<input
|
||||
type="range"
|
||||
min="1"
|
||||
max="60"
|
||||
value={frameParseFps}
|
||||
onChange={(e) => setFrameParseFps(parseInt(e.target.value))}
|
||||
className="flex-1 accent-cyan-500"
|
||||
/>
|
||||
<span className="text-sm font-mono text-cyan-400 w-12 text-right">{frameParseFps}</span>
|
||||
</div>
|
||||
<p className="text-[10px] text-gray-600 mt-1">帧率越低,提取的帧越少,处理速度越快</p>
|
||||
</div>
|
||||
</div>
|
||||
<div className="flex justify-end gap-3 mt-6">
|
||||
<button
|
||||
onClick={() => { setShowFrameConfig(false); setFrameProject(null); }}
|
||||
disabled={isGeneratingFrames}
|
||||
className="px-4 py-2 rounded-lg text-sm text-gray-400 hover:text-white transition-colors disabled:opacity-50"
|
||||
>
|
||||
取消
|
||||
</button>
|
||||
<button
|
||||
onClick={handleGenerateFrames}
|
||||
disabled={isGeneratingFrames}
|
||||
className="px-4 py-2 rounded-lg text-sm font-medium bg-cyan-500 hover:bg-cyan-400 text-black transition-all disabled:opacity-60"
|
||||
>
|
||||
{isGeneratingFrames ? '入队中...' : '开始生成帧'}
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* New project modal */}
|
||||
{showModal && (
|
||||
<div className="fixed inset-0 z-50 flex items-center justify-center bg-black/60 backdrop-blur-sm">
|
||||
|
||||
@@ -20,12 +20,14 @@ describe('ToolsPalette', () => {
|
||||
);
|
||||
|
||||
fireEvent.click(screen.getByTitle('创建多边形 (P)'));
|
||||
fireEvent.click(screen.getByTitle('调整多边形 (E)'));
|
||||
fireEvent.click(screen.getByTitle('正向选点 (SAM)'));
|
||||
fireEvent.click(screen.getByTitle('撤销操作 (Ctrl+Z)'));
|
||||
fireEvent.click(screen.getByTitle('重做操作 (Ctrl+Shift+Z)'));
|
||||
|
||||
expect(setActiveTool).toHaveBeenNthCalledWith(1, 'create_polygon');
|
||||
expect(setActiveTool).toHaveBeenNthCalledWith(2, 'point_pos');
|
||||
expect(setActiveTool).toHaveBeenNthCalledWith(2, 'edit_polygon');
|
||||
expect(setActiveTool).toHaveBeenNthCalledWith(3, 'point_pos');
|
||||
expect(onUndo).toHaveBeenCalled();
|
||||
expect(onRedo).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import React from 'react';
|
||||
import { MousePointer2, Hexagon, Square, Circle, Minus, Combine, Scissors, Wand2, Undo, Redo, Crosshair, PlusCircle, MinusCircle, SquareDashed } from 'lucide-react';
|
||||
import { MousePointer2, PencilLine, Hexagon, Square, Circle, Minus, Combine, Scissors, Wand2, Undo, Redo, Crosshair, PlusCircle, MinusCircle, SquareDashed } from 'lucide-react';
|
||||
import { cn } from '../lib/utils';
|
||||
|
||||
interface ToolsPaletteProps {
|
||||
@@ -23,6 +23,7 @@ export function ToolsPalette({
|
||||
}: ToolsPaletteProps) {
|
||||
const tools = [
|
||||
{ id: 'move', icon: MousePointer2, label: '拖拽 / 选择 (V)' },
|
||||
{ id: 'edit_polygon', icon: PencilLine, label: '调整多边形 (E)' },
|
||||
{ id: 'create_polygon', icon: Hexagon, label: '创建多边形 (P)' },
|
||||
{ id: 'create_rectangle', icon: Square, label: '创建矩形 (R)' },
|
||||
{ id: 'create_circle', icon: Circle, label: '创建圆 (O)' },
|
||||
|
||||
@@ -82,23 +82,16 @@ describe('VideoWorkspace', () => {
|
||||
expect(apiMock.getProjectAnnotations).toHaveBeenCalledWith('1');
|
||||
});
|
||||
|
||||
it('triggers parsing when a media project has no frames yet', async () => {
|
||||
apiMock.getProjectFrames
|
||||
.mockResolvedValueOnce([])
|
||||
.mockResolvedValueOnce([
|
||||
{ id: 11, project_id: 1, frame_index: 0, image_url: '/parsed.jpg', width: 320, height: 240 },
|
||||
]);
|
||||
apiMock.parseMedia.mockResolvedValueOnce({ id: 7, status: 'queued', progress: 0 });
|
||||
apiMock.getTask.mockResolvedValueOnce({ id: 7, status: 'success', progress: 100, message: '解析完成' });
|
||||
it('does not auto-generate frames when a media project has no frames yet', async () => {
|
||||
apiMock.getProjectFrames.mockResolvedValueOnce([]);
|
||||
|
||||
render(<VideoWorkspace />);
|
||||
|
||||
await waitFor(() => expect(apiMock.parseMedia).toHaveBeenCalledWith('1'));
|
||||
expect(apiMock.getTask).toHaveBeenCalledWith(7);
|
||||
await waitFor(() => expect(useStore.getState().frames[0]).toEqual(expect.objectContaining({
|
||||
id: '11',
|
||||
url: '/parsed.jpg',
|
||||
})));
|
||||
await waitFor(() => expect(apiMock.getProjectFrames).toHaveBeenCalledWith('1'));
|
||||
expect(apiMock.parseMedia).not.toHaveBeenCalled();
|
||||
expect(apiMock.getTask).not.toHaveBeenCalled();
|
||||
expect(useStore.getState().frames).toEqual([]);
|
||||
expect(await screen.findByText('该项目已导入视频但尚未生成帧,请在项目库点击“生成帧”')).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('hydrates saved annotations after loading frames', async () => {
|
||||
|
||||
@@ -8,10 +8,8 @@ import {
|
||||
exportMasks,
|
||||
getProjectAnnotations,
|
||||
getProjectFrames,
|
||||
getTask,
|
||||
getTemplates,
|
||||
importGtMask,
|
||||
parseMedia,
|
||||
propagateMasks,
|
||||
saveAnnotation,
|
||||
updateAnnotation,
|
||||
@@ -23,10 +21,6 @@ import { FrameTimeline } from './FrameTimeline';
|
||||
import { ModelStatusBadge } from './ModelStatusBadge';
|
||||
import type { Frame } from '../store/useStore';
|
||||
|
||||
function sleep(ms: number) {
|
||||
return new Promise((resolve) => setTimeout(resolve, ms));
|
||||
}
|
||||
|
||||
export function VideoWorkspace({ onNavigateToAI }: { onNavigateToAI?: () => void }) {
|
||||
const gtMaskInputRef = React.useRef<HTMLInputElement>(null);
|
||||
const activeTool = useStore((state) => state.activeTool);
|
||||
@@ -72,64 +66,31 @@ export function VideoWorkspace({ onNavigateToAI }: { onNavigateToAI?: () => void
|
||||
const data = await getProjectFrames(String(currentProject.id));
|
||||
if (cancelled) return;
|
||||
|
||||
if (data.length === 0 && currentProject.video_path) {
|
||||
// No frames yet but video exists -> queue parsing and poll the task.
|
||||
try {
|
||||
const task = await parseMedia(String(currentProject.id));
|
||||
if (cancelled) return;
|
||||
setStatusMessage(`解析任务已入队 #${task.id}`);
|
||||
let completed = false;
|
||||
for (let attempt = 0; attempt < 60; attempt += 1) {
|
||||
const freshTask = await getTask(task.id);
|
||||
if (cancelled) return;
|
||||
setStatusMessage(freshTask.message || `解析进度 ${freshTask.progress}%`);
|
||||
if (freshTask.status === 'success') {
|
||||
completed = true;
|
||||
break;
|
||||
}
|
||||
if (freshTask.status === 'failed') {
|
||||
setStatusMessage(freshTask.error || '解析任务失败');
|
||||
return;
|
||||
}
|
||||
await sleep(2000);
|
||||
}
|
||||
if (!completed) {
|
||||
setStatusMessage('解析仍在后台运行,可稍后刷新工作区');
|
||||
return;
|
||||
}
|
||||
const fresh = await getProjectFrames(String(currentProject.id));
|
||||
if (cancelled) return;
|
||||
const mappedFrames = fresh.map((f) => ({
|
||||
id: String(f.id),
|
||||
projectId: String(f.project_id),
|
||||
index: f.frame_index,
|
||||
url: f.image_url,
|
||||
width: f.width ?? 0,
|
||||
height: f.height ?? 0,
|
||||
timestampMs: f.timestamp_ms ?? undefined,
|
||||
sourceFrameNumber: f.source_frame_number ?? undefined,
|
||||
}));
|
||||
setFrames(mappedFrames);
|
||||
setCurrentFrame(0);
|
||||
await hydrateSavedAnnotations(String(currentProject.id), mappedFrames);
|
||||
} catch (err) {
|
||||
console.error('Parse failed:', err);
|
||||
const mappedFrames = data.map((f) => ({
|
||||
id: String(f.id),
|
||||
projectId: String(f.project_id),
|
||||
index: f.frame_index,
|
||||
url: f.image_url,
|
||||
width: f.width ?? 0,
|
||||
height: f.height ?? 0,
|
||||
timestampMs: f.timestamp_ms ?? undefined,
|
||||
sourceFrameNumber: f.source_frame_number ?? undefined,
|
||||
}));
|
||||
setFrames(mappedFrames);
|
||||
setCurrentFrame(0);
|
||||
if (mappedFrames.length === 0) {
|
||||
setMasks([]);
|
||||
if (currentProject.status === 'parsing') {
|
||||
setStatusMessage('生成帧任务正在后台运行,可在 Dashboard 查看进度');
|
||||
} else if (currentProject.video_path) {
|
||||
setStatusMessage('该项目已导入视频但尚未生成帧,请在项目库点击“生成帧”');
|
||||
} else {
|
||||
setStatusMessage('当前项目没有可显示帧');
|
||||
}
|
||||
} else {
|
||||
const mappedFrames = data.map((f) => ({
|
||||
id: String(f.id),
|
||||
projectId: String(f.project_id),
|
||||
index: f.frame_index,
|
||||
url: f.image_url,
|
||||
width: f.width ?? 0,
|
||||
height: f.height ?? 0,
|
||||
timestampMs: f.timestamp_ms ?? undefined,
|
||||
sourceFrameNumber: f.source_frame_number ?? undefined,
|
||||
}));
|
||||
setFrames(mappedFrames);
|
||||
setCurrentFrame(0);
|
||||
await hydrateSavedAnnotations(String(currentProject.id), mappedFrames);
|
||||
return;
|
||||
}
|
||||
setStatusMessage('');
|
||||
await hydrateSavedAnnotations(String(currentProject.id), mappedFrames);
|
||||
} catch (err) {
|
||||
console.error('Failed to load frames:', err);
|
||||
}
|
||||
|
||||
@@ -2,12 +2,14 @@ import { afterEach, describe, expect, it, vi } from 'vitest';
|
||||
|
||||
describe('progress websocket client', () => {
|
||||
afterEach(() => {
|
||||
vi.useRealTimers();
|
||||
vi.restoreAllMocks();
|
||||
vi.resetModules();
|
||||
vi.unstubAllGlobals();
|
||||
});
|
||||
|
||||
it('connects using the configured URL and reports open state', async () => {
|
||||
it('connects using the configured URL, reports open state, and sends heartbeat pings', async () => {
|
||||
vi.useFakeTimers();
|
||||
const instances: any[] = [];
|
||||
class FakeWebSocket {
|
||||
static CONNECTING = 0;
|
||||
@@ -21,14 +23,26 @@ describe('progress websocket client', () => {
|
||||
instances.push(this);
|
||||
}
|
||||
close = vi.fn();
|
||||
send = vi.fn();
|
||||
}
|
||||
vi.stubGlobal('WebSocket', FakeWebSocket);
|
||||
|
||||
const { progressWS } = await import('./websocket');
|
||||
const statusCallback = vi.fn();
|
||||
const unsubscribeStatus = progressWS.onStatus(statusCallback);
|
||||
progressWS.connect();
|
||||
instances[0].onopen?.();
|
||||
|
||||
expect(instances[0].url).toContain('/ws/progress');
|
||||
expect(progressWS.isConnected()).toBe(true);
|
||||
expect(statusCallback).toHaveBeenCalledWith('connected');
|
||||
expect(instances[0].send).toHaveBeenCalledWith('ping');
|
||||
|
||||
vi.advanceTimersByTime(15000);
|
||||
expect(instances[0].send).toHaveBeenCalledTimes(2);
|
||||
|
||||
unsubscribeStatus();
|
||||
progressWS.disconnect();
|
||||
});
|
||||
|
||||
it('subscribes and unsubscribes progress callbacks', async () => {
|
||||
@@ -43,4 +57,41 @@ describe('progress websocket client', () => {
|
||||
expect(callback).toHaveBeenCalledTimes(1);
|
||||
expect(callback).toHaveBeenCalledWith({ type: 'status', message: 'ok' });
|
||||
});
|
||||
|
||||
it('notifies connection status changes and schedules reconnect on close', async () => {
|
||||
vi.useFakeTimers();
|
||||
const instances: any[] = [];
|
||||
class FakeWebSocket {
|
||||
static CONNECTING = 0;
|
||||
static OPEN = 1;
|
||||
readyState = FakeWebSocket.OPEN;
|
||||
onopen?: () => void;
|
||||
onmessage?: (event: MessageEvent) => void;
|
||||
onclose?: () => void;
|
||||
onerror?: () => void;
|
||||
constructor(public url: string) {
|
||||
instances.push(this);
|
||||
}
|
||||
close = vi.fn();
|
||||
send = vi.fn();
|
||||
}
|
||||
vi.stubGlobal('WebSocket', FakeWebSocket);
|
||||
|
||||
const { progressWS } = await import('./websocket');
|
||||
const statusCallback = vi.fn();
|
||||
const unsubscribeStatus = progressWS.onStatus(statusCallback);
|
||||
|
||||
progressWS.connect();
|
||||
instances[0].onopen?.();
|
||||
instances[0].onclose?.();
|
||||
|
||||
expect(statusCallback).toHaveBeenCalledWith('disconnected');
|
||||
expect(statusCallback).toHaveBeenCalledWith('reconnecting');
|
||||
|
||||
vi.advanceTimersByTime(3000);
|
||||
expect(instances).toHaveLength(2);
|
||||
|
||||
unsubscribeStatus();
|
||||
progressWS.disconnect();
|
||||
});
|
||||
});
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import { WS_PROGRESS_URL } from './config';
|
||||
|
||||
type ProgressCallback = (data: ProgressMessage) => void;
|
||||
type ConnectionStatus = 'connecting' | 'connected' | 'reconnecting' | 'disconnected';
|
||||
type StatusCallback = (status: ConnectionStatus) => void;
|
||||
|
||||
interface ProgressMessage {
|
||||
type: 'progress' | 'status' | 'error' | 'complete' | 'cancelled';
|
||||
@@ -20,9 +22,12 @@ class ProgressWebSocket {
|
||||
private ws: WebSocket | null = null;
|
||||
private url: string;
|
||||
private callbacks: Set<ProgressCallback> = new Set();
|
||||
private statusCallbacks: Set<StatusCallback> = new Set();
|
||||
private reconnectTimer: ReturnType<typeof setTimeout> | null = null;
|
||||
private heartbeatTimer: ReturnType<typeof setInterval> | null = null;
|
||||
private reconnectInterval = 3000;
|
||||
private maxReconnectInterval = 30000;
|
||||
private heartbeatInterval = 15000;
|
||||
private shouldReconnect = false;
|
||||
private shouldCloseAfterOpen = false;
|
||||
private currentInterval = 3000;
|
||||
@@ -38,6 +43,7 @@ class ProgressWebSocket {
|
||||
|
||||
this.shouldReconnect = true;
|
||||
this.shouldCloseAfterOpen = false;
|
||||
this.notifyStatus('connecting');
|
||||
|
||||
try {
|
||||
this.ws = new WebSocket(this.url);
|
||||
@@ -50,6 +56,8 @@ class ProgressWebSocket {
|
||||
return;
|
||||
}
|
||||
this.currentInterval = this.reconnectInterval;
|
||||
this.startHeartbeat();
|
||||
this.notifyStatus('connected');
|
||||
console.log('[WebSocket] Connected to progress stream');
|
||||
};
|
||||
|
||||
@@ -64,7 +72,9 @@ class ProgressWebSocket {
|
||||
|
||||
this.ws.onclose = () => {
|
||||
console.log('[WebSocket] Connection closed');
|
||||
this.stopHeartbeat();
|
||||
this.ws = null;
|
||||
this.notifyStatus('disconnected');
|
||||
if (this.shouldReconnect) {
|
||||
this.scheduleReconnect();
|
||||
}
|
||||
@@ -72,7 +82,9 @@ class ProgressWebSocket {
|
||||
|
||||
this.ws.onerror = () => {
|
||||
// 静默处理错误,避免在 CONNECTING 状态时 close 触发浏览器报错
|
||||
this.stopHeartbeat();
|
||||
this.ws = null;
|
||||
this.notifyStatus('disconnected');
|
||||
if (this.shouldReconnect) {
|
||||
this.scheduleReconnect();
|
||||
}
|
||||
@@ -85,6 +97,7 @@ class ProgressWebSocket {
|
||||
|
||||
disconnect() {
|
||||
this.shouldReconnect = false;
|
||||
this.stopHeartbeat();
|
||||
if (this.reconnectTimer) {
|
||||
clearTimeout(this.reconnectTimer);
|
||||
this.reconnectTimer = null;
|
||||
@@ -102,6 +115,7 @@ class ProgressWebSocket {
|
||||
this.ws.close();
|
||||
}
|
||||
this.ws = null;
|
||||
this.notifyStatus('disconnected');
|
||||
}
|
||||
|
||||
onProgress(callback: ProgressCallback) {
|
||||
@@ -111,21 +125,53 @@ class ProgressWebSocket {
|
||||
};
|
||||
}
|
||||
|
||||
onStatus(callback: StatusCallback) {
|
||||
this.statusCallbacks.add(callback);
|
||||
callback(this.isConnected() ? 'connected' : 'disconnected');
|
||||
return () => {
|
||||
this.statusCallbacks.delete(callback);
|
||||
};
|
||||
}
|
||||
|
||||
private scheduleReconnect() {
|
||||
if (this.reconnectTimer) {
|
||||
clearTimeout(this.reconnectTimer);
|
||||
}
|
||||
this.notifyStatus('reconnecting');
|
||||
this.reconnectTimer = setTimeout(() => {
|
||||
console.log(`[WebSocket] Reconnecting in ${this.currentInterval}ms...`);
|
||||
console.log('[WebSocket] Reconnecting to progress stream...');
|
||||
this.connect();
|
||||
this.currentInterval = Math.min(this.currentInterval * 1.5, this.maxReconnectInterval);
|
||||
}, this.currentInterval);
|
||||
}
|
||||
|
||||
private startHeartbeat() {
|
||||
this.stopHeartbeat();
|
||||
this.sendHeartbeat();
|
||||
this.heartbeatTimer = setInterval(() => this.sendHeartbeat(), this.heartbeatInterval);
|
||||
}
|
||||
|
||||
private stopHeartbeat() {
|
||||
if (this.heartbeatTimer) {
|
||||
clearInterval(this.heartbeatTimer);
|
||||
this.heartbeatTimer = null;
|
||||
}
|
||||
}
|
||||
|
||||
private sendHeartbeat() {
|
||||
if (this.ws?.readyState === WebSocket.OPEN) {
|
||||
this.ws.send('ping');
|
||||
}
|
||||
}
|
||||
|
||||
private notifyStatus(status: ConnectionStatus) {
|
||||
this.statusCallbacks.forEach((cb) => cb(status));
|
||||
}
|
||||
|
||||
isConnected(): boolean {
|
||||
return this.ws !== null && this.ws.readyState === WebSocket.OPEN;
|
||||
}
|
||||
}
|
||||
|
||||
export const progressWS = new ProgressWebSocket();
|
||||
export type { ProgressMessage };
|
||||
export type { ConnectionStatus, ProgressMessage };
|
||||
|
||||
@@ -102,6 +102,15 @@ vi.mock('react-konva', () => ({
|
||||
props.onClick?.(konvaEvent);
|
||||
if (konvaEvent.cancelBubble) event.stopPropagation();
|
||||
}}
|
||||
onDoubleClick={(event) => {
|
||||
const point = {
|
||||
x: event.clientX || 120,
|
||||
y: event.clientY || 80,
|
||||
};
|
||||
const konvaEvent = { ...makeStageEvent(point.x, point.y), cancelBubble: false };
|
||||
props.onDblClick?.(konvaEvent);
|
||||
if (konvaEvent.cancelBubble) event.stopPropagation();
|
||||
}}
|
||||
/>
|
||||
),
|
||||
}));
|
||||
|
||||
Reference in New Issue
Block a user