feat: 完善 SAM2.1 模型选择与标注工作流

- 后端 SAM2 引擎新增 sam2.1_hiera_tiny、sam2.1_hiera_small、sam2.1_hiera_base_plus、sam2.1_hiera_large 四个变体定义,并按变体维护 checkpoint/config、image predictor、video predictor、加载状态、错误信息和真实状态回报。

- 后端 SAM registry 仅暴露当前产品启用的 SAM2.1 变体,保留 sam2 作为 tiny 兼容别名,拒绝 sam3 产品入口,并把 point、box、interactive、auto、propagate 都分发到所选 SAM2.1 变体。

- 后端默认配置和下载脚本切换到 SAM2.1 checkpoint 命名,支持 legacy SAM2 checkpoint fallback,并在状态消息中标出 fallback 使用情况。

- 前端全局 AI 模型状态新增 SAM2.1 tiny/small/base+/large 类型和默认 tiny,API 请求默认携带 sam2.1_hiera_tiny,AI 页面提供模型变体选择和所选模型状态展示。

- AI 智能分割页移除当前产品不使用的 SAM3/文本提示入口,保留正向点、反向点、框选和参数开关;AI 页只展示本页生成的候选 mask,并支持遮罩清晰度调节、候选 mask 上继续加正/反点、清空本页候选、推送到工作区编辑。

- 工作区和 Canvas 补强 SAM2 交互式细化链路:框选后正/反点继续细化同一个候选 mask,反向点请求启用背景过滤,空结果会移除被否定候选;AI 推送到工作区后保留选中态和未保存 draft mask。

- 工作区标注保存闭环补强:未保存 mask 可归档保存,dirty saved mask 可更新,保存后用后端 saved annotation 替换已提交 draft,清空/删除已保存 mask 时同步后端删除。

- Dashboard 任务进度区改为展示 queued、running、success、failed、cancelled 最近任务,处理中统计只计算 queued/running,并保留近期完成记录。

- 时间轴在顶部时间进度条和底部缩略图导航轴之间新增已编辑帧标记带,基于当前项目帧内 masks 标出已有编辑/标注的帧,并支持点击标记跳转。

- 前端测试覆盖 SAM2.1 变体选择、模型状态徽标、AI 页候选隔离、遮罩透明度、候选上追加正/反点、推送工作区保留选择、Canvas 交互式细化、VideoWorkspace 传播/保存、Dashboard 进度和时间轴已编辑帧标记。

- 后端测试覆盖 SAM2.1 变体状态、sam2 alias 兼容、sam3 禁用、semantic 禁用、传播标注保存、Dashboard 最近任务状态和 SAM3 历史测试跳过说明。

- README、AGENTS 和 doc 文档同步当前真实进度,更新 SAM2.1 变体、SAM3 禁用、接口契约、设计冻结、需求冻结、前端元素审计、实施计划、FastAPI docs 说明和测试矩阵。
This commit is contained in:
2026-05-01 23:39:53 +08:00
parent 8a9247075e
commit 29a1a87e52
38 changed files with 1087 additions and 631 deletions

View File

