feat: 建立 SAM2 标注闭环基线

- 打通工作区真实标注闭环:支持手工多边形、矩形、圆形、点区域和线段生成 mask,并可保存、回显、更新和删除后端 annotation。

- 增强 polygon 编辑器:支持顶点拖动、顶点删除、边中点插入、多 polygon 子区域选择编辑,以及区域合并和区域去除。

- 接入 GT mask 导入:后端支持二值/多类别 mask 拆分、contour 转 polygon、distance transform seed point,前端支持导入、回显和 seed point 拖动编辑。

- 完善导出能力:COCO JSON 导出对齐前端,PNG mask ZIP 同时包含单标注 mask、按 zIndex 融合的 semantic_frame 和 semantic_classes.json。

- 打通异步任务管理:新增任务取消、重试、失败详情接口与 Dashboard 控件,worker 支持取消状态检查并通过 Redis/WebSocket 推送 cancelled 事件。

- 对接 Dashboard 后端数据:概览统计、解析队列和实时流转记录从 FastAPI 聚合接口与 WebSocket 更新。

- 增强 AI 推理参数:前端发送 crop_to_prompt、auto_filter_background 和 min_score,后端支持点/框 prompt 局部裁剪推理、结果回映射和负向点/低分过滤。

- 接入 SAM3 基础设施:新增独立 Python 3.12 sam3 环境安装脚本、外部 worker helper、后端桥接和真实 Python/CUDA/包/HF checkpoint access 状态检测。

- 保留 SAM3 授权边界:当前官方 facebook/sam3 gated 权重未授权时状态接口会返回不可用,不伪装成可推理。

- 增强前端状态管理:新增 mask undo/redo 历史栈、AI 模型选择状态、保存状态 dirty/draft/saved 流转和项目状态归一化。

- 更新前端 API 封装:补充 annotation CRUD、GT mask import、mask ZIP export、task cancel/retry/detail、AI runtime status 和 prediction options。

- 更新 UI 控件:ToolsPalette、AISegmentation、VideoWorkspace 和 CanvasArea 接入真实操作、导入导出、撤销重做、任务控制和模型状态。

- 新增 polygon-clipping 依赖,用于前端区域 union/difference 几何运算。

- 完善后端 schemas/status/progress:补充 AI 模型外部状态字段、任务 cancelled 状态和进度事件 payload。

- 补充测试覆盖:新增后端任务控制、SAM3 桥接、GT mask、导出融合、AI options 测试;补充前端 Canvas、Dashboard、VideoWorkspace、ToolsPalette、API 和 store 测试。

- 更新 README、AGENTS 和 doc 文档:冻结当前需求/设计/测试计划,标注真实功能、剩余 Mock、SAM3 授权边界和后续实施顺序。
This commit is contained in:
2026-05-01 15:26:25 +08:00
parent f020ff3b4f
commit 689a9ba283
48 changed files with 3280 additions and 176 deletions

View File

@@ -40,4 +40,26 @@ describe('AISegmentation', () => {
expect(useStore.getState().aiModel).toBe('sam3');
expect(await screen.findByText('SAM 3 missing runtime')).toBeInTheDocument();
});
it('passes enabled inference parameters to the backend', async () => {
apiMock.predictMask.mockResolvedValueOnce({ masks: [] });
render(<AISegmentation onSendToWorkspace={vi.fn()} />);
fireEvent.click(screen.getByText('正向选点'));
fireEvent.click(screen.getByTestId('konva-stage'));
fireEvent.click(await screen.findByText('执行高精度语义分割'));
expect(apiMock.predictMask).toHaveBeenCalledWith(expect.objectContaining({
imageId: 'frame-1',
imageWidth: 640,
imageHeight: 360,
model: 'sam2',
points: [{ x: 120, y: 80, type: 'pos' }],
options: {
crop_to_prompt: false,
auto_filter_background: true,
min_score: 0.05,
},
}));
});
});

View File

