Files
Pre_Seg_Server/src/lib/api.test.ts
admin 5ab4602535 feat: 完善视频传播、标注编辑和拆帧闭环
- 接入 SAM2 视频传播能力:新增 /api/ai/propagate,支持用当前帧 mask/polygon/bbox 作为 seed,通过 SAM2 video predictor 向前、向后或双向传播,并可保存为真实 annotation。
- 接入 SAM3 video tracker:通过独立 Python 3.12 external worker 调用 SAM3 video predictor/tracker,使用本地 checkpoint 与 bbox seed 执行视频级跟踪,并在模型状态中标记 video_track 能力。
- 完善 SAM 模型分发:sam_registry 按 model_id 明确区分 sam2 propagation 与 sam3 video_track,避免两个模型链路混用。
- 打通前端“传播片段”:VideoWorkspace 使用当前选中 mask 和当前 AI 模型调用后端传播接口,传播结果回写并刷新工作区已保存标注。
- 增强 SAM3 本地 checkpoint 配置:新增 sam3_checkpoint_path 配置和 .env.example 示例,状态检查改为基于本地 checkpoint/独立环境/模型包可用性。
- 完善视频拆帧参数:/api/media/parse 支持 parse_fps、max_frames、target_width,后端任务保存帧时间戳、源帧号和 frame_sequence 元数据。
- 增加运行时 schema 兼容处理:启动时为旧 frames 表补充 timestamp_ms 和 source_frame_number 列,避免旧库升级后缺字段。
- 强化 Canvas 标注编辑:补齐多边形闭合、点工具、顶点拖拽、边中点插入、Delete/Backspace 删除、区域合并和重叠去除等交互。
- 增强语义分类联动:选中 mask 后可通过右侧语义分类树更新标签、颜色和 class metadata,并同步到保存/导出链路。
- 增加关键帧时间轴体验:FrameTimeline 显示具体时间信息,并支持键盘左右方向键切换关键帧。
- 完善 AI 交互分割参数:前端保留正向点、反向点、框选和 interactive prompt 的调用状态,支持 SAM2 细化候选区域与 SAM3 bbox 入口。
- 扩展后端/前端 API 类型:新增 propagateMasks、传播请求/响应 schema,并补齐 annotation、导出、模型状态和任务接口的测试覆盖。
- 更新项目文档:同步 README、AGENTS、接口契约、需求冻结、设计冻结、前端元素审计、实施计划和测试计划,标明真实功能边界与剩余风险。
- 增加测试覆盖:补充 SAM2/SAM3 传播、SAM3 状态、媒体拆帧参数、Canvas 编辑、语义标签切换、时间轴、工作区传播和 API 合约测试。
- 加强仓库安全边界:将 sam3权重/ 加入 .gitignore,避免本地模型权重被误提交。

验证:npm run test:run;pytest backend/tests;npm run lint;npm run build;python -m py_compile;git diff --check。
2026-05-01 20:27:33 +08:00

514 lines
17 KiB
TypeScript