@@ -22,23 +22,24 @@ describe('AISegmentation', () => {
frames: [{ id: 'frame-1', projectId: 'project-1', index: 0, url: '/frame.jpg', width: 640, height: 360 }],
});
apiMock.getAiModelStatus.mockResolvedValue({
selected_model: 'sam2',
selected_model: 'sam2.1_hiera_tiny',
gpu: { available: true, device: 'cuda', name: 'RTX 4090', torch_available: true },
models: [
{ id: 'sam2', label: 'SAM 2', available: true, loaded: false, device: 'cuda', supports: ['point', 'box'], message: 'SAM 2 ready', package_available: true, checkpoint_exists: true, python_ok: true, torch_ok: true, cuda_required: false },
{ id: 'sam3', label: 'SAM 3', available: false, loaded: false, device: 'cuda', supports: ['semantic'], message: 'SAM 3 missing runtime', package_available: false, checkpoint_exists: false, python_ok: false, torch_ok: true, cuda_required: true },
{ id: 'sam2.1_hiera_tiny', label: 'SAM 2.1 Tiny', available: true, loaded: false, device: 'cuda', supports: ['point', 'box'], message: 'SAM 2.1 Tiny ready', package_available: true, checkpoint_exists: true, python_ok: true, torch_ok: true, cuda_required: false },
],
});
});
it('lets the user choose SAM3 for subsequent predictions', async () => {
it('shows the SAM2.1 variant selector without exposing SAM3', async () => {
render(<AISegmentation onSendToWorkspace={vi.fn()} />);
const sam3Button = (await screen.findByText('SAM3')).closest('button')!;
fireEvent.click(sam3Button);
expect(useStore.getState().aiModel).toBe('sam3');
expect(await screen.findByText('SAM 3 missing runtime')).toBeInTheDocument();
expect(await screen.findByText('SAM 2.1 Tiny')).toBeInTheDocument();
expect(screen.getByText('tiny')).toBeInTheDocument();
expect(screen.getByText('small')).toBeInTheDocument();
expect(screen.getByText('base+')).toBeInTheDocument();
expect(screen.getByText('large')).toBeInTheDocument();
expect(screen.queryByText('SAM3')).not.toBeInTheDocument();
expect(apiMock.getAiModelStatus).toHaveBeenCalledWith('sam2.1_hiera_tiny');
});
it('passes enabled inference parameters to the backend', async () => {
@@ -53,7 +54,7 @@ describe('AISegmentation', () => {
imageId: 'frame-1',
imageWidth: 640,
imageHeight: 360,
model: 'sam2',
model: 'sam2.1_hiera_tiny',
points: [{ x: 120, y: 80, type: 'pos' }],
options: {
crop_to_prompt: false,
@@ -63,16 +64,50 @@ describe('AISegmentation', () => {
}));
});
it('does not run SAM2 text-only prompts as semantic segmentation', async () => {
it('sends the selected SAM2.1 variant to prediction', async () => {
apiMock.predictMask.mockResolvedValueOnce({ masks: [] });
render(<AISegmentation onSendToWorkspace={vi.fn()} />);
fireEvent.change(screen.getByPlaceholderText("例如:'分割出左侧车道上行驶的所有红色汽车'..."), {
target: { value: '胆囊' },
fireEvent.click(await screen.findByText('small'));
fireEvent.click(screen.getByText('正向选点'));
fireEvent.click(screen.getByTestId('konva-stage'));
fireEvent.click(await screen.findByText('执行高精度语义分割'));
expect(apiMock.getAiModelStatus).toHaveBeenCalledWith('sam2.1_hiera_small');
expect(apiMock.predictMask).toHaveBeenCalledWith(expect.objectContaining({
model: 'sam2.1_hiera_small',
}));
});
it('does not render masks that were created in the workspace', async () => {
useStore.setState({
masks: [
{
id: 'workspace-mask',
frameId: 'frame-1',
pathData: 'M 0 0 L 10 0 L 10 10 Z',
label: 'Manual Mask',
color: '#ff0000',
segmentation: [[0, 0, 10, 0, 10, 10]],
metadata: { source: 'manual' },
},
],
selectedMaskIds: ['workspace-mask'],
});
render(<AISegmentation onSendToWorkspace={vi.fn()} />);
expect(screen.queryAllByTestId('konva-path')).toHaveLength(0);
await waitFor(() => expect(useStore.getState().selectedMaskIds).toEqual([]));
});
it('requires point prompts before running SAM2 inference', async () => {
render(<AISegmentation onSendToWorkspace={vi.fn()} />);
fireEvent.click(await screen.findByText('执行高精度语义分割'));
expect(apiMock.predictMask).not.toHaveBeenCalled();
expect(await screen.findByText('SAM2 不支持文本语义提示;请先放置正/反向点,或切换到 SAM3 使用文本语义。')).toBeInTheDocument();
expect(await screen.findByText('请先放置正/反向提示点。')).toBeInTheDocument();
});
it('keeps only the best SAM2 candidate when the backend returns overlapping alternatives', async () => {
@@ -106,8 +141,116 @@ describe('AISegmentation', () => {
await waitFor(() => expect(useStore.getState().masks).toHaveLength(1));
expect(useStore.getState().masks[0].id).toBe('sam2-best');
expect(useStore.getState().masks[0].metadata).toEqual({ source: 'ai_segmentation' });
expect(useStore.getState().selectedMaskIds).toEqual(['sam2-best']);
expect(await screen.findByText('SAM2 返回 2 个候选,已采用最高分区域。')).toBeInTheDocument();
expect(await screen.findByText('SAM 2.1 Tiny 返回 2 个候选,已采用最高分区域。')).toBeInTheDocument();
});
it('adjusts the AI mask preview opacity without changing mask data', async () => {
apiMock.predictMask.mockResolvedValueOnce({
masks: [
{
id: 'sam2-mask',
pathData: 'M 10 10 L 40 10 L 40 40 Z',
label: 'AI Mask',
color: '#06b6d4',
segmentation: [[10, 10, 40, 10, 40, 40]],
bbox: [10, 10, 30, 30],
area: 900,
},
],
});
render(<AISegmentation onSendToWorkspace={vi.fn()} />);
fireEvent.click(screen.getByText('正向选点'));
fireEvent.click(screen.getByTestId('konva-stage'));
fireEvent.click(await screen.findByText('执行高精度语义分割'));
await waitFor(() => expect(screen.getByTestId('konva-path')).toBeInTheDocument());
const maskGroup = () => screen.getAllByTestId('konva-group').find((group) => group.getAttribute('data-opacity'));
expect(maskGroup()).toHaveAttribute('data-opacity', '0.72');
fireEvent.change(screen.getByLabelText('遮罩清晰度'), { target: { value: '35' } });
expect(maskGroup()).toHaveAttribute('data-opacity', '0.35');
expect(useStore.getState().masks[0].segmentation).toEqual([[10, 10, 40, 10, 40, 40]]);
});
it('lets positive and negative prompt points be added on top of an AI mask', async () => {
apiMock.predictMask
.mockResolvedValueOnce({
masks: [
{
id: 'sam2-mask',
pathData: 'M 10 10 L 40 10 L 40 40 Z',
label: 'AI Mask',
color: '#06b6d4',
segmentation: [[10, 10, 40, 10, 40, 40]],
bbox: [10, 10, 30, 30],
area: 900,
},
],
})
.mockResolvedValueOnce({ masks: [] });
render(<AISegmentation onSendToWorkspace={vi.fn()} />);
fireEvent.click(screen.getByText('正向选点'));
fireEvent.click(screen.getByTestId('konva-stage'), { clientX: 120, clientY: 80 });
fireEvent.click(await screen.findByText('执行高精度语义分割'));
await waitFor(() => expect(screen.getByTestId('konva-path')).toBeInTheDocument());
fireEvent.click(screen.getByText('反向选点'));
fireEvent.click(screen.getByTestId('konva-path'), { clientX: 220, clientY: 140 });
await waitFor(() => expect(screen.getAllByTestId('konva-circle')).toHaveLength(4));
fireEvent.click(screen.getByText('执行高精度语义分割'));
expect(apiMock.predictMask).toHaveBeenLastCalledWith(expect.objectContaining({
points: [
{ x: 120, y: 80, type: 'pos' },
{ x: 220, y: 140, type: 'neg' },
],
}));
expect(useStore.getState().selectedMaskIds).toEqual(['sam2-mask']);
});
it('clears only AI page candidates and keeps workspace masks in the store', async () => {
useStore.setState({
masks: [
{
id: 'workspace-mask',
frameId: 'frame-1',
pathData: 'M 0 0 L 10 0 L 10 10 Z',
label: 'Manual Mask',
color: '#ff0000',
segmentation: [[0, 0, 10, 0, 10, 10]],
metadata: { source: 'manual' },
},
],
});
apiMock.predictMask.mockResolvedValueOnce({
masks: [
{
id: 'sam2-mask',
pathData: 'M 10 10 L 40 10 L 40 40 Z',
label: 'AI Mask',
color: '#06b6d4',
segmentation: [[10, 10, 40, 10, 40, 40]],
bbox: [10, 10, 30, 30],
area: 900,
},
],
});
render(<AISegmentation onSendToWorkspace={vi.fn()} />);
fireEvent.click(screen.getByText('正向选点'));
fireEvent.click(screen.getByTestId('konva-stage'));
await waitFor(() => expect(screen.getAllByTestId('konva-circle')).toHaveLength(2));
fireEvent.click(await screen.findByText('执行高精度语义分割'));
await waitFor(() => expect(useStore.getState().masks.map((mask) => mask.id)).toEqual(['workspace-mask', 'sam2-mask']));
fireEvent.click(screen.getByText('清空全体锚点'));
expect(useStore.getState().masks.map((mask) => mask.id)).toEqual(['workspace-mask']);
expect(useStore.getState().selectedMaskIds).toEqual([]);
});
it('lets a SAM2 result be selected and relabeled from the ontology panel', async () => {
@@ -186,111 +329,4 @@ describe('AISegmentation', () => {
expect(onSendToWorkspace).toHaveBeenCalled();
});
it('prompts for semantic text before running SAM3 inference', async () => {
apiMock.getAiModelStatus.mockResolvedValue({
selected_model: 'sam3',
gpu: { available: true, device: 'cuda', name: 'RTX 4090', torch_available: true },
models: [
{ id: 'sam2', label: 'SAM 2', available: true, loaded: false, device: 'cuda', supports: ['point', 'box'], message: 'SAM 2 ready', package_available: true, checkpoint_exists: true, python_ok: true, torch_ok: true, cuda_required: false },
{ id: 'sam3', label: 'SAM 3', available: true, loaded: true, device: 'cuda', supports: ['semantic'], message: 'SAM 3 ready', package_available: true, checkpoint_exists: true, python_ok: true, torch_ok: true, cuda_required: true },
],
});
render(<AISegmentation onSendToWorkspace={vi.fn()} />);
const sam3Button = (await screen.findByText('SAM3')).closest('button')!;
fireEvent.click(sam3Button);
fireEvent.click(screen.getByText('执行高精度语义分割'));
expect(apiMock.predictMask).not.toHaveBeenCalled();
expect(await screen.findByText('SAM3 当前使用文本语义提示,请先输入要分割的目标描述。')).toBeInTheDocument();
});
it('shows feedback when SAM3 semantic inference returns no masks', async () => {
apiMock.getAiModelStatus.mockResolvedValue({
selected_model: 'sam3',
gpu: { available: true, device: 'cuda', name: 'RTX 4090', torch_available: true },
models: [
{ id: 'sam2', label: 'SAM 2', available: true, loaded: false, device: 'cuda', supports: ['point', 'box'], message: 'SAM 2 ready', package_available: true, checkpoint_exists: true, python_ok: true, torch_ok: true, cuda_required: false },
{ id: 'sam3', label: 'SAM 3', available: true, loaded: true, device: 'cuda', supports: ['semantic'], message: 'SAM 3 ready', package_available: true, checkpoint_exists: true, python_ok: true, torch_ok: true, cuda_required: true },
],
});
apiMock.predictMask.mockResolvedValueOnce({ masks: [] });
render(<AISegmentation onSendToWorkspace={vi.fn()} />);
const sam3Button = (await screen.findByText('SAM3')).closest('button')!;
fireEvent.click(sam3Button);
fireEvent.change(screen.getByPlaceholderText("例如:'分割出左侧车道上行驶的所有红色汽车'..."), {
target: { value: '胆囊' },
});
fireEvent.click(screen.getByText('执行高精度语义分割'));
await waitFor(() => expect(apiMock.predictMask).toHaveBeenCalledWith(expect.objectContaining({
model: 'sam3',
points: undefined,
text: '胆囊',
})));
expect(await screen.findByText('SAM3 已完成语义推理,但没有返回区域。请尝试英文目标描述,或换到包含该目标的帧。当前提示: 胆囊')).toBeInTheDocument();
});
it('runs SAM3 semantic text inference and assigns the active class to returned masks', async () => {
apiMock.getAiModelStatus.mockResolvedValue({
selected_model: 'sam3',
gpu: { available: true, device: 'cuda', name: 'RTX 4090', torch_available: true },
models: [
{ id: 'sam2', label: 'SAM 2', available: true, loaded: false, device: 'cuda', supports: ['point', 'box'], message: 'SAM 2 ready', package_available: true, checkpoint_exists: true, python_ok: true, torch_ok: true, cuda_required: false },
{ id: 'sam3', label: 'SAM 3', available: true, loaded: true, device: 'cuda', supports: ['semantic'], message: 'SAM 3 ready', package_available: true, checkpoint_exists: true, python_ok: true, torch_ok: true, cuda_required: true },
],
});
apiMock.predictMask.mockResolvedValueOnce({
masks: [
{
id: 'semantic-1',
pathData: 'M 10 10 L 40 10 L 40 40 Z',
label: 'semantic result',
color: '#06b6d4',
segmentation: [[10, 10, 40, 10, 40, 40]],
bbox: [10, 10, 30, 30],
area: 900,
},
],
});
useStore.setState({
activeTemplateId: 'template-1',
activeClassId: 'class-1',
activeClass: { id: 'class-1', name: '胆囊', color: '#ff0000', zIndex: 30 },
});
render(<AISegmentation onSendToWorkspace={vi.fn()} />);
const sam3Button = (await screen.findByText('SAM3')).closest('button')!;
fireEvent.click(sam3Button);
fireEvent.change(screen.getByPlaceholderText("例如:'分割出左侧车道上行驶的所有红色汽车'..."), {
target: { value: '胆囊' },
});
fireEvent.click(screen.getByText('执行高精度语义分割'));
await waitFor(() => expect(apiMock.predictMask).toHaveBeenCalledWith(expect.objectContaining({
imageId: 'frame-1',
imageWidth: 640,
imageHeight: 360,
model: 'sam3',
points: undefined,
text: '胆囊',
options: {
crop_to_prompt: false,
auto_filter_background: true,
min_score: 0.05,
},
})));
expect(useStore.getState().masks[0]).toEqual(expect.objectContaining({
id: 'semantic-1',
frameId: 'frame-1',
templateId: 'template-1',
classId: 'class-1',
className: '胆囊',
classZIndex: 30,
label: '胆囊',
color: '#ff0000',
saveStatus: 'draft',
}));
});
});

View File

@@ -4,7 +4,7 @@ import { cn } from '../lib/utils';
import { Stage, Layer, Image as KonvaImage, Circle, Path, Group } from 'react-konva';
import useImage from 'use-image';
import { OntologyInspector } from './OntologyInspector';
import { useStore } from '../store/useStore';
import { SAM2_MODEL_OPTIONS, useStore } from '../store/useStore';
import { getAiModelStatus, predictMask, type AiRuntimeStatus } from '../lib/api';
interface AISegmentationProps {
@@ -16,7 +16,7 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
const setActiveTool = useStore((state) => state.setActiveTool);
const masks = useStore((state) => state.masks);
const addMask = useStore((state) => state.addMask);
const clearMasks = useStore((state) => state.clearMasks);
const setMasks = useStore((state) => state.setMasks);
const selectedMaskIds = useStore((state) => state.selectedMaskIds);
const setSelectedMaskIds = useStore((state) => state.setSelectedMaskIds);
const maskHistory = useStore((state) => state.maskHistory);
@@ -25,17 +25,18 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
const redoMasks = useStore((state) => state.redoMasks);
const frames = useStore((state) => state.frames);
const currentFrameIndex = useStore((state) => state.currentFrameIndex);
const activeTemplateId = useStore((state) => state.activeTemplateId);
const activeClass = useStore((state) => state.activeClass);
const aiModel = useStore((state) => state.aiModel);
const setAiModel = useStore((state) => state.setAiModel);
const activeTemplateId = useStore((state) => state.activeTemplateId);
const activeClass = useStore((state) => state.activeClass);
const [semanticText, setSemanticText] = useState('');
const [modelStatus, setModelStatus] = useState<AiRuntimeStatus | null>(null);
const [autoDeleteBg, setAutoDeleteBg] = useState(true);
const [cropMode, setCropMode] = useState(false);
const [maskOpacity, setMaskOpacity] = useState(72);
const [isInferencing, setIsInferencing] = useState(false);
const [inferenceMessage, setInferenceMessage] = useState('');
const [aiMaskIds, setAiMaskIds] = useState<string[]>([]);
// Canvas state
const [scale, setScale] = useState(1);
@@ -45,7 +46,10 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
const currentFrame = frames[currentFrameIndex] || null;
const previewUrl = currentFrame?.url || 'https://images.unsplash.com/photo-1549317661-bd32c8ce0be2?q=80&w=2070&auto=format&fit=crop';
const [image] = useImage(previewUrl);
const frameMasks = currentFrame ? masks.filter((mask) => mask.frameId === currentFrame.id) : masks;
const aiMaskIdSet = new Set(aiMaskIds);
const frameMasks = currentFrame
? masks.filter((mask) => mask.frameId === currentFrame.id && aiMaskIdSet.has(mask.id))
: masks.filter((mask) => aiMaskIdSet.has(mask.id));
const selectedModelStatus = modelStatus?.models.find((model) => model.id === aiModel);
const modelCanInfer = selectedModelStatus?.available ?? true;
@@ -65,6 +69,16 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
};
}, [aiModel]);
useEffect(() => {
const visibleIds = new Set(frameMasks.map((mask) => mask.id));
const nextSelectedMaskIds = selectedMaskIds.filter((id) => visibleIds.has(id));
const changed = nextSelectedMaskIds.length !== selectedMaskIds.length
|| nextSelectedMaskIds.some((id, index) => id !== selectedMaskIds[index]);
if (changed) {
setSelectedMaskIds(nextSelectedMaskIds);
}
}, [frameMasks, selectedMaskIds, setSelectedMaskIds]);
const handleWheel = (e: any) => {
e.evt.preventDefault();
const scaleBy = 1.1;
@@ -94,17 +108,8 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
};
const runInference = useCallback(async () => {
const textPrompt = semanticText.trim();
if (aiModel === 'sam3' && !textPrompt) {
setInferenceMessage('SAM3 当前使用文本语义提示,请先输入要分割的目标描述。');
return;
}
if (aiModel === 'sam2' && textPrompt && points.length === 0) {
setInferenceMessage('SAM2 不支持文本语义提示;请先放置正/反向点,或切换到 SAM3 使用文本语义。');
return;
}
if (points.length === 0 && !textPrompt) {
setInferenceMessage('请先放置正/反向提示点,或输入语义描述。');
if (points.length === 0) {
setInferenceMessage('请先放置正/反向提示点。');
return;
}
if (!currentFrame?.id) {
@@ -129,8 +134,7 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
imageWidth,
imageHeight,
model: aiModel,
points: aiModel === 'sam3' ? undefined : points.map((p) => ({ x: p.x, y: p.y, type: p.type })),
text: textPrompt || undefined,
points: points.map((p) => ({ x: p.x, y: p.y, type: p.type })),
options: {
crop_to_prompt: cropMode,
auto_filter_background: autoDeleteBg,
@@ -138,15 +142,13 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
},
});
const masksToApply = aiModel === 'sam2' ? result.masks.slice(0, 1) : result.masks;
const masksToApply = result.masks.slice(0, 1);
if (masksToApply.length === 0) {
setInferenceMessage(aiModel === 'sam3'
? `SAM3 已完成语义推理,但没有返回区域。请尝试英文目标描述,或换到包含该目标的帧。当前提示: ${textPrompt}`
: '模型没有返回可用区域,请换一个更具体的描述或调整提示。');
setInferenceMessage('模型没有返回可用区域,请调整提示点后重试。');
} else {
setInferenceMessage(aiModel === 'sam2' && result.masks.length > 1
? `SAM2 返回 ${result.masks.length} 个候选,已采用最高分区域。`
setInferenceMessage(result.masks.length > 1
? `${selectedModelStatus?.label || 'SAM 2.1'} 返回 ${result.masks.length} 个候选,已采用最高分区域。`
: `已生成 ${masksToApply.length} 个候选区域。`);
}
const generatedMaskIds: string[] = [];
@@ -169,9 +171,11 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
segmentation: m.segmentation,
bbox: m.bbox,
area: m.area,
metadata: { source: 'ai_segmentation' },
});
});
if (generatedMaskIds.length > 0) {
setAiMaskIds((existingIds) => [...existingIds, ...generatedMaskIds]);
setSelectedMaskIds(generatedMaskIds);
}
} catch (err) {
@@ -181,17 +185,32 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
} finally {
setIsInferencing(false);
}
}, [activeClass, activeTemplateId, addMask, aiModel, autoDeleteBg, cropMode, currentFrame?.height, currentFrame?.id, currentFrame?.width, image?.height, image?.naturalHeight, image?.naturalWidth, image?.width, points, semanticText, setSelectedMaskIds]);
}, [activeClass, activeTemplateId, addMask, aiModel, autoDeleteBg, cropMode, currentFrame?.height, currentFrame?.id, currentFrame?.width, image?.height, image?.naturalHeight, image?.naturalWidth, image?.width, points, selectedModelStatus?.label, setSelectedMaskIds]);
const clearAiLayer = useCallback(() => {
setPoints([]);
if (aiMaskIds.length === 0) return;
const idsToRemove = new Set(aiMaskIds);
setMasks(masks.filter((mask) => !idsToRemove.has(mask.id)));
setSelectedMaskIds(selectedMaskIds.filter((id) => !idsToRemove.has(id)));
setAiMaskIds([]);
}, [aiMaskIds, masks, selectedMaskIds, setMasks, setSelectedMaskIds]);
const addPromptPointFromEvent = useCallback((event: any) => {
if (effectiveTool !== 'point_pos' && effectiveTool !== 'point_neg') return false;
const stage = event.target?.getStage?.();
const pos = stage?.getRelativePointerPosition?.();
if (!pos) return false;
setPoints((currentPoints) => [
...currentPoints,
{ x: pos.x, y: pos.y, type: effectiveTool === 'point_pos' ? 'pos' : 'neg' },
]);
return true;
}, [effectiveTool]);
const handleStageClick = (e: any) => {
if (effectiveTool === 'move') return;
if (effectiveTool === 'point_pos' || effectiveTool === 'point_neg') {
const stage = e.target.getStage();
const pos = stage.getRelativePointerPosition();
if (pos) {
setPoints([...points, { x: pos.x, y: pos.y, type: effectiveTool === 'point_pos' ? 'pos' : 'neg' }]);
}
}
addPromptPointFromEvent(e);
};
return (
@@ -206,24 +225,39 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
</div>
<div className="flex-1 overflow-y-auto p-6 flex flex-col gap-8">
{/* Model Select */}
{/* Model Status */}
<div>
<h3 className="text-[10px] font-bold text-gray-500 uppercase tracking-widest mb-3"></h3>
<div className="bg-[#111] border border-white/5 grid grid-cols-2 gap-1 p-1 rounded-lg">
{(modelStatus?.models || [
{ id: 'sam2' as const, label: 'SAM 2', available: true, message: '正在读取 SAM 2 状态' },
{ id: 'sam3' as const, label: 'SAM 3', available: false, message: '正在读取 SAM 3 状态' },
]).map((m) => (
<button
key={m.id}
className={cn("text-xs py-2 rounded-md transition-colors text-center uppercase tracking-wider font-mono", aiModel === m.id ? "bg-white/10 text-white font-medium shadow-sm" : "text-gray-500 hover:text-gray-300 hover:bg-white/5")}
onClick={() => setAiModel(m.id)}
title={m.message}
>
{m.label.replace(' ', '')}
<span className={cn("ml-1", m.available ? "text-emerald-400" : "text-amber-400")}></span>
</button>
))}
<h3 className="text-[10px] font-bold text-gray-500 uppercase tracking-widest mb-3"></h3>
<div className="bg-[#111] border border-white/5 p-3 rounded-lg">
<div className="flex items-center justify-between">
<span className="text-xs uppercase tracking-wider font-mono text-white">{selectedModelStatus?.label || 'SAM 2.1'}</span>
<span className={cn("text-xs", modelCanInfer ? "text-emerald-400" : "text-amber-400")}>
{modelCanInfer ? '可用' : '不可用'}
</span>
</div>
<div className="mt-3 grid grid-cols-2 gap-2">
{SAM2_MODEL_OPTIONS.map((option) => {
const status = modelStatus?.models.find((model) => model.id === option.id);
const available = status?.available ?? false;
const selected = aiModel === option.id;
return (
<button
key={option.id}
type="button"
onClick={() => setAiModel(option.id)}
className={cn(
"h-8 rounded border px-2 text-[10px] uppercase tracking-wider transition-colors flex items-center justify-between",
selected
? "bg-cyan-500/10 border-cyan-400/40 text-cyan-300"
: "bg-white/[0.03] border-white/5 text-gray-400 hover:bg-white/5 hover:text-gray-200"
)}
>
<span>{option.shortLabel}</span>
<span className={cn("h-1.5 w-1.5 rounded-full", available ? "bg-emerald-400" : "bg-amber-400")} />
</button>
);
})}
</div>
</div>
<div className="mt-2 text-[10px] text-gray-500 leading-relaxed">
<div>{selectedModelStatus?.message || '正在读取模型状态...'}</div>
@@ -269,20 +303,6 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
</div>
</div>
{/* Semantic Description */}
<div>
<div className="flex justify-between items-center mb-3">
<h3 className="text-[10px] font-bold text-gray-500 uppercase tracking-widest"></h3>
<span className="text-[9px] bg-cyan-500/10 text-cyan-400 px-1.5 py-0.5 rounded border border-cyan-500/20 font-mono"></span>
</div>
<textarea
value={semanticText}
onChange={e => setSemanticText(e.target.value)}
placeholder="例如:'分割出左侧车道上行驶的所有红色汽车'..."
className="w-full bg-[#111] border border-white/5 rounded-lg p-3 text-sm text-white placeholder-gray-600 focus:outline-none focus:border-cyan-500/50 focus:ring-1 focus:ring-cyan-500/50 transition-all font-sans min-h-[100px] resize-none hover:border-white/10"
/>
</div>
{/* Parameters */}
<div>
<h3 className="text-[10px] font-bold text-gray-500 uppercase tracking-widest mb-3 flex items-center gap-2"></h3>
@@ -300,6 +320,23 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
<div className={cn("absolute top-0.5 left-0.5 w-3 h-3 bg-white rounded-full transition-transform shadow-sm", autoDeleteBg ? "translate-x-4" : "")} />
</button>
</div>
<div className="space-y-2">
<div className="flex items-center justify-between">
<label htmlFor="ai-mask-opacity" className="text-[11px] text-gray-400 uppercase tracking-wider font-medium"></label>
<span className="text-[10px] font-mono text-cyan-400">{maskOpacity}%</span>
</div>
<input
id="ai-mask-opacity"
type="range"
min="20"
max="100"
step="5"
value={maskOpacity}
onChange={(event) => setMaskOpacity(Number(event.target.value))}
className="w-full accent-cyan-400"
/>
</div>
</div>
</div>
</div>
@@ -340,7 +377,7 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
<header className="h-16 border-b border-white/5 bg-[#111] flex items-center justify-between px-6 shrink-0">
<div className="flex flex-col">
<h2 className="text-sm font-semibold tracking-wide text-white"> (Visualizer)</h2>
<span className="text-[10px] text-gray-500 uppercase tracking-widest font-mono">{aiModel.toUpperCase()} </span>
<span className="text-[10px] text-gray-500 uppercase tracking-widest font-mono">SAM 2.1 </span>
</div>
<div className="flex items-center gap-4">
<button
@@ -363,7 +400,7 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
<button className="flex items-center gap-2 text-xs text-gray-400 hover:text-white transition-colors bg-white/5 hover:bg-white/10 px-3 py-1.5 rounded-md border border-white/5">
<ImageIcon size={14} />
</button>
<button className="text-xs text-gray-400 hover:text-white transition-colors px-3 py-1.5" onClick={() => { setPoints([]); clearMasks(); }}>
<button className="text-xs text-gray-400 hover:text-white transition-colors px-3 py-1.5" onClick={clearAiLayer}>
</button>
</div>
@@ -395,24 +432,38 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
)}
{/* AI Returned Masks */}
{frameMasks.map((mask) => (
<Group key={mask.id} opacity={selectedMaskIds.includes(mask.id) ? 0.72 : 0.45}>
{frameMasks.map((mask) => {
const isSelected = selectedMaskIds.includes(mask.id);
const previewOpacity = isSelected
? maskOpacity / 100
: Math.max(0.18, (maskOpacity / 100) * 0.62);
return (
<Group key={mask.id} opacity={previewOpacity}>
<Path
data={mask.pathData}
fill={mask.color}
stroke={mask.color}
strokeWidth={(selectedMaskIds.includes(mask.id) ? 2.5 : 1) / scale}
strokeWidth={(isSelected ? 2.5 : 1) / scale}
onClick={(event: any) => {
if (addPromptPointFromEvent(event)) {
event.cancelBubble = true;
return;
}
event.cancelBubble = true;
setSelectedMaskIds([mask.id]);
}}
onTap={(event: any) => {
if (addPromptPointFromEvent(event)) {
event.cancelBubble = true;
return;
}
event.cancelBubble = true;
setSelectedMaskIds([mask.id]);
}}
/>
</Group>
))}
);
})}
{/* Points */}
{points.map((p, i) => (

View File

@@ -47,7 +47,7 @@ describe('CanvasArea', () => {
imageId: 'frame-1',
imageWidth: 640,
imageHeight: 360,
model: 'sam2',
model: 'sam2.1_hiera_tiny',
points: [{ x: 120, y: 80, type: 'pos' }],
box: undefined,
}));
@@ -65,55 +65,6 @@ describe('CanvasArea', () => {
}));
});
it('explains that SAM3 point prompts are not supported in the workspace', async () => {
useStore.setState({ aiModel: 'sam3' });
render(<CanvasArea activeTool="point_pos" frame={frame} />);
fireEvent.click(screen.getByTestId('konva-stage'));
expect(apiMock.predictMask).not.toHaveBeenCalled();
expect(await screen.findByText(/SAM3 当前工作区只支持框选提示/)).toBeInTheDocument();
});
it('calls SAM3 prediction with a box prompt from the workspace', async () => {
useStore.setState({ aiModel: 'sam3' });
apiMock.predictMask.mockResolvedValueOnce({
masks: [
{
id: 'sam3-box-mask',
pathData: 'M 20 20 L 80 20 L 80 80 Z',
label: 'AI Mask',
color: '#06b6d4',
segmentation: [[20, 20, 80, 20, 80, 80]],
bbox: [20, 20, 60, 60],
area: 3600,
},
],
});
render(<CanvasArea activeTool="box_select" frame={frame} />);
const stage = screen.getByTestId('konva-stage');
fireEvent.mouseDown(stage, { clientX: 120, clientY: 80 });
fireEvent.mouseMove(stage, { clientX: 260, clientY: 200 });
fireEvent.mouseUp(stage, { clientX: 260, clientY: 200 });
await waitFor(() => expect(apiMock.predictMask).toHaveBeenCalledWith({
imageId: 'frame-1',
imageWidth: 640,
imageHeight: 360,
model: 'sam3',
points: undefined,
box: { x1: 120, y1: 80, x2: 260, y2: 200 },
}));
expect(useStore.getState().masks[0]).toEqual(expect.objectContaining({
id: 'sam3-box-mask',
metadata: expect.objectContaining({
source: 'sam3_box',
promptBox: { x1: 120, y1: 80, x2: 260, y2: 200 },
}),
}));
});
it('refines one SAM2 candidate mask from an initial box with positive and negative points', async () => {
apiMock.predictMask
.mockResolvedValueOnce({
@@ -166,7 +117,7 @@ describe('CanvasArea', () => {
imageId: 'frame-1',
imageWidth: 640,
imageHeight: 360,
model: 'sam2',
model: 'sam2.1_hiera_tiny',
points: undefined,
box: { x1: 120, y1: 80, x2: 260, y2: 200 },
}));
@@ -179,7 +130,7 @@ describe('CanvasArea', () => {
imageId: 'frame-1',
imageWidth: 640,
imageHeight: 360,
model: 'sam2',
model: 'sam2.1_hiera_tiny',
points: [{ x: 150, y: 100, type: 'pos' }],
box: { x1: 120, y1: 80, x2: 260, y2: 200 },
}));
@@ -200,7 +151,7 @@ describe('CanvasArea', () => {
imageId: 'frame-1',
imageWidth: 640,
imageHeight: 360,
model: 'sam2',
model: 'sam2.1_hiera_tiny',
points: [
{ x: 150, y: 100, type: 'pos' },
{ x: 300, y: 150, type: 'neg' },
@@ -249,7 +200,7 @@ describe('CanvasArea', () => {
imageId: 'frame-1',
imageWidth: 640,
imageHeight: 360,
model: 'sam2',
model: 'sam2.1_hiera_tiny',
points: [{ x: 180, y: 120, type: 'neg' }],
box: { x1: 120, y1: 80, x2: 260, y2: 200 },
options: { auto_filter_background: true, min_score: 0.05 },

View File

@@ -326,12 +326,35 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
}, [frame?.id]);
useEffect(() => {
setGlobalSelectedMaskIds(selectedMaskIds);
}, [selectedMaskIds, setGlobalSelectedMaskIds]);
const currentGlobalSelectedIds = useStore.getState().selectedMaskIds;
if (selectedMaskIds.length === 0) {
const validGlobalSelectedIds = currentGlobalSelectedIds.filter((id) => (
frameMasks.some((mask) => mask.id === id)
));
if (validGlobalSelectedIds.length > 0) return;
}
const isSameSelection = currentGlobalSelectedIds.length === selectedMaskIds.length
&& currentGlobalSelectedIds.every((id, index) => id === selectedMaskIds[index]);
if (!isSameSelection) {
setGlobalSelectedMaskIds(selectedMaskIds);
}
}, [frameMasks, selectedMaskIds, setGlobalSelectedMaskIds]);
useEffect(() => () => setGlobalSelectedMaskIds([]), [setGlobalSelectedMaskIds]);
useEffect(() => {
if (!selectedMaskId) {
const validGlobalSelectedIds = useStore.getState().selectedMaskIds.filter((id) => (
frameMasks.some((mask) => mask.id === id)
));
if (validGlobalSelectedIds.length > 0) {
setSelectedMaskId(validGlobalSelectedIds[0]);
setSelectedMaskIds(validGlobalSelectedIds);
setSelectedPolygonIndex(0);
setSelectedVertexIndex(null);
return;
}
}
if (selectedMaskId && !frameMasks.some((mask) => mask.id === selectedMaskId)) {
setSelectedMaskId(null);
setSelectedMaskIds([]);
@@ -444,11 +467,6 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
setInferenceMessage('请先选择一帧图像。');
return;
}
if (aiModel === 'sam3' && (!promptBox || (promptPoints?.length ?? 0) > 0)) {
setInferenceMessage('SAM3 当前工作区只支持框选提示;正/反点修正请切回 SAM2。');
return;
}
const imageWidth = frame.width || image?.naturalWidth || image?.width || 0;
const imageHeight = frame.height || image?.naturalHeight || image?.height || 0;
if (imageWidth <= 0 || imageHeight <= 0) {
@@ -482,7 +500,7 @@ export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnota
const color = activeClass?.color || existingCandidate?.color || m.color;
const metadata = {
...(existingCandidate?.metadata || {}),
source: aiModel === 'sam3' ? 'sam3_box' : 'sam2_interactive',
source: 'sam2_interactive',
promptBox: promptBox || null,
promptPointCount: promptPoints?.length || 0,
promptNegativePointCount: promptPoints?.filter((point) => point.type === 'neg').length || 0,

View File

@@ -97,6 +97,42 @@ describe('Dashboard', () => {
expect(screen.queryByText('City_Driving_Dataset_004.mp4')).not.toBeInTheDocument();
});
it('keeps a recently completed task visible in the progress panel', async () => {
apiMock.getDashboardOverview.mockResolvedValueOnce({
summary: {
project_count: 1,
parsing_task_count: 0,
annotation_count: 0,
frame_count: 120,
template_count: 1,
system_load_percent: 8,
},
tasks: [
{
id: 'task-20',
task_id: 20,
project_id: 1,
name: 'completed.mp4',
progress: 100,
status: '解析完成',
raw_status: 'success',
error: null,
frame_count: 120,
updated_at: '2026-05-01T00:00:00Z',
},
],
activity: [],
});
render(<Dashboard />);
expect(await screen.findByText('任务进度 (当前 / 最近)')).toBeInTheDocument();
expect(screen.getByText('completed.mp4')).toBeInTheDocument();
expect(screen.getByText('100%')).toBeInTheDocument();
expect(screen.getByText('解析完成')).toBeInTheDocument();
expect(screen.queryByText(/当前无处理任务/)).not.toBeInTheDocument();
});
it('connects to the progress stream and updates progress tasks', async () => {
render(<Dashboard />);

View File

@@ -312,7 +312,7 @@ export function Dashboard() {
<div className="grid grid-cols-1 lg:grid-cols-3 gap-6">
<div className="lg:col-span-2 bg-[#111] border border-white/5 rounded-xl p-6 min-h-[400px]">
<h2 className="text-sm font-medium text-gray-400 uppercase tracking-widest mb-6"> ()</h2>
<h2 className="text-sm font-medium text-gray-400 uppercase tracking-widest mb-6"> ( / )</h2>
<div className="space-y-4">
{isLoading && (
<div className="text-sm text-gray-500 text-center py-12"> Dashboard ...</div>
@@ -371,7 +371,7 @@ export function Dashboard() {
</div>
))}
{!isLoading && tasks.length === 0 && (
<div className="text-sm text-gray-500 text-center py-12"></div>
<div className="text-sm text-gray-500 text-center py-12"></div>
)}
</div>
</div>

View File

@@ -51,6 +51,28 @@ describe('FrameTimeline', () => {
expect(screen.getAllByText('00:00.20').length).toBeGreaterThan(0);
});
it('marks edited frames between the time progress bar and frame navigator', () => {
useStore.setState({
frames: [
{ id: 'f1', projectId: 'p1', index: 0, url: '/1.jpg', width: 640, height: 360 },
{ id: 'f2', projectId: 'p1', index: 1, url: '/2.jpg', width: 640, height: 360 },
{ id: 'f3', projectId: 'p1', index: 2, url: '/3.jpg', width: 640, height: 360 },
],
masks: [
{ id: 'm1', frameId: 'f2', pathData: 'M 0 0 Z', label: 'Draft', color: '#06b6d4' },
{ id: 'm2', frameId: 'f3', annotationId: '9', pathData: 'M 0 0 Z', label: 'Saved', color: '#22c55e' },
{ id: 'outside', frameId: 'other-frame', pathData: 'M 0 0 Z', label: 'Other', color: '#fff' },
],
});
render(<FrameTimeline />);
expect(screen.getByText('已编辑')).toBeInTheDocument();
expect(screen.getByText('2 帧')).toBeInTheDocument();
fireEvent.click(screen.getByLabelText('跳转到已编辑帧 3'));
expect(useStore.getState().currentFrameIndex).toBe(2);
});
it('changes frames with left and right arrow keys without leaving bounds', () => {
useStore.setState({
currentFrameIndex: 1,

View File

@@ -7,6 +7,7 @@ export function FrameTimeline() {
const frames = useStore((state) => state.frames);
const currentProject = useStore((state) => state.currentProject);
const currentFrameIndex = useStore((state) => state.currentFrameIndex);
const masks = useStore((state) => state.masks);
const setCurrentFrame = useStore((state) => state.setCurrentFrame);
const [isPlaying, setIsPlaying] = useState(false);
@@ -22,6 +23,17 @@ export function FrameTimeline() {
}, [currentProject?.original_fps, currentProject?.parse_fps]);
const currentSeconds = totalFrames > 0 ? currentFrameIndex / timeBaseFps : 0;
const totalSeconds = totalFrames > 0 ? Math.max(totalFrames - 1, 0) / timeBaseFps : 0;
const editedFrameMarkers = useMemo(() => {
const frameIds = new Set(frames.map((frame) => frame.id));
const editedIds = new Set(
masks
.filter((mask) => frameIds.has(mask.frameId))
.map((mask) => mask.frameId),
);
return frames
.map((frame, index) => ({ frame, index }))
.filter(({ frame }) => editedIds.has(frame.id));
}, [frames, masks]);
const formatTime = (seconds: number) => {
const safeSeconds = Math.max(0, seconds);
@@ -83,7 +95,7 @@ export function FrameTimeline() {
: [];
return (
<div className="h-32 bg-[#111] border-t border-white/5 flex flex-col shrink-0 z-20">
<div className="h-36 bg-[#111] border-t border-white/5 flex flex-col shrink-0 z-20">
<div className="h-4 bg-[#0d0d0d] flex items-center group relative">
<div className="absolute left-3 -top-5 text-[10px] font-mono text-gray-500 pointer-events-none">
{formatTime(currentSeconds)}
@@ -117,6 +129,34 @@ export function FrameTimeline() {
</div>
</div>
</div>
<div className="h-5 bg-[#0f0f0f] border-y border-white/[0.03] px-4 flex items-center gap-3">
<div className="w-20 text-[9px] font-mono uppercase tracking-widest text-gray-500 shrink-0"></div>
<div className="relative h-3 flex-1">
<div className="absolute left-0 right-0 top-1/2 h-px -translate-y-1/2 bg-white/5" />
{editedFrameMarkers.map(({ frame, index }) => {
const isCurrent = index === currentFrameIndex;
const left = totalFrames > 0 ? ((index + 1) / totalFrames) * 100 : 0;
return (
<button
key={frame.id}
type="button"
aria-label={`跳转到已编辑帧 ${index + 1}`}
title={`已编辑帧 ${index + 1}`}
onClick={() => setCurrentFrame(index)}
className={cn(
"absolute top-1/2 -translate-x-1/2 -translate-y-1/2 rounded-full border transition-all",
isCurrent
? "h-3 w-3 bg-cyan-300 border-cyan-100 shadow-[0_0_12px_rgba(34,211,238,0.65)]"
: "h-2 w-2 bg-amber-300 border-amber-100/80 hover:h-3 hover:w-3 hover:bg-cyan-300 hover:border-cyan-100"
)}
style={{ left: `${left}%` }}
/>
);
})}
</div>
<div className="w-20 text-right text-[9px] font-mono text-gray-500 shrink-0">{editedFrameMarkers.length} </div>
</div>
<div className="flex-1 flex items-center px-4 gap-6">
<div className="flex flex-col items-center gap-2 px-4 border-r border-white/10 shrink-0">

View File

@@ -17,11 +17,10 @@ describe('ModelStatusBadge', () => {
resetStore();
vi.clearAllMocks();
apiMock.getAiModelStatus.mockResolvedValue({
selected_model: 'sam2',
selected_model: 'sam2.1_hiera_tiny',
gpu: { available: true, device: 'cuda', name: 'RTX 4090', torch_available: true },
models: [
{ id: 'sam2', label: 'SAM 2', available: true, loaded: false, device: 'cuda', supports: ['point', 'box'], message: 'SAM 2 ready', package_available: true, checkpoint_exists: true, python_ok: true, torch_ok: true, cuda_required: false },
{ id: 'sam3', label: 'SAM 3', available: false, loaded: false, device: 'cuda', supports: ['semantic'], message: 'SAM 3 missing runtime', package_available: false, checkpoint_exists: false, python_ok: false, torch_ok: true, cuda_required: true },
{ id: 'sam2.1_hiera_tiny', label: 'SAM 2.1 Tiny', available: true, loaded: false, device: 'cuda', supports: ['point', 'box'], message: 'SAM 2.1 Tiny ready', package_available: true, checkpoint_exists: true, python_ok: true, torch_ok: true, cuda_required: false },
],
});
});
@@ -29,17 +28,14 @@ describe('ModelStatusBadge', () => {
it('loads real model status for the selected model', async () => {
render(<ModelStatusBadge />);
expect(await screen.findByText('SAM 2 可用')).toBeInTheDocument();
expect(apiMock.getAiModelStatus).toHaveBeenCalledWith('sam2');
expect(await screen.findByText('SAM 2.1 Tiny 可用')).toBeInTheDocument();
expect(apiMock.getAiModelStatus).toHaveBeenCalledWith('sam2.1_hiera_tiny');
});
it('shows unavailable state when SAM3 is selected but not runnable', async () => {
useStore.getState().setAiModel('sam3');
it('does not expose disabled SAM3 status in the badge', async () => {
render(<ModelStatusBadge />);
await waitFor(() => expect(apiMock.getAiModelStatus).toHaveBeenCalledWith('sam3'));
expect(await screen.findByText('SAM 3 不可用')).toBeInTheDocument();
expect(screen.getByTitle('SAM 3 missing runtime')).toBeInTheDocument();
await waitFor(() => expect(apiMock.getAiModelStatus).toHaveBeenCalledWith('sam2.1_hiera_tiny'));
expect(screen.queryByText(/SAM 3/)).not.toBeInTheDocument();
});
});

View File

@@ -50,7 +50,7 @@ describe('VideoWorkspace', () => {
apiMock.annotationToMask.mockReturnValue(null);
apiMock.getTask.mockResolvedValue({ id: 1, status: 'success', progress: 100, message: '解析完成' });
apiMock.propagateMasks.mockResolvedValue({
model: 'sam2',
model: 'sam2.1_hiera_tiny',
direction: 'forward',
source_frame_id: 10,
processed_frame_count: 3,
@@ -58,11 +58,10 @@ describe('VideoWorkspace', () => {
annotations: [],
});
apiMock.getAiModelStatus.mockResolvedValue({
selected_model: 'sam2',
selected_model: 'sam2.1_hiera_tiny',
gpu: { available: false, device: 'cpu', name: null, torch_available: true },
models: [
{ id: 'sam2', label: 'SAM 2', available: true, loaded: false, device: 'cpu', supports: [], message: 'ready', package_available: true, checkpoint_exists: true, python_ok: true, torch_ok: true, cuda_required: false },
{ id: 'sam3', label: 'SAM 3', available: false, loaded: false, device: 'unavailable', supports: [], message: 'missing', package_available: false, checkpoint_exists: false, python_ok: false, torch_ok: true, cuda_required: true },
{ id: 'sam2.1_hiera_tiny', label: 'SAM 2.1 Tiny', available: true, loaded: false, device: 'cpu', supports: [], message: 'ready', package_available: true, checkpoint_exists: true, python_ok: true, torch_ok: true, cuda_required: false },
],
});
});
@@ -116,12 +115,65 @@ describe('VideoWorkspace', () => {
]));
});
it('preserves unsaved AI masks when hydrating saved annotations after entering the workspace', async () => {
apiMock.getProjectFrames.mockResolvedValueOnce([
{ id: 10, project_id: 1, frame_index: 0, image_url: '/frame.jpg', width: 640, height: 360 },
]);
apiMock.getProjectAnnotations.mockResolvedValueOnce([{ id: 99, frame_id: 10 }]);
apiMock.annotationToMask.mockReturnValueOnce({
id: 'annotation-99',
annotationId: '99',
frameId: '10',
saved: true,
pathData: 'M 0 0 Z',
label: 'Saved',
color: '#06b6d4',
});
useStore.setState({
activeTool: 'edit_polygon',
selectedMaskIds: ['ai-mask'],
masks: [{
id: 'ai-mask',
frameId: '10',
pathData: 'M 10 10 L 40 10 L 40 40 Z',
label: 'AI Mask',
color: '#06b6d4',
segmentation: [[10, 10, 40, 10, 40, 40]],
saveStatus: 'draft',
saved: false,
metadata: { source: 'ai_segmentation' },
}],
});
render(<VideoWorkspace />);
await waitFor(() => expect(useStore.getState().masks.map((mask) => mask.id)).toEqual([
'ai-mask',
'annotation-99',
]));
expect(useStore.getState().selectedMaskIds).toEqual(['ai-mask']);
expect(useStore.getState().activeTool).toBe('edit_polygon');
});
it('saves pending masks through the archive button', async () => {
apiMock.getProjectFrames.mockResolvedValueOnce([
{ id: 10, project_id: 1, frame_index: 0, image_url: '/frame.jpg', width: 640, height: 360 },
]);
apiMock.getProjectAnnotations
.mockResolvedValueOnce([])
.mockResolvedValueOnce([{ id: 5, frame_id: 10 }]);
apiMock.buildAnnotationPayload.mockReturnValueOnce({ project_id: 1, frame_id: 10, mask_data: { polygons: [] } });
apiMock.saveAnnotation.mockResolvedValueOnce({ id: 5 });
apiMock.annotationToMask.mockReturnValueOnce({
id: 'annotation-5',
annotationId: '5',
frameId: '10',
saved: true,
saveStatus: 'saved',
pathData: 'M 0 0 Z',
label: 'Saved AI Mask',
color: '#06b6d4',
});
render(<VideoWorkspace />);
await waitFor(() => expect(useStore.getState().frames).toHaveLength(1));
@@ -153,6 +205,10 @@ describe('VideoWorkspace', () => {
expect.objectContaining({ id: '10' }),
'2',
);
await waitFor(() => expect(useStore.getState().masks).toEqual([
expect.objectContaining({ id: 'annotation-5', saved: true, saveStatus: 'saved' }),
]));
expect(useStore.getState().masks.some((mask) => mask.id === 'mask-1')).toBe(false);
});
it('updates dirty saved masks through the archive button', async () => {
@@ -346,7 +402,7 @@ describe('VideoWorkspace', () => {
await waitFor(() => expect(useStore.getState().frames).toHaveLength(2));
act(() => {
useStore.setState({
aiModel: 'sam2',
aiModel: 'sam2.1_hiera_tiny',
activeTemplateId: '2',
selectedMaskIds: ['mask-1'],
masks: [{
@@ -366,7 +422,7 @@ describe('VideoWorkspace', () => {
await waitFor(() => expect(apiMock.propagateMasks).toHaveBeenCalledWith({
project_id: 1,
frame_id: 10,
model: 'sam2',
model: 'sam2.1_hiera_tiny',
direction: 'forward',
max_frames: 30,
include_source: false,

View File

@@ -34,9 +34,14 @@ export function VideoWorkspace({ onNavigateToAI }: { onNavigateToAI?: () => void
const activeTemplateId = useStore((state) => state.activeTemplateId);
const aiModel = useStore((state) => state.aiModel);
const selectedMaskIds = useStore((state) => state.selectedMaskIds);
const latestSelectedMaskIdsRef = React.useRef<string[]>(selectedMaskIds);
if (selectedMaskIds.length > 0) {
latestSelectedMaskIdsRef.current = selectedMaskIds;
}
const setFrames = useStore((state) => state.setFrames);
const setCurrentFrame = useStore((state) => state.setCurrentFrame);
const setMasks = useStore((state) => state.setMasks);
const setSelectedMaskIds = useStore((state) => state.setSelectedMaskIds);
const undoMasks = useStore((state) => state.undoMasks);
const redoMasks = useStore((state) => state.redoMasks);
const [isSaving, setIsSaving] = useState(false);
@@ -45,8 +50,15 @@ export function VideoWorkspace({ onNavigateToAI }: { onNavigateToAI?: () => void
const [isPropagating, setIsPropagating] = useState(false);
const [statusMessage, setStatusMessage] = useState('');
const hydrateSavedAnnotations = useCallback(async (projectId: string, projectFrames: Frame[]) => {
const hydrateSavedAnnotations = useCallback(async (
projectId: string,
projectFrames: Frame[],
preserveSelectedIds: string[] = [],
excludeUnsavedMaskIds: string[] = [],
) => {
const frameById = new Map(projectFrames.map((frame) => [frame.id, frame]));
const projectFrameIds = new Set(projectFrames.map((frame) => frame.id));
const excludedDraftIds = new Set(excludeUnsavedMaskIds);
const annotations = await getProjectAnnotations(projectId);
const savedMasks = annotations
.map((annotation) => {
@@ -54,14 +66,27 @@ export function VideoWorkspace({ onNavigateToAI }: { onNavigateToAI?: () => void
return frame ? annotationToMask(annotation, frame) : null;
})
.filter((mask): mask is NonNullable<typeof mask> => Boolean(mask));
setMasks(savedMasks);
}, [setMasks]);
const currentMasks = useStore.getState().masks;
const unsavedMasks = currentMasks.filter((mask) => (
!projectFrameIds.has(mask.frameId) || (!mask.annotationId && !excludedDraftIds.has(mask.id))
));
const mergedMasks = [...unsavedMasks, ...savedMasks];
setMasks(mergedMasks);
if (preserveSelectedIds.length > 0) {
const mergedMaskIds = new Set(mergedMasks.map((mask) => mask.id));
const nextSelectedIds = preserveSelectedIds.filter((id) => mergedMaskIds.has(id));
if (nextSelectedIds.length > 0) {
setSelectedMaskIds(nextSelectedIds);
}
}
}, [setMasks, setSelectedMaskIds]);
useEffect(() => {
if (!currentProject?.id) return;
let cancelled = false;
const loadFrames = async () => {
const selectedIdsBeforeLoad = latestSelectedMaskIdsRef.current;
try {
const data = await getProjectFrames(String(currentProject.id));
if (cancelled) return;
@@ -90,7 +115,7 @@ export function VideoWorkspace({ onNavigateToAI }: { onNavigateToAI?: () => void
return;
}
setStatusMessage('');
await hydrateSavedAnnotations(String(currentProject.id), mappedFrames);
await hydrateSavedAnnotations(String(currentProject.id), mappedFrames, selectedIdsBeforeLoad);
} catch (err) {
console.error('Failed to load frames:', err);
}
@@ -126,12 +151,13 @@ export function VideoWorkspace({ onNavigateToAI }: { onNavigateToAI?: () => void
setIsSaving(true);
setStatusMessage('正在保存标注...');
try {
const createPayloads = pendingMasks
const createItems = pendingMasks
.map((mask) => {
const frame = frameById.get(mask.frameId);
return frame ? buildAnnotationPayload(currentProject.id, mask, frame, activeTemplateId) : null;
const payload = frame ? buildAnnotationPayload(currentProject.id, mask, frame, activeTemplateId) : null;
return payload ? { maskId: mask.id, payload } : null;
})
.filter((payload): payload is NonNullable<typeof payload> => Boolean(payload));
.filter((item): item is NonNullable<typeof item> => Boolean(item));
const updatePayloads = dirtyMasks
.map((mask) => {
@@ -148,17 +174,22 @@ export function VideoWorkspace({ onNavigateToAI }: { onNavigateToAI?: () => void
})
.filter((item): item is NonNullable<typeof item> => Boolean(item));
if (createPayloads.length === 0 && updatePayloads.length === 0) {
if (createItems.length === 0 && updatePayloads.length === 0) {
setStatusMessage('没有可保存的标注数据');
return 0;
}
await Promise.all([
...createPayloads.map((payload) => saveAnnotation(payload)),
...createItems.map(({ payload }) => saveAnnotation(payload)),
...updatePayloads.map(({ annotationId, payload }) => updateAnnotation(annotationId, payload)),
]);
await hydrateSavedAnnotations(currentProject.id, frames);
const savedCount = createPayloads.length + updatePayloads.length;
await hydrateSavedAnnotations(
currentProject.id,
frames,
useStore.getState().selectedMaskIds,
createItems.map(({ maskId }) => maskId),
);
const savedCount = createItems.length + updatePayloads.length;
setStatusMessage(`已保存 ${savedCount} 个标注`);
return savedCount;
} catch (err) {

View File

@@ -224,7 +224,7 @@ describe('api client contracts', () => {
axiosMock.client.post.mockResolvedValueOnce({
data: {
model: 'sam2',
model: 'sam2.1_hiera_tiny',
direction: 'forward',
source_frame_id: 5,
processed_frame_count: 3,
@@ -235,7 +235,7 @@ describe('api client contracts', () => {
await expect(propagateMasks({
project_id: 9,
frame_id: 5,
model: 'sam2',
model: 'sam2.1_hiera_tiny',
seed: {
polygons: [[[0, 0], [1, 0], [1, 1]]],
label: 'mask',
@@ -247,7 +247,7 @@ describe('api client contracts', () => {
expect(axiosMock.client.post).toHaveBeenCalledWith('/api/ai/propagate', {
project_id: 9,
frame_id: 5,
model: 'sam2',
model: 'sam2.1_hiera_tiny',
seed: {
polygons: [[[0, 0], [1, 0], [1, 1]]],
label: 'mask',
@@ -384,7 +384,7 @@ describe('api client contracts', () => {
points: [[0.5, 0.5], [0.1, 0.1]],
labels: [1, 0],
},
model: 'sam2',
model: 'sam2.1_hiera_tiny',
});
expect(result.masks[0]).toEqual(expect.objectContaining({
pathData: 'M 100 50 L 300 50 L 300 150 L 100 150 Z',
@@ -410,7 +410,7 @@ describe('api client contracts', () => {
image_id: 5,
prompt_type: 'box',
prompt_data: [0.1, 0.1, 0.5, 0.5],
model: 'sam2',
model: 'sam2.1_hiera_tiny',
});
});
@@ -437,11 +437,11 @@ describe('api client contracts', () => {
points: [[0.2, 0.2], [0.4, 0.4]],
labels: [1, 0],
},
model: 'sam2',
model: 'sam2.1_hiera_tiny',
});
});
it('uses semantic prompt type for text-only AI prediction', async () => {
it('serializes text-only prediction as semantic when called directly', async () => {
const { predictMask } = await import('./api');
axiosMock.client.post.mockResolvedValueOnce({ data: { polygons: [], scores: [] } });
@@ -449,7 +449,6 @@ describe('api client contracts', () => {
imageId: '6',
imageWidth: 640,
imageHeight: 360,
model: 'sam3',
text: '分割胆囊',
});
@@ -457,7 +456,7 @@ describe('api client contracts', () => {
image_id: 6,
prompt_type: 'semantic',
prompt_data: '分割胆囊',
model: 'sam3',
model: 'sam2.1_hiera_tiny',
});
});
@@ -484,7 +483,7 @@ describe('api client contracts', () => {
points: [[0.5, 0.5]],
labels: [1],
},
model: 'sam2',
model: 'sam2.1_hiera_tiny',
options: {
crop_to_prompt: true,
auto_filter_background: true,
@@ -496,18 +495,17 @@ describe('api client contracts', () => {
it('loads AI model and GPU runtime status', async () => {
const { getAiModelStatus } = await import('./api');
const status = {
selected_model: 'sam2',
selected_model: 'sam2.1_hiera_tiny',
gpu: { available: false, device: 'cpu', name: null, torch_available: true, torch_version: '2.x', cuda_version: null },
models: [
{ id: 'sam2', label: 'SAM 2', available: true, loaded: false, device: 'cpu', supports: ['point'], message: 'ready', package_available: true, checkpoint_exists: true, checkpoint_path: 'model.pt', python_ok: true, torch_ok: true, cuda_required: false },
{ id: 'sam3', label: 'SAM 3', available: false, loaded: false, device: 'unavailable', supports: ['semantic'], message: 'missing runtime', package_available: false, checkpoint_exists: false, checkpoint_path: null, python_ok: false, torch_ok: true, cuda_required: true },
{ id: 'sam2.1_hiera_tiny', label: 'SAM 2.1 Tiny', available: true, loaded: false, device: 'cpu', supports: ['point'], message: 'ready', package_available: true, checkpoint_exists: true, checkpoint_path: 'model.pt', python_ok: true, torch_ok: true, cuda_required: false },
],
};
axiosMock.client.get.mockResolvedValueOnce({ data: status });
await expect(getAiModelStatus('sam3')).resolves.toEqual(status);
await expect(getAiModelStatus('sam2.1_hiera_tiny')).resolves.toEqual(status);
expect(axiosMock.client.get).toHaveBeenCalledWith('/api/ai/models/status', {
params: { selected_model: 'sam3' },
params: { selected_model: 'sam2.1_hiera_tiny' },
});
});
});

View File

@@ -1,5 +1,5 @@
import axios, { AxiosError } from 'axios';
import type { AiModelId, Frame, Mask, Project, Template } from '../store/useStore';
import { DEFAULT_AI_MODEL_ID, type AiModelId, type Frame, type Mask, type Project, type Template } from '../store/useStore';
import { API_BASE_URL } from './config';
const apiClient = axios.create({
@@ -557,7 +557,7 @@ export async function predictMask(payload: PredictMaskPayload): Promise<PredictM
image_id: Number(payload.imageId),
prompt_type,
prompt_data,
model: payload.model || 'sam2',
model: payload.model || DEFAULT_AI_MODEL_ID,
...(payload.options ? { options: payload.options } : {}),
});

View File

@@ -17,7 +17,20 @@ export interface Project {
updatedAt?: string;
}
export type AiModelId = 'sam2' | 'sam3';
export type AiModelId =
| 'sam2.1_hiera_tiny'
| 'sam2.1_hiera_small'
| 'sam2.1_hiera_base_plus'
| 'sam2.1_hiera_large';
export const DEFAULT_AI_MODEL_ID: AiModelId = 'sam2.1_hiera_tiny';
export const SAM2_MODEL_OPTIONS: Array<{ id: AiModelId; label: string; shortLabel: string }> = [
{ id: 'sam2.1_hiera_tiny', label: 'SAM 2.1 Tiny', shortLabel: 'tiny' },
{ id: 'sam2.1_hiera_small', label: 'SAM 2.1 Small', shortLabel: 'small' },
{ id: 'sam2.1_hiera_base_plus', label: 'SAM 2.1 Base+', shortLabel: 'base+' },
{ id: 'sam2.1_hiera_large', label: 'SAM 2.1 Large', shortLabel: 'large' },
];
export interface Frame {
id: string;
@@ -195,7 +208,7 @@ export const useStore = create<AppState>((set) => ({
// Workspace
activeModule: 'workspace',
activeTool: 'move',
aiModel: 'sam2',
aiModel: DEFAULT_AI_MODEL_ID,
frames: [],
currentFrameIndex: 0,
annotations: [],

View File

@@ -63,7 +63,7 @@ vi.mock('react-konva', () => ({
);
},
Layer: ({ children }: any) => <div data-testid="konva-layer">{children}</div>,
Group: ({ children }: any) => <div data-testid="konva-group">{children}</div>,
Group: ({ children, opacity }: any) => <div data-testid="konva-group" data-opacity={opacity}>{children}</div>,
Image: ({ image }: any) => <img data-testid="konva-image" alt="" src={image?.src || ''} />,
Circle: (props: any) => (
<span
@@ -72,7 +72,11 @@ vi.mock('react-konva', () => ({
data-x={props.x}
data-y={props.y}
onClick={(event) => {
const konvaEvent = { cancelBubble: false };
const point = {
x: event.clientX || 120,
y: event.clientY || 80,
};
const konvaEvent = { ...makeStageEvent(point.x, point.y), cancelBubble: false };
props.onClick?.(konvaEvent);
if (konvaEvent.cancelBubble) event.stopPropagation();
}}
@@ -98,7 +102,11 @@ vi.mock('react-konva', () => ({
data-fill={props.fill}
data-fill-rule={props.fillRule}
onClick={(event) => {
const konvaEvent = { cancelBubble: false };
const point = {
x: event.clientX || 120,
y: event.clientY || 80,
};
const konvaEvent = { ...makeStageEvent(point.x, point.y), cancelBubble: false };
props.onClick?.(konvaEvent);
if (konvaEvent.cancelBubble) event.stopPropagation();
}}

View File

@@ -1,4 +1,4 @@
import { useStore } from '../store/useStore';
import { DEFAULT_AI_MODEL_ID, useStore } from '../store/useStore';
export function resetStore() {
useStore.setState({
@@ -8,7 +8,7 @@ export function resetStore() {
currentProject: null,
activeModule: 'workspace',
activeTool: 'move',
aiModel: 'sam2',
aiModel: DEFAULT_AI_MODEL_ID,
frames: [],
currentFrameIndex: 0,
annotations: [],