@@ -17,6 +17,10 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
const masks = useStore((state) => state.masks);
const addMask = useStore((state) => state.addMask);
const clearMasks = useStore((state) => state.clearMasks);
const maskHistory = useStore((state) => state.maskHistory);
const maskFuture = useStore((state) => state.maskFuture);
const undoMasks = useStore((state) => state.undoMasks);
const redoMasks = useStore((state) => state.redoMasks);
const frames = useStore((state) => state.frames);
const currentFrameIndex = useStore((state) => state.currentFrameIndex);
const activeTemplateId = useStore((state) => state.activeTemplateId);
@@ -109,6 +113,11 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
model: aiModel,
points: points.map((p) => ({ x: p.x, y: p.y, type: p.type })),
text: semanticText.trim() || undefined,
options: {
crop_to_prompt: cropMode,
auto_filter_background: autoDeleteBg,
min_score: autoDeleteBg ? 0.05 : 0,
},
});
result.masks.forEach((m) => {
@@ -136,7 +145,7 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
} finally {
setIsInferencing(false);
}
}, [activeClass, activeTemplateId, addMask, aiModel, currentFrame?.height, currentFrame?.id, currentFrame?.width, image?.height, image?.naturalHeight, image?.naturalWidth, image?.width, points, semanticText]);
}, [activeClass, activeTemplateId, addMask, aiModel, autoDeleteBg, cropMode, currentFrame?.height, currentFrame?.id, currentFrame?.width, image?.height, image?.naturalHeight, image?.naturalWidth, image?.width, points, semanticText]);
const handleStageClick = (e: any) => {
if (effectiveTool === 'move') return;
@@ -290,10 +299,20 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
<span className="text-[10px] text-gray-500 uppercase tracking-widest font-mono">{aiModel.toUpperCase()} </span>
</div>
<div className="flex items-center gap-4">
<button className="w-8 h-8 rounded text-gray-400 hover:bg-white/5 hover:text-white flex items-center justify-center transition-colors" title="撤销操作 (Ctrl+Z)">
<button
onClick={undoMasks}
disabled={maskHistory.length === 0}
className="w-8 h-8 rounded text-gray-400 hover:bg-white/5 hover:text-white flex items-center justify-center transition-colors disabled:opacity-30 disabled:hover:bg-transparent disabled:hover:text-gray-400 disabled:cursor-not-allowed"
title="撤销操作 (Ctrl+Z)"
>
<Undo size={14} />
</button>
<button className="w-8 h-8 rounded text-gray-400 hover:bg-white/5 hover:text-white flex items-center justify-center transition-colors" title="重做操作 (Ctrl+Shift+Z)">
<button
onClick={redoMasks}
disabled={maskFuture.length === 0}
className="w-8 h-8 rounded text-gray-400 hover:bg-white/5 hover:text-white flex items-center justify-center transition-colors disabled:opacity-30 disabled:hover:bg-transparent disabled:hover:text-gray-400 disabled:cursor-not-allowed"
title="重做操作 (Ctrl+Shift+Z)"
>
<Redo size={14} />
</button>
<div className="w-px h-4 bg-white/10 mx-1"></div>

View File

@@ -79,6 +79,271 @@ describe('CanvasArea', () => {
expect(screen.getByText('遮罩数: 1')).toBeInTheDocument();
});
it('renders imported GT seed points for editable point regions', () => {
useStore.setState({
masks: [
{
id: 'gt-1',
frameId: 'frame-1',
pathData: 'M 0 0 L 10 0 L 10 10 Z',
label: 'GT',
color: '#22c55e',
points: [[120, 80]],
},
],
});
render(<CanvasArea activeTool="move" frame={frame} />);
expect(screen.getAllByTestId('konva-circle')).toHaveLength(2);
});
it('selects a polygon mask and drags a vertex into dirty saved state', () => {
useStore.setState({
masks: [
{
id: 'annotation-99',
annotationId: '99',
frameId: 'frame-1',
pathData: 'M 10 10 L 90 10 L 90 40 Z',
label: 'Saved',
color: '#06b6d4',
saved: true,
saveStatus: 'saved',
segmentation: [[10, 10, 90, 10, 90, 40]],
bbox: [10, 10, 80, 30],
},
],
});
render(<CanvasArea activeTool="move" frame={frame} />);
fireEvent.click(screen.getByTestId('konva-path'));
const handles = screen.getAllByTestId('konva-circle')
.filter((element) => element.getAttribute('data-fill') === '#ffffff');
expect(handles).toHaveLength(3);
fireEvent.mouseUp(handles[0], { clientX: 20, clientY: 30 });
expect(useStore.getState().masks[0]).toEqual(expect.objectContaining({
pathData: 'M 20 30 L 90 10 L 90 40 Z',
segmentation: [[20, 30, 90, 10, 90, 40]],
bbox: [20, 10, 70, 30],
area: 1050,
saveStatus: 'dirty',
saved: false,
}));
});
it('deletes a selected polygon vertex without dropping below three points', () => {
useStore.setState({
masks: [
{
id: 'draft-1',
frameId: 'frame-1',
pathData: 'M 10 10 L 90 10 L 90 40 L 10 40 Z',
label: 'Draft',
color: '#06b6d4',
saveStatus: 'draft',
segmentation: [[10, 10, 90, 10, 90, 40, 10, 40]],
bbox: [10, 10, 80, 30],
},
],
});
render(<CanvasArea activeTool="move" frame={frame} />);
fireEvent.click(screen.getByTestId('konva-path'));
const handles = screen.getAllByTestId('konva-circle')
.filter((element) => element.getAttribute('data-fill') === '#ffffff');
fireEvent.click(handles[0]);
fireEvent.keyDown(window, { key: 'Delete' });
expect(useStore.getState().masks[0]).toEqual(expect.objectContaining({
pathData: 'M 90 10 L 90 40 L 10 40 Z',
segmentation: [[90, 10, 90, 40, 10, 40]],
saveStatus: 'draft',
}));
});
it('inserts a polygon vertex from an edge midpoint handle', () => {
useStore.setState({
masks: [
{
id: 'draft-1',
frameId: 'frame-1',
pathData: 'M 10 10 L 90 10 L 90 40 Z',
label: 'Draft',
color: '#06b6d4',
saveStatus: 'draft',
segmentation: [[10, 10, 90, 10, 90, 40]],
bbox: [10, 10, 80, 30],
},
],
});
render(<CanvasArea activeTool="move" frame={frame} />);
fireEvent.click(screen.getByTestId('konva-path'));
const edgeHandles = screen.getAllByTestId('konva-circle')
.filter((element) => element.getAttribute('data-fill') === '#22d3ee');
fireEvent.click(edgeHandles[0]);
expect(useStore.getState().masks[0]).toEqual(expect.objectContaining({
segmentation: [[10, 10, 50, 10, 90, 10, 90, 40]],
pathData: 'M 10 10 L 50 10 L 90 10 L 90 40 Z',
saveStatus: 'draft',
}));
});
it('edits the selected polygon in a multi-polygon mask', () => {
useStore.setState({
masks: [
{
id: 'multi-1',
frameId: 'frame-1',
pathData: 'M 10 10 L 50 10 L 50 40 Z M 100 100 L 150 100 L 150 140 Z',
label: 'Multi',
color: '#06b6d4',
saveStatus: 'draft',
segmentation: [
[10, 10, 50, 10, 50, 40],
[100, 100, 150, 100, 150, 140],
],
bbox: [10, 10, 140, 130],
},
],
});
render(<CanvasArea activeTool="move" frame={frame} />);
const paths = screen.getAllByTestId('konva-path');
fireEvent.click(paths[1]);
const vertexHandles = screen.getAllByTestId('konva-circle')
.filter((element) => element.getAttribute('data-fill') === '#ffffff');
fireEvent.mouseUp(vertexHandles[0], { clientX: 120, clientY: 120 });
expect(useStore.getState().masks[0].segmentation).toEqual([
[10, 10, 50, 10, 50, 40],
[120, 120, 150, 100, 150, 140],
]);
});
it('merges selected draft masks with polygon union', () => {
useStore.setState({
masks: [
{
id: 'm1',
frameId: 'frame-1',
pathData: 'M 10 10 L 90 10 L 90 50 L 10 50 Z',
label: 'A',
color: '#06b6d4',
segmentation: [[10, 10, 90, 10, 90, 50, 10, 50]],
},
{
id: 'm2',
frameId: 'frame-1',
pathData: 'M 50 30 L 120 30 L 120 80 L 50 80 Z',
label: 'B',
color: '#ff0000',
segmentation: [[50, 30, 120, 30, 120, 80, 50, 80]],
},
],
});
render(<CanvasArea activeTool="area_merge" frame={frame} />);
const paths = screen.getAllByTestId('konva-path');
fireEvent.click(paths[0]);
fireEvent.click(paths[1]);
fireEvent.click(screen.getByRole('button', { name: '合并选中' }));
expect(useStore.getState().masks).toHaveLength(1);
expect(useStore.getState().masks[0]).toEqual(expect.objectContaining({
id: 'm1',
segmentation: [[10, 10, 90, 10, 90, 30, 120, 30, 120, 80, 50, 80, 50, 50, 10, 50]],
bbox: [10, 10, 110, 70],
saveStatus: 'draft',
}));
});
it('removes overlap from the primary selected mask with polygon difference', () => {
useStore.setState({
masks: [
{
id: 'm1',
frameId: 'frame-1',
pathData: 'M 10 10 L 90 10 L 90 50 L 10 50 Z',
label: 'A',
color: '#06b6d4',
segmentation: [[10, 10, 90, 10, 90, 50, 10, 50]],
},
{
id: 'm2',
frameId: 'frame-1',
pathData: 'M 50 30 L 120 30 L 120 80 L 50 80 Z',
label: 'B',
color: '#ff0000',
segmentation: [[50, 30, 120, 30, 120, 80, 50, 80]],
},
],
});
render(<CanvasArea activeTool="area_remove" frame={frame} />);
const paths = screen.getAllByTestId('konva-path');
fireEvent.click(paths[0]);
fireEvent.click(paths[1]);
fireEvent.click(screen.getByRole('button', { name: '从主区域去除' }));
expect(useStore.getState().masks).toHaveLength(2);
expect(useStore.getState().masks[0]).toEqual(expect.objectContaining({
id: 'm1',
segmentation: [[10, 10, 90, 10, 90, 30, 50, 30, 50, 50, 10, 50]],
bbox: [10, 10, 80, 40],
saveStatus: 'draft',
}));
expect(useStore.getState().masks[1].id).toBe('m2');
});
it('creates a manual rectangle mask that can be undone and redone', () => {
useStore.setState({
activeTemplateId: '2',
activeClass: { id: 'c1', name: '胆囊', color: '#ff0000', zIndex: 20 },
activeClassId: 'c1',
});
render(<CanvasArea activeTool="create_rectangle" frame={frame} />);
const stage = screen.getByTestId('konva-stage');
fireEvent.mouseDown(stage);
fireEvent.mouseMove(stage);
fireEvent.mouseUp(stage);
expect(useStore.getState().masks).toHaveLength(1);
expect(useStore.getState().masks[0]).toEqual(expect.objectContaining({
frameId: 'frame-1',
label: '胆囊',
color: '#ff0000',
saveStatus: 'draft',
segmentation: [[120, 80, 260, 80, 260, 200, 120, 200]],
bbox: [120, 80, 140, 120],
}));
useStore.getState().undoMasks();
expect(useStore.getState().masks).toEqual([]);
useStore.getState().redoMasks();
expect(useStore.getState().masks).toHaveLength(1);
});
it('finalizes a clicked polygon with Enter', () => {
render(<CanvasArea activeTool="create_polygon" frame={frame} />);
const stage = screen.getByTestId('konva-stage');
fireEvent.click(stage, { clientX: 120, clientY: 80 });
fireEvent.click(stage, { clientX: 220, clientY: 80 });
fireEvent.click(stage, { clientX: 180, clientY: 160 });
fireEvent.keyDown(window, { key: 'Enter' });
expect(useStore.getState().masks).toHaveLength(1);
expect(useStore.getState().masks[0].metadata).toEqual(expect.objectContaining({
source: 'manual',
shape: '多边形',
}));
});
it('applies the selected class to current-frame masks and marks saved masks dirty', () => {
useStore.setState({
activeTemplateId: '2',

View File

@@ -1,17 +1,180 @@
import React, { useEffect, useRef, useState, useCallback } from 'react';
import { Stage, Layer, Image as KonvaImage, Circle, Rect, Path, Group } from 'react-konva';
import polygonClipping, { type MultiPolygon, type Pair } from 'polygon-clipping';
import useImage from 'use-image';
import { useStore } from '../store/useStore';
import { predictMask } from '../lib/api';
import type { Frame } from '../store/useStore';
import type { Frame, Mask } from '../store/useStore';
interface CanvasAreaProps {
activeTool: string;
frame: Frame | null;
onClearMasks?: () => void;
onDeleteMaskAnnotations?: (annotationIds: string[]) => Promise<void> | void;
}
export function CanvasArea({ activeTool, frame, onClearMasks }: CanvasAreaProps) {
type CanvasPoint = { x: number; y: number };
const DRAG_MANUAL_TOOLS = new Set(['create_rectangle', 'create_circle', 'create_line']);
const POLYGON_TOOL = 'create_polygon';
const POINT_TOOL = 'create_point';
const BOOLEAN_TOOLS = new Set(['area_merge', 'area_remove']);
function clamp(value: number, min: number, max: number): number {
return Math.min(Math.max(value, min), max);
}
function polygonPath(points: CanvasPoint[]): string {
if (points.length === 0) return '';
return points
.map((point, index) => `${index === 0 ? 'M' : 'L'} ${point.x} ${point.y}`)
.join(' ')
.concat(' Z');
}
function segmentationPath(segmentation?: number[][]): string {
return (segmentation || [])
.map((polygon) => polygonPath(flatPolygonToPoints(polygon)))
.filter(Boolean)
.join(' ');
}
function segmentationPolygonPath(segmentation: number[][] | undefined, polygonIndex: number): string {
const polygon = segmentation?.[polygonIndex];
return polygon ? polygonPath(flatPolygonToPoints(polygon)) : '';
}
function polygonSegmentation(points: CanvasPoint[]): number[][] {
return [points.flatMap((point) => [point.x, point.y])];
}
function segmentationToPoints(segmentation?: number[][], polygonIndex = 0): CanvasPoint[] {
const polygon = segmentation?.[polygonIndex] || [];
const points: CanvasPoint[] = [];
for (let index = 0; index < polygon.length - 1; index += 2) {
points.push({ x: polygon[index], y: polygon[index + 1] });
}
return points;
}
function flatPolygonToPoints(polygon: number[]): CanvasPoint[] {
const points: CanvasPoint[] = [];
for (let index = 0; index < polygon.length - 1; index += 2) {
points.push({ x: polygon[index], y: polygon[index + 1] });
}
return points;
}
function segmentationAllPoints(segmentation?: number[][]): CanvasPoint[] {
return (segmentation || []).flatMap((polygon) => flatPolygonToPoints(polygon));
}
function polygonBbox(points: CanvasPoint[]): [number, number, number, number] {
const xs = points.map((point) => point.x);
const ys = points.map((point) => point.y);
const minX = Math.min(...xs);
const minY = Math.min(...ys);
const maxX = Math.max(...xs);
const maxY = Math.max(...ys);
return [minX, minY, maxX - minX, maxY - minY];
}
function polygonArea(points: CanvasPoint[]): number {
if (points.length < 3) return 0;
const sum = points.reduce((acc, point, index) => {
const next = points[(index + 1) % points.length];
return acc + point.x * next.y - next.x * point.y;
}, 0);
return Math.abs(sum) / 2;
}
function segmentationArea(segmentation?: number[][]): number {
return (segmentation || []).reduce((sum, polygon) => sum + polygonArea(flatPolygonToPoints(polygon)), 0);
}
function segmentationBbox(segmentation?: number[][]): [number, number, number, number] | undefined {
const points = segmentationAllPoints(segmentation);
return points.length > 0 ? polygonBbox(points) : undefined;
}
function closeRing(points: CanvasPoint[]): Pair[] {
const ring = points.map((point) => [point.x, point.y] as Pair);
const first = ring[0];
const last = ring[ring.length - 1];
if (first && last && (first[0] !== last[0] || first[1] !== last[1])) {
ring.push([first[0], first[1]]);
}
return ring;
}
function maskToMultiPolygon(mask: Mask): MultiPolygon | null {
const polygons = (mask.segmentation || [])
.map((polygon) => flatPolygonToPoints(polygon))
.filter((points) => points.length >= 3)
.map((points) => [closeRing(points)]);
return polygons.length > 0 ? polygons : null;
}
function multiPolygonToSegmentation(geometry: MultiPolygon): number[][] {
return geometry
.map((polygon) => polygon[0] || [])
.map((ring) => {
const openRing = ring.length > 1
&& ring[0][0] === ring[ring.length - 1][0]
&& ring[0][1] === ring[ring.length - 1][1]
? ring.slice(0, -1)
: ring;
return openRing.flatMap(([x, y]) => [x, y]);
})
.filter((polygon) => polygon.length >= 6);
}
function rectanglePoints(start: CanvasPoint, end: CanvasPoint): CanvasPoint[] {
const x1 = Math.min(start.x, end.x);
const y1 = Math.min(start.y, end.y);
const x2 = Math.max(start.x, end.x);
const y2 = Math.max(start.y, end.y);
return [
{ x: x1, y: y1 },
{ x: x2, y: y1 },
{ x: x2, y: y2 },
{ x: x1, y: y2 },
];
}
function circlePoints(start: CanvasPoint, end: CanvasPoint): CanvasPoint[] {
const cx = (start.x + end.x) / 2;
const cy = (start.y + end.y) / 2;
const rx = Math.abs(end.x - start.x) / 2;
const ry = Math.abs(end.y - start.y) / 2;
return Array.from({ length: 32 }, (_, index) => {
const angle = (Math.PI * 2 * index) / 32;
return { x: cx + Math.cos(angle) * rx, y: cy + Math.sin(angle) * ry };
});
}
function pointRegion(point: CanvasPoint, radius = 5): CanvasPoint[] {
return Array.from({ length: 12 }, (_, index) => {
const angle = (Math.PI * 2 * index) / 12;
return { x: point.x + Math.cos(angle) * radius, y: point.y + Math.sin(angle) * radius };
});
}
function lineRegion(start: CanvasPoint, end: CanvasPoint, halfWidth = 4): CanvasPoint[] {
const dx = end.x - start.x;
const dy = end.y - start.y;
const length = Math.hypot(dx, dy) || 1;
const nx = (-dy / length) * halfWidth;
const ny = (dx / length) * halfWidth;
return [
{ x: start.x + nx, y: start.y + ny },
{ x: end.x + nx, y: end.y + ny },
{ x: end.x - nx, y: end.y - ny },
{ x: start.x - nx, y: start.y - ny },
];
}
export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnotations }: CanvasAreaProps) {
const containerRef = useRef<HTMLDivElement>(null);
const [stageSize, setStageSize] = useState({ width: 800, height: 600 });
const [scale, setScale] = useState(1);
@@ -20,22 +183,46 @@ export function CanvasArea({ activeTool, frame, onClearMasks }: CanvasAreaProps)
const [cursorPos, setCursorPos] = useState({ x: 0, y: 0 });
const [boxStart, setBoxStart] = useState<{ x: number, y: number } | null>(null);
const [boxCurrent, setBoxCurrent] = useState<{ x: number, y: number } | null>(null);
const [manualStart, setManualStart] = useState<CanvasPoint | null>(null);
const [manualCurrent, setManualCurrent] = useState<CanvasPoint | null>(null);
const [polygonPoints, setPolygonPoints] = useState<CanvasPoint[]>([]);
const [selectedMaskId, setSelectedMaskId] = useState<string | null>(null);
const [selectedMaskIds, setSelectedMaskIds] = useState<string[]>([]);
const [selectedPolygonIndex, setSelectedPolygonIndex] = useState(0);
const [selectedVertexIndex, setSelectedVertexIndex] = useState<number | null>(null);
const [isInferencing, setIsInferencing] = useState(false);
const masks = useStore((state) => state.masks);
const addMask = useStore((state) => state.addMask);
const updateMask = useStore((state) => state.updateMask);
const clearMasks = useStore((state) => state.clearMasks);
const setMasks = useStore((state) => state.setMasks);
const storeActiveTool = useStore((state) => state.activeTool);
const aiModel = useStore((state) => state.aiModel);
const activeTemplateId = useStore((state) => state.activeTemplateId);
const activeClass = useStore((state) => state.activeClass);
const undoMasks = useStore((state) => state.undoMasks);
const redoMasks = useStore((state) => state.redoMasks);
const effectiveTool = activeTool || storeActiveTool;
// Load the actual frame image
const [image] = useImage(frame?.url || '');
const frameMasks = masks.filter((mask) => mask.frameId === frame?.id);
const selectedMask = React.useMemo(
() => frameMasks.find((mask) => mask.id === selectedMaskId) || null,
[frameMasks, selectedMaskId],
);
const booleanSelectedMasks = React.useMemo(
() => selectedMaskIds
.map((id) => frameMasks.find((mask) => mask.id === id))
.filter((mask): mask is Mask => Boolean(mask)),
[frameMasks, selectedMaskIds],
);
const selectedMaskPoints = React.useMemo(
() => segmentationToPoints(selectedMask?.segmentation, selectedPolygonIndex),
[selectedMask?.segmentation, selectedPolygonIndex],
);
const savedMaskCount = frameMasks.filter((mask) => mask.saveStatus === 'saved' || mask.saved).length;
const draftMaskCount = frameMasks.filter((mask) => !mask.annotationId).length;
const dirtyMaskCount = frameMasks.filter((mask) => mask.saveStatus === 'dirty').length;
@@ -55,6 +242,25 @@ export function CanvasArea({ activeTool, frame, onClearMasks }: CanvasAreaProps)
return () => window.removeEventListener('resize', handleResize);
}, []);
useEffect(() => {
setManualStart(null);
setManualCurrent(null);
setPolygonPoints([]);
setSelectedMaskId(null);
setSelectedMaskIds([]);
setSelectedPolygonIndex(0);
setSelectedVertexIndex(null);
}, [effectiveTool, frame?.id]);
useEffect(() => {
if (selectedMaskId && !frameMasks.some((mask) => mask.id === selectedMaskId)) {
setSelectedMaskId(null);
setSelectedMaskIds([]);
setSelectedPolygonIndex(0);
setSelectedVertexIndex(null);
}
}, [frameMasks, selectedMaskId]);
const handleWheel = (e: any) => {
e.evt.preventDefault();
const scaleBy = 1.1;
@@ -74,6 +280,50 @@ export function CanvasArea({ activeTool, frame, onClearMasks }: CanvasAreaProps)
});
};
const stagePoint = (e: any): CanvasPoint | null => {
const stage = e.target.getStage();
const relPos = stage?.getRelativePointerPosition();
if (!relPos) return null;
const imageWidth = frame?.width || image?.naturalWidth || image?.width || stageSize.width;
const imageHeight = frame?.height || image?.naturalHeight || image?.height || stageSize.height;
return {
x: clamp(relPos.x, 0, imageWidth),
y: clamp(relPos.y, 0, imageHeight),
};
};
const createManualMask = useCallback((shape: string, polygon: CanvasPoint[]) => {
if (!frame?.id || polygon.length < 3) return;
const area = polygonArea(polygon);
if (area <= 1) return;
const color = activeClass?.color || '#06b6d4';
const label = activeClass?.name || `手工${shape}`;
const mask: Mask = {
id: `manual-${frame.id}-${shape}-${Date.now()}`,
frameId: frame.id,
templateId: activeTemplateId || undefined,
classId: activeClass?.id,
className: activeClass?.name,
classZIndex: activeClass?.zIndex,
saveStatus: 'draft',
saved: false,
pathData: polygonPath(polygon),
label,
color,
segmentation: polygonSegmentation(polygon),
points: shape === '点区域'
? [[
polygon.reduce((sum, point) => sum + point.x, 0) / polygon.length,
polygon.reduce((sum, point) => sum + point.y, 0) / polygon.length,
]]
: undefined,
bbox: polygonBbox(polygon),
area,
metadata: { source: 'manual', shape },
};
addMask(mask);
}, [activeClass, activeTemplateId, addMask, frame?.id]);
const handleMouseMove = (e: any) => {
const stage = e.target.getStage();
if (!stage) return;
@@ -90,6 +340,13 @@ export function CanvasArea({ activeTool, frame, onClearMasks }: CanvasAreaProps)
setBoxCurrent({ x: relPos.x, y: relPos.y });
}
}
if (manualStart && DRAG_MANUAL_TOOLS.has(effectiveTool)) {
const pos = stage.getRelativePointerPosition();
if (pos) {
setManualCurrent({ x: pos.x, y: pos.y });
}
}
};
const runInference = useCallback(async (promptPoints?: typeof points, promptBox?: { x1: number, y1: number, x2: number, y2: number }) => {
@@ -132,6 +389,7 @@ export function CanvasArea({ activeTool, frame, onClearMasks }: CanvasAreaProps)
label,
color,
segmentation: m.segmentation,
points: promptPoints?.filter((p) => p.type === 'pos').map((p) => [p.x, p.y]),
bbox: m.bbox,
area: m.area,
});
@@ -170,6 +428,15 @@ export function CanvasArea({ activeTool, frame, onClearMasks }: CanvasAreaProps)
};
const handleStageMouseDown = (e: any) => {
if (DRAG_MANUAL_TOOLS.has(effectiveTool)) {
const pos = stagePoint(e);
if (pos) {
setManualStart(pos);
setManualCurrent(pos);
}
return;
}
if (effectiveTool === 'box_select') {
const stage = e.target.getStage();
const pos = stage.getRelativePointerPosition();
@@ -181,6 +448,27 @@ export function CanvasArea({ activeTool, frame, onClearMasks }: CanvasAreaProps)
};
const handleStageMouseUp = (e: any) => {
if (DRAG_MANUAL_TOOLS.has(effectiveTool) && manualStart) {
const end = stagePoint(e) || manualCurrent || manualStart;
const width = Math.abs(end.x - manualStart.x);
const height = Math.abs(end.y - manualStart.y);
const distance = Math.hypot(width, height);
if (effectiveTool === 'create_rectangle' && width > 4 && height > 4) {
createManualMask('矩形', rectanglePoints(manualStart, end));
}
if (effectiveTool === 'create_circle' && width > 4 && height > 4) {
createManualMask('圆形', circlePoints(manualStart, end));
}
if (effectiveTool === 'create_line' && distance > 4) {
createManualMask('线段', lineRegion(manualStart, end));
}
setManualStart(null);
setManualCurrent(null);
return;
}
if (effectiveTool === 'box_select' && boxStart && boxCurrent) {
const x1 = Math.min(boxStart.x, boxCurrent.x);
const y1 = Math.min(boxStart.y, boxCurrent.y);
@@ -199,12 +487,32 @@ export function CanvasArea({ activeTool, frame, onClearMasks }: CanvasAreaProps)
const handleStageClick = (e: any) => {
if (effectiveTool === 'move') return;
if (effectiveTool === 'box_select') return; // handled by mouseup
if (DRAG_MANUAL_TOOLS.has(effectiveTool)) return;
if (effectiveTool === POINT_TOOL) {
const pos = stagePoint(e);
if (pos) {
createManualMask('点区域', pointRegion(pos));
}
return;
}
if (effectiveTool === POLYGON_TOOL) {
const pos = stagePoint(e);
if (pos) {
setPolygonPoints((current) => [...current, pos]);
}
return;
}
if (effectiveTool === 'point_pos' || effectiveTool === 'point_neg') {
const stage = e.target.getStage();
const pos = stage.getRelativePointerPosition();
if (pos) {
const newPoints = [...points, { x: pos.x, y: pos.y, type: effectiveTool === 'point_pos' ? 'pos' : 'neg' as 'pos'|'neg' }];
const newPoints = [
...points,
{ x: pos.x, y: pos.y, type: (effectiveTool === 'point_pos' ? 'pos' : 'neg') as 'pos' | 'neg' },
];
setPoints(newPoints);
// Auto-trigger inference after point selection
runInference(newPoints);
@@ -212,6 +520,74 @@ export function CanvasArea({ activeTool, frame, onClearMasks }: CanvasAreaProps)
}
};
const updatePolygonMask = useCallback((mask: Mask, nextPoints: CanvasPoint[], polygonIndex = 0) => {
if (nextPoints.length < 3) return;
const nextSegmentation = [...(mask.segmentation || [])];
nextSegmentation[polygonIndex] = nextPoints.flatMap((point) => [point.x, point.y]);
const bbox = segmentationBbox(nextSegmentation) || polygonBbox(nextPoints);
updateMask(mask.id, {
pathData: segmentationPath(nextSegmentation),
segmentation: nextSegmentation,
bbox,
area: segmentationArea(nextSegmentation),
saveStatus: mask.annotationId ? 'dirty' : 'draft',
saved: mask.annotationId ? false : mask.saved,
});
}, [updateMask]);
const updateMaskFromSegmentation = useCallback((mask: Mask, segmentation: number[][]): Mask => {
const bbox = segmentationBbox(segmentation);
return {
...mask,
pathData: segmentationPath(segmentation),
segmentation,
bbox,
area: segmentationArea(segmentation),
saveStatus: mask.annotationId ? 'dirty' : 'draft',
saved: mask.annotationId ? false : mask.saved,
};
}, []);
useEffect(() => {
const handleKeyDown = (event: KeyboardEvent) => {
const key = event.key.toLowerCase();
if ((event.metaKey || event.ctrlKey) && key === 'z') {
event.preventDefault();
if (event.shiftKey) redoMasks();
else undoMasks();
return;
}
if ((event.metaKey || event.ctrlKey) && key === 'y') {
event.preventDefault();
redoMasks();
return;
}
if ((event.key === 'Delete' || event.key === 'Backspace') && selectedMask && selectedVertexIndex !== null) {
const currentPoints = segmentationToPoints(selectedMask.segmentation, selectedPolygonIndex);
if (currentPoints.length > 3) {
event.preventDefault();
const nextPoints = currentPoints.filter((_, index) => index !== selectedVertexIndex);
updatePolygonMask(selectedMask, nextPoints, selectedPolygonIndex);
setSelectedVertexIndex(null);
}
return;
}
if (effectiveTool !== POLYGON_TOOL) return;
if (event.key === 'Enter' && polygonPoints.length >= 3) {
event.preventDefault();
createManualMask('多边形', polygonPoints);
setPolygonPoints([]);
}
if (event.key === 'Escape') {
event.preventDefault();
setPolygonPoints([]);
}
};
window.addEventListener('keydown', handleKeyDown);
return () => window.removeEventListener('keydown', handleKeyDown);
}, [createManualMask, effectiveTool, polygonPoints, redoMasks, selectedMask, selectedPolygonIndex, selectedVertexIndex, undoMasks, updatePolygonMask]);
const boxRect = React.useMemo(() => {
if (!boxStart || !boxCurrent) return null;
const x = Math.min(boxStart.x, boxCurrent.x);
@@ -221,6 +597,132 @@ export function CanvasArea({ activeTool, frame, onClearMasks }: CanvasAreaProps)
return { x, y, width, height };
}, [boxStart, boxCurrent]);
const manualPreviewPath = React.useMemo(() => {
if (manualStart && manualCurrent) {
if (effectiveTool === 'create_rectangle') return polygonPath(rectanglePoints(manualStart, manualCurrent));
if (effectiveTool === 'create_circle') return polygonPath(circlePoints(manualStart, manualCurrent));
if (effectiveTool === 'create_line') return polygonPath(lineRegion(manualStart, manualCurrent));
}
if (effectiveTool === POLYGON_TOOL && polygonPoints.length > 0) {
const previewPoints = [...polygonPoints, cursorPos];
return polygonPath(previewPoints);
}
return null;
}, [cursorPos, effectiveTool, manualCurrent, manualStart, polygonPoints]);
const handleSeedPointDragEnd = (mask: Mask, pointIndex: number, event: any) => {
const x = event.target.x();
const y = event.target.y();
const nextPoints = [...(mask.points || [])];
nextPoints[pointIndex] = [x, y];
updateMask(mask.id, {
points: nextPoints,
saveStatus: mask.annotationId ? 'dirty' : 'draft',
saved: mask.annotationId ? false : mask.saved,
});
};
const handleMaskSelect = (mask: Mask, event: any, polygonIndex = 0) => {
event.cancelBubble = true;
if (BOOLEAN_TOOLS.has(effectiveTool)) {
setSelectedMaskIds((current) => (
current.includes(mask.id)
? current.filter((id) => id !== mask.id)
: [...current, mask.id]
));
setSelectedMaskId(mask.id);
setSelectedPolygonIndex(polygonIndex);
setSelectedVertexIndex(null);
return;
}
setSelectedMaskId(mask.id);
setSelectedMaskIds([mask.id]);
setSelectedPolygonIndex(polygonIndex);
setSelectedVertexIndex(null);
};
const handleVertexDragEnd = (mask: Mask, vertexIndex: number, event: any) => {
const imageWidth = frame?.width || image?.naturalWidth || image?.width || stageSize.width;
const imageHeight = frame?.height || image?.naturalHeight || image?.height || stageSize.height;
const currentPoints = segmentationToPoints(mask.segmentation, selectedPolygonIndex);
if (!currentPoints[vertexIndex]) return;
const nextPoints = currentPoints.map((point, index) => (
index === vertexIndex
? {
x: clamp(event.target.x(), 0, imageWidth),
y: clamp(event.target.y(), 0, imageHeight),
}
: point
));
setSelectedMaskId(mask.id);
setSelectedVertexIndex(vertexIndex);
updatePolygonMask(mask, nextPoints, selectedPolygonIndex);
};
const handleEdgeInsert = (mask: Mask, edgeIndex: number, event: any) => {
event.cancelBubble = true;
const currentPoints = segmentationToPoints(mask.segmentation, selectedPolygonIndex);
const start = currentPoints[edgeIndex];
const end = currentPoints[(edgeIndex + 1) % currentPoints.length];
if (!start || !end) return;
const inserted = { x: (start.x + end.x) / 2, y: (start.y + end.y) / 2 };
const nextPoints = [
...currentPoints.slice(0, edgeIndex + 1),
inserted,
...currentPoints.slice(edgeIndex + 1),
];
setSelectedMaskId(mask.id);
setSelectedVertexIndex(edgeIndex + 1);
updatePolygonMask(mask, nextPoints, selectedPolygonIndex);
};
const handleBooleanOperation = async () => {
if (!frame || booleanSelectedMasks.length < 2) return;
const primary = booleanSelectedMasks[0];
const primaryGeometry = maskToMultiPolygon(primary);
if (!primaryGeometry) return;
const clipGeometries = booleanSelectedMasks
.slice(1)
.map(maskToMultiPolygon)
.filter((geometry): geometry is MultiPolygon => Boolean(geometry));
if (clipGeometries.length === 0) return;
const resultGeometry = effectiveTool === 'area_merge'
? polygonClipping.union(primaryGeometry, ...clipGeometries)
: polygonClipping.difference(primaryGeometry, ...clipGeometries);
const resultSegmentation = multiPolygonToSegmentation(resultGeometry);
if (resultSegmentation.length === 0) {
const deleteIds = primary.annotationId ? [primary.annotationId] : [];
setMasks(masks.filter((mask) => mask.id !== primary.id));
if (deleteIds.length > 0) await onDeleteMaskAnnotations?.(deleteIds);
setSelectedMaskId(null);
setSelectedMaskIds([]);
setSelectedVertexIndex(null);
return;
}
const nextPrimary = updateMaskFromSegmentation(primary, resultSegmentation);
const secondaryIds = effectiveTool === 'area_merge'
? new Set(booleanSelectedMasks.slice(1).map((mask) => mask.id))
: new Set<string>();
const secondaryAnnotationIds = effectiveTool === 'area_merge'
? booleanSelectedMasks
.slice(1)
.map((mask) => mask.annotationId)
.filter((annotationId): annotationId is string => Boolean(annotationId))
: [];
setMasks(masks
.filter((mask) => !secondaryIds.has(mask.id))
.map((mask) => (mask.id === primary.id ? nextPrimary : mask)));
if (secondaryAnnotationIds.length > 0) await onDeleteMaskAnnotations?.(secondaryAnnotationIds);
setSelectedMaskId(primary.id);
setSelectedMaskIds([primary.id]);
setSelectedVertexIndex(null);
};
return (
<div ref={containerRef} className="w-full h-full relative cursor-crosshair overflow-hidden rounded-sm">
{isInferencing && (
@@ -257,13 +759,18 @@ export function CanvasArea({ activeTool, frame, onClearMasks }: CanvasAreaProps)
{/* AI Returned Masks */}
{frameMasks.map((mask) => (
<Group key={mask.id} opacity={0.5}>
<Path
data={mask.pathData}
fill={mask.color}
stroke={mask.color}
strokeWidth={1 / scale}
/>
<Group key={mask.id} opacity={selectedMaskIds.includes(mask.id) ? 0.65 : 0.5}>
{(mask.segmentation && mask.segmentation.length > 0 ? mask.segmentation : [undefined]).map((_, polygonIndex) => (
<Path
key={`${mask.id}-polygon-${polygonIndex}`}
data={mask.segmentation ? segmentationPolygonPath(mask.segmentation, polygonIndex) : mask.pathData}
fill={mask.color}
stroke={mask.color}
strokeWidth={(selectedMaskIds.includes(mask.id) ? 2 : 1) / scale}
onClick={(event: any) => handleMaskSelect(mask, event, polygonIndex)}
onTap={(event: any) => handleMaskSelect(mask, event, polygonIndex)}
/>
))}
</Group>
))}
@@ -281,6 +788,86 @@ export function CanvasArea({ activeTool, frame, onClearMasks }: CanvasAreaProps)
/>
)}
{/* Manual shape preview */}
{manualPreviewPath && (
<Path
data={manualPreviewPath}
fill="rgba(34, 211, 238, 0.12)"
stroke="#22d3ee"
strokeWidth={2 / scale}
dash={[5 / scale, 5 / scale]}
/>
)}
{polygonPoints.map((point, index) => (
<Circle
key={`poly-point-${index}`}
x={point.x}
y={point.y}
radius={4 / scale}
fill="#22d3ee"
stroke="#ffffff"
strokeWidth={1 / scale}
/>
))}
{/* Imported GT seed points / editable point regions */}
{frameMasks.flatMap((mask) => (mask.points || []).map(([x, y], index) => (
<Group key={`${mask.id}-seed-${index}`} x={x} y={y}>
<Circle
radius={5 / scale}
fill="#facc15"
stroke="#111827"
strokeWidth={2 / scale}
draggable
onDragEnd={(event: any) => handleSeedPointDragEnd(mask, index, event)}
/>
<Circle radius={1.5 / scale} fill="#111827" />
</Group>
)))}
{/* Polygon edge insertion handles */}
{selectedMask && selectedMaskPoints.map((point, index) => {
const next = selectedMaskPoints[(index + 1) % selectedMaskPoints.length];
if (!next) return null;
return (
<Circle
key={`${selectedMask.id}-edge-${selectedPolygonIndex}-${index}`}
x={(point.x + next.x) / 2}
y={(point.y + next.y) / 2}
radius={3.5 / scale}
fill="#22d3ee"
stroke="#111827"
strokeWidth={1.5 / scale}
onClick={(event: any) => handleEdgeInsert(selectedMask, index, event)}
onTap={(event: any) => handleEdgeInsert(selectedMask, index, event)}
/>
);
})}
{/* Polygon vertex editor */}
{selectedMask && selectedMaskPoints.map((point, index) => (
<Circle
key={`${selectedMask.id}-vertex-${selectedPolygonIndex}-${index}`}
x={point.x}
y={point.y}
radius={(selectedVertexIndex === index ? 6 : 4.5) / scale}
fill={selectedVertexIndex === index ? '#22d3ee' : '#ffffff'}
stroke={selectedMask.color}
strokeWidth={2 / scale}
draggable
onClick={(event: any) => {
event.cancelBubble = true;
setSelectedVertexIndex(index);
}}
onTap={(event: any) => {
event.cancelBubble = true;
setSelectedVertexIndex(index);
}}
onDragEnd={(event: any) => handleVertexDragEnd(selectedMask, index, event)}
/>
))}
{/* AI Prompts Point Regions */}
{points.map((p, i) => (
<Group key={i} x={p.x} y={p.y}>
@@ -313,6 +900,14 @@ export function CanvasArea({ activeTool, frame, onClearMasks }: CanvasAreaProps)
{frameMasks.length > 0 && (
<div className="absolute bottom-4 right-4 flex gap-2">
{BOOLEAN_TOOLS.has(effectiveTool) && booleanSelectedMasks.length >= 2 && (
<button
onClick={handleBooleanOperation}
className="text-xs bg-emerald-500/10 hover:bg-emerald-500/20 text-emerald-300 border border-emerald-500/20 px-3 py-1.5 rounded transition-colors"
>
{effectiveTool === 'area_merge' ? '合并选中' : '从主区域去除'}
</button>
)}
{activeClass && (
<button
onClick={handleApplyActiveClass}

View File

@@ -1,9 +1,12 @@
import { act, render, screen, waitFor } from '@testing-library/react';
import { act, fireEvent, render, screen, waitFor } from '@testing-library/react';
import { beforeEach, describe, expect, it, vi } from 'vitest';
import { Dashboard } from './Dashboard';
const apiMock = vi.hoisted(() => ({
getDashboardOverview: vi.fn(),
cancelTask: vi.fn(),
retryTask: vi.fn(),
getTask: vi.fn(),
}));
const wsMock = vi.hoisted(() => {
@@ -31,6 +34,9 @@ vi.mock('../lib/websocket', () => ({
vi.mock('../lib/api', () => ({
getDashboardOverview: apiMock.getDashboardOverview,
cancelTask: apiMock.cancelTask,
retryTask: apiMock.retryTask,
getTask: apiMock.getTask,
}));
describe('Dashboard', () => {
@@ -55,6 +61,8 @@ describe('Dashboard', () => {
name: '真实项目.mp4',
progress: 60,
status: 'pending',
raw_status: 'running',
error: null,
frame_count: 10,
updated_at: '2026-05-01T00:00:00Z',
},
@@ -112,4 +120,100 @@ describe('Dashboard', () => {
await waitFor(() => expect(screen.getByText('Progress stream active')).toBeInTheDocument());
expect(screen.getByText('解析完成: done.mp4')).toBeInTheDocument();
});
it('cancels, retries, and opens task failure details', async () => {
apiMock.getDashboardOverview.mockResolvedValueOnce({
summary: {
project_count: 1,
parsing_task_count: 1,
annotation_count: 0,
frame_count: 0,
template_count: 0,
system_load_percent: 5,
},
tasks: [
{
id: 'task-10',
task_id: 10,
project_id: 1,
name: 'running.mp4',
progress: 30,
status: '正在下载媒体文件',
raw_status: 'running',
error: null,
frame_count: 0,
updated_at: '2026-05-01T00:00:00Z',
},
{
id: 'task-11',
task_id: 11,
project_id: 1,
name: 'failed.mp4',
progress: 100,
status: '解析失败',
raw_status: 'failed',
error: 'ffmpeg failed',
frame_count: 0,
updated_at: '2026-05-01T00:01:00Z',
},
],
activity: [],
});
apiMock.cancelTask.mockResolvedValueOnce({
id: 10,
task_type: 'parse_video',
status: 'cancelled',
progress: 100,
message: '任务已取消',
project_id: 1,
error: 'Cancelled by user',
result: null,
payload: { source_type: 'video' },
created_at: 'created',
updated_at: 'updated',
});
apiMock.retryTask.mockResolvedValueOnce({
id: 12,
task_type: 'parse_video',
status: 'queued',
progress: 0,
message: '重试任务已入队(源任务 #11',
project_id: 1,
error: null,
result: null,
payload: { source_type: 'video', retry_of: 11 },
created_at: 'created',
updated_at: 'updated',
});
apiMock.getTask.mockResolvedValueOnce({
id: 11,
task_type: 'parse_video',
status: 'failed',
progress: 100,
message: '解析失败',
project_id: 1,
celery_task_id: 'celery-11',
payload: { source_type: 'video' },
result: null,
error: 'ffmpeg failed',
created_at: 'created',
started_at: 'started',
finished_at: 'finished',
updated_at: 'updated',
});
render(<Dashboard />);
await screen.findByText('running.mp4');
fireEvent.click(screen.getByRole('button', { name: '取消' }));
await waitFor(() => expect(apiMock.cancelTask).toHaveBeenCalledWith(10));
fireEvent.click(screen.getAllByRole('button', { name: '详情' })[1]);
await waitFor(() => expect(apiMock.getTask).toHaveBeenCalledWith(11));
expect(await screen.findByText('任务详情 #11')).toBeInTheDocument();
expect(screen.getByText('ffmpeg failed')).toBeInTheDocument();
fireEvent.click(screen.getAllByRole('button', { name: '重试' })[1]);
await waitFor(() => expect(apiMock.retryTask).toHaveBeenCalledWith(11));
});
});

View File

@@ -1,8 +1,17 @@
import React, { useState, useEffect } from 'react';
import { Activity, Clock, Folders, CheckCircle2, Loader2 } from 'lucide-react';
import { Activity, AlertTriangle, Clock, Folders, CheckCircle2, Info, Loader2, RotateCcw, XCircle } from 'lucide-react';
import { progressWS, type ProgressMessage } from '../lib/websocket';
import { cn } from '../lib/utils';
import { getDashboardOverview, type DashboardActivity, type DashboardOverview, type DashboardTask } from '../lib/api';
import {
cancelTask,
getDashboardOverview,
getTask,
retryTask,
type DashboardActivity,
type DashboardOverview,
type DashboardTask,
type ProcessingTask,
} from '../lib/api';
const emptySummary: DashboardOverview['summary'] = {
project_count: 0,
@@ -20,6 +29,29 @@ export function Dashboard() {
const [activityLog, setActivityLog] = useState<DashboardActivity[]>([]);
const [isLoading, setIsLoading] = useState(true);
const [loadError, setLoadError] = useState('');
const [selectedTask, setSelectedTask] = useState<ProcessingTask | null>(null);
const [taskActionMessage, setTaskActionMessage] = useState('');
const [busyTaskId, setBusyTaskId] = useState<string | null>(null);
const taskFromProcessingTask = (task: ProcessingTask, name = `任务 ${task.id}`): DashboardTask => ({
id: `task-${task.id}`,
task_id: task.id,
project_id: task.project_id ?? 0,
name,
progress: task.progress,
status: task.message || task.status,
raw_status: task.status,
error: task.error,
frame_count: Number(task.result?.frames_extracted || 0),
updated_at: task.updated_at,
});
const prependActivity = (message: string, project = '系统') => {
setActivityLog((prev) => [
{ id: `task-action-${Date.now()}`, kind: 'task', time: new Date().toISOString(), message, project },
...prev.slice(0, 9),
]);
};
useEffect(() => {
let cancelled = false;
@@ -90,6 +122,8 @@ export function Dashboard() {
name: taskTitle(data),
progress: data.progress ?? 0,
status: data.status ?? '处理中',
raw_status: 'running',
error: data.error,
frame_count: 0,
updated_at: new Date().toISOString(),
},
@@ -100,7 +134,7 @@ export function Dashboard() {
if (data.type === 'complete' && data.taskId) {
setTasks((prev) =>
prev.map((t) =>
t.id === data.taskId ? { ...t, progress: 100, status: '已完成' } : t
t.id === data.taskId ? { ...t, progress: 100, status: '已完成', raw_status: 'success' } : t
)
);
setActivityLog((prev) => [
@@ -109,10 +143,26 @@ export function Dashboard() {
]);
}
if (data.type === 'cancelled' && data.taskId) {
setTasks((prev) =>
prev.map((t) =>
t.id === data.taskId
? { ...t, progress: 100, status: data.message || '任务已取消', raw_status: 'cancelled', error: data.error }
: t
)
);
setActivityLog((prev) => [
{ id: `ws-cancelled-${Date.now()}`, kind: 'websocket', time: new Date().toISOString(), message: data.message || `任务已取消: ${taskTitle(data)}`, project: data.projectName || '系统' },
...prev.slice(0, 9),
]);
}
if (data.type === 'error' && data.taskId) {
setTasks((prev) =>
prev.map((t) =>
t.id === data.taskId ? { ...t, progress: data.progress ?? t.progress, status: `错误: ${data.error || data.message || '未知错误'}` } : t
t.id === data.taskId
? { ...t, progress: data.progress ?? t.progress, status: `错误: ${data.error || data.message || '未知错误'}`, raw_status: 'failed', error: data.error }
: t
)
);
setActivityLog((prev) => [
@@ -160,6 +210,65 @@ export function Dashboard() {
});
}
const taskRawStatus = (task: DashboardTask): string => task.raw_status || (
task.status.includes('取消') ? 'cancelled'
: task.status.includes('失败') || task.status.includes('错误') ? 'failed'
: task.progress >= 100 ? 'success'
: 'running'
);
const canCancel = (task: DashboardTask): boolean => ['queued', 'running'].includes(taskRawStatus(task)) && Boolean(task.task_id);
const canRetry = (task: DashboardTask): boolean => ['failed', 'cancelled'].includes(taskRawStatus(task)) && Boolean(task.task_id);
const handleCancelTask = async (task: DashboardTask) => {
if (!task.task_id) return;
setBusyTaskId(task.id);
setTaskActionMessage('');
try {
const updated = await cancelTask(task.task_id);
setTasks((prev) => prev.map((item) => (
item.id === task.id ? taskFromProcessingTask(updated, task.name) : item
)));
prependActivity(`任务已取消 #${updated.id}`, task.name);
} catch (err) {
console.error('Cancel task failed:', err);
setTaskActionMessage('任务取消失败,请检查后端服务');
} finally {
setBusyTaskId(null);
}
};
const handleRetryTask = async (task: DashboardTask) => {
if (!task.task_id) return;
setBusyTaskId(task.id);
setTaskActionMessage('');
try {
const retried = await retryTask(task.task_id);
const dashboardTask = taskFromProcessingTask(retried, task.name);
setTasks((prev) => [dashboardTask, ...prev.filter((item) => item.id !== dashboardTask.id)]);
prependActivity(`重试任务已入队 #${retried.id}`, task.name);
} catch (err) {
console.error('Retry task failed:', err);
setTaskActionMessage('任务重试失败,请检查后端服务');
} finally {
setBusyTaskId(null);
}
};
const handleOpenTaskDetail = async (task: DashboardTask) => {
if (!task.task_id) return;
setBusyTaskId(task.id);
setTaskActionMessage('');
try {
setSelectedTask(await getTask(task.task_id));
} catch (err) {
console.error('Load task detail failed:', err);
setTaskActionMessage('失败详情加载失败');
} finally {
setBusyTaskId(null);
}
};
return (
<div className="p-8 w-full h-full overflow-y-auto bg-[#0a0a0a]">
<header className="mb-8">
@@ -177,6 +286,7 @@ export function Dashboard() {
</div>
<p className="text-gray-400 text-sm mt-1"></p>
{loadError && <p className="text-red-400 text-xs mt-2">{loadError}</p>}
{taskActionMessage && <p className="text-amber-400 text-xs mt-2">{taskActionMessage}</p>}
</header>
<div className="grid grid-cols-1 md:grid-cols-2 lg:grid-cols-4 gap-4 mb-8">
@@ -213,16 +323,47 @@ export function Dashboard() {
<div className="h-full bg-gradient-to-r from-cyan-600 to-cyan-400 rounded-full transition-all duration-500" style={{ width: `${task.progress}%` }} />
</div>
<div className="text-xs text-gray-500 flex items-center gap-2">
{task.status === '已完成' || task.progress >= 100 ? (
{taskRawStatus(task) === 'success' || task.status === '已完成' ? (
<CheckCircle2 size={12} className="text-emerald-400" />
) : task.status.includes('错误') ? (
<span className="text-red-400"></span>
) : taskRawStatus(task) === 'failed' ? (
<AlertTriangle size={12} className="text-red-400" />
) : taskRawStatus(task) === 'cancelled' ? (
<XCircle size={12} className="text-amber-400" />
) : (
<Loader2 size={12} className="text-cyan-400 animate-spin" />
)}
{task.status}
<span className="text-gray-600">: {task.frame_count}</span>
</div>
<div className="mt-3 flex flex-wrap items-center gap-2">
{canCancel(task) && (
<button
onClick={() => handleCancelTask(task)}
disabled={busyTaskId === task.id}
className="inline-flex items-center gap-1 rounded border border-red-500/20 bg-red-500/10 px-2 py-1 text-[11px] text-red-300 hover:bg-red-500/20 disabled:opacity-40 disabled:cursor-not-allowed"
>
<XCircle size={12} />
</button>
)}
{canRetry(task) && (
<button
onClick={() => handleRetryTask(task)}
disabled={busyTaskId === task.id}
className="inline-flex items-center gap-1 rounded border border-cyan-500/20 bg-cyan-500/10 px-2 py-1 text-[11px] text-cyan-300 hover:bg-cyan-500/20 disabled:opacity-40 disabled:cursor-not-allowed"
>
<RotateCcw size={12} />
</button>
)}
{task.task_id && (
<button
onClick={() => handleOpenTaskDetail(task)}
disabled={busyTaskId === task.id}
className="inline-flex items-center gap-1 rounded border border-white/10 bg-white/5 px-2 py-1 text-[11px] text-gray-300 hover:bg-white/10 disabled:opacity-40 disabled:cursor-not-allowed"
>
<Info size={12} />
</button>
)}
</div>
</div>
))}
{!isLoading && tasks.length === 0 && (
@@ -253,6 +394,46 @@ export function Dashboard() {
</div>
</div>
</div>
{selectedTask && (
<div className="fixed inset-0 z-50 flex items-center justify-center bg-black/70 px-4">
<div className="w-full max-w-2xl rounded-lg border border-white/10 bg-[#111] p-5 shadow-2xl">
<div className="flex items-center justify-between gap-3 border-b border-white/10 pb-3">
<div>
<h3 className="text-sm font-semibold text-white"> #{selectedTask.id}</h3>
<p className="mt-1 text-xs text-gray-500">{selectedTask.message || selectedTask.status}</p>
</div>
<button
onClick={() => setSelectedTask(null)}
className="rounded border border-white/10 bg-white/5 px-2 py-1 text-xs text-gray-300 hover:bg-white/10"
>
</button>
</div>
<div className="mt-4 grid grid-cols-2 gap-3 text-xs text-gray-400">
<div>: <span className="text-gray-200">{selectedTask.status}</span></div>
<div>: <span className="text-gray-200">{selectedTask.progress}%</span></div>
<div> ID: <span className="text-gray-200">{selectedTask.project_id ?? '-'}</span></div>
<div>Celery ID: <span className="text-gray-200">{selectedTask.celery_task_id || '-'}</span></div>
<div>: <span className="text-gray-200">{selectedTask.created_at}</span></div>
<div>: <span className="text-gray-200">{selectedTask.finished_at || '-'}</span></div>
</div>
{selectedTask.error && (
<div className="mt-4 rounded border border-red-500/20 bg-red-500/10 p-3 text-xs text-red-200">
{selectedTask.error}
</div>
)}
<div className="mt-4 grid gap-3 md:grid-cols-2">
<pre className="max-h-48 overflow-auto rounded border border-white/10 bg-[#0a0a0a] p-3 text-[11px] text-gray-300">
{JSON.stringify(selectedTask.payload || {}, null, 2)}
</pre>
<pre className="max-h-48 overflow-auto rounded border border-white/10 bg-[#0a0a0a] p-3 text-[11px] text-gray-300">
{JSON.stringify(selectedTask.result || {}, null, 2)}
</pre>
</div>
</div>
</div>
)}
</div>
);
}

View File

@@ -3,18 +3,31 @@ import { describe, expect, it, vi } from 'vitest';
import { ToolsPalette } from './ToolsPalette';
describe('ToolsPalette', () => {
it('switches tools and exposes UI-only placeholder buttons', () => {
it('switches tools and dispatches undo/redo actions when available', () => {
const setActiveTool = vi.fn();
const onUndo = vi.fn();
const onRedo = vi.fn();
render(<ToolsPalette activeTool="move" setActiveTool={setActiveTool} />);
render(
<ToolsPalette
activeTool="move"
setActiveTool={setActiveTool}
onUndo={onUndo}
onRedo={onRedo}
canUndo
canRedo
/>,
);
fireEvent.click(screen.getByTitle('创建多边形 (P)'));
fireEvent.click(screen.getByTitle('正向选点 (SAM)'));
fireEvent.click(screen.getByTitle('撤销操作 (Ctrl+Z)'));
fireEvent.click(screen.getByTitle('重做操作 (Ctrl+Shift+Z)'));
expect(setActiveTool).toHaveBeenNthCalledWith(1, 'create_polygon');
expect(setActiveTool).toHaveBeenNthCalledWith(2, 'point_pos');
expect(screen.getByTitle('撤销操作 (Ctrl+Z)')).toBeInTheDocument();
expect(screen.getByTitle('重做操作 (Ctrl+Shift+Z)')).toBeInTheDocument();
expect(onUndo).toHaveBeenCalled();
expect(onRedo).toHaveBeenCalled();
});
it('switches to SAM trigger and calls the AI navigation hook', () => {

View File

@@ -6,9 +6,21 @@ interface ToolsPaletteProps {
activeTool: string;
setActiveTool: (tool: string) => void;
onTriggerAI?: () => void;
onUndo?: () => void;
onRedo?: () => void;
canUndo?: boolean;
canRedo?: boolean;
}
export function ToolsPalette({ activeTool, setActiveTool, onTriggerAI }: ToolsPaletteProps) {
export function ToolsPalette({
activeTool,
setActiveTool,
onTriggerAI,
onUndo,
onRedo,
canUndo = false,
canRedo = false,
}: ToolsPaletteProps) {
const tools = [
{ id: 'move', icon: MousePointer2, label: '拖拽 / 选择 (V)' },
{ id: 'create_polygon', icon: Hexagon, label: '创建多边形 (P)' },
@@ -91,10 +103,20 @@ export function ToolsPalette({ activeTool, setActiveTool, onTriggerAI }: ToolsPa
<div className="w-full h-px bg-white/10 my-1" />
<button className="w-10 h-10 rounded text-gray-500 hover:bg-white/5 hover:text-white flex items-center justify-center transition-colors" title="撤销操作 (Ctrl+Z)">
<button
onClick={onUndo}
disabled={!canUndo}
className="w-10 h-10 rounded text-gray-500 hover:bg-white/5 hover:text-white flex items-center justify-center transition-colors disabled:opacity-30 disabled:hover:bg-transparent disabled:hover:text-gray-500 disabled:cursor-not-allowed"
title="撤销操作 (Ctrl+Z)"
>
<Undo size={18} />
</button>
<button className="w-10 h-10 rounded text-gray-500 hover:bg-white/5 hover:text-white flex items-center justify-center transition-colors" title="重做操作 (Ctrl+Shift+Z)">
<button
onClick={onRedo}
disabled={!canRedo}
className="w-10 h-10 rounded text-gray-500 hover:bg-white/5 hover:text-white flex items-center justify-center transition-colors disabled:opacity-30 disabled:hover:bg-transparent disabled:hover:text-gray-500 disabled:cursor-not-allowed"
title="重做操作 (Ctrl+Shift+Z)"
>
<Redo size={18} />
</button>

View File

@@ -14,6 +14,8 @@ const apiMock = vi.hoisted(() => ({
updateAnnotation: vi.fn(),
deleteAnnotation: vi.fn(),
exportCoco: vi.fn(),
exportMasks: vi.fn(),
importGtMask: vi.fn(),
annotationToMask: vi.fn(),
buildAnnotationPayload: vi.fn(),
getAiModelStatus: vi.fn(),
@@ -29,6 +31,8 @@ vi.mock('../lib/api', () => ({
updateAnnotation: apiMock.updateAnnotation,
deleteAnnotation: apiMock.deleteAnnotation,
exportCoco: apiMock.exportCoco,
exportMasks: apiMock.exportMasks,
importGtMask: apiMock.importGtMask,
annotationToMask: apiMock.annotationToMask,
buildAnnotationPayload: apiMock.buildAnnotationPayload,
getAiModelStatus: apiMock.getAiModelStatus,
@@ -256,4 +260,64 @@ describe('VideoWorkspace', () => {
await waitFor(() => expect(apiMock.saveAnnotation).toHaveBeenCalled());
expect(apiMock.exportCoco).toHaveBeenCalledWith('1');
});
it('auto-saves pending masks before exporting PNG masks', async () => {
apiMock.getProjectFrames.mockResolvedValueOnce([
{ id: 10, project_id: 1, frame_index: 0, image_url: '/frame.jpg', width: 640, height: 360 },
]);
apiMock.buildAnnotationPayload.mockReturnValueOnce({ project_id: 1, frame_id: 10, mask_data: { polygons: [] } });
apiMock.saveAnnotation.mockResolvedValueOnce({ id: 5 });
apiMock.exportMasks.mockResolvedValueOnce(new Blob(['zip'], { type: 'application/zip' }));
render(<VideoWorkspace />);
await waitFor(() => expect(useStore.getState().frames).toHaveLength(1));
act(() => {
useStore.setState({
masks: [{
id: 'mask-1',
frameId: '10',
pathData: 'M 0 0 Z',
label: 'AI Mask',
color: '#06b6d4',
segmentation: [[0, 0, 10, 0, 10, 10]],
}],
});
});
fireEvent.click(screen.getByRole('button', { name: '导出 PNG Mask ZIP' }));
await waitFor(() => expect(apiMock.saveAnnotation).toHaveBeenCalled());
expect(apiMock.exportMasks).toHaveBeenCalledWith('1');
});
it('imports a GT mask for the current frame and hydrates saved annotations', async () => {
apiMock.getProjectFrames.mockResolvedValueOnce([
{ id: 10, project_id: 1, frame_index: 0, image_url: '/frame.jpg', width: 640, height: 360 },
]);
apiMock.importGtMask.mockResolvedValueOnce([{ id: 88, frame_id: 10 }]);
apiMock.getProjectAnnotations
.mockResolvedValueOnce([])
.mockResolvedValueOnce([{ id: 88, frame_id: 10 }]);
apiMock.annotationToMask.mockReturnValueOnce({
id: 'annotation-88',
annotationId: '88',
frameId: '10',
saved: true,
pathData: 'M 0 0 Z',
label: 'GT Mask',
color: '#22c55e',
});
render(<VideoWorkspace />);
await waitFor(() => expect(useStore.getState().frames).toHaveLength(1));
const fileInput = document.querySelector('input[type="file"]') as HTMLInputElement;
const file = new File(['mask'], 'mask.png', { type: 'image/png' });
fireEvent.change(fileInput, { target: { files: [file] } });
await waitFor(() => expect(apiMock.importGtMask).toHaveBeenCalledWith(file, '1', '10'));
await waitFor(() => expect(useStore.getState().masks).toEqual([
expect.objectContaining({ id: 'annotation-88', label: 'GT Mask' }),
]));
});
});

View File

@@ -5,10 +5,12 @@ import {
buildAnnotationPayload,
deleteAnnotation,
exportCoco,
exportMasks,
getProjectAnnotations,
getProjectFrames,
getTask,
getTemplates,
importGtMask,
parseMedia,
saveAnnotation,
updateAnnotation,
@@ -25,18 +27,24 @@ function sleep(ms: number) {
}
export function VideoWorkspace({ onNavigateToAI }: { onNavigateToAI?: () => void }) {
const gtMaskInputRef = React.useRef<HTMLInputElement>(null);
const activeTool = useStore((state) => state.activeTool);
const setActiveTool = useStore((state) => state.setActiveTool);
const currentProject = useStore((state) => state.currentProject);
const frames = useStore((state) => state.frames);
const currentFrameIndex = useStore((state) => state.currentFrameIndex);
const masks = useStore((state) => state.masks);
const maskHistory = useStore((state) => state.maskHistory);
const maskFuture = useStore((state) => state.maskFuture);
const activeTemplateId = useStore((state) => state.activeTemplateId);
const setFrames = useStore((state) => state.setFrames);
const setCurrentFrame = useStore((state) => state.setCurrentFrame);
const setMasks = useStore((state) => state.setMasks);
const undoMasks = useStore((state) => state.undoMasks);
const redoMasks = useStore((state) => state.redoMasks);
const [isSaving, setIsSaving] = useState(false);
const [isExporting, setIsExporting] = useState(false);
const [isImportingGt, setIsImportingGt] = useState(false);
const [statusMessage, setStatusMessage] = useState('');
const hydrateSavedAnnotations = useCallback(async (projectId: string, projectFrames: Frame[]) => {
@@ -216,6 +224,18 @@ export function VideoWorkspace({ onNavigateToAI }: { onNavigateToAI?: () => void
}
}, [currentFrame, masks, setMasks]);
const handleDeleteMaskAnnotations = useCallback(async (annotationIds: string[]) => {
if (annotationIds.length === 0) return;
try {
await Promise.all(annotationIds.map((annotationId) => deleteAnnotation(annotationId)));
setStatusMessage(`已删除 ${annotationIds.length} 个被合并标注`);
} catch (err) {
console.error('Delete merged annotations failed:', err);
setStatusMessage('合并后删除原标注失败,请检查后端服务');
throw err;
}
}, []);
const handleSave = async () => {
try {
await savePendingAnnotations();
@@ -248,6 +268,52 @@ export function VideoWorkspace({ onNavigateToAI }: { onNavigateToAI?: () => void
}
};
const downloadBlob = (blob: Blob, filename: string) => {
const url = URL.createObjectURL(blob);
const link = document.createElement('a');
link.href = url;
link.download = filename;
document.body.appendChild(link);
link.click();
link.remove();
URL.revokeObjectURL(url);
};
const handleExportMasks = async () => {
if (!currentProject?.id) return;
setIsExporting(true);
setStatusMessage('正在准备导出语义 Mask ZIP...');
try {
await savePendingAnnotations({ silent: true });
const blob = await exportMasks(currentProject.id);
downloadBlob(blob, `project_${currentProject.id}_masks.zip`);
setStatusMessage('PNG Mask ZIP 已导出');
} catch (err) {
console.error('Mask export failed:', err);
setStatusMessage('Mask 导出失败,请检查后端服务');
} finally {
setIsExporting(false);
}
};
const handleImportGtMask = async (event: React.ChangeEvent<HTMLInputElement>) => {
const file = event.target.files?.[0];
if (!file || !currentProject?.id || !currentFrame?.id) return;
setIsImportingGt(true);
setStatusMessage('正在导入 GT Mask...');
try {
const imported = await importGtMask(file, currentProject.id, currentFrame.id);
await hydrateSavedAnnotations(currentProject.id, frames);
setStatusMessage(`已导入 ${imported.length} 个 GT 区域`);
} catch (err) {
console.error('GT mask import failed:', err);
setStatusMessage('GT Mask 导入失败,请检查文件或后端服务');
} finally {
setIsImportingGt(false);
event.target.value = '';
}
};
return (
<div className="w-full h-full flex flex-col bg-[#0a0a0a]">
{/* Top Header / Status bar */}
@@ -264,6 +330,27 @@ export function VideoWorkspace({ onNavigateToAI }: { onNavigateToAI?: () => void
</span>
)}
<ModelStatusBadge />
<input
ref={gtMaskInputRef}
type="file"
accept="image/png,image/jpeg,image/bmp,image/tiff"
className="hidden"
onChange={handleImportGtMask}
/>
<button
onClick={() => gtMaskInputRef.current?.click()}
disabled={!currentProject?.id || !currentFrame?.id || isImportingGt || isSaving || isExporting}
className="px-4 py-1.5 bg-white/5 hover:bg-white/10 border border-white/10 rounded-md text-xs transition-colors text-white disabled:opacity-40 disabled:cursor-not-allowed"
>
{isImportingGt ? '导入中...' : '导入 GT Mask'}
</button>
<button
onClick={handleExportMasks}
disabled={!currentProject?.id || isExporting || isSaving}
className="px-4 py-1.5 bg-white/5 hover:bg-white/10 border border-white/10 rounded-md text-xs transition-colors text-white disabled:opacity-40 disabled:cursor-not-allowed"
>
{isExporting ? '导出中...' : '导出 PNG Mask ZIP'}
</button>
<button
onClick={handleExport}
disabled={!currentProject?.id || isExporting || isSaving}
@@ -283,11 +370,24 @@ export function VideoWorkspace({ onNavigateToAI }: { onNavigateToAI?: () => void
{/* Main Workspace Area */}
<div className="flex-1 flex overflow-hidden">
<ToolsPalette activeTool={activeTool} setActiveTool={setActiveTool} onTriggerAI={onNavigateToAI} />
<ToolsPalette
activeTool={activeTool}
setActiveTool={setActiveTool}
onTriggerAI={onNavigateToAI}
onUndo={undoMasks}
onRedo={redoMasks}
canUndo={maskHistory.length > 0}
canRedo={maskFuture.length > 0}
/>
<div className="flex-1 relative flex items-center justify-center p-8 bg-[#151515] overflow-hidden">
<div className="relative w-full h-full bg-[#1e1e1e] border border-white/5 shadow-2xl rounded-sm">
<CanvasArea activeTool={activeTool} frame={currentFrame} onClearMasks={handleClearCurrentFrameMasks} />
<CanvasArea
activeTool={activeTool}
frame={currentFrame}
onClearMasks={handleClearCurrentFrameMasks}
onDeleteMaskAnnotations={handleDeleteMaskAnnotations}
/>
</div>
</div>

View File

@@ -101,6 +101,17 @@ describe('api client contracts', () => {
});
});
it('exports PNG masks from the backend route shape', async () => {
const { exportMasks } = await import('./api');
const blob = new Blob(['zip'], { type: 'application/zip' });
axiosMock.client.get.mockResolvedValueOnce({ data: blob });
await expect(exportMasks('9')).resolves.toBe(blob);
expect(axiosMock.client.get).toHaveBeenCalledWith('/api/export/9/masks', {
responseType: 'blob',
});
});
it('loads dashboard overview from the backend summary endpoint', async () => {
const { getDashboardOverview } = await import('./api');
const overview = {
@@ -125,8 +136,8 @@ describe('api client contracts', () => {
expect(axiosMock.client.get).toHaveBeenCalledWith('/api/dashboard/overview');
});
it('queues media parsing and reads processing task status', async () => {
const { getTask, parseMedia } = await import('./api');
it('queues media parsing and manages processing task lifecycle', async () => {
const { cancelTask, getTask, parseMedia, retryTask } = await import('./api');
const task = {
id: 12,
task_type: 'parse_video',
@@ -145,6 +156,8 @@ describe('api client contracts', () => {
};
axiosMock.client.post.mockResolvedValueOnce({ data: task });
axiosMock.client.get.mockResolvedValueOnce({ data: { ...task, status: 'success', progress: 100 } });
axiosMock.client.post.mockResolvedValueOnce({ data: { ...task, status: 'cancelled', progress: 100 } });
axiosMock.client.post.mockResolvedValueOnce({ data: { ...task, id: 13, status: 'queued', progress: 0 } });
await expect(parseMedia('9')).resolves.toEqual(task);
expect(axiosMock.client.post).toHaveBeenCalledWith('/api/media/parse', null, {
@@ -153,6 +166,12 @@ describe('api client contracts', () => {
await expect(getTask(12)).resolves.toEqual(expect.objectContaining({ status: 'success', progress: 100 }));
expect(axiosMock.client.get).toHaveBeenCalledWith('/api/tasks/12');
await expect(cancelTask(12)).resolves.toEqual(expect.objectContaining({ status: 'cancelled' }));
expect(axiosMock.client.post).toHaveBeenCalledWith('/api/tasks/12/cancel');
await expect(retryTask(12)).resolves.toEqual(expect.objectContaining({ id: 13, status: 'queued' }));
expect(axiosMock.client.post).toHaveBeenCalledWith('/api/tasks/12/retry');
});
it('lists, saves, updates, and deletes annotations with the backend annotation contract', async () => {
@@ -204,6 +223,25 @@ describe('api client contracts', () => {
expect(axiosMock.client.delete).toHaveBeenCalledWith('/api/ai/annotations/1');
});
it('imports GT masks through multipart form data', async () => {
const { importGtMask } = await import('./api');
const file = new File(['mask'], 'mask.png', { type: 'image/png' });
const saved = [{ id: 1, project_id: 9, frame_id: 5, template_id: null, mask_data: null, points: null, bbox: null }];
axiosMock.client.post.mockResolvedValueOnce({ data: saved });
await expect(importGtMask(file, '9', '5', '2')).resolves.toEqual(saved);
expect(axiosMock.client.post).toHaveBeenCalledWith(
'/api/ai/import-gt-mask',
expect.any(FormData),
{ headers: { 'Content-Type': 'multipart/form-data' } },
);
const form = axiosMock.client.post.mock.calls.at(-1)?.[1] as FormData;
expect(form.get('file')).toBe(file);
expect(form.get('project_id')).toBe('9');
expect(form.get('frame_id')).toBe('5');
expect(form.get('template_id')).toBe('2');
});
it('builds annotation payloads from frontend masks and restores saved annotations to masks', async () => {
const { annotationToMask, buildAnnotationPayload } = await import('./api');
const frame = { id: '5', projectId: '9', index: 0, url: '/frame.jpg', width: 100, height: 50 };
@@ -244,7 +282,7 @@ describe('api client contracts', () => {
color: '#06b6d4',
class: { id: 'c1', name: '胆囊', color: '#ff0000', zIndex: 20 },
},
points: null,
points: [[0.5, 0.5]],
bbox: null,
created_at: 'created',
updated_at: 'updated',
@@ -261,10 +299,28 @@ describe('api client contracts', () => {
saveStatus: 'saved',
saved: true,
pathData: 'M 10 10 L 90 10 L 90 40 Z',
points: [[50, 25]],
bbox: [10, 10, 80, 30],
}));
});
it('preserves editable point regions in annotation payloads', async () => {
const { buildAnnotationPayload } = await import('./api');
const frame = { id: '5', projectId: '9', index: 0, url: '/frame.jpg', width: 100, height: 50 };
expect(buildAnnotationPayload('9', {
id: 'm1',
frameId: '5',
pathData: 'M 10 10 L 90 10 L 90 40 Z',
label: 'GT Mask',
color: '#22c55e',
segmentation: [[10, 10, 90, 10, 90, 40]],
points: [[50, 25]],
}, frame)).toEqual(expect.objectContaining({
points: [[0.5, 0.5]],
}));
});
it('normalizes positive and negative point prompts for AI prediction', async () => {
const { predictMask } = await import('./api');
axiosMock.client.post.mockResolvedValueOnce({
@@ -341,6 +397,38 @@ describe('api client contracts', () => {
});
});
it('passes AI post-processing options to prediction endpoint', async () => {
const { predictMask } = await import('./api');
axiosMock.client.post.mockResolvedValueOnce({ data: { polygons: [], scores: [] } });
await predictMask({
imageId: '7',
imageWidth: 640,
imageHeight: 360,
points: [{ x: 320, y: 180, type: 'pos' }],
options: {
crop_to_prompt: true,
auto_filter_background: true,
min_score: 0.05,
},
});
expect(axiosMock.client.post).toHaveBeenCalledWith('/api/ai/predict', {
image_id: 7,
prompt_type: 'point',
prompt_data: {
points: [[0.5, 0.5]],
labels: [1],
},
model: 'sam2',
options: {
crop_to_prompt: true,
auto_filter_background: true,
min_score: 0.05,
},
});
});
it('loads AI model and GPU runtime status', async () => {
const { getAiModelStatus } = await import('./api');
const status = {

View File

@@ -197,6 +197,16 @@ export async function getTask(taskId: string | number): Promise<ProcessingTask>
return response.data;
}
export async function cancelTask(taskId: string | number): Promise<ProcessingTask> {
const response = await apiClient.post(`/api/tasks/${taskId}/cancel`);
return response.data;
}
export async function retryTask(taskId: string | number): Promise<ProcessingTask> {
const response = await apiClient.post(`/api/tasks/${taskId}/retry`);
return response.data;
}
interface PredictMaskPayload {
imageId: string;
imageWidth: number;
@@ -205,6 +215,12 @@ interface PredictMaskPayload {
points?: { x: number; y: number; type: 'pos' | 'neg' }[];
box?: { x1: number; y1: number; x2: number; y2: number };
text?: string;
options?: {
crop_to_prompt?: boolean;
auto_filter_background?: boolean;
min_score?: number;
crop_margin?: number;
};
}
interface PredictMaskResult {
@@ -234,6 +250,8 @@ export interface AiModelStatus {
python_ok: boolean;
torch_ok: boolean;
cuda_required: boolean;
external_available?: boolean;
external_python?: string | null;
}
export interface AiRuntimeStatus {
@@ -301,6 +319,8 @@ export interface DashboardTask {
name: string;
progress: number;
status: string;
raw_status?: string;
error?: string | null;
frame_count: number;
updated_at: string | null;
}
@@ -397,7 +417,7 @@ export function buildAnnotationPayload(
}
: undefined;
return {
const payload: SaveAnnotationPayload = {
project_id: Number(projectId),
frame_id: Number(frame.id),
template_id: effectiveTemplateId ? Number(effectiveTemplateId) : undefined,
@@ -416,6 +436,15 @@ export function buildAnnotationPayload(
]
: undefined,
};
if (mask.points) {
payload.points = mask.points.map(([x, y]) => [
clamp01(x / Math.max(frame.width, 1)),
clamp01(y / Math.max(frame.height, 1)),
]);
}
return payload;
}
export function annotationToMask(annotation: SavedAnnotation, frame: Frame): Mask | null {
@@ -438,6 +467,7 @@ export function annotationToMask(annotation: SavedAnnotation, frame: Frame): Mas
label: classMetadata?.name || annotation.mask_data?.label || `Annotation ${annotation.id}`,
color: classMetadata?.color || annotation.mask_data?.color || '#06b6d4',
segmentation: polygons.map((polygon) => polygon.flatMap(([x, y]) => [x * frame.width, y * frame.height])),
points: annotation.points?.map(([x, y]) => [x * frame.width, y * frame.height]),
bbox,
area: bbox[2] * bbox[3],
};
@@ -471,6 +501,7 @@ export async function predictMask(payload: PredictMaskPayload): Promise<PredictM
prompt_type,
prompt_data,
model: payload.model || 'sam2',
...(payload.options ? { options: payload.options } : {}),
});
const polygons: number[][][] = response.data.polygons || [];
@@ -523,6 +554,23 @@ export async function deleteAnnotation(annotationId: string): Promise<void> {
await apiClient.delete(`/api/ai/annotations/${annotationId}`);
}
export async function importGtMask(
file: File,
projectId: string,
frameId: string,
templateId?: string | null,
): Promise<SavedAnnotation[]> {
const formData = new FormData();
formData.append('file', file);
formData.append('project_id', projectId);
formData.append('frame_id', frameId);
if (templateId) formData.append('template_id', templateId);
const response = await apiClient.post('/api/ai/import-gt-mask', formData, {
headers: { 'Content-Type': 'multipart/form-data' },
});
return response.data;
}
export async function getDashboardOverview(): Promise<DashboardOverview> {
const response = await apiClient.get('/api/dashboard/overview');
return response.data;
@@ -536,4 +584,11 @@ export async function exportCoco(projectId: string): Promise<Blob> {
return response.data;
}
export async function exportMasks(projectId: string): Promise<Blob> {
const response = await apiClient.get(`/api/export/${projectId}/masks`, {
responseType: 'blob',
});
return response.data;
}
export default apiClient;

View File

@@ -3,7 +3,7 @@ import { WS_PROGRESS_URL } from './config';
type ProgressCallback = (data: ProgressMessage) => void;
interface ProgressMessage {
type: 'progress' | 'status' | 'error' | 'complete';
type: 'progress' | 'status' | 'error' | 'complete' | 'cancelled';
taskId?: string;
task_id?: number;
project_id?: number;

View File

@@ -53,4 +53,15 @@ describe('useStore', () => {
expect(useStore.getState().masks).toEqual([]);
expect(useStore.getState().templates).toEqual([]);
});
it('keeps undo and redo history for mask edits', () => {
useStore.getState().addMask({ id: 'm1', frameId: 'f1', pathData: 'M 0 0 Z', label: 'mask 1', color: '#fff' });
useStore.getState().addMask({ id: 'm2', frameId: 'f1', pathData: 'M 1 1 Z', label: 'mask 2', color: '#000' });
expect(useStore.getState().masks.map((mask) => mask.id)).toEqual(['m1', 'm2']);
useStore.getState().undoMasks();
expect(useStore.getState().masks.map((mask) => mask.id)).toEqual(['m1']);
useStore.getState().redoMasks();
expect(useStore.getState().masks.map((mask) => mask.id)).toEqual(['m1', 'm2']);
});
});

View File

@@ -56,8 +56,10 @@ export interface Mask {
color: string;
opacity?: number;
segmentation?: number[][];
points?: number[][];
bbox?: [number, number, number, number];
area?: number;
metadata?: Record<string, unknown>;
}
export interface Template {
@@ -110,6 +112,8 @@ export interface AppState {
currentFrameIndex: number;
annotations: Annotation[];
masks: Mask[];
maskHistory: Mask[][];
maskFuture: Mask[][];
setActiveModule: (module: string) => void;
setActiveTool: (tool: string) => void;
setAiModel: (model: AiModelId) => void;
@@ -120,6 +124,8 @@ export interface AppState {
updateMask: (id: string, updates: Partial<Mask>) => void;
setMasks: (masks: Mask[]) => void;
clearMasks: () => void;
undoMasks: () => void;
redoMasks: () => void;
removeAnnotation: (id: string) => void;
// Templates
@@ -161,6 +167,8 @@ export const useStore = create<AppState>((set) => ({
frames: [],
annotations: [],
masks: [],
maskHistory: [],
maskFuture: [],
activeTemplateId: null,
activeClassId: null,
activeClass: null,
@@ -187,6 +195,8 @@ export const useStore = create<AppState>((set) => ({
currentFrameIndex: 0,
annotations: [],
masks: [],
maskHistory: [],
maskFuture: [],
setActiveModule: (activeModule: string) => set({ activeModule }),
setActiveTool: (activeTool: string) => set({ activeTool }),
setAiModel: (aiModel: AiModelId) => set({ aiModel }),
@@ -195,13 +205,54 @@ export const useStore = create<AppState>((set) => ({
addAnnotation: (annotation: Annotation) =>
set((state) => ({ annotations: [...state.annotations, annotation] })),
addMask: (mask: Mask) =>
set((state) => ({ masks: [...state.masks, mask] })),
set((state) => ({
masks: [...state.masks, mask],
maskHistory: [...state.maskHistory, state.masks],
maskFuture: [],
})),
updateMask: (id: string, updates: Partial<Mask>) =>
set((state) => ({
masks: state.masks.map((mask) => (mask.id === id ? { ...mask, ...updates } : mask)),
maskHistory: [...state.maskHistory, state.masks],
maskFuture: [],
})),
setMasks: (masks: Mask[]) => set({ masks }),
clearMasks: () => set({ masks: [] }),
setMasks: (masks: Mask[]) =>
set((state) => {
const isInitialHydration = state.masks.length === 0
&& state.maskHistory.length === 0
&& state.maskFuture.length === 0;
return {
masks,
maskHistory: isInitialHydration ? [] : [...state.maskHistory, state.masks],
maskFuture: [],
};
}),
clearMasks: () =>
set((state) => ({
masks: [],
maskHistory: [...state.maskHistory, state.masks],
maskFuture: [],
})),
undoMasks: () =>
set((state) => {
if (state.maskHistory.length === 0) return state;
const previous = state.maskHistory[state.maskHistory.length - 1];
return {
masks: previous,
maskHistory: state.maskHistory.slice(0, -1),
maskFuture: [state.masks, ...state.maskFuture],
};
}),
redoMasks: () =>
set((state) => {
if (state.maskFuture.length === 0) return state;
const [next, ...rest] = state.maskFuture;
return {
masks: next,
maskHistory: [...state.maskHistory, state.masks],
maskFuture: rest,
};
}),
removeAnnotation: (id: string) =>
set((state) => ({
annotations: state.annotations.filter((a) => a.id !== id),

View File

@@ -32,24 +32,69 @@ function makeStageEvent(x = 120, y = 80) {
}
vi.mock('react-konva', () => ({
Stage: ({ children, onClick, onMouseDown, onMouseUp, onMouseMove, onWheel }: any) => (
Stage: ({ children, onClick, onMouseDown, onMouseUp, onMouseMove, onWheel }: any) => {
const coords = (event: React.MouseEvent<HTMLDivElement>, fallbackX: number, fallbackY: number) => ({
x: event.clientX || fallbackX,
y: event.clientY || fallbackY,
});
return (
<div
data-testid="konva-stage"
onClick={() => onClick?.(makeStageEvent())}
onMouseDown={() => onMouseDown?.(makeStageEvent())}
onMouseUp={() => onMouseUp?.(makeStageEvent(260, 200))}
onMouseMove={() => onMouseMove?.(makeStageEvent(180, 120))}
onClick={(event) => {
const point = coords(event, 120, 80);
onClick?.(makeStageEvent(point.x, point.y));
}}
onMouseDown={(event) => {
const point = coords(event, 120, 80);
onMouseDown?.(makeStageEvent(point.x, point.y));
}}
onMouseUp={(event) => {
const point = coords(event, 260, 200);
onMouseUp?.(makeStageEvent(point.x, point.y));
}}
onMouseMove={(event) => {
const point = coords(event, 180, 120);
onMouseMove?.(makeStageEvent(point.x, point.y));
}}
onWheel={() => onWheel?.(makeStageEvent())}
>
{children}
</div>
),
);
},
Layer: ({ children }: any) => <div data-testid="konva-layer">{children}</div>,
Group: ({ children }: any) => <div data-testid="konva-group">{children}</div>,
Image: ({ image }: any) => <img data-testid="konva-image" alt="" src={image?.src || ''} />,
Circle: (props: any) => <span data-testid="konva-circle" data-fill={props.fill} />,
Circle: (props: any) => (
<span
data-testid="konva-circle"
data-fill={props.fill}
data-x={props.x}
data-y={props.y}
onClick={() => props.onClick?.({ cancelBubble: false })}
onMouseUp={(event: React.MouseEvent<HTMLSpanElement>) => props.onDragEnd?.({
target: {
x: () => event.clientX || props.x || 0,
y: () => event.clientY || props.y || 0,
},
})}
onDragEnd={(event: React.DragEvent<HTMLSpanElement>) => props.onDragEnd?.({
target: {
x: () => event.clientX || props.x || 0,
y: () => event.clientY || props.y || 0,
},
})}
/>
),
Rect: (props: any) => <span data-testid="konva-rect" data-width={props.width} />,
Path: (props: any) => <span data-testid="konva-path" data-path={props.data} data-fill={props.fill} />,
Path: (props: any) => (
<span
data-testid="konva-path"
data-path={props.data}
data-fill={props.fill}
onClick={() => props.onClick?.({ cancelBubble: false })}
/>
),
}));
vi.mock('use-image', () => ({

View File

@@ -13,6 +13,8 @@ export function resetStore() {
currentFrameIndex: 0,
annotations: [],
masks: [],
maskHistory: [],
maskFuture: [],
templates: [],
activeTemplateId: null,
activeClassId: null,