import { beforeEach, describe, expect, it, vi } from 'vitest';
const axiosMock = vi.hoisted(() => {
const client = {
get: vi.fn(),
post: vi.fn(),
patch: vi.fn(),
delete: vi.fn(),
interceptors: {
request: { use: vi.fn() },
response: { use: vi.fn() },
},
};
return { client, create: vi.fn(() => client) };
});
vi.mock('axios', () => ({
default: {
create: axiosMock.create,
},
}));
describe('api client contracts', () => {
beforeEach(() => {
vi.clearAllMocks();
vi.setSystemTime(new Date('2026-05-01T00:00:00Z'));
});
it('maps backend project fields into frontend project fields', async () => {
const { getProjects } = await import('./api');
axiosMock.client.get.mockResolvedValueOnce({
data: [
{
id: 7,
name: 'Demo',
description: 'desc',
status: 'ready',
frame_count: 12,
original_fps: 29.97,
parse_fps: 10,
thumbnail_url: 'thumb',
video_path: 'uploads/demo.mp4',
source_type: 'video',
created_at: 'created',
updated_at: 'updated',
},
],
});
await expect(getProjects()).resolves.toEqual([
expect.objectContaining({
id: '7',
name: 'Demo',
status: 'ready',
frames: 12,
fps: '30FPS',
thumbnail_url: 'thumb',
video_path: 'uploads/demo.mp4',
source_type: 'video',
createdAt: 'created',
updatedAt: 'updated',
}),
]);
expect(axiosMock.client.get).toHaveBeenCalledWith('/api/projects');
});
it('updates projects with PATCH instead of the old PUT contract', async () => {
const { updateProject } = await import('./api');
axiosMock.client.patch.mockResolvedValueOnce({ data: { id: 3, name: 'Renamed', status: 'ready' } });
await updateProject('3', { name: 'Renamed' } as any);
expect(axiosMock.client.patch).toHaveBeenCalledWith('/api/projects/3', { name: 'Renamed' });
});
it('normalizes legacy project status values returned by existing databases', async () => {
const { getProjects } = await import('./api');
axiosMock.client.get.mockResolvedValueOnce({
data: [
{ id: 1, name: 'Old Ready', status: 'Ready' },
{ id: 2, name: 'Old Parsing', status: 'Parsing' },
{ id: 3, name: 'Old Error', status: 'Error' },
],
});
await expect(getProjects()).resolves.toEqual([
expect.objectContaining({ status: 'ready' }),
expect.objectContaining({ status: 'parsing' }),
expect.objectContaining({ status: 'error' }),
]);
});
it('exports COCO from the backend route shape', async () => {
const { exportCoco } = await import('./api');
const blob = new Blob(['{}'], { type: 'application/json' });
axiosMock.client.get.mockResolvedValueOnce({ data: blob });
await expect(exportCoco('9')).resolves.toBe(blob);
expect(axiosMock.client.get).toHaveBeenCalledWith('/api/export/9/coco', {
responseType: 'blob',
});
});
it('exports PNG masks from the backend route shape', async () => {
const { exportMasks } = await import('./api');
const blob = new Blob(['zip'], { type: 'application/zip' });
axiosMock.client.get.mockResolvedValueOnce({ data: blob });
await expect(exportMasks('9')).resolves.toBe(blob);
expect(axiosMock.client.get).toHaveBeenCalledWith('/api/export/9/masks', {
responseType: 'blob',
});
});
it('loads dashboard overview from the backend summary endpoint', async () => {
const { getDashboardOverview } = await import('./api');
const overview = {
summary: {
project_count: 2,
parsing_task_count: 1,
annotation_count: 5,
frame_count: 100,
template_count: 3,
system_load_percent: 12,
},
tasks: [
{ id: 'project-1', project_id: 1, name: 'Demo', progress: 60, status: 'pending', frame_count: 10, updated_at: 'now' },
],
activity: [
{ id: 'project-1', kind: 'project', time: 'now', message: '项目状态: pending', project: 'Demo' },
],
};
axiosMock.client.get.mockResolvedValueOnce({ data: overview });
await expect(getDashboardOverview()).resolves.toEqual(overview);
expect(axiosMock.client.get).toHaveBeenCalledWith('/api/dashboard/overview');
});
it('queues media parsing and manages processing task lifecycle', async () => {
const { cancelTask, getTask, parseMedia, retryTask } = await import('./api');
const task = {
id: 12,
task_type: 'parse_video',
status: 'queued',
progress: 0,
message: '解析任务已入队',
project_id: 9,
celery_task_id: 'celery-12',
payload: { source_type: 'video' },
result: null,
error: null,
created_at: 'created',
started_at: null,
finished_at: null,
updated_at: 'updated',
};
axiosMock.client.post.mockResolvedValueOnce({ data: task });
axiosMock.client.get.mockResolvedValueOnce({ data: { ...task, status: 'success', progress: 100 } });
axiosMock.client.post.mockResolvedValueOnce({ data: { ...task, status: 'cancelled', progress: 100 } });
axiosMock.client.post.mockResolvedValueOnce({ data: { ...task, id: 13, status: 'queued', progress: 0 } });
await expect(parseMedia('9', { parseFps: 15, maxFrames: 120, targetWidth: 960 })).resolves.toEqual(task);
expect(axiosMock.client.post).toHaveBeenCalledWith('/api/media/parse', null, {
params: { project_id: '9', parse_fps: 15, max_frames: 120, target_width: 960 },
});
await expect(getTask(12)).resolves.toEqual(expect.objectContaining({ status: 'success', progress: 100 }));
expect(axiosMock.client.get).toHaveBeenCalledWith('/api/tasks/12');
await expect(cancelTask(12)).resolves.toEqual(expect.objectContaining({ status: 'cancelled' }));
expect(axiosMock.client.post).toHaveBeenCalledWith('/api/tasks/12/cancel');
await expect(retryTask(12)).resolves.toEqual(expect.objectContaining({ id: 13, status: 'queued' }));
expect(axiosMock.client.post).toHaveBeenCalledWith('/api/tasks/12/retry');
});
it('lists, saves, updates, and deletes annotations with the backend annotation contract', async () => {
const { deleteAnnotation, getProjectAnnotations, propagateMasks, saveAnnotation, updateAnnotation } = await import('./api');
const saved = {
id: 1,
project_id: 9,
frame_id: 5,
template_id: 2,
mask_data: { polygons: [[[0, 0], [1, 0], [1, 1]]] },
points: null,
bbox: null,
created_at: 'created',
updated_at: 'updated',
};
axiosMock.client.get.mockResolvedValueOnce({ data: [saved] });
axiosMock.client.post.mockResolvedValueOnce({ data: saved });
await expect(getProjectAnnotations('9', '5')).resolves.toEqual([saved]);
expect(axiosMock.client.get).toHaveBeenCalledWith('/api/ai/annotations', {
params: { project_id: 9, frame_id: 5 },
});
await expect(saveAnnotation({
project_id: 9,
frame_id: 5,
template_id: 2,
mask_data: { polygons: [[[0, 0], [1, 0], [1, 1]]], label: 'mask' },
})).resolves.toEqual(saved);
expect(axiosMock.client.post).toHaveBeenCalledWith('/api/ai/annotate', {
project_id: 9,
frame_id: 5,
template_id: 2,
mask_data: { polygons: [[[0, 0], [1, 0], [1, 1]]], label: 'mask' },
});
axiosMock.client.patch.mockResolvedValueOnce({ data: { ...saved, mask_data: { ...saved.mask_data, label: 'updated' } } });
await expect(updateAnnotation('1', {
template_id: 2,
mask_data: { polygons: [[[0, 0], [1, 0], [1, 1]]], label: 'updated' },
})).resolves.toEqual(expect.objectContaining({ mask_data: expect.objectContaining({ label: 'updated' }) }));
expect(axiosMock.client.patch).toHaveBeenCalledWith('/api/ai/annotations/1', {
template_id: 2,
mask_data: { polygons: [[[0, 0], [1, 0], [1, 1]]], label: 'updated' },
});
axiosMock.client.delete.mockResolvedValueOnce({ data: null });
await expect(deleteAnnotation('1')).resolves.toBeUndefined();
expect(axiosMock.client.delete).toHaveBeenCalledWith('/api/ai/annotations/1');
axiosMock.client.post.mockResolvedValueOnce({
data: {
model: 'sam2',
direction: 'forward',
source_frame_id: 5,
processed_frame_count: 3,
created_annotation_count: 2,
annotations: [saved],
},
});
await expect(propagateMasks({
project_id: 9,
frame_id: 5,
model: 'sam2',
seed: {
polygons: [[[0, 0], [1, 0], [1, 1]]],
label: 'mask',
color: '#06b6d4',
},
direction: 'forward',
max_frames: 30,
})).resolves.toEqual(expect.objectContaining({ created_annotation_count: 2 }));
expect(axiosMock.client.post).toHaveBeenCalledWith('/api/ai/propagate', {
project_id: 9,
frame_id: 5,
model: 'sam2',
seed: {
polygons: [[[0, 0], [1, 0], [1, 1]]],
label: 'mask',
color: '#06b6d4',
},
direction: 'forward',
max_frames: 30,
}, {
timeout: 600000,
});
});
it('imports GT masks through multipart form data', async () => {
const { importGtMask } = await import('./api');
const file = new File(['mask'], 'mask.png', { type: 'image/png' });
const saved = [{ id: 1, project_id: 9, frame_id: 5, template_id: null, mask_data: null, points: null, bbox: null }];
axiosMock.client.post.mockResolvedValueOnce({ data: saved });
await expect(importGtMask(file, '9', '5', '2')).resolves.toEqual(saved);
expect(axiosMock.client.post).toHaveBeenCalledWith(
'/api/ai/import-gt-mask',
expect.any(FormData),
{ headers: { 'Content-Type': 'multipart/form-data' } },
);
const form = axiosMock.client.post.mock.calls.at(-1)?.[1] as FormData;
expect(form.get('file')).toBe(file);
expect(form.get('project_id')).toBe('9');
expect(form.get('frame_id')).toBe('5');
expect(form.get('template_id')).toBe('2');
});
it('builds annotation payloads from frontend masks and restores saved annotations to masks', async () => {
const { annotationToMask, buildAnnotationPayload } = await import('./api');
const frame = { id: '5', projectId: '9', index: 0, url: '/frame.jpg', width: 100, height: 50 };
const payload = buildAnnotationPayload('9', {
id: 'm1',
frameId: '5',
pathData: 'M 10 10 L 90 10 L 90 40 Z',
label: '胆囊',
color: '#ff0000',
classId: 'c1',
className: '胆囊',
classZIndex: 20,
segmentation: [[10, 10, 90, 10, 90, 40]],
bbox: [10, 10, 80, 30],
}, frame, '2');
expect(payload).toEqual({
project_id: 9,
frame_id: 5,
template_id: 2,
mask_data: {
polygons: [[[0.1, 0.2], [0.9, 0.2], [0.9, 0.8]]],
label: '胆囊',
color: '#ff0000',
class: { id: 'c1', name: '胆囊', color: '#ff0000', zIndex: 20 },
},
bbox: [0.1, 0.2, 0.8, 0.6],
});
expect(annotationToMask({
id: 3,
project_id: 9,
frame_id: 5,
template_id: 2,
mask_data: {
polygons: [[[0.1, 0.2], [0.9, 0.2], [0.9, 0.8]]],
label: '旧标签',
color: '#06b6d4',
class: { id: 'c1', name: '胆囊', color: '#ff0000', zIndex: 20 },
},
points: [[0.5, 0.5]],
bbox: null,
created_at: 'created',
updated_at: 'updated',
}, frame)).toEqual(expect.objectContaining({
id: 'annotation-3',
annotationId: '3',
frameId: '5',
templateId: '2',
classId: 'c1',
className: '胆囊',
classZIndex: 20,
label: '胆囊',
color: '#ff0000',
saveStatus: 'saved',
saved: true,
pathData: 'M 10 10 L 90 10 L 90 40 Z',
points: [[50, 25]],
bbox: [10, 10, 80, 30],
}));
});
it('preserves editable point regions in annotation payloads', async () => {
const { buildAnnotationPayload } = await import('./api');
const frame = { id: '5', projectId: '9', index: 0, url: '/frame.jpg', width: 100, height: 50 };
expect(buildAnnotationPayload('9', {
id: 'm1',
frameId: '5',
pathData: 'M 10 10 L 90 10 L 90 40 Z',
label: 'GT Mask',
color: '#22c55e',
segmentation: [[10, 10, 90, 10, 90, 40]],
points: [[50, 25]],
}, frame)).toEqual(expect.objectContaining({
points: [[0.5, 0.5]],
}));
});
it('normalizes positive and negative point prompts for AI prediction', async () => {
const { predictMask } = await import('./api');
axiosMock.client.post.mockResolvedValueOnce({
data: {
polygons: [[[0.25, 0.25], [0.75, 0.25], [0.75, 0.75], [0.25, 0.75]]],
scores: [0.9],
},
});
const result = await predictMask({
imageId: '42',
imageWidth: 400,
imageHeight: 200,
points: [
{ x: 200, y: 100, type: 'pos' },
{ x: 40, y: 20, type: 'neg' },
],
});
expect(axiosMock.client.post).toHaveBeenCalledWith('/api/ai/predict', {
image_id: 42,
prompt_type: 'point',
prompt_data: {
points: [[0.5, 0.5], [0.1, 0.1]],
labels: [1, 0],
},
model: 'sam2',
});
expect(result.masks[0]).toEqual(expect.objectContaining({
pathData: 'M 100 50 L 300 50 L 300 150 L 100 150 Z',
segmentation: [[100, 50, 300, 50, 300, 150, 100, 150]],
bbox: [100, 50, 200, 100],
area: 20000,
confidence: 0.9,
}));
});
it('normalizes box prompts for AI prediction', async () => {
const { predictMask } = await import('./api');
axiosMock.client.post.mockResolvedValueOnce({ data: { polygons: [], scores: [] } });
await predictMask({
imageId: '5',
imageWidth: 640,
imageHeight: 320,
box: { x1: 64, y1: 32, x2: 320, y2: 160 },
});
expect(axiosMock.client.post).toHaveBeenCalledWith('/api/ai/predict', {
image_id: 5,
prompt_type: 'box',
prompt_data: [0.1, 0.1, 0.5, 0.5],
model: 'sam2',
});
});
it('normalizes combined box and point prompts for interactive SAM2 refinement', async () => {
const { predictMask } = await import('./api');
axiosMock.client.post.mockResolvedValueOnce({ data: { polygons: [], scores: [] } });
await predictMask({
imageId: '5',
imageWidth: 640,
imageHeight: 320,
box: { x1: 64, y1: 32, x2: 320, y2: 160 },
points: [
{ x: 128, y: 64, type: 'pos' },
{ x: 256, y: 128, type: 'neg' },
],
});
expect(axiosMock.client.post).toHaveBeenCalledWith('/api/ai/predict', {
image_id: 5,
prompt_type: 'interactive',
prompt_data: {
box: [0.1, 0.1, 0.5, 0.5],
points: [[0.2, 0.2], [0.4, 0.4]],
labels: [1, 0],
},
model: 'sam2',
});
});
it('uses semantic prompt type for text-only AI prediction', async () => {
const { predictMask } = await import('./api');
axiosMock.client.post.mockResolvedValueOnce({ data: { polygons: [], scores: [] } });
await predictMask({
imageId: '6',
imageWidth: 640,
imageHeight: 360,
model: 'sam3',
text: '分割胆囊',
});
expect(axiosMock.client.post).toHaveBeenCalledWith('/api/ai/predict', {
image_id: 6,
prompt_type: 'semantic',
prompt_data: '分割胆囊',
model: 'sam3',
});
});
it('passes AI post-processing options to prediction endpoint', async () => {
const { predictMask } = await import('./api');
axiosMock.client.post.mockResolvedValueOnce({ data: { polygons: [], scores: [] } });
await predictMask({
imageId: '7',
imageWidth: 640,
imageHeight: 360,
points: [{ x: 320, y: 180, type: 'pos' }],
options: {
crop_to_prompt: true,
auto_filter_background: true,
min_score: 0.05,
},
});
expect(axiosMock.client.post).toHaveBeenCalledWith('/api/ai/predict', {
image_id: 7,
prompt_type: 'point',
prompt_data: {
points: [[0.5, 0.5]],
labels: [1],
},
model: 'sam2',
options: {
crop_to_prompt: true,
auto_filter_background: true,
min_score: 0.05,
},
});
});
it('loads AI model and GPU runtime status', async () => {
const { getAiModelStatus } = await import('./api');
const status = {
selected_model: 'sam2',
gpu: { available: false, device: 'cpu', name: null, torch_available: true, torch_version: '2.x', cuda_version: null },
models: [
{ id: 'sam2', label: 'SAM 2', available: true, loaded: false, device: 'cpu', supports: ['point'], message: 'ready', package_available: true, checkpoint_exists: true, checkpoint_path: 'model.pt', python_ok: true, torch_ok: true, cuda_required: false },
{ id: 'sam3', label: 'SAM 3', available: false, loaded: false, device: 'unavailable', supports: ['semantic'], message: 'missing runtime', package_available: false, checkpoint_exists: false, checkpoint_path: null, python_ok: false, torch_ok: true, cuda_required: true },
],
};
axiosMock.client.get.mockResolvedValueOnce({ data: status });
await expect(getAiModelStatus('sam3')).resolves.toEqual(status);
expect(axiosMock.client.get).toHaveBeenCalledWith('/api/ai/models/status', {
params: { selected_model: 'sam3' },
});
});
});