feat: 建立 SAM2 标注闭环基线
- 打通工作区真实标注闭环:支持手工多边形、矩形、圆形、点区域和线段生成 mask,并可保存、回显、更新和删除后端 annotation。 - 增强 polygon 编辑器:支持顶点拖动、顶点删除、边中点插入、多 polygon 子区域选择编辑,以及区域合并和区域去除。 - 接入 GT mask 导入:后端支持二值/多类别 mask 拆分、contour 转 polygon、distance transform seed point,前端支持导入、回显和 seed point 拖动编辑。 - 完善导出能力:COCO JSON 导出对齐前端,PNG mask ZIP 同时包含单标注 mask、按 zIndex 融合的 semantic_frame 和 semantic_classes.json。 - 打通异步任务管理:新增任务取消、重试、失败详情接口与 Dashboard 控件,worker 支持取消状态检查并通过 Redis/WebSocket 推送 cancelled 事件。 - 对接 Dashboard 后端数据:概览统计、解析队列和实时流转记录从 FastAPI 聚合接口与 WebSocket 更新。 - 增强 AI 推理参数:前端发送 crop_to_prompt、auto_filter_background 和 min_score,后端支持点/框 prompt 局部裁剪推理、结果回映射和负向点/低分过滤。 - 接入 SAM3 基础设施:新增独立 Python 3.12 sam3 环境安装脚本、外部 worker helper、后端桥接和真实 Python/CUDA/包/HF checkpoint access 状态检测。 - 保留 SAM3 授权边界:当前官方 facebook/sam3 gated 权重未授权时状态接口会返回不可用,不伪装成可推理。 - 增强前端状态管理:新增 mask undo/redo 历史栈、AI 模型选择状态、保存状态 dirty/draft/saved 流转和项目状态归一化。 - 更新前端 API 封装:补充 annotation CRUD、GT mask import、mask ZIP export、task cancel/retry/detail、AI runtime status 和 prediction options。 - 更新 UI 控件:ToolsPalette、AISegmentation、VideoWorkspace 和 CanvasArea 接入真实操作、导入导出、撤销重做、任务控制和模型状态。 - 新增 polygon-clipping 依赖,用于前端区域 union/difference 几何运算。 - 完善后端 schemas/status/progress:补充 AI 模型外部状态字段、任务 cancelled 状态和进度事件 payload。 - 补充测试覆盖:新增后端任务控制、SAM3 桥接、GT mask、导出融合、AI options 测试;补充前端 Canvas、Dashboard、VideoWorkspace、ToolsPalette、API 和 store 测试。 - 更新 README、AGENTS 和 doc 文档:冻结当前需求/设计/测试计划,标注真实功能、剩余 Mock、SAM3 授权边界和后续实施顺序。
This commit is contained in:
@@ -40,4 +40,26 @@ describe('AISegmentation', () => {
|
||||
expect(useStore.getState().aiModel).toBe('sam3');
|
||||
expect(await screen.findByText('SAM 3 missing runtime')).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('passes enabled inference parameters to the backend', async () => {
|
||||
apiMock.predictMask.mockResolvedValueOnce({ masks: [] });
|
||||
render(<AISegmentation onSendToWorkspace={vi.fn()} />);
|
||||
|
||||
fireEvent.click(screen.getByText('正向选点'));
|
||||
fireEvent.click(screen.getByTestId('konva-stage'));
|
||||
fireEvent.click(await screen.findByText('执行高精度语义分割'));
|
||||
|
||||
expect(apiMock.predictMask).toHaveBeenCalledWith(expect.objectContaining({
|
||||
imageId: 'frame-1',
|
||||
imageWidth: 640,
|
||||
imageHeight: 360,
|
||||
model: 'sam2',
|
||||
points: [{ x: 120, y: 80, type: 'pos' }],
|
||||
options: {
|
||||
crop_to_prompt: false,
|
||||
auto_filter_background: true,
|
||||
min_score: 0.05,
|
||||
},
|
||||
}));
|
||||
});
|
||||
});
|
||||
|
||||
@@ -17,6 +17,10 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
|
||||
const masks = useStore((state) => state.masks);
|
||||
const addMask = useStore((state) => state.addMask);
|
||||
const clearMasks = useStore((state) => state.clearMasks);
|
||||
const maskHistory = useStore((state) => state.maskHistory);
|
||||
const maskFuture = useStore((state) => state.maskFuture);
|
||||
const undoMasks = useStore((state) => state.undoMasks);
|
||||
const redoMasks = useStore((state) => state.redoMasks);
|
||||
const frames = useStore((state) => state.frames);
|
||||
const currentFrameIndex = useStore((state) => state.currentFrameIndex);
|
||||
const activeTemplateId = useStore((state) => state.activeTemplateId);
|
||||
@@ -109,6 +113,11 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
|
||||
model: aiModel,
|
||||
points: points.map((p) => ({ x: p.x, y: p.y, type: p.type })),
|
||||
text: semanticText.trim() || undefined,
|
||||
options: {
|
||||
crop_to_prompt: cropMode,
|
||||
auto_filter_background: autoDeleteBg,
|
||||
min_score: autoDeleteBg ? 0.05 : 0,
|
||||
},
|
||||
});
|
||||
|
||||
result.masks.forEach((m) => {
|
||||
@@ -136,7 +145,7 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
|
||||
} finally {
|
||||
setIsInferencing(false);
|
||||
}
|
||||
}, [activeClass, activeTemplateId, addMask, aiModel, currentFrame?.height, currentFrame?.id, currentFrame?.width, image?.height, image?.naturalHeight, image?.naturalWidth, image?.width, points, semanticText]);
|
||||
}, [activeClass, activeTemplateId, addMask, aiModel, autoDeleteBg, cropMode, currentFrame?.height, currentFrame?.id, currentFrame?.width, image?.height, image?.naturalHeight, image?.naturalWidth, image?.width, points, semanticText]);
|
||||
|
||||
const handleStageClick = (e: any) => {
|
||||
if (effectiveTool === 'move') return;
|
||||
@@ -290,10 +299,20 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
|
||||
<span className="text-[10px] text-gray-500 uppercase tracking-widest font-mono">{aiModel.toUpperCase()} 动态推理渲染</span>
|
||||
</div>
|
||||
<div className="flex items-center gap-4">
|
||||
<button className="w-8 h-8 rounded text-gray-400 hover:bg-white/5 hover:text-white flex items-center justify-center transition-colors" title="撤销操作 (Ctrl+Z)">
|
||||
<button
|
||||
onClick={undoMasks}
|
||||
disabled={maskHistory.length === 0}
|
||||
className="w-8 h-8 rounded text-gray-400 hover:bg-white/5 hover:text-white flex items-center justify-center transition-colors disabled:opacity-30 disabled:hover:bg-transparent disabled:hover:text-gray-400 disabled:cursor-not-allowed"
|
||||
title="撤销操作 (Ctrl+Z)"
|
||||
>
|
||||
<Undo size={14} />
|
||||
</button>
|
||||
<button className="w-8 h-8 rounded text-gray-400 hover:bg-white/5 hover:text-white flex items-center justify-center transition-colors" title="重做操作 (Ctrl+Shift+Z)">
|
||||
<button
|
||||
onClick={redoMasks}
|
||||
disabled={maskFuture.length === 0}
|
||||
className="w-8 h-8 rounded text-gray-400 hover:bg-white/5 hover:text-white flex items-center justify-center transition-colors disabled:opacity-30 disabled:hover:bg-transparent disabled:hover:text-gray-400 disabled:cursor-not-allowed"
|
||||
title="重做操作 (Ctrl+Shift+Z)"
|
||||
>
|
||||
<Redo size={14} />
|
||||
</button>
|
||||
<div className="w-px h-4 bg-white/10 mx-1"></div>
|
||||
|
||||
@@ -79,6 +79,271 @@ describe('CanvasArea', () => {
|
||||
expect(screen.getByText('遮罩数: 1')).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('renders imported GT seed points for editable point regions', () => {
|
||||
useStore.setState({
|
||||
masks: [
|
||||
{
|
||||
id: 'gt-1',
|
||||
frameId: 'frame-1',
|
||||
pathData: 'M 0 0 L 10 0 L 10 10 Z',
|
||||
label: 'GT',
|
||||
color: '#22c55e',
|
||||
points: [[120, 80]],
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
render(<CanvasArea activeTool="move" frame={frame} />);
|
||||
|
||||
expect(screen.getAllByTestId('konva-circle')).toHaveLength(2);
|
||||
});
|
||||
|
||||
it('selects a polygon mask and drags a vertex into dirty saved state', () => {
|
||||
useStore.setState({
|
||||
masks: [
|
||||
{
|
||||
id: 'annotation-99',
|
||||
annotationId: '99',
|
||||
frameId: 'frame-1',
|
||||
pathData: 'M 10 10 L 90 10 L 90 40 Z',
|
||||
label: 'Saved',
|
||||
color: '#06b6d4',
|
||||
saved: true,
|
||||
saveStatus: 'saved',
|
||||
segmentation: [[10, 10, 90, 10, 90, 40]],
|
||||
bbox: [10, 10, 80, 30],
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
render(<CanvasArea activeTool="move" frame={frame} />);
|
||||
fireEvent.click(screen.getByTestId('konva-path'));
|
||||
const handles = screen.getAllByTestId('konva-circle')
|
||||
.filter((element) => element.getAttribute('data-fill') === '#ffffff');
|
||||
expect(handles).toHaveLength(3);
|
||||
|
||||
fireEvent.mouseUp(handles[0], { clientX: 20, clientY: 30 });
|
||||
|
||||
expect(useStore.getState().masks[0]).toEqual(expect.objectContaining({
|
||||
pathData: 'M 20 30 L 90 10 L 90 40 Z',
|
||||
segmentation: [[20, 30, 90, 10, 90, 40]],
|
||||
bbox: [20, 10, 70, 30],
|
||||
area: 1050,
|
||||
saveStatus: 'dirty',
|
||||
saved: false,
|
||||
}));
|
||||
});
|
||||
|
||||
it('deletes a selected polygon vertex without dropping below three points', () => {
|
||||
useStore.setState({
|
||||
masks: [
|
||||
{
|
||||
id: 'draft-1',
|
||||
frameId: 'frame-1',
|
||||
pathData: 'M 10 10 L 90 10 L 90 40 L 10 40 Z',
|
||||
label: 'Draft',
|
||||
color: '#06b6d4',
|
||||
saveStatus: 'draft',
|
||||
segmentation: [[10, 10, 90, 10, 90, 40, 10, 40]],
|
||||
bbox: [10, 10, 80, 30],
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
render(<CanvasArea activeTool="move" frame={frame} />);
|
||||
fireEvent.click(screen.getByTestId('konva-path'));
|
||||
const handles = screen.getAllByTestId('konva-circle')
|
||||
.filter((element) => element.getAttribute('data-fill') === '#ffffff');
|
||||
fireEvent.click(handles[0]);
|
||||
fireEvent.keyDown(window, { key: 'Delete' });
|
||||
|
||||
expect(useStore.getState().masks[0]).toEqual(expect.objectContaining({
|
||||
pathData: 'M 90 10 L 90 40 L 10 40 Z',
|
||||
segmentation: [[90, 10, 90, 40, 10, 40]],
|
||||
saveStatus: 'draft',
|
||||
}));
|
||||
});
|
||||
|
||||
it('inserts a polygon vertex from an edge midpoint handle', () => {
|
||||
useStore.setState({
|
||||
masks: [
|
||||
{
|
||||
id: 'draft-1',
|
||||
frameId: 'frame-1',
|
||||
pathData: 'M 10 10 L 90 10 L 90 40 Z',
|
||||
label: 'Draft',
|
||||
color: '#06b6d4',
|
||||
saveStatus: 'draft',
|
||||
segmentation: [[10, 10, 90, 10, 90, 40]],
|
||||
bbox: [10, 10, 80, 30],
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
render(<CanvasArea activeTool="move" frame={frame} />);
|
||||
fireEvent.click(screen.getByTestId('konva-path'));
|
||||
const edgeHandles = screen.getAllByTestId('konva-circle')
|
||||
.filter((element) => element.getAttribute('data-fill') === '#22d3ee');
|
||||
fireEvent.click(edgeHandles[0]);
|
||||
|
||||
expect(useStore.getState().masks[0]).toEqual(expect.objectContaining({
|
||||
segmentation: [[10, 10, 50, 10, 90, 10, 90, 40]],
|
||||
pathData: 'M 10 10 L 50 10 L 90 10 L 90 40 Z',
|
||||
saveStatus: 'draft',
|
||||
}));
|
||||
});
|
||||
|
||||
it('edits the selected polygon in a multi-polygon mask', () => {
|
||||
useStore.setState({
|
||||
masks: [
|
||||
{
|
||||
id: 'multi-1',
|
||||
frameId: 'frame-1',
|
||||
pathData: 'M 10 10 L 50 10 L 50 40 Z M 100 100 L 150 100 L 150 140 Z',
|
||||
label: 'Multi',
|
||||
color: '#06b6d4',
|
||||
saveStatus: 'draft',
|
||||
segmentation: [
|
||||
[10, 10, 50, 10, 50, 40],
|
||||
[100, 100, 150, 100, 150, 140],
|
||||
],
|
||||
bbox: [10, 10, 140, 130],
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
render(<CanvasArea activeTool="move" frame={frame} />);
|
||||
const paths = screen.getAllByTestId('konva-path');
|
||||
fireEvent.click(paths[1]);
|
||||
const vertexHandles = screen.getAllByTestId('konva-circle')
|
||||
.filter((element) => element.getAttribute('data-fill') === '#ffffff');
|
||||
fireEvent.mouseUp(vertexHandles[0], { clientX: 120, clientY: 120 });
|
||||
|
||||
expect(useStore.getState().masks[0].segmentation).toEqual([
|
||||
[10, 10, 50, 10, 50, 40],
|
||||
[120, 120, 150, 100, 150, 140],
|
||||
]);
|
||||
});
|
||||
|
||||
it('merges selected draft masks with polygon union', () => {
|
||||
useStore.setState({
|
||||
masks: [
|
||||
{
|
||||
id: 'm1',
|
||||
frameId: 'frame-1',
|
||||
pathData: 'M 10 10 L 90 10 L 90 50 L 10 50 Z',
|
||||
label: 'A',
|
||||
color: '#06b6d4',
|
||||
segmentation: [[10, 10, 90, 10, 90, 50, 10, 50]],
|
||||
},
|
||||
{
|
||||
id: 'm2',
|
||||
frameId: 'frame-1',
|
||||
pathData: 'M 50 30 L 120 30 L 120 80 L 50 80 Z',
|
||||
label: 'B',
|
||||
color: '#ff0000',
|
||||
segmentation: [[50, 30, 120, 30, 120, 80, 50, 80]],
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
render(<CanvasArea activeTool="area_merge" frame={frame} />);
|
||||
const paths = screen.getAllByTestId('konva-path');
|
||||
fireEvent.click(paths[0]);
|
||||
fireEvent.click(paths[1]);
|
||||
fireEvent.click(screen.getByRole('button', { name: '合并选中' }));
|
||||
|
||||
expect(useStore.getState().masks).toHaveLength(1);
|
||||
expect(useStore.getState().masks[0]).toEqual(expect.objectContaining({
|
||||
id: 'm1',
|
||||
segmentation: [[10, 10, 90, 10, 90, 30, 120, 30, 120, 80, 50, 80, 50, 50, 10, 50]],
|
||||
bbox: [10, 10, 110, 70],
|
||||
saveStatus: 'draft',
|
||||
}));
|
||||
});
|
||||
|
||||
it('removes overlap from the primary selected mask with polygon difference', () => {
|
||||
useStore.setState({
|
||||
masks: [
|
||||
{
|
||||
id: 'm1',
|
||||
frameId: 'frame-1',
|
||||
pathData: 'M 10 10 L 90 10 L 90 50 L 10 50 Z',
|
||||
label: 'A',
|
||||
color: '#06b6d4',
|
||||
segmentation: [[10, 10, 90, 10, 90, 50, 10, 50]],
|
||||
},
|
||||
{
|
||||
id: 'm2',
|
||||
frameId: 'frame-1',
|
||||
pathData: 'M 50 30 L 120 30 L 120 80 L 50 80 Z',
|
||||
label: 'B',
|
||||
color: '#ff0000',
|
||||
segmentation: [[50, 30, 120, 30, 120, 80, 50, 80]],
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
render(<CanvasArea activeTool="area_remove" frame={frame} />);
|
||||
const paths = screen.getAllByTestId('konva-path');
|
||||
fireEvent.click(paths[0]);
|
||||
fireEvent.click(paths[1]);
|
||||
fireEvent.click(screen.getByRole('button', { name: '从主区域去除' }));
|
||||
|
||||
expect(useStore.getState().masks).toHaveLength(2);
|
||||
expect(useStore.getState().masks[0]).toEqual(expect.objectContaining({
|
||||
id: 'm1',
|
||||
segmentation: [[10, 10, 90, 10, 90, 30, 50, 30, 50, 50, 10, 50]],
|
||||
bbox: [10, 10, 80, 40],
|
||||
saveStatus: 'draft',
|
||||
}));
|
||||
expect(useStore.getState().masks[1].id).toBe('m2');
|
||||
});
|
||||
|
||||
it('creates a manual rectangle mask that can be undone and redone', () => {
|
||||
useStore.setState({
|
||||
activeTemplateId: '2',
|
||||
activeClass: { id: 'c1', name: '胆囊', color: '#ff0000', zIndex: 20 },
|
||||
activeClassId: 'c1',
|
||||
});
|
||||
|
||||
render(<CanvasArea activeTool="create_rectangle" frame={frame} />);
|
||||
const stage = screen.getByTestId('konva-stage');
|
||||
fireEvent.mouseDown(stage);
|
||||
fireEvent.mouseMove(stage);
|
||||
fireEvent.mouseUp(stage);
|
||||
|
||||
expect(useStore.getState().masks).toHaveLength(1);
|
||||
expect(useStore.getState().masks[0]).toEqual(expect.objectContaining({
|
||||
frameId: 'frame-1',
|
||||
label: '胆囊',
|
||||
color: '#ff0000',
|
||||
saveStatus: 'draft',
|
||||
segmentation: [[120, 80, 260, 80, 260, 200, 120, 200]],
|
||||
bbox: [120, 80, 140, 120],
|
||||
}));
|
||||
|
||||
useStore.getState().undoMasks();
|
||||
expect(useStore.getState().masks).toEqual([]);
|
||||
useStore.getState().redoMasks();
|
||||
expect(useStore.getState().masks).toHaveLength(1);
|
||||
});
|
||||
|
||||
it('finalizes a clicked polygon with Enter', () => {
|
||||
render(<CanvasArea activeTool="create_polygon" frame={frame} />);
|
||||
const stage = screen.getByTestId('konva-stage');
|
||||
fireEvent.click(stage, { clientX: 120, clientY: 80 });
|
||||
fireEvent.click(stage, { clientX: 220, clientY: 80 });
|
||||
fireEvent.click(stage, { clientX: 180, clientY: 160 });
|
||||
fireEvent.keyDown(window, { key: 'Enter' });
|
||||
|
||||
expect(useStore.getState().masks).toHaveLength(1);
|
||||
expect(useStore.getState().masks[0].metadata).toEqual(expect.objectContaining({
|
||||
source: 'manual',
|
||||
shape: '多边形',
|
||||
}));
|
||||
});
|
||||
|
||||
it('applies the selected class to current-frame masks and marks saved masks dirty', () => {
|
||||
useStore.setState({
|
||||
activeTemplateId: '2',
|
||||
|
||||
@@ -1,17 +1,180 @@
|
||||
import React, { useEffect, useRef, useState, useCallback } from 'react';
|
||||
import { Stage, Layer, Image as KonvaImage, Circle, Rect, Path, Group } from 'react-konva';
|
||||
import polygonClipping, { type MultiPolygon, type Pair } from 'polygon-clipping';
|
||||
import useImage from 'use-image';
|
||||
import { useStore } from '../store/useStore';
|
||||
import { predictMask } from '../lib/api';
|
||||
import type { Frame } from '../store/useStore';
|
||||
import type { Frame, Mask } from '../store/useStore';
|
||||
|
||||
interface CanvasAreaProps {
|
||||
activeTool: string;
|
||||
frame: Frame | null;
|
||||
onClearMasks?: () => void;
|
||||
onDeleteMaskAnnotations?: (annotationIds: string[]) => Promise<void> | void;
|
||||
}
|
||||
|
||||
export function CanvasArea({ activeTool, frame, onClearMasks }: CanvasAreaProps) {
|
||||
type CanvasPoint = { x: number; y: number };
|
||||
|
||||
const DRAG_MANUAL_TOOLS = new Set(['create_rectangle', 'create_circle', 'create_line']);
|
||||
const POLYGON_TOOL = 'create_polygon';
|
||||
const POINT_TOOL = 'create_point';
|
||||
const BOOLEAN_TOOLS = new Set(['area_merge', 'area_remove']);
|
||||
|
||||
function clamp(value: number, min: number, max: number): number {
|
||||
return Math.min(Math.max(value, min), max);
|
||||
}
|
||||
|
||||
function polygonPath(points: CanvasPoint[]): string {
|
||||
if (points.length === 0) return '';
|
||||
return points
|
||||
.map((point, index) => `${index === 0 ? 'M' : 'L'} ${point.x} ${point.y}`)
|
||||
.join(' ')
|
||||
.concat(' Z');
|
||||
}
|
||||
|
||||
function segmentationPath(segmentation?: number[][]): string {
|
||||
return (segmentation || [])
|
||||
.map((polygon) => polygonPath(flatPolygonToPoints(polygon)))
|
||||
.filter(Boolean)
|
||||
.join(' ');
|
||||
}
|
||||
|
||||
function segmentationPolygonPath(segmentation: number[][] | undefined, polygonIndex: number): string {
|
||||
const polygon = segmentation?.[polygonIndex];
|
||||
return polygon ? polygonPath(flatPolygonToPoints(polygon)) : '';
|
||||
}
|
||||
|
||||
function polygonSegmentation(points: CanvasPoint[]): number[][] {
|
||||
return [points.flatMap((point) => [point.x, point.y])];
|
||||
}
|
||||
|
||||
function segmentationToPoints(segmentation?: number[][], polygonIndex = 0): CanvasPoint[] {
|
||||
const polygon = segmentation?.[polygonIndex] || [];
|
||||
const points: CanvasPoint[] = [];
|
||||
for (let index = 0; index < polygon.length - 1; index += 2) {
|
||||
points.push({ x: polygon[index], y: polygon[index + 1] });
|
||||
}
|
||||
return points;
|
||||
}
|
||||
|
||||
function flatPolygonToPoints(polygon: number[]): CanvasPoint[] {
|
||||
const points: CanvasPoint[] = [];
|
||||
for (let index = 0; index < polygon.length - 1; index += 2) {
|
||||
points.push({ x: polygon[index], y: polygon[index + 1] });
|
||||
}
|
||||
return points;
|
||||
}
|
||||
|
||||
function segmentationAllPoints(segmentation?: number[][]): CanvasPoint[] {
|
||||
return (segmentation || []).flatMap((polygon) => flatPolygonToPoints(polygon));
|
||||
}
|
||||
|
||||
function polygonBbox(points: CanvasPoint[]): [number, number, number, number] {
|
||||
const xs = points.map((point) => point.x);
|
||||
const ys = points.map((point) => point.y);
|
||||
const minX = Math.min(...xs);
|
||||
const minY = Math.min(...ys);
|
||||
const maxX = Math.max(...xs);
|
||||
const maxY = Math.max(...ys);
|
||||
return [minX, minY, maxX - minX, maxY - minY];
|
||||
}
|
||||
|
||||
function polygonArea(points: CanvasPoint[]): number {
|
||||
if (points.length < 3) return 0;
|
||||
const sum = points.reduce((acc, point, index) => {
|
||||
const next = points[(index + 1) % points.length];
|
||||
return acc + point.x * next.y - next.x * point.y;
|
||||
}, 0);
|
||||
return Math.abs(sum) / 2;
|
||||
}
|
||||
|
||||
function segmentationArea(segmentation?: number[][]): number {
|
||||
return (segmentation || []).reduce((sum, polygon) => sum + polygonArea(flatPolygonToPoints(polygon)), 0);
|
||||
}
|
||||
|
||||
function segmentationBbox(segmentation?: number[][]): [number, number, number, number] | undefined {
|
||||
const points = segmentationAllPoints(segmentation);
|
||||
return points.length > 0 ? polygonBbox(points) : undefined;
|
||||
}
|
||||
|
||||
function closeRing(points: CanvasPoint[]): Pair[] {
|
||||
const ring = points.map((point) => [point.x, point.y] as Pair);
|
||||
const first = ring[0];
|
||||
const last = ring[ring.length - 1];
|
||||
if (first && last && (first[0] !== last[0] || first[1] !== last[1])) {
|
||||
ring.push([first[0], first[1]]);
|
||||
}
|
||||
return ring;
|
||||
}
|
||||
|
||||
function maskToMultiPolygon(mask: Mask): MultiPolygon | null {
|
||||
const polygons = (mask.segmentation || [])
|
||||
.map((polygon) => flatPolygonToPoints(polygon))
|
||||
.filter((points) => points.length >= 3)
|
||||
.map((points) => [closeRing(points)]);
|
||||
return polygons.length > 0 ? polygons : null;
|
||||
}
|
||||
|
||||
function multiPolygonToSegmentation(geometry: MultiPolygon): number[][] {
|
||||
return geometry
|
||||
.map((polygon) => polygon[0] || [])
|
||||
.map((ring) => {
|
||||
const openRing = ring.length > 1
|
||||
&& ring[0][0] === ring[ring.length - 1][0]
|
||||
&& ring[0][1] === ring[ring.length - 1][1]
|
||||
? ring.slice(0, -1)
|
||||
: ring;
|
||||
return openRing.flatMap(([x, y]) => [x, y]);
|
||||
})
|
||||
.filter((polygon) => polygon.length >= 6);
|
||||
}
|
||||
|
||||
function rectanglePoints(start: CanvasPoint, end: CanvasPoint): CanvasPoint[] {
|
||||
const x1 = Math.min(start.x, end.x);
|
||||
const y1 = Math.min(start.y, end.y);
|
||||
const x2 = Math.max(start.x, end.x);
|
||||
const y2 = Math.max(start.y, end.y);
|
||||
return [
|
||||
{ x: x1, y: y1 },
|
||||
{ x: x2, y: y1 },
|
||||
{ x: x2, y: y2 },
|
||||
{ x: x1, y: y2 },
|
||||
];
|
||||
}
|
||||
|
||||
function circlePoints(start: CanvasPoint, end: CanvasPoint): CanvasPoint[] {
|
||||
const cx = (start.x + end.x) / 2;
|
||||
const cy = (start.y + end.y) / 2;
|
||||
const rx = Math.abs(end.x - start.x) / 2;
|
||||
const ry = Math.abs(end.y - start.y) / 2;
|
||||
return Array.from({ length: 32 }, (_, index) => {
|
||||
const angle = (Math.PI * 2 * index) / 32;
|
||||
return { x: cx + Math.cos(angle) * rx, y: cy + Math.sin(angle) * ry };
|
||||
});
|
||||
}
|
||||
|
||||
function pointRegion(point: CanvasPoint, radius = 5): CanvasPoint[] {
|
||||
return Array.from({ length: 12 }, (_, index) => {
|
||||
const angle = (Math.PI * 2 * index) / 12;
|
||||
return { x: point.x + Math.cos(angle) * radius, y: point.y + Math.sin(angle) * radius };
|
||||
});
|
||||
}
|
||||
|
||||
function lineRegion(start: CanvasPoint, end: CanvasPoint, halfWidth = 4): CanvasPoint[] {
|
||||
const dx = end.x - start.x;
|
||||
const dy = end.y - start.y;
|
||||
const length = Math.hypot(dx, dy) || 1;
|
||||
const nx = (-dy / length) * halfWidth;
|
||||
const ny = (dx / length) * halfWidth;
|
||||
return [
|
||||
{ x: start.x + nx, y: start.y + ny },
|
||||
{ x: end.x + nx, y: end.y + ny },
|
||||
{ x: end.x - nx, y: end.y - ny },
|
||||
{ x: start.x - nx, y: start.y - ny },
|
||||
];
|
||||
}
|
||||
|
||||
export function CanvasArea({ activeTool, frame, onClearMasks, onDeleteMaskAnnotations }: CanvasAreaProps) {
|
||||
const containerRef = useRef<HTMLDivElement>(null);
|
||||
const [stageSize, setStageSize] = useState({ width: 800, height: 600 });
|
||||
const [scale, setScale] = useState(1);
|
||||
@@ -20,22 +183,46 @@ export function CanvasArea({ activeTool, frame, onClearMasks }: CanvasAreaProps)
|
||||
const [cursorPos, setCursorPos] = useState({ x: 0, y: 0 });
|
||||
const [boxStart, setBoxStart] = useState<{ x: number, y: number } | null>(null);
|
||||
const [boxCurrent, setBoxCurrent] = useState<{ x: number, y: number } | null>(null);
|
||||
const [manualStart, setManualStart] = useState<CanvasPoint | null>(null);
|
||||
const [manualCurrent, setManualCurrent] = useState<CanvasPoint | null>(null);
|
||||
const [polygonPoints, setPolygonPoints] = useState<CanvasPoint[]>([]);
|
||||
const [selectedMaskId, setSelectedMaskId] = useState<string | null>(null);
|
||||
const [selectedMaskIds, setSelectedMaskIds] = useState<string[]>([]);
|
||||
const [selectedPolygonIndex, setSelectedPolygonIndex] = useState(0);
|
||||
const [selectedVertexIndex, setSelectedVertexIndex] = useState<number | null>(null);
|
||||
const [isInferencing, setIsInferencing] = useState(false);
|
||||
|
||||
const masks = useStore((state) => state.masks);
|
||||
const addMask = useStore((state) => state.addMask);
|
||||
const updateMask = useStore((state) => state.updateMask);
|
||||
const clearMasks = useStore((state) => state.clearMasks);
|
||||
const setMasks = useStore((state) => state.setMasks);
|
||||
const storeActiveTool = useStore((state) => state.activeTool);
|
||||
const aiModel = useStore((state) => state.aiModel);
|
||||
const activeTemplateId = useStore((state) => state.activeTemplateId);
|
||||
const activeClass = useStore((state) => state.activeClass);
|
||||
const undoMasks = useStore((state) => state.undoMasks);
|
||||
const redoMasks = useStore((state) => state.redoMasks);
|
||||
|
||||
const effectiveTool = activeTool || storeActiveTool;
|
||||
|
||||
// Load the actual frame image
|
||||
const [image] = useImage(frame?.url || '');
|
||||
const frameMasks = masks.filter((mask) => mask.frameId === frame?.id);
|
||||
const selectedMask = React.useMemo(
|
||||
() => frameMasks.find((mask) => mask.id === selectedMaskId) || null,
|
||||
[frameMasks, selectedMaskId],
|
||||
);
|
||||
const booleanSelectedMasks = React.useMemo(
|
||||
() => selectedMaskIds
|
||||
.map((id) => frameMasks.find((mask) => mask.id === id))
|
||||
.filter((mask): mask is Mask => Boolean(mask)),
|
||||
[frameMasks, selectedMaskIds],
|
||||
);
|
||||
const selectedMaskPoints = React.useMemo(
|
||||
() => segmentationToPoints(selectedMask?.segmentation, selectedPolygonIndex),
|
||||
[selectedMask?.segmentation, selectedPolygonIndex],
|
||||
);
|
||||
const savedMaskCount = frameMasks.filter((mask) => mask.saveStatus === 'saved' || mask.saved).length;
|
||||
const draftMaskCount = frameMasks.filter((mask) => !mask.annotationId).length;
|
||||
const dirtyMaskCount = frameMasks.filter((mask) => mask.saveStatus === 'dirty').length;
|
||||
@@ -55,6 +242,25 @@ export function CanvasArea({ activeTool, frame, onClearMasks }: CanvasAreaProps)
|
||||
return () => window.removeEventListener('resize', handleResize);
|
||||
}, []);
|
||||
|
||||
useEffect(() => {
|
||||
setManualStart(null);
|
||||
setManualCurrent(null);
|
||||
setPolygonPoints([]);
|
||||
setSelectedMaskId(null);
|
||||
setSelectedMaskIds([]);
|
||||
setSelectedPolygonIndex(0);
|
||||
setSelectedVertexIndex(null);
|
||||
}, [effectiveTool, frame?.id]);
|
||||
|
||||
useEffect(() => {
|
||||
if (selectedMaskId && !frameMasks.some((mask) => mask.id === selectedMaskId)) {
|
||||
setSelectedMaskId(null);
|
||||
setSelectedMaskIds([]);
|
||||
setSelectedPolygonIndex(0);
|
||||
setSelectedVertexIndex(null);
|
||||
}
|
||||
}, [frameMasks, selectedMaskId]);
|
||||
|
||||
const handleWheel = (e: any) => {
|
||||
e.evt.preventDefault();
|
||||
const scaleBy = 1.1;
|
||||
@@ -74,6 +280,50 @@ export function CanvasArea({ activeTool, frame, onClearMasks }: CanvasAreaProps)
|
||||
});
|
||||
};
|
||||
|
||||
const stagePoint = (e: any): CanvasPoint | null => {
|
||||
const stage = e.target.getStage();
|
||||
const relPos = stage?.getRelativePointerPosition();
|
||||
if (!relPos) return null;
|
||||
const imageWidth = frame?.width || image?.naturalWidth || image?.width || stageSize.width;
|
||||
const imageHeight = frame?.height || image?.naturalHeight || image?.height || stageSize.height;
|
||||
return {
|
||||
x: clamp(relPos.x, 0, imageWidth),
|
||||
y: clamp(relPos.y, 0, imageHeight),
|
||||
};
|
||||
};
|
||||
|
||||
const createManualMask = useCallback((shape: string, polygon: CanvasPoint[]) => {
|
||||
if (!frame?.id || polygon.length < 3) return;
|
||||
const area = polygonArea(polygon);
|
||||
if (area <= 1) return;
|
||||
const color = activeClass?.color || '#06b6d4';
|
||||
const label = activeClass?.name || `手工${shape}`;
|
||||
const mask: Mask = {
|
||||
id: `manual-${frame.id}-${shape}-${Date.now()}`,
|
||||
frameId: frame.id,
|
||||
templateId: activeTemplateId || undefined,
|
||||
classId: activeClass?.id,
|
||||
className: activeClass?.name,
|
||||
classZIndex: activeClass?.zIndex,
|
||||
saveStatus: 'draft',
|
||||
saved: false,
|
||||
pathData: polygonPath(polygon),
|
||||
label,
|
||||
color,
|
||||
segmentation: polygonSegmentation(polygon),
|
||||
points: shape === '点区域'
|
||||
? [[
|
||||
polygon.reduce((sum, point) => sum + point.x, 0) / polygon.length,
|
||||
polygon.reduce((sum, point) => sum + point.y, 0) / polygon.length,
|
||||
]]
|
||||
: undefined,
|
||||
bbox: polygonBbox(polygon),
|
||||
area,
|
||||
metadata: { source: 'manual', shape },
|
||||
};
|
||||
addMask(mask);
|
||||
}, [activeClass, activeTemplateId, addMask, frame?.id]);
|
||||
|
||||
const handleMouseMove = (e: any) => {
|
||||
const stage = e.target.getStage();
|
||||
if (!stage) return;
|
||||
@@ -90,6 +340,13 @@ export function CanvasArea({ activeTool, frame, onClearMasks }: CanvasAreaProps)
|
||||
setBoxCurrent({ x: relPos.x, y: relPos.y });
|
||||
}
|
||||
}
|
||||
|
||||
if (manualStart && DRAG_MANUAL_TOOLS.has(effectiveTool)) {
|
||||
const pos = stage.getRelativePointerPosition();
|
||||
if (pos) {
|
||||
setManualCurrent({ x: pos.x, y: pos.y });
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
const runInference = useCallback(async (promptPoints?: typeof points, promptBox?: { x1: number, y1: number, x2: number, y2: number }) => {
|
||||
@@ -132,6 +389,7 @@ export function CanvasArea({ activeTool, frame, onClearMasks }: CanvasAreaProps)
|
||||
label,
|
||||
color,
|
||||
segmentation: m.segmentation,
|
||||
points: promptPoints?.filter((p) => p.type === 'pos').map((p) => [p.x, p.y]),
|
||||
bbox: m.bbox,
|
||||
area: m.area,
|
||||
});
|
||||
@@ -170,6 +428,15 @@ export function CanvasArea({ activeTool, frame, onClearMasks }: CanvasAreaProps)
|
||||
};
|
||||
|
||||
const handleStageMouseDown = (e: any) => {
|
||||
if (DRAG_MANUAL_TOOLS.has(effectiveTool)) {
|
||||
const pos = stagePoint(e);
|
||||
if (pos) {
|
||||
setManualStart(pos);
|
||||
setManualCurrent(pos);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (effectiveTool === 'box_select') {
|
||||
const stage = e.target.getStage();
|
||||
const pos = stage.getRelativePointerPosition();
|
||||
@@ -181,6 +448,27 @@ export function CanvasArea({ activeTool, frame, onClearMasks }: CanvasAreaProps)
|
||||
};
|
||||
|
||||
const handleStageMouseUp = (e: any) => {
|
||||
if (DRAG_MANUAL_TOOLS.has(effectiveTool) && manualStart) {
|
||||
const end = stagePoint(e) || manualCurrent || manualStart;
|
||||
const width = Math.abs(end.x - manualStart.x);
|
||||
const height = Math.abs(end.y - manualStart.y);
|
||||
const distance = Math.hypot(width, height);
|
||||
|
||||
if (effectiveTool === 'create_rectangle' && width > 4 && height > 4) {
|
||||
createManualMask('矩形', rectanglePoints(manualStart, end));
|
||||
}
|
||||
if (effectiveTool === 'create_circle' && width > 4 && height > 4) {
|
||||
createManualMask('圆形', circlePoints(manualStart, end));
|
||||
}
|
||||
if (effectiveTool === 'create_line' && distance > 4) {
|
||||
createManualMask('线段', lineRegion(manualStart, end));
|
||||
}
|
||||
|
||||
setManualStart(null);
|
||||
setManualCurrent(null);
|
||||
return;
|
||||
}
|
||||
|
||||
if (effectiveTool === 'box_select' && boxStart && boxCurrent) {
|
||||
const x1 = Math.min(boxStart.x, boxCurrent.x);
|
||||
const y1 = Math.min(boxStart.y, boxCurrent.y);
|
||||
@@ -199,12 +487,32 @@ export function CanvasArea({ activeTool, frame, onClearMasks }: CanvasAreaProps)
|
||||
const handleStageClick = (e: any) => {
|
||||
if (effectiveTool === 'move') return;
|
||||
if (effectiveTool === 'box_select') return; // handled by mouseup
|
||||
if (DRAG_MANUAL_TOOLS.has(effectiveTool)) return;
|
||||
|
||||
if (effectiveTool === POINT_TOOL) {
|
||||
const pos = stagePoint(e);
|
||||
if (pos) {
|
||||
createManualMask('点区域', pointRegion(pos));
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (effectiveTool === POLYGON_TOOL) {
|
||||
const pos = stagePoint(e);
|
||||
if (pos) {
|
||||
setPolygonPoints((current) => [...current, pos]);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (effectiveTool === 'point_pos' || effectiveTool === 'point_neg') {
|
||||
const stage = e.target.getStage();
|
||||
const pos = stage.getRelativePointerPosition();
|
||||
if (pos) {
|
||||
const newPoints = [...points, { x: pos.x, y: pos.y, type: effectiveTool === 'point_pos' ? 'pos' : 'neg' as 'pos'|'neg' }];
|
||||
const newPoints = [
|
||||
...points,
|
||||
{ x: pos.x, y: pos.y, type: (effectiveTool === 'point_pos' ? 'pos' : 'neg') as 'pos' | 'neg' },
|
||||
];
|
||||
setPoints(newPoints);
|
||||
// Auto-trigger inference after point selection
|
||||
runInference(newPoints);
|
||||
@@ -212,6 +520,74 @@ export function CanvasArea({ activeTool, frame, onClearMasks }: CanvasAreaProps)
|
||||
}
|
||||
};
|
||||
|
||||
const updatePolygonMask = useCallback((mask: Mask, nextPoints: CanvasPoint[], polygonIndex = 0) => {
|
||||
if (nextPoints.length < 3) return;
|
||||
const nextSegmentation = [...(mask.segmentation || [])];
|
||||
nextSegmentation[polygonIndex] = nextPoints.flatMap((point) => [point.x, point.y]);
|
||||
const bbox = segmentationBbox(nextSegmentation) || polygonBbox(nextPoints);
|
||||
updateMask(mask.id, {
|
||||
pathData: segmentationPath(nextSegmentation),
|
||||
segmentation: nextSegmentation,
|
||||
bbox,
|
||||
area: segmentationArea(nextSegmentation),
|
||||
saveStatus: mask.annotationId ? 'dirty' : 'draft',
|
||||
saved: mask.annotationId ? false : mask.saved,
|
||||
});
|
||||
}, [updateMask]);
|
||||
|
||||
const updateMaskFromSegmentation = useCallback((mask: Mask, segmentation: number[][]): Mask => {
|
||||
const bbox = segmentationBbox(segmentation);
|
||||
return {
|
||||
...mask,
|
||||
pathData: segmentationPath(segmentation),
|
||||
segmentation,
|
||||
bbox,
|
||||
area: segmentationArea(segmentation),
|
||||
saveStatus: mask.annotationId ? 'dirty' : 'draft',
|
||||
saved: mask.annotationId ? false : mask.saved,
|
||||
};
|
||||
}, []);
|
||||
|
||||
useEffect(() => {
|
||||
const handleKeyDown = (event: KeyboardEvent) => {
|
||||
const key = event.key.toLowerCase();
|
||||
if ((event.metaKey || event.ctrlKey) && key === 'z') {
|
||||
event.preventDefault();
|
||||
if (event.shiftKey) redoMasks();
|
||||
else undoMasks();
|
||||
return;
|
||||
}
|
||||
if ((event.metaKey || event.ctrlKey) && key === 'y') {
|
||||
event.preventDefault();
|
||||
redoMasks();
|
||||
return;
|
||||
}
|
||||
if ((event.key === 'Delete' || event.key === 'Backspace') && selectedMask && selectedVertexIndex !== null) {
|
||||
const currentPoints = segmentationToPoints(selectedMask.segmentation, selectedPolygonIndex);
|
||||
if (currentPoints.length > 3) {
|
||||
event.preventDefault();
|
||||
const nextPoints = currentPoints.filter((_, index) => index !== selectedVertexIndex);
|
||||
updatePolygonMask(selectedMask, nextPoints, selectedPolygonIndex);
|
||||
setSelectedVertexIndex(null);
|
||||
}
|
||||
return;
|
||||
}
|
||||
if (effectiveTool !== POLYGON_TOOL) return;
|
||||
if (event.key === 'Enter' && polygonPoints.length >= 3) {
|
||||
event.preventDefault();
|
||||
createManualMask('多边形', polygonPoints);
|
||||
setPolygonPoints([]);
|
||||
}
|
||||
if (event.key === 'Escape') {
|
||||
event.preventDefault();
|
||||
setPolygonPoints([]);
|
||||
}
|
||||
};
|
||||
|
||||
window.addEventListener('keydown', handleKeyDown);
|
||||
return () => window.removeEventListener('keydown', handleKeyDown);
|
||||
}, [createManualMask, effectiveTool, polygonPoints, redoMasks, selectedMask, selectedPolygonIndex, selectedVertexIndex, undoMasks, updatePolygonMask]);
|
||||
|
||||
const boxRect = React.useMemo(() => {
|
||||
if (!boxStart || !boxCurrent) return null;
|
||||
const x = Math.min(boxStart.x, boxCurrent.x);
|
||||
@@ -221,6 +597,132 @@ export function CanvasArea({ activeTool, frame, onClearMasks }: CanvasAreaProps)
|
||||
return { x, y, width, height };
|
||||
}, [boxStart, boxCurrent]);
|
||||
|
||||
const manualPreviewPath = React.useMemo(() => {
|
||||
if (manualStart && manualCurrent) {
|
||||
if (effectiveTool === 'create_rectangle') return polygonPath(rectanglePoints(manualStart, manualCurrent));
|
||||
if (effectiveTool === 'create_circle') return polygonPath(circlePoints(manualStart, manualCurrent));
|
||||
if (effectiveTool === 'create_line') return polygonPath(lineRegion(manualStart, manualCurrent));
|
||||
}
|
||||
if (effectiveTool === POLYGON_TOOL && polygonPoints.length > 0) {
|
||||
const previewPoints = [...polygonPoints, cursorPos];
|
||||
return polygonPath(previewPoints);
|
||||
}
|
||||
return null;
|
||||
}, [cursorPos, effectiveTool, manualCurrent, manualStart, polygonPoints]);
|
||||
|
||||
const handleSeedPointDragEnd = (mask: Mask, pointIndex: number, event: any) => {
|
||||
const x = event.target.x();
|
||||
const y = event.target.y();
|
||||
const nextPoints = [...(mask.points || [])];
|
||||
nextPoints[pointIndex] = [x, y];
|
||||
updateMask(mask.id, {
|
||||
points: nextPoints,
|
||||
saveStatus: mask.annotationId ? 'dirty' : 'draft',
|
||||
saved: mask.annotationId ? false : mask.saved,
|
||||
});
|
||||
};
|
||||
|
||||
const handleMaskSelect = (mask: Mask, event: any, polygonIndex = 0) => {
|
||||
event.cancelBubble = true;
|
||||
if (BOOLEAN_TOOLS.has(effectiveTool)) {
|
||||
setSelectedMaskIds((current) => (
|
||||
current.includes(mask.id)
|
||||
? current.filter((id) => id !== mask.id)
|
||||
: [...current, mask.id]
|
||||
));
|
||||
setSelectedMaskId(mask.id);
|
||||
setSelectedPolygonIndex(polygonIndex);
|
||||
setSelectedVertexIndex(null);
|
||||
return;
|
||||
}
|
||||
setSelectedMaskId(mask.id);
|
||||
setSelectedMaskIds([mask.id]);
|
||||
setSelectedPolygonIndex(polygonIndex);
|
||||
setSelectedVertexIndex(null);
|
||||
};
|
||||
|
||||
const handleVertexDragEnd = (mask: Mask, vertexIndex: number, event: any) => {
|
||||
const imageWidth = frame?.width || image?.naturalWidth || image?.width || stageSize.width;
|
||||
const imageHeight = frame?.height || image?.naturalHeight || image?.height || stageSize.height;
|
||||
const currentPoints = segmentationToPoints(mask.segmentation, selectedPolygonIndex);
|
||||
if (!currentPoints[vertexIndex]) return;
|
||||
const nextPoints = currentPoints.map((point, index) => (
|
||||
index === vertexIndex
|
||||
? {
|
||||
x: clamp(event.target.x(), 0, imageWidth),
|
||||
y: clamp(event.target.y(), 0, imageHeight),
|
||||
}
|
||||
: point
|
||||
));
|
||||
setSelectedMaskId(mask.id);
|
||||
setSelectedVertexIndex(vertexIndex);
|
||||
updatePolygonMask(mask, nextPoints, selectedPolygonIndex);
|
||||
};
|
||||
|
||||
const handleEdgeInsert = (mask: Mask, edgeIndex: number, event: any) => {
|
||||
event.cancelBubble = true;
|
||||
const currentPoints = segmentationToPoints(mask.segmentation, selectedPolygonIndex);
|
||||
const start = currentPoints[edgeIndex];
|
||||
const end = currentPoints[(edgeIndex + 1) % currentPoints.length];
|
||||
if (!start || !end) return;
|
||||
const inserted = { x: (start.x + end.x) / 2, y: (start.y + end.y) / 2 };
|
||||
const nextPoints = [
|
||||
...currentPoints.slice(0, edgeIndex + 1),
|
||||
inserted,
|
||||
...currentPoints.slice(edgeIndex + 1),
|
||||
];
|
||||
setSelectedMaskId(mask.id);
|
||||
setSelectedVertexIndex(edgeIndex + 1);
|
||||
updatePolygonMask(mask, nextPoints, selectedPolygonIndex);
|
||||
};
|
||||
|
||||
const handleBooleanOperation = async () => {
|
||||
if (!frame || booleanSelectedMasks.length < 2) return;
|
||||
const primary = booleanSelectedMasks[0];
|
||||
const primaryGeometry = maskToMultiPolygon(primary);
|
||||
if (!primaryGeometry) return;
|
||||
|
||||
const clipGeometries = booleanSelectedMasks
|
||||
.slice(1)
|
||||
.map(maskToMultiPolygon)
|
||||
.filter((geometry): geometry is MultiPolygon => Boolean(geometry));
|
||||
if (clipGeometries.length === 0) return;
|
||||
|
||||
const resultGeometry = effectiveTool === 'area_merge'
|
||||
? polygonClipping.union(primaryGeometry, ...clipGeometries)
|
||||
: polygonClipping.difference(primaryGeometry, ...clipGeometries);
|
||||
const resultSegmentation = multiPolygonToSegmentation(resultGeometry);
|
||||
|
||||
if (resultSegmentation.length === 0) {
|
||||
const deleteIds = primary.annotationId ? [primary.annotationId] : [];
|
||||
setMasks(masks.filter((mask) => mask.id !== primary.id));
|
||||
if (deleteIds.length > 0) await onDeleteMaskAnnotations?.(deleteIds);
|
||||
setSelectedMaskId(null);
|
||||
setSelectedMaskIds([]);
|
||||
setSelectedVertexIndex(null);
|
||||
return;
|
||||
}
|
||||
|
||||
const nextPrimary = updateMaskFromSegmentation(primary, resultSegmentation);
|
||||
const secondaryIds = effectiveTool === 'area_merge'
|
||||
? new Set(booleanSelectedMasks.slice(1).map((mask) => mask.id))
|
||||
: new Set<string>();
|
||||
const secondaryAnnotationIds = effectiveTool === 'area_merge'
|
||||
? booleanSelectedMasks
|
||||
.slice(1)
|
||||
.map((mask) => mask.annotationId)
|
||||
.filter((annotationId): annotationId is string => Boolean(annotationId))
|
||||
: [];
|
||||
|
||||
setMasks(masks
|
||||
.filter((mask) => !secondaryIds.has(mask.id))
|
||||
.map((mask) => (mask.id === primary.id ? nextPrimary : mask)));
|
||||
if (secondaryAnnotationIds.length > 0) await onDeleteMaskAnnotations?.(secondaryAnnotationIds);
|
||||
setSelectedMaskId(primary.id);
|
||||
setSelectedMaskIds([primary.id]);
|
||||
setSelectedVertexIndex(null);
|
||||
};
|
||||
|
||||
return (
|
||||
<div ref={containerRef} className="w-full h-full relative cursor-crosshair overflow-hidden rounded-sm">
|
||||
{isInferencing && (
|
||||
@@ -257,13 +759,18 @@ export function CanvasArea({ activeTool, frame, onClearMasks }: CanvasAreaProps)
|
||||
|
||||
{/* AI Returned Masks */}
|
||||
{frameMasks.map((mask) => (
|
||||
<Group key={mask.id} opacity={0.5}>
|
||||
<Path
|
||||
data={mask.pathData}
|
||||
fill={mask.color}
|
||||
stroke={mask.color}
|
||||
strokeWidth={1 / scale}
|
||||
/>
|
||||
<Group key={mask.id} opacity={selectedMaskIds.includes(mask.id) ? 0.65 : 0.5}>
|
||||
{(mask.segmentation && mask.segmentation.length > 0 ? mask.segmentation : [undefined]).map((_, polygonIndex) => (
|
||||
<Path
|
||||
key={`${mask.id}-polygon-${polygonIndex}`}
|
||||
data={mask.segmentation ? segmentationPolygonPath(mask.segmentation, polygonIndex) : mask.pathData}
|
||||
fill={mask.color}
|
||||
stroke={mask.color}
|
||||
strokeWidth={(selectedMaskIds.includes(mask.id) ? 2 : 1) / scale}
|
||||
onClick={(event: any) => handleMaskSelect(mask, event, polygonIndex)}
|
||||
onTap={(event: any) => handleMaskSelect(mask, event, polygonIndex)}
|
||||
/>
|
||||
))}
|
||||
</Group>
|
||||
))}
|
||||
|
||||
@@ -281,6 +788,86 @@ export function CanvasArea({ activeTool, frame, onClearMasks }: CanvasAreaProps)
|
||||
/>
|
||||
)}
|
||||
|
||||
{/* Manual shape preview */}
|
||||
{manualPreviewPath && (
|
||||
<Path
|
||||
data={manualPreviewPath}
|
||||
fill="rgba(34, 211, 238, 0.12)"
|
||||
stroke="#22d3ee"
|
||||
strokeWidth={2 / scale}
|
||||
dash={[5 / scale, 5 / scale]}
|
||||
/>
|
||||
)}
|
||||
|
||||
{polygonPoints.map((point, index) => (
|
||||
<Circle
|
||||
key={`poly-point-${index}`}
|
||||
x={point.x}
|
||||
y={point.y}
|
||||
radius={4 / scale}
|
||||
fill="#22d3ee"
|
||||
stroke="#ffffff"
|
||||
strokeWidth={1 / scale}
|
||||
/>
|
||||
))}
|
||||
|
||||
{/* Imported GT seed points / editable point regions */}
|
||||
{frameMasks.flatMap((mask) => (mask.points || []).map(([x, y], index) => (
|
||||
<Group key={`${mask.id}-seed-${index}`} x={x} y={y}>
|
||||
<Circle
|
||||
radius={5 / scale}
|
||||
fill="#facc15"
|
||||
stroke="#111827"
|
||||
strokeWidth={2 / scale}
|
||||
draggable
|
||||
onDragEnd={(event: any) => handleSeedPointDragEnd(mask, index, event)}
|
||||
/>
|
||||
<Circle radius={1.5 / scale} fill="#111827" />
|
||||
</Group>
|
||||
)))}
|
||||
|
||||
{/* Polygon edge insertion handles */}
|
||||
{selectedMask && selectedMaskPoints.map((point, index) => {
|
||||
const next = selectedMaskPoints[(index + 1) % selectedMaskPoints.length];
|
||||
if (!next) return null;
|
||||
return (
|
||||
<Circle
|
||||
key={`${selectedMask.id}-edge-${selectedPolygonIndex}-${index}`}
|
||||
x={(point.x + next.x) / 2}
|
||||
y={(point.y + next.y) / 2}
|
||||
radius={3.5 / scale}
|
||||
fill="#22d3ee"
|
||||
stroke="#111827"
|
||||
strokeWidth={1.5 / scale}
|
||||
onClick={(event: any) => handleEdgeInsert(selectedMask, index, event)}
|
||||
onTap={(event: any) => handleEdgeInsert(selectedMask, index, event)}
|
||||
/>
|
||||
);
|
||||
})}
|
||||
|
||||
{/* Polygon vertex editor */}
|
||||
{selectedMask && selectedMaskPoints.map((point, index) => (
|
||||
<Circle
|
||||
key={`${selectedMask.id}-vertex-${selectedPolygonIndex}-${index}`}
|
||||
x={point.x}
|
||||
y={point.y}
|
||||
radius={(selectedVertexIndex === index ? 6 : 4.5) / scale}
|
||||
fill={selectedVertexIndex === index ? '#22d3ee' : '#ffffff'}
|
||||
stroke={selectedMask.color}
|
||||
strokeWidth={2 / scale}
|
||||
draggable
|
||||
onClick={(event: any) => {
|
||||
event.cancelBubble = true;
|
||||
setSelectedVertexIndex(index);
|
||||
}}
|
||||
onTap={(event: any) => {
|
||||
event.cancelBubble = true;
|
||||
setSelectedVertexIndex(index);
|
||||
}}
|
||||
onDragEnd={(event: any) => handleVertexDragEnd(selectedMask, index, event)}
|
||||
/>
|
||||
))}
|
||||
|
||||
{/* AI Prompts Point Regions */}
|
||||
{points.map((p, i) => (
|
||||
<Group key={i} x={p.x} y={p.y}>
|
||||
@@ -313,6 +900,14 @@ export function CanvasArea({ activeTool, frame, onClearMasks }: CanvasAreaProps)
|
||||
|
||||
{frameMasks.length > 0 && (
|
||||
<div className="absolute bottom-4 right-4 flex gap-2">
|
||||
{BOOLEAN_TOOLS.has(effectiveTool) && booleanSelectedMasks.length >= 2 && (
|
||||
<button
|
||||
onClick={handleBooleanOperation}
|
||||
className="text-xs bg-emerald-500/10 hover:bg-emerald-500/20 text-emerald-300 border border-emerald-500/20 px-3 py-1.5 rounded transition-colors"
|
||||
>
|
||||
{effectiveTool === 'area_merge' ? '合并选中' : '从主区域去除'}
|
||||
</button>
|
||||
)}
|
||||
{activeClass && (
|
||||
<button
|
||||
onClick={handleApplyActiveClass}
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
import { act, render, screen, waitFor } from '@testing-library/react';
|
||||
import { act, fireEvent, render, screen, waitFor } from '@testing-library/react';
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest';
|
||||
import { Dashboard } from './Dashboard';
|
||||
|
||||
const apiMock = vi.hoisted(() => ({
|
||||
getDashboardOverview: vi.fn(),
|
||||
cancelTask: vi.fn(),
|
||||
retryTask: vi.fn(),
|
||||
getTask: vi.fn(),
|
||||
}));
|
||||
|
||||
const wsMock = vi.hoisted(() => {
|
||||
@@ -31,6 +34,9 @@ vi.mock('../lib/websocket', () => ({
|
||||
|
||||
vi.mock('../lib/api', () => ({
|
||||
getDashboardOverview: apiMock.getDashboardOverview,
|
||||
cancelTask: apiMock.cancelTask,
|
||||
retryTask: apiMock.retryTask,
|
||||
getTask: apiMock.getTask,
|
||||
}));
|
||||
|
||||
describe('Dashboard', () => {
|
||||
@@ -55,6 +61,8 @@ describe('Dashboard', () => {
|
||||
name: '真实项目.mp4',
|
||||
progress: 60,
|
||||
status: 'pending',
|
||||
raw_status: 'running',
|
||||
error: null,
|
||||
frame_count: 10,
|
||||
updated_at: '2026-05-01T00:00:00Z',
|
||||
},
|
||||
@@ -112,4 +120,100 @@ describe('Dashboard', () => {
|
||||
await waitFor(() => expect(screen.getByText('Progress stream active')).toBeInTheDocument());
|
||||
expect(screen.getByText('解析完成: done.mp4')).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('cancels, retries, and opens task failure details', async () => {
|
||||
apiMock.getDashboardOverview.mockResolvedValueOnce({
|
||||
summary: {
|
||||
project_count: 1,
|
||||
parsing_task_count: 1,
|
||||
annotation_count: 0,
|
||||
frame_count: 0,
|
||||
template_count: 0,
|
||||
system_load_percent: 5,
|
||||
},
|
||||
tasks: [
|
||||
{
|
||||
id: 'task-10',
|
||||
task_id: 10,
|
||||
project_id: 1,
|
||||
name: 'running.mp4',
|
||||
progress: 30,
|
||||
status: '正在下载媒体文件',
|
||||
raw_status: 'running',
|
||||
error: null,
|
||||
frame_count: 0,
|
||||
updated_at: '2026-05-01T00:00:00Z',
|
||||
},
|
||||
{
|
||||
id: 'task-11',
|
||||
task_id: 11,
|
||||
project_id: 1,
|
||||
name: 'failed.mp4',
|
||||
progress: 100,
|
||||
status: '解析失败',
|
||||
raw_status: 'failed',
|
||||
error: 'ffmpeg failed',
|
||||
frame_count: 0,
|
||||
updated_at: '2026-05-01T00:01:00Z',
|
||||
},
|
||||
],
|
||||
activity: [],
|
||||
});
|
||||
apiMock.cancelTask.mockResolvedValueOnce({
|
||||
id: 10,
|
||||
task_type: 'parse_video',
|
||||
status: 'cancelled',
|
||||
progress: 100,
|
||||
message: '任务已取消',
|
||||
project_id: 1,
|
||||
error: 'Cancelled by user',
|
||||
result: null,
|
||||
payload: { source_type: 'video' },
|
||||
created_at: 'created',
|
||||
updated_at: 'updated',
|
||||
});
|
||||
apiMock.retryTask.mockResolvedValueOnce({
|
||||
id: 12,
|
||||
task_type: 'parse_video',
|
||||
status: 'queued',
|
||||
progress: 0,
|
||||
message: '重试任务已入队(源任务 #11)',
|
||||
project_id: 1,
|
||||
error: null,
|
||||
result: null,
|
||||
payload: { source_type: 'video', retry_of: 11 },
|
||||
created_at: 'created',
|
||||
updated_at: 'updated',
|
||||
});
|
||||
apiMock.getTask.mockResolvedValueOnce({
|
||||
id: 11,
|
||||
task_type: 'parse_video',
|
||||
status: 'failed',
|
||||
progress: 100,
|
||||
message: '解析失败',
|
||||
project_id: 1,
|
||||
celery_task_id: 'celery-11',
|
||||
payload: { source_type: 'video' },
|
||||
result: null,
|
||||
error: 'ffmpeg failed',
|
||||
created_at: 'created',
|
||||
started_at: 'started',
|
||||
finished_at: 'finished',
|
||||
updated_at: 'updated',
|
||||
});
|
||||
|
||||
render(<Dashboard />);
|
||||
|
||||
await screen.findByText('running.mp4');
|
||||
fireEvent.click(screen.getByRole('button', { name: '取消' }));
|
||||
await waitFor(() => expect(apiMock.cancelTask).toHaveBeenCalledWith(10));
|
||||
|
||||
fireEvent.click(screen.getAllByRole('button', { name: '详情' })[1]);
|
||||
await waitFor(() => expect(apiMock.getTask).toHaveBeenCalledWith(11));
|
||||
expect(await screen.findByText('任务详情 #11')).toBeInTheDocument();
|
||||
expect(screen.getByText('ffmpeg failed')).toBeInTheDocument();
|
||||
|
||||
fireEvent.click(screen.getAllByRole('button', { name: '重试' })[1]);
|
||||
await waitFor(() => expect(apiMock.retryTask).toHaveBeenCalledWith(11));
|
||||
});
|
||||
});
|
||||
|
||||
@@ -1,8 +1,17 @@
|
||||
import React, { useState, useEffect } from 'react';
|
||||
import { Activity, Clock, Folders, CheckCircle2, Loader2 } from 'lucide-react';
|
||||
import { Activity, AlertTriangle, Clock, Folders, CheckCircle2, Info, Loader2, RotateCcw, XCircle } from 'lucide-react';
|
||||
import { progressWS, type ProgressMessage } from '../lib/websocket';
|
||||
import { cn } from '../lib/utils';
|
||||
import { getDashboardOverview, type DashboardActivity, type DashboardOverview, type DashboardTask } from '../lib/api';
|
||||
import {
|
||||
cancelTask,
|
||||
getDashboardOverview,
|
||||
getTask,
|
||||
retryTask,
|
||||
type DashboardActivity,
|
||||
type DashboardOverview,
|
||||
type DashboardTask,
|
||||
type ProcessingTask,
|
||||
} from '../lib/api';
|
||||
|
||||
const emptySummary: DashboardOverview['summary'] = {
|
||||
project_count: 0,
|
||||
@@ -20,6 +29,29 @@ export function Dashboard() {
|
||||
const [activityLog, setActivityLog] = useState<DashboardActivity[]>([]);
|
||||
const [isLoading, setIsLoading] = useState(true);
|
||||
const [loadError, setLoadError] = useState('');
|
||||
const [selectedTask, setSelectedTask] = useState<ProcessingTask | null>(null);
|
||||
const [taskActionMessage, setTaskActionMessage] = useState('');
|
||||
const [busyTaskId, setBusyTaskId] = useState<string | null>(null);
|
||||
|
||||
const taskFromProcessingTask = (task: ProcessingTask, name = `任务 ${task.id}`): DashboardTask => ({
|
||||
id: `task-${task.id}`,
|
||||
task_id: task.id,
|
||||
project_id: task.project_id ?? 0,
|
||||
name,
|
||||
progress: task.progress,
|
||||
status: task.message || task.status,
|
||||
raw_status: task.status,
|
||||
error: task.error,
|
||||
frame_count: Number(task.result?.frames_extracted || 0),
|
||||
updated_at: task.updated_at,
|
||||
});
|
||||
|
||||
const prependActivity = (message: string, project = '系统') => {
|
||||
setActivityLog((prev) => [
|
||||
{ id: `task-action-${Date.now()}`, kind: 'task', time: new Date().toISOString(), message, project },
|
||||
...prev.slice(0, 9),
|
||||
]);
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
let cancelled = false;
|
||||
@@ -90,6 +122,8 @@ export function Dashboard() {
|
||||
name: taskTitle(data),
|
||||
progress: data.progress ?? 0,
|
||||
status: data.status ?? '处理中',
|
||||
raw_status: 'running',
|
||||
error: data.error,
|
||||
frame_count: 0,
|
||||
updated_at: new Date().toISOString(),
|
||||
},
|
||||
@@ -100,7 +134,7 @@ export function Dashboard() {
|
||||
if (data.type === 'complete' && data.taskId) {
|
||||
setTasks((prev) =>
|
||||
prev.map((t) =>
|
||||
t.id === data.taskId ? { ...t, progress: 100, status: '已完成' } : t
|
||||
t.id === data.taskId ? { ...t, progress: 100, status: '已完成', raw_status: 'success' } : t
|
||||
)
|
||||
);
|
||||
setActivityLog((prev) => [
|
||||
@@ -109,10 +143,26 @@ export function Dashboard() {
|
||||
]);
|
||||
}
|
||||
|
||||
if (data.type === 'cancelled' && data.taskId) {
|
||||
setTasks((prev) =>
|
||||
prev.map((t) =>
|
||||
t.id === data.taskId
|
||||
? { ...t, progress: 100, status: data.message || '任务已取消', raw_status: 'cancelled', error: data.error }
|
||||
: t
|
||||
)
|
||||
);
|
||||
setActivityLog((prev) => [
|
||||
{ id: `ws-cancelled-${Date.now()}`, kind: 'websocket', time: new Date().toISOString(), message: data.message || `任务已取消: ${taskTitle(data)}`, project: data.projectName || '系统' },
|
||||
...prev.slice(0, 9),
|
||||
]);
|
||||
}
|
||||
|
||||
if (data.type === 'error' && data.taskId) {
|
||||
setTasks((prev) =>
|
||||
prev.map((t) =>
|
||||
t.id === data.taskId ? { ...t, progress: data.progress ?? t.progress, status: `错误: ${data.error || data.message || '未知错误'}` } : t
|
||||
t.id === data.taskId
|
||||
? { ...t, progress: data.progress ?? t.progress, status: `错误: ${data.error || data.message || '未知错误'}`, raw_status: 'failed', error: data.error }
|
||||
: t
|
||||
)
|
||||
);
|
||||
setActivityLog((prev) => [
|
||||
@@ -160,6 +210,65 @@ export function Dashboard() {
|
||||
});
|
||||
}
|
||||
|
||||
const taskRawStatus = (task: DashboardTask): string => task.raw_status || (
|
||||
task.status.includes('取消') ? 'cancelled'
|
||||
: task.status.includes('失败') || task.status.includes('错误') ? 'failed'
|
||||
: task.progress >= 100 ? 'success'
|
||||
: 'running'
|
||||
);
|
||||
|
||||
const canCancel = (task: DashboardTask): boolean => ['queued', 'running'].includes(taskRawStatus(task)) && Boolean(task.task_id);
|
||||
const canRetry = (task: DashboardTask): boolean => ['failed', 'cancelled'].includes(taskRawStatus(task)) && Boolean(task.task_id);
|
||||
|
||||
const handleCancelTask = async (task: DashboardTask) => {
|
||||
if (!task.task_id) return;
|
||||
setBusyTaskId(task.id);
|
||||
setTaskActionMessage('');
|
||||
try {
|
||||
const updated = await cancelTask(task.task_id);
|
||||
setTasks((prev) => prev.map((item) => (
|
||||
item.id === task.id ? taskFromProcessingTask(updated, task.name) : item
|
||||
)));
|
||||
prependActivity(`任务已取消 #${updated.id}`, task.name);
|
||||
} catch (err) {
|
||||
console.error('Cancel task failed:', err);
|
||||
setTaskActionMessage('任务取消失败,请检查后端服务');
|
||||
} finally {
|
||||
setBusyTaskId(null);
|
||||
}
|
||||
};
|
||||
|
||||
const handleRetryTask = async (task: DashboardTask) => {
|
||||
if (!task.task_id) return;
|
||||
setBusyTaskId(task.id);
|
||||
setTaskActionMessage('');
|
||||
try {
|
||||
const retried = await retryTask(task.task_id);
|
||||
const dashboardTask = taskFromProcessingTask(retried, task.name);
|
||||
setTasks((prev) => [dashboardTask, ...prev.filter((item) => item.id !== dashboardTask.id)]);
|
||||
prependActivity(`重试任务已入队 #${retried.id}`, task.name);
|
||||
} catch (err) {
|
||||
console.error('Retry task failed:', err);
|
||||
setTaskActionMessage('任务重试失败,请检查后端服务');
|
||||
} finally {
|
||||
setBusyTaskId(null);
|
||||
}
|
||||
};
|
||||
|
||||
const handleOpenTaskDetail = async (task: DashboardTask) => {
|
||||
if (!task.task_id) return;
|
||||
setBusyTaskId(task.id);
|
||||
setTaskActionMessage('');
|
||||
try {
|
||||
setSelectedTask(await getTask(task.task_id));
|
||||
} catch (err) {
|
||||
console.error('Load task detail failed:', err);
|
||||
setTaskActionMessage('失败详情加载失败');
|
||||
} finally {
|
||||
setBusyTaskId(null);
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="p-8 w-full h-full overflow-y-auto bg-[#0a0a0a]">
|
||||
<header className="mb-8">
|
||||
@@ -177,6 +286,7 @@ export function Dashboard() {
|
||||
</div>
|
||||
<p className="text-gray-400 text-sm mt-1">系统全局数据吞吐状态与所有接入项目进度实时洞察驾驶舱。</p>
|
||||
{loadError && <p className="text-red-400 text-xs mt-2">{loadError}</p>}
|
||||
{taskActionMessage && <p className="text-amber-400 text-xs mt-2">{taskActionMessage}</p>}
|
||||
</header>
|
||||
|
||||
<div className="grid grid-cols-1 md:grid-cols-2 lg:grid-cols-4 gap-4 mb-8">
|
||||
@@ -213,16 +323,47 @@ export function Dashboard() {
|
||||
<div className="h-full bg-gradient-to-r from-cyan-600 to-cyan-400 rounded-full transition-all duration-500" style={{ width: `${task.progress}%` }} />
|
||||
</div>
|
||||
<div className="text-xs text-gray-500 flex items-center gap-2">
|
||||
{task.status === '已完成' || task.progress >= 100 ? (
|
||||
{taskRawStatus(task) === 'success' || task.status === '已完成' ? (
|
||||
<CheckCircle2 size={12} className="text-emerald-400" />
|
||||
) : task.status.includes('错误') ? (
|
||||
<span className="text-red-400">●</span>
|
||||
) : taskRawStatus(task) === 'failed' ? (
|
||||
<AlertTriangle size={12} className="text-red-400" />
|
||||
) : taskRawStatus(task) === 'cancelled' ? (
|
||||
<XCircle size={12} className="text-amber-400" />
|
||||
) : (
|
||||
<Loader2 size={12} className="text-cyan-400 animate-spin" />
|
||||
)}
|
||||
{task.status}
|
||||
<span className="text-gray-600">帧: {task.frame_count}</span>
|
||||
</div>
|
||||
<div className="mt-3 flex flex-wrap items-center gap-2">
|
||||
{canCancel(task) && (
|
||||
<button
|
||||
onClick={() => handleCancelTask(task)}
|
||||
disabled={busyTaskId === task.id}
|
||||
className="inline-flex items-center gap-1 rounded border border-red-500/20 bg-red-500/10 px-2 py-1 text-[11px] text-red-300 hover:bg-red-500/20 disabled:opacity-40 disabled:cursor-not-allowed"
|
||||
>
|
||||
<XCircle size={12} /> 取消
|
||||
</button>
|
||||
)}
|
||||
{canRetry(task) && (
|
||||
<button
|
||||
onClick={() => handleRetryTask(task)}
|
||||
disabled={busyTaskId === task.id}
|
||||
className="inline-flex items-center gap-1 rounded border border-cyan-500/20 bg-cyan-500/10 px-2 py-1 text-[11px] text-cyan-300 hover:bg-cyan-500/20 disabled:opacity-40 disabled:cursor-not-allowed"
|
||||
>
|
||||
<RotateCcw size={12} /> 重试
|
||||
</button>
|
||||
)}
|
||||
{task.task_id && (
|
||||
<button
|
||||
onClick={() => handleOpenTaskDetail(task)}
|
||||
disabled={busyTaskId === task.id}
|
||||
className="inline-flex items-center gap-1 rounded border border-white/10 bg-white/5 px-2 py-1 text-[11px] text-gray-300 hover:bg-white/10 disabled:opacity-40 disabled:cursor-not-allowed"
|
||||
>
|
||||
<Info size={12} /> 详情
|
||||
</button>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
))}
|
||||
{!isLoading && tasks.length === 0 && (
|
||||
@@ -253,6 +394,46 @@ export function Dashboard() {
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{selectedTask && (
|
||||
<div className="fixed inset-0 z-50 flex items-center justify-center bg-black/70 px-4">
|
||||
<div className="w-full max-w-2xl rounded-lg border border-white/10 bg-[#111] p-5 shadow-2xl">
|
||||
<div className="flex items-center justify-between gap-3 border-b border-white/10 pb-3">
|
||||
<div>
|
||||
<h3 className="text-sm font-semibold text-white">任务详情 #{selectedTask.id}</h3>
|
||||
<p className="mt-1 text-xs text-gray-500">{selectedTask.message || selectedTask.status}</p>
|
||||
</div>
|
||||
<button
|
||||
onClick={() => setSelectedTask(null)}
|
||||
className="rounded border border-white/10 bg-white/5 px-2 py-1 text-xs text-gray-300 hover:bg-white/10"
|
||||
>
|
||||
关闭
|
||||
</button>
|
||||
</div>
|
||||
<div className="mt-4 grid grid-cols-2 gap-3 text-xs text-gray-400">
|
||||
<div>状态: <span className="text-gray-200">{selectedTask.status}</span></div>
|
||||
<div>进度: <span className="text-gray-200">{selectedTask.progress}%</span></div>
|
||||
<div>项目 ID: <span className="text-gray-200">{selectedTask.project_id ?? '-'}</span></div>
|
||||
<div>Celery ID: <span className="text-gray-200">{selectedTask.celery_task_id || '-'}</span></div>
|
||||
<div>创建: <span className="text-gray-200">{selectedTask.created_at}</span></div>
|
||||
<div>结束: <span className="text-gray-200">{selectedTask.finished_at || '-'}</span></div>
|
||||
</div>
|
||||
{selectedTask.error && (
|
||||
<div className="mt-4 rounded border border-red-500/20 bg-red-500/10 p-3 text-xs text-red-200">
|
||||
{selectedTask.error}
|
||||
</div>
|
||||
)}
|
||||
<div className="mt-4 grid gap-3 md:grid-cols-2">
|
||||
<pre className="max-h-48 overflow-auto rounded border border-white/10 bg-[#0a0a0a] p-3 text-[11px] text-gray-300">
|
||||
{JSON.stringify(selectedTask.payload || {}, null, 2)}
|
||||
</pre>
|
||||
<pre className="max-h-48 overflow-auto rounded border border-white/10 bg-[#0a0a0a] p-3 text-[11px] text-gray-300">
|
||||
{JSON.stringify(selectedTask.result || {}, null, 2)}
|
||||
</pre>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -3,18 +3,31 @@ import { describe, expect, it, vi } from 'vitest';
|
||||
import { ToolsPalette } from './ToolsPalette';
|
||||
|
||||
describe('ToolsPalette', () => {
|
||||
it('switches tools and exposes UI-only placeholder buttons', () => {
|
||||
it('switches tools and dispatches undo/redo actions when available', () => {
|
||||
const setActiveTool = vi.fn();
|
||||
const onUndo = vi.fn();
|
||||
const onRedo = vi.fn();
|
||||
|
||||
render(<ToolsPalette activeTool="move" setActiveTool={setActiveTool} />);
|
||||
render(
|
||||
<ToolsPalette
|
||||
activeTool="move"
|
||||
setActiveTool={setActiveTool}
|
||||
onUndo={onUndo}
|
||||
onRedo={onRedo}
|
||||
canUndo
|
||||
canRedo
|
||||
/>,
|
||||
);
|
||||
|
||||
fireEvent.click(screen.getByTitle('创建多边形 (P)'));
|
||||
fireEvent.click(screen.getByTitle('正向选点 (SAM)'));
|
||||
fireEvent.click(screen.getByTitle('撤销操作 (Ctrl+Z)'));
|
||||
fireEvent.click(screen.getByTitle('重做操作 (Ctrl+Shift+Z)'));
|
||||
|
||||
expect(setActiveTool).toHaveBeenNthCalledWith(1, 'create_polygon');
|
||||
expect(setActiveTool).toHaveBeenNthCalledWith(2, 'point_pos');
|
||||
expect(screen.getByTitle('撤销操作 (Ctrl+Z)')).toBeInTheDocument();
|
||||
expect(screen.getByTitle('重做操作 (Ctrl+Shift+Z)')).toBeInTheDocument();
|
||||
expect(onUndo).toHaveBeenCalled();
|
||||
expect(onRedo).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('switches to SAM trigger and calls the AI navigation hook', () => {
|
||||
|
||||
@@ -6,9 +6,21 @@ interface ToolsPaletteProps {
|
||||
activeTool: string;
|
||||
setActiveTool: (tool: string) => void;
|
||||
onTriggerAI?: () => void;
|
||||
onUndo?: () => void;
|
||||
onRedo?: () => void;
|
||||
canUndo?: boolean;
|
||||
canRedo?: boolean;
|
||||
}
|
||||
|
||||
export function ToolsPalette({ activeTool, setActiveTool, onTriggerAI }: ToolsPaletteProps) {
|
||||
export function ToolsPalette({
|
||||
activeTool,
|
||||
setActiveTool,
|
||||
onTriggerAI,
|
||||
onUndo,
|
||||
onRedo,
|
||||
canUndo = false,
|
||||
canRedo = false,
|
||||
}: ToolsPaletteProps) {
|
||||
const tools = [
|
||||
{ id: 'move', icon: MousePointer2, label: '拖拽 / 选择 (V)' },
|
||||
{ id: 'create_polygon', icon: Hexagon, label: '创建多边形 (P)' },
|
||||
@@ -91,10 +103,20 @@ export function ToolsPalette({ activeTool, setActiveTool, onTriggerAI }: ToolsPa
|
||||
|
||||
<div className="w-full h-px bg-white/10 my-1" />
|
||||
|
||||
<button className="w-10 h-10 rounded text-gray-500 hover:bg-white/5 hover:text-white flex items-center justify-center transition-colors" title="撤销操作 (Ctrl+Z)">
|
||||
<button
|
||||
onClick={onUndo}
|
||||
disabled={!canUndo}
|
||||
className="w-10 h-10 rounded text-gray-500 hover:bg-white/5 hover:text-white flex items-center justify-center transition-colors disabled:opacity-30 disabled:hover:bg-transparent disabled:hover:text-gray-500 disabled:cursor-not-allowed"
|
||||
title="撤销操作 (Ctrl+Z)"
|
||||
>
|
||||
<Undo size={18} />
|
||||
</button>
|
||||
<button className="w-10 h-10 rounded text-gray-500 hover:bg-white/5 hover:text-white flex items-center justify-center transition-colors" title="重做操作 (Ctrl+Shift+Z)">
|
||||
<button
|
||||
onClick={onRedo}
|
||||
disabled={!canRedo}
|
||||
className="w-10 h-10 rounded text-gray-500 hover:bg-white/5 hover:text-white flex items-center justify-center transition-colors disabled:opacity-30 disabled:hover:bg-transparent disabled:hover:text-gray-500 disabled:cursor-not-allowed"
|
||||
title="重做操作 (Ctrl+Shift+Z)"
|
||||
>
|
||||
<Redo size={18} />
|
||||
</button>
|
||||
|
||||
|
||||
@@ -14,6 +14,8 @@ const apiMock = vi.hoisted(() => ({
|
||||
updateAnnotation: vi.fn(),
|
||||
deleteAnnotation: vi.fn(),
|
||||
exportCoco: vi.fn(),
|
||||
exportMasks: vi.fn(),
|
||||
importGtMask: vi.fn(),
|
||||
annotationToMask: vi.fn(),
|
||||
buildAnnotationPayload: vi.fn(),
|
||||
getAiModelStatus: vi.fn(),
|
||||
@@ -29,6 +31,8 @@ vi.mock('../lib/api', () => ({
|
||||
updateAnnotation: apiMock.updateAnnotation,
|
||||
deleteAnnotation: apiMock.deleteAnnotation,
|
||||
exportCoco: apiMock.exportCoco,
|
||||
exportMasks: apiMock.exportMasks,
|
||||
importGtMask: apiMock.importGtMask,
|
||||
annotationToMask: apiMock.annotationToMask,
|
||||
buildAnnotationPayload: apiMock.buildAnnotationPayload,
|
||||
getAiModelStatus: apiMock.getAiModelStatus,
|
||||
@@ -256,4 +260,64 @@ describe('VideoWorkspace', () => {
|
||||
await waitFor(() => expect(apiMock.saveAnnotation).toHaveBeenCalled());
|
||||
expect(apiMock.exportCoco).toHaveBeenCalledWith('1');
|
||||
});
|
||||
|
||||
it('auto-saves pending masks before exporting PNG masks', async () => {
|
||||
apiMock.getProjectFrames.mockResolvedValueOnce([
|
||||
{ id: 10, project_id: 1, frame_index: 0, image_url: '/frame.jpg', width: 640, height: 360 },
|
||||
]);
|
||||
apiMock.buildAnnotationPayload.mockReturnValueOnce({ project_id: 1, frame_id: 10, mask_data: { polygons: [] } });
|
||||
apiMock.saveAnnotation.mockResolvedValueOnce({ id: 5 });
|
||||
apiMock.exportMasks.mockResolvedValueOnce(new Blob(['zip'], { type: 'application/zip' }));
|
||||
|
||||
render(<VideoWorkspace />);
|
||||
await waitFor(() => expect(useStore.getState().frames).toHaveLength(1));
|
||||
act(() => {
|
||||
useStore.setState({
|
||||
masks: [{
|
||||
id: 'mask-1',
|
||||
frameId: '10',
|
||||
pathData: 'M 0 0 Z',
|
||||
label: 'AI Mask',
|
||||
color: '#06b6d4',
|
||||
segmentation: [[0, 0, 10, 0, 10, 10]],
|
||||
}],
|
||||
});
|
||||
});
|
||||
|
||||
fireEvent.click(screen.getByRole('button', { name: '导出 PNG Mask ZIP' }));
|
||||
|
||||
await waitFor(() => expect(apiMock.saveAnnotation).toHaveBeenCalled());
|
||||
expect(apiMock.exportMasks).toHaveBeenCalledWith('1');
|
||||
});
|
||||
|
||||
it('imports a GT mask for the current frame and hydrates saved annotations', async () => {
|
||||
apiMock.getProjectFrames.mockResolvedValueOnce([
|
||||
{ id: 10, project_id: 1, frame_index: 0, image_url: '/frame.jpg', width: 640, height: 360 },
|
||||
]);
|
||||
apiMock.importGtMask.mockResolvedValueOnce([{ id: 88, frame_id: 10 }]);
|
||||
apiMock.getProjectAnnotations
|
||||
.mockResolvedValueOnce([])
|
||||
.mockResolvedValueOnce([{ id: 88, frame_id: 10 }]);
|
||||
apiMock.annotationToMask.mockReturnValueOnce({
|
||||
id: 'annotation-88',
|
||||
annotationId: '88',
|
||||
frameId: '10',
|
||||
saved: true,
|
||||
pathData: 'M 0 0 Z',
|
||||
label: 'GT Mask',
|
||||
color: '#22c55e',
|
||||
});
|
||||
|
||||
render(<VideoWorkspace />);
|
||||
await waitFor(() => expect(useStore.getState().frames).toHaveLength(1));
|
||||
|
||||
const fileInput = document.querySelector('input[type="file"]') as HTMLInputElement;
|
||||
const file = new File(['mask'], 'mask.png', { type: 'image/png' });
|
||||
fireEvent.change(fileInput, { target: { files: [file] } });
|
||||
|
||||
await waitFor(() => expect(apiMock.importGtMask).toHaveBeenCalledWith(file, '1', '10'));
|
||||
await waitFor(() => expect(useStore.getState().masks).toEqual([
|
||||
expect.objectContaining({ id: 'annotation-88', label: 'GT Mask' }),
|
||||
]));
|
||||
});
|
||||
});
|
||||
|
||||
@@ -5,10 +5,12 @@ import {
|
||||
buildAnnotationPayload,
|
||||
deleteAnnotation,
|
||||
exportCoco,
|
||||
exportMasks,
|
||||
getProjectAnnotations,
|
||||
getProjectFrames,
|
||||
getTask,
|
||||
getTemplates,
|
||||
importGtMask,
|
||||
parseMedia,
|
||||
saveAnnotation,
|
||||
updateAnnotation,
|
||||
@@ -25,18 +27,24 @@ function sleep(ms: number) {
|
||||
}
|
||||
|
||||
export function VideoWorkspace({ onNavigateToAI }: { onNavigateToAI?: () => void }) {
|
||||
const gtMaskInputRef = React.useRef<HTMLInputElement>(null);
|
||||
const activeTool = useStore((state) => state.activeTool);
|
||||
const setActiveTool = useStore((state) => state.setActiveTool);
|
||||
const currentProject = useStore((state) => state.currentProject);
|
||||
const frames = useStore((state) => state.frames);
|
||||
const currentFrameIndex = useStore((state) => state.currentFrameIndex);
|
||||
const masks = useStore((state) => state.masks);
|
||||
const maskHistory = useStore((state) => state.maskHistory);
|
||||
const maskFuture = useStore((state) => state.maskFuture);
|
||||
const activeTemplateId = useStore((state) => state.activeTemplateId);
|
||||
const setFrames = useStore((state) => state.setFrames);
|
||||
const setCurrentFrame = useStore((state) => state.setCurrentFrame);
|
||||
const setMasks = useStore((state) => state.setMasks);
|
||||
const undoMasks = useStore((state) => state.undoMasks);
|
||||
const redoMasks = useStore((state) => state.redoMasks);
|
||||
const [isSaving, setIsSaving] = useState(false);
|
||||
const [isExporting, setIsExporting] = useState(false);
|
||||
const [isImportingGt, setIsImportingGt] = useState(false);
|
||||
const [statusMessage, setStatusMessage] = useState('');
|
||||
|
||||
const hydrateSavedAnnotations = useCallback(async (projectId: string, projectFrames: Frame[]) => {
|
||||
@@ -216,6 +224,18 @@ export function VideoWorkspace({ onNavigateToAI }: { onNavigateToAI?: () => void
|
||||
}
|
||||
}, [currentFrame, masks, setMasks]);
|
||||
|
||||
const handleDeleteMaskAnnotations = useCallback(async (annotationIds: string[]) => {
|
||||
if (annotationIds.length === 0) return;
|
||||
try {
|
||||
await Promise.all(annotationIds.map((annotationId) => deleteAnnotation(annotationId)));
|
||||
setStatusMessage(`已删除 ${annotationIds.length} 个被合并标注`);
|
||||
} catch (err) {
|
||||
console.error('Delete merged annotations failed:', err);
|
||||
setStatusMessage('合并后删除原标注失败,请检查后端服务');
|
||||
throw err;
|
||||
}
|
||||
}, []);
|
||||
|
||||
const handleSave = async () => {
|
||||
try {
|
||||
await savePendingAnnotations();
|
||||
@@ -248,6 +268,52 @@ export function VideoWorkspace({ onNavigateToAI }: { onNavigateToAI?: () => void
|
||||
}
|
||||
};
|
||||
|
||||
const downloadBlob = (blob: Blob, filename: string) => {
|
||||
const url = URL.createObjectURL(blob);
|
||||
const link = document.createElement('a');
|
||||
link.href = url;
|
||||
link.download = filename;
|
||||
document.body.appendChild(link);
|
||||
link.click();
|
||||
link.remove();
|
||||
URL.revokeObjectURL(url);
|
||||
};
|
||||
|
||||
const handleExportMasks = async () => {
|
||||
if (!currentProject?.id) return;
|
||||
setIsExporting(true);
|
||||
setStatusMessage('正在准备导出语义 Mask ZIP...');
|
||||
try {
|
||||
await savePendingAnnotations({ silent: true });
|
||||
const blob = await exportMasks(currentProject.id);
|
||||
downloadBlob(blob, `project_${currentProject.id}_masks.zip`);
|
||||
setStatusMessage('PNG Mask ZIP 已导出');
|
||||
} catch (err) {
|
||||
console.error('Mask export failed:', err);
|
||||
setStatusMessage('Mask 导出失败,请检查后端服务');
|
||||
} finally {
|
||||
setIsExporting(false);
|
||||
}
|
||||
};
|
||||
|
||||
const handleImportGtMask = async (event: React.ChangeEvent<HTMLInputElement>) => {
|
||||
const file = event.target.files?.[0];
|
||||
if (!file || !currentProject?.id || !currentFrame?.id) return;
|
||||
setIsImportingGt(true);
|
||||
setStatusMessage('正在导入 GT Mask...');
|
||||
try {
|
||||
const imported = await importGtMask(file, currentProject.id, currentFrame.id);
|
||||
await hydrateSavedAnnotations(currentProject.id, frames);
|
||||
setStatusMessage(`已导入 ${imported.length} 个 GT 区域`);
|
||||
} catch (err) {
|
||||
console.error('GT mask import failed:', err);
|
||||
setStatusMessage('GT Mask 导入失败,请检查文件或后端服务');
|
||||
} finally {
|
||||
setIsImportingGt(false);
|
||||
event.target.value = '';
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="w-full h-full flex flex-col bg-[#0a0a0a]">
|
||||
{/* Top Header / Status bar */}
|
||||
@@ -264,6 +330,27 @@ export function VideoWorkspace({ onNavigateToAI }: { onNavigateToAI?: () => void
|
||||
</span>
|
||||
)}
|
||||
<ModelStatusBadge />
|
||||
<input
|
||||
ref={gtMaskInputRef}
|
||||
type="file"
|
||||
accept="image/png,image/jpeg,image/bmp,image/tiff"
|
||||
className="hidden"
|
||||
onChange={handleImportGtMask}
|
||||
/>
|
||||
<button
|
||||
onClick={() => gtMaskInputRef.current?.click()}
|
||||
disabled={!currentProject?.id || !currentFrame?.id || isImportingGt || isSaving || isExporting}
|
||||
className="px-4 py-1.5 bg-white/5 hover:bg-white/10 border border-white/10 rounded-md text-xs transition-colors text-white disabled:opacity-40 disabled:cursor-not-allowed"
|
||||
>
|
||||
{isImportingGt ? '导入中...' : '导入 GT Mask'}
|
||||
</button>
|
||||
<button
|
||||
onClick={handleExportMasks}
|
||||
disabled={!currentProject?.id || isExporting || isSaving}
|
||||
className="px-4 py-1.5 bg-white/5 hover:bg-white/10 border border-white/10 rounded-md text-xs transition-colors text-white disabled:opacity-40 disabled:cursor-not-allowed"
|
||||
>
|
||||
{isExporting ? '导出中...' : '导出 PNG Mask ZIP'}
|
||||
</button>
|
||||
<button
|
||||
onClick={handleExport}
|
||||
disabled={!currentProject?.id || isExporting || isSaving}
|
||||
@@ -283,11 +370,24 @@ export function VideoWorkspace({ onNavigateToAI }: { onNavigateToAI?: () => void
|
||||
|
||||
{/* Main Workspace Area */}
|
||||
<div className="flex-1 flex overflow-hidden">
|
||||
<ToolsPalette activeTool={activeTool} setActiveTool={setActiveTool} onTriggerAI={onNavigateToAI} />
|
||||
<ToolsPalette
|
||||
activeTool={activeTool}
|
||||
setActiveTool={setActiveTool}
|
||||
onTriggerAI={onNavigateToAI}
|
||||
onUndo={undoMasks}
|
||||
onRedo={redoMasks}
|
||||
canUndo={maskHistory.length > 0}
|
||||
canRedo={maskFuture.length > 0}
|
||||
/>
|
||||
|
||||
<div className="flex-1 relative flex items-center justify-center p-8 bg-[#151515] overflow-hidden">
|
||||
<div className="relative w-full h-full bg-[#1e1e1e] border border-white/5 shadow-2xl rounded-sm">
|
||||
<CanvasArea activeTool={activeTool} frame={currentFrame} onClearMasks={handleClearCurrentFrameMasks} />
|
||||
<CanvasArea
|
||||
activeTool={activeTool}
|
||||
frame={currentFrame}
|
||||
onClearMasks={handleClearCurrentFrameMasks}
|
||||
onDeleteMaskAnnotations={handleDeleteMaskAnnotations}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
|
||||
@@ -101,6 +101,17 @@ describe('api client contracts', () => {
|
||||
});
|
||||
});
|
||||
|
||||
it('exports PNG masks from the backend route shape', async () => {
|
||||
const { exportMasks } = await import('./api');
|
||||
const blob = new Blob(['zip'], { type: 'application/zip' });
|
||||
axiosMock.client.get.mockResolvedValueOnce({ data: blob });
|
||||
|
||||
await expect(exportMasks('9')).resolves.toBe(blob);
|
||||
expect(axiosMock.client.get).toHaveBeenCalledWith('/api/export/9/masks', {
|
||||
responseType: 'blob',
|
||||
});
|
||||
});
|
||||
|
||||
it('loads dashboard overview from the backend summary endpoint', async () => {
|
||||
const { getDashboardOverview } = await import('./api');
|
||||
const overview = {
|
||||
@@ -125,8 +136,8 @@ describe('api client contracts', () => {
|
||||
expect(axiosMock.client.get).toHaveBeenCalledWith('/api/dashboard/overview');
|
||||
});
|
||||
|
||||
it('queues media parsing and reads processing task status', async () => {
|
||||
const { getTask, parseMedia } = await import('./api');
|
||||
it('queues media parsing and manages processing task lifecycle', async () => {
|
||||
const { cancelTask, getTask, parseMedia, retryTask } = await import('./api');
|
||||
const task = {
|
||||
id: 12,
|
||||
task_type: 'parse_video',
|
||||
@@ -145,6 +156,8 @@ describe('api client contracts', () => {
|
||||
};
|
||||
axiosMock.client.post.mockResolvedValueOnce({ data: task });
|
||||
axiosMock.client.get.mockResolvedValueOnce({ data: { ...task, status: 'success', progress: 100 } });
|
||||
axiosMock.client.post.mockResolvedValueOnce({ data: { ...task, status: 'cancelled', progress: 100 } });
|
||||
axiosMock.client.post.mockResolvedValueOnce({ data: { ...task, id: 13, status: 'queued', progress: 0 } });
|
||||
|
||||
await expect(parseMedia('9')).resolves.toEqual(task);
|
||||
expect(axiosMock.client.post).toHaveBeenCalledWith('/api/media/parse', null, {
|
||||
@@ -153,6 +166,12 @@ describe('api client contracts', () => {
|
||||
|
||||
await expect(getTask(12)).resolves.toEqual(expect.objectContaining({ status: 'success', progress: 100 }));
|
||||
expect(axiosMock.client.get).toHaveBeenCalledWith('/api/tasks/12');
|
||||
|
||||
await expect(cancelTask(12)).resolves.toEqual(expect.objectContaining({ status: 'cancelled' }));
|
||||
expect(axiosMock.client.post).toHaveBeenCalledWith('/api/tasks/12/cancel');
|
||||
|
||||
await expect(retryTask(12)).resolves.toEqual(expect.objectContaining({ id: 13, status: 'queued' }));
|
||||
expect(axiosMock.client.post).toHaveBeenCalledWith('/api/tasks/12/retry');
|
||||
});
|
||||
|
||||
it('lists, saves, updates, and deletes annotations with the backend annotation contract', async () => {
|
||||
@@ -204,6 +223,25 @@ describe('api client contracts', () => {
|
||||
expect(axiosMock.client.delete).toHaveBeenCalledWith('/api/ai/annotations/1');
|
||||
});
|
||||
|
||||
it('imports GT masks through multipart form data', async () => {
|
||||
const { importGtMask } = await import('./api');
|
||||
const file = new File(['mask'], 'mask.png', { type: 'image/png' });
|
||||
const saved = [{ id: 1, project_id: 9, frame_id: 5, template_id: null, mask_data: null, points: null, bbox: null }];
|
||||
axiosMock.client.post.mockResolvedValueOnce({ data: saved });
|
||||
|
||||
await expect(importGtMask(file, '9', '5', '2')).resolves.toEqual(saved);
|
||||
expect(axiosMock.client.post).toHaveBeenCalledWith(
|
||||
'/api/ai/import-gt-mask',
|
||||
expect.any(FormData),
|
||||
{ headers: { 'Content-Type': 'multipart/form-data' } },
|
||||
);
|
||||
const form = axiosMock.client.post.mock.calls.at(-1)?.[1] as FormData;
|
||||
expect(form.get('file')).toBe(file);
|
||||
expect(form.get('project_id')).toBe('9');
|
||||
expect(form.get('frame_id')).toBe('5');
|
||||
expect(form.get('template_id')).toBe('2');
|
||||
});
|
||||
|
||||
it('builds annotation payloads from frontend masks and restores saved annotations to masks', async () => {
|
||||
const { annotationToMask, buildAnnotationPayload } = await import('./api');
|
||||
const frame = { id: '5', projectId: '9', index: 0, url: '/frame.jpg', width: 100, height: 50 };
|
||||
@@ -244,7 +282,7 @@ describe('api client contracts', () => {
|
||||
color: '#06b6d4',
|
||||
class: { id: 'c1', name: '胆囊', color: '#ff0000', zIndex: 20 },
|
||||
},
|
||||
points: null,
|
||||
points: [[0.5, 0.5]],
|
||||
bbox: null,
|
||||
created_at: 'created',
|
||||
updated_at: 'updated',
|
||||
@@ -261,10 +299,28 @@ describe('api client contracts', () => {
|
||||
saveStatus: 'saved',
|
||||
saved: true,
|
||||
pathData: 'M 10 10 L 90 10 L 90 40 Z',
|
||||
points: [[50, 25]],
|
||||
bbox: [10, 10, 80, 30],
|
||||
}));
|
||||
});
|
||||
|
||||
it('preserves editable point regions in annotation payloads', async () => {
|
||||
const { buildAnnotationPayload } = await import('./api');
|
||||
const frame = { id: '5', projectId: '9', index: 0, url: '/frame.jpg', width: 100, height: 50 };
|
||||
|
||||
expect(buildAnnotationPayload('9', {
|
||||
id: 'm1',
|
||||
frameId: '5',
|
||||
pathData: 'M 10 10 L 90 10 L 90 40 Z',
|
||||
label: 'GT Mask',
|
||||
color: '#22c55e',
|
||||
segmentation: [[10, 10, 90, 10, 90, 40]],
|
||||
points: [[50, 25]],
|
||||
}, frame)).toEqual(expect.objectContaining({
|
||||
points: [[0.5, 0.5]],
|
||||
}));
|
||||
});
|
||||
|
||||
it('normalizes positive and negative point prompts for AI prediction', async () => {
|
||||
const { predictMask } = await import('./api');
|
||||
axiosMock.client.post.mockResolvedValueOnce({
|
||||
@@ -341,6 +397,38 @@ describe('api client contracts', () => {
|
||||
});
|
||||
});
|
||||
|
||||
it('passes AI post-processing options to prediction endpoint', async () => {
|
||||
const { predictMask } = await import('./api');
|
||||
axiosMock.client.post.mockResolvedValueOnce({ data: { polygons: [], scores: [] } });
|
||||
|
||||
await predictMask({
|
||||
imageId: '7',
|
||||
imageWidth: 640,
|
||||
imageHeight: 360,
|
||||
points: [{ x: 320, y: 180, type: 'pos' }],
|
||||
options: {
|
||||
crop_to_prompt: true,
|
||||
auto_filter_background: true,
|
||||
min_score: 0.05,
|
||||
},
|
||||
});
|
||||
|
||||
expect(axiosMock.client.post).toHaveBeenCalledWith('/api/ai/predict', {
|
||||
image_id: 7,
|
||||
prompt_type: 'point',
|
||||
prompt_data: {
|
||||
points: [[0.5, 0.5]],
|
||||
labels: [1],
|
||||
},
|
||||
model: 'sam2',
|
||||
options: {
|
||||
crop_to_prompt: true,
|
||||
auto_filter_background: true,
|
||||
min_score: 0.05,
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
it('loads AI model and GPU runtime status', async () => {
|
||||
const { getAiModelStatus } = await import('./api');
|
||||
const status = {
|
||||
|
||||
@@ -197,6 +197,16 @@ export async function getTask(taskId: string | number): Promise<ProcessingTask>
|
||||
return response.data;
|
||||
}
|
||||
|
||||
export async function cancelTask(taskId: string | number): Promise<ProcessingTask> {
|
||||
const response = await apiClient.post(`/api/tasks/${taskId}/cancel`);
|
||||
return response.data;
|
||||
}
|
||||
|
||||
export async function retryTask(taskId: string | number): Promise<ProcessingTask> {
|
||||
const response = await apiClient.post(`/api/tasks/${taskId}/retry`);
|
||||
return response.data;
|
||||
}
|
||||
|
||||
interface PredictMaskPayload {
|
||||
imageId: string;
|
||||
imageWidth: number;
|
||||
@@ -205,6 +215,12 @@ interface PredictMaskPayload {
|
||||
points?: { x: number; y: number; type: 'pos' | 'neg' }[];
|
||||
box?: { x1: number; y1: number; x2: number; y2: number };
|
||||
text?: string;
|
||||
options?: {
|
||||
crop_to_prompt?: boolean;
|
||||
auto_filter_background?: boolean;
|
||||
min_score?: number;
|
||||
crop_margin?: number;
|
||||
};
|
||||
}
|
||||
|
||||
interface PredictMaskResult {
|
||||
@@ -234,6 +250,8 @@ export interface AiModelStatus {
|
||||
python_ok: boolean;
|
||||
torch_ok: boolean;
|
||||
cuda_required: boolean;
|
||||
external_available?: boolean;
|
||||
external_python?: string | null;
|
||||
}
|
||||
|
||||
export interface AiRuntimeStatus {
|
||||
@@ -301,6 +319,8 @@ export interface DashboardTask {
|
||||
name: string;
|
||||
progress: number;
|
||||
status: string;
|
||||
raw_status?: string;
|
||||
error?: string | null;
|
||||
frame_count: number;
|
||||
updated_at: string | null;
|
||||
}
|
||||
@@ -397,7 +417,7 @@ export function buildAnnotationPayload(
|
||||
}
|
||||
: undefined;
|
||||
|
||||
return {
|
||||
const payload: SaveAnnotationPayload = {
|
||||
project_id: Number(projectId),
|
||||
frame_id: Number(frame.id),
|
||||
template_id: effectiveTemplateId ? Number(effectiveTemplateId) : undefined,
|
||||
@@ -416,6 +436,15 @@ export function buildAnnotationPayload(
|
||||
]
|
||||
: undefined,
|
||||
};
|
||||
|
||||
if (mask.points) {
|
||||
payload.points = mask.points.map(([x, y]) => [
|
||||
clamp01(x / Math.max(frame.width, 1)),
|
||||
clamp01(y / Math.max(frame.height, 1)),
|
||||
]);
|
||||
}
|
||||
|
||||
return payload;
|
||||
}
|
||||
|
||||
export function annotationToMask(annotation: SavedAnnotation, frame: Frame): Mask | null {
|
||||
@@ -438,6 +467,7 @@ export function annotationToMask(annotation: SavedAnnotation, frame: Frame): Mas
|
||||
label: classMetadata?.name || annotation.mask_data?.label || `Annotation ${annotation.id}`,
|
||||
color: classMetadata?.color || annotation.mask_data?.color || '#06b6d4',
|
||||
segmentation: polygons.map((polygon) => polygon.flatMap(([x, y]) => [x * frame.width, y * frame.height])),
|
||||
points: annotation.points?.map(([x, y]) => [x * frame.width, y * frame.height]),
|
||||
bbox,
|
||||
area: bbox[2] * bbox[3],
|
||||
};
|
||||
@@ -471,6 +501,7 @@ export async function predictMask(payload: PredictMaskPayload): Promise<PredictM
|
||||
prompt_type,
|
||||
prompt_data,
|
||||
model: payload.model || 'sam2',
|
||||
...(payload.options ? { options: payload.options } : {}),
|
||||
});
|
||||
|
||||
const polygons: number[][][] = response.data.polygons || [];
|
||||
@@ -523,6 +554,23 @@ export async function deleteAnnotation(annotationId: string): Promise<void> {
|
||||
await apiClient.delete(`/api/ai/annotations/${annotationId}`);
|
||||
}
|
||||
|
||||
export async function importGtMask(
|
||||
file: File,
|
||||
projectId: string,
|
||||
frameId: string,
|
||||
templateId?: string | null,
|
||||
): Promise<SavedAnnotation[]> {
|
||||
const formData = new FormData();
|
||||
formData.append('file', file);
|
||||
formData.append('project_id', projectId);
|
||||
formData.append('frame_id', frameId);
|
||||
if (templateId) formData.append('template_id', templateId);
|
||||
const response = await apiClient.post('/api/ai/import-gt-mask', formData, {
|
||||
headers: { 'Content-Type': 'multipart/form-data' },
|
||||
});
|
||||
return response.data;
|
||||
}
|
||||
|
||||
export async function getDashboardOverview(): Promise<DashboardOverview> {
|
||||
const response = await apiClient.get('/api/dashboard/overview');
|
||||
return response.data;
|
||||
@@ -536,4 +584,11 @@ export async function exportCoco(projectId: string): Promise<Blob> {
|
||||
return response.data;
|
||||
}
|
||||
|
||||
export async function exportMasks(projectId: string): Promise<Blob> {
|
||||
const response = await apiClient.get(`/api/export/${projectId}/masks`, {
|
||||
responseType: 'blob',
|
||||
});
|
||||
return response.data;
|
||||
}
|
||||
|
||||
export default apiClient;
|
||||
|
||||
@@ -3,7 +3,7 @@ import { WS_PROGRESS_URL } from './config';
|
||||
type ProgressCallback = (data: ProgressMessage) => void;
|
||||
|
||||
interface ProgressMessage {
|
||||
type: 'progress' | 'status' | 'error' | 'complete';
|
||||
type: 'progress' | 'status' | 'error' | 'complete' | 'cancelled';
|
||||
taskId?: string;
|
||||
task_id?: number;
|
||||
project_id?: number;
|
||||
|
||||
@@ -53,4 +53,15 @@ describe('useStore', () => {
|
||||
expect(useStore.getState().masks).toEqual([]);
|
||||
expect(useStore.getState().templates).toEqual([]);
|
||||
});
|
||||
|
||||
it('keeps undo and redo history for mask edits', () => {
|
||||
useStore.getState().addMask({ id: 'm1', frameId: 'f1', pathData: 'M 0 0 Z', label: 'mask 1', color: '#fff' });
|
||||
useStore.getState().addMask({ id: 'm2', frameId: 'f1', pathData: 'M 1 1 Z', label: 'mask 2', color: '#000' });
|
||||
|
||||
expect(useStore.getState().masks.map((mask) => mask.id)).toEqual(['m1', 'm2']);
|
||||
useStore.getState().undoMasks();
|
||||
expect(useStore.getState().masks.map((mask) => mask.id)).toEqual(['m1']);
|
||||
useStore.getState().redoMasks();
|
||||
expect(useStore.getState().masks.map((mask) => mask.id)).toEqual(['m1', 'm2']);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -56,8 +56,10 @@ export interface Mask {
|
||||
color: string;
|
||||
opacity?: number;
|
||||
segmentation?: number[][];
|
||||
points?: number[][];
|
||||
bbox?: [number, number, number, number];
|
||||
area?: number;
|
||||
metadata?: Record<string, unknown>;
|
||||
}
|
||||
|
||||
export interface Template {
|
||||
@@ -110,6 +112,8 @@ export interface AppState {
|
||||
currentFrameIndex: number;
|
||||
annotations: Annotation[];
|
||||
masks: Mask[];
|
||||
maskHistory: Mask[][];
|
||||
maskFuture: Mask[][];
|
||||
setActiveModule: (module: string) => void;
|
||||
setActiveTool: (tool: string) => void;
|
||||
setAiModel: (model: AiModelId) => void;
|
||||
@@ -120,6 +124,8 @@ export interface AppState {
|
||||
updateMask: (id: string, updates: Partial<Mask>) => void;
|
||||
setMasks: (masks: Mask[]) => void;
|
||||
clearMasks: () => void;
|
||||
undoMasks: () => void;
|
||||
redoMasks: () => void;
|
||||
removeAnnotation: (id: string) => void;
|
||||
|
||||
// Templates
|
||||
@@ -161,6 +167,8 @@ export const useStore = create<AppState>((set) => ({
|
||||
frames: [],
|
||||
annotations: [],
|
||||
masks: [],
|
||||
maskHistory: [],
|
||||
maskFuture: [],
|
||||
activeTemplateId: null,
|
||||
activeClassId: null,
|
||||
activeClass: null,
|
||||
@@ -187,6 +195,8 @@ export const useStore = create<AppState>((set) => ({
|
||||
currentFrameIndex: 0,
|
||||
annotations: [],
|
||||
masks: [],
|
||||
maskHistory: [],
|
||||
maskFuture: [],
|
||||
setActiveModule: (activeModule: string) => set({ activeModule }),
|
||||
setActiveTool: (activeTool: string) => set({ activeTool }),
|
||||
setAiModel: (aiModel: AiModelId) => set({ aiModel }),
|
||||
@@ -195,13 +205,54 @@ export const useStore = create<AppState>((set) => ({
|
||||
addAnnotation: (annotation: Annotation) =>
|
||||
set((state) => ({ annotations: [...state.annotations, annotation] })),
|
||||
addMask: (mask: Mask) =>
|
||||
set((state) => ({ masks: [...state.masks, mask] })),
|
||||
set((state) => ({
|
||||
masks: [...state.masks, mask],
|
||||
maskHistory: [...state.maskHistory, state.masks],
|
||||
maskFuture: [],
|
||||
})),
|
||||
updateMask: (id: string, updates: Partial<Mask>) =>
|
||||
set((state) => ({
|
||||
masks: state.masks.map((mask) => (mask.id === id ? { ...mask, ...updates } : mask)),
|
||||
maskHistory: [...state.maskHistory, state.masks],
|
||||
maskFuture: [],
|
||||
})),
|
||||
setMasks: (masks: Mask[]) => set({ masks }),
|
||||
clearMasks: () => set({ masks: [] }),
|
||||
setMasks: (masks: Mask[]) =>
|
||||
set((state) => {
|
||||
const isInitialHydration = state.masks.length === 0
|
||||
&& state.maskHistory.length === 0
|
||||
&& state.maskFuture.length === 0;
|
||||
return {
|
||||
masks,
|
||||
maskHistory: isInitialHydration ? [] : [...state.maskHistory, state.masks],
|
||||
maskFuture: [],
|
||||
};
|
||||
}),
|
||||
clearMasks: () =>
|
||||
set((state) => ({
|
||||
masks: [],
|
||||
maskHistory: [...state.maskHistory, state.masks],
|
||||
maskFuture: [],
|
||||
})),
|
||||
undoMasks: () =>
|
||||
set((state) => {
|
||||
if (state.maskHistory.length === 0) return state;
|
||||
const previous = state.maskHistory[state.maskHistory.length - 1];
|
||||
return {
|
||||
masks: previous,
|
||||
maskHistory: state.maskHistory.slice(0, -1),
|
||||
maskFuture: [state.masks, ...state.maskFuture],
|
||||
};
|
||||
}),
|
||||
redoMasks: () =>
|
||||
set((state) => {
|
||||
if (state.maskFuture.length === 0) return state;
|
||||
const [next, ...rest] = state.maskFuture;
|
||||
return {
|
||||
masks: next,
|
||||
maskHistory: [...state.maskHistory, state.masks],
|
||||
maskFuture: rest,
|
||||
};
|
||||
}),
|
||||
removeAnnotation: (id: string) =>
|
||||
set((state) => ({
|
||||
annotations: state.annotations.filter((a) => a.id !== id),
|
||||
|
||||
@@ -32,24 +32,69 @@ function makeStageEvent(x = 120, y = 80) {
|
||||
}
|
||||
|
||||
vi.mock('react-konva', () => ({
|
||||
Stage: ({ children, onClick, onMouseDown, onMouseUp, onMouseMove, onWheel }: any) => (
|
||||
Stage: ({ children, onClick, onMouseDown, onMouseUp, onMouseMove, onWheel }: any) => {
|
||||
const coords = (event: React.MouseEvent<HTMLDivElement>, fallbackX: number, fallbackY: number) => ({
|
||||
x: event.clientX || fallbackX,
|
||||
y: event.clientY || fallbackY,
|
||||
});
|
||||
return (
|
||||
<div
|
||||
data-testid="konva-stage"
|
||||
onClick={() => onClick?.(makeStageEvent())}
|
||||
onMouseDown={() => onMouseDown?.(makeStageEvent())}
|
||||
onMouseUp={() => onMouseUp?.(makeStageEvent(260, 200))}
|
||||
onMouseMove={() => onMouseMove?.(makeStageEvent(180, 120))}
|
||||
onClick={(event) => {
|
||||
const point = coords(event, 120, 80);
|
||||
onClick?.(makeStageEvent(point.x, point.y));
|
||||
}}
|
||||
onMouseDown={(event) => {
|
||||
const point = coords(event, 120, 80);
|
||||
onMouseDown?.(makeStageEvent(point.x, point.y));
|
||||
}}
|
||||
onMouseUp={(event) => {
|
||||
const point = coords(event, 260, 200);
|
||||
onMouseUp?.(makeStageEvent(point.x, point.y));
|
||||
}}
|
||||
onMouseMove={(event) => {
|
||||
const point = coords(event, 180, 120);
|
||||
onMouseMove?.(makeStageEvent(point.x, point.y));
|
||||
}}
|
||||
onWheel={() => onWheel?.(makeStageEvent())}
|
||||
>
|
||||
{children}
|
||||
</div>
|
||||
),
|
||||
);
|
||||
},
|
||||
Layer: ({ children }: any) => <div data-testid="konva-layer">{children}</div>,
|
||||
Group: ({ children }: any) => <div data-testid="konva-group">{children}</div>,
|
||||
Image: ({ image }: any) => <img data-testid="konva-image" alt="" src={image?.src || ''} />,
|
||||
Circle: (props: any) => <span data-testid="konva-circle" data-fill={props.fill} />,
|
||||
Circle: (props: any) => (
|
||||
<span
|
||||
data-testid="konva-circle"
|
||||
data-fill={props.fill}
|
||||
data-x={props.x}
|
||||
data-y={props.y}
|
||||
onClick={() => props.onClick?.({ cancelBubble: false })}
|
||||
onMouseUp={(event: React.MouseEvent<HTMLSpanElement>) => props.onDragEnd?.({
|
||||
target: {
|
||||
x: () => event.clientX || props.x || 0,
|
||||
y: () => event.clientY || props.y || 0,
|
||||
},
|
||||
})}
|
||||
onDragEnd={(event: React.DragEvent<HTMLSpanElement>) => props.onDragEnd?.({
|
||||
target: {
|
||||
x: () => event.clientX || props.x || 0,
|
||||
y: () => event.clientY || props.y || 0,
|
||||
},
|
||||
})}
|
||||
/>
|
||||
),
|
||||
Rect: (props: any) => <span data-testid="konva-rect" data-width={props.width} />,
|
||||
Path: (props: any) => <span data-testid="konva-path" data-path={props.data} data-fill={props.fill} />,
|
||||
Path: (props: any) => (
|
||||
<span
|
||||
data-testid="konva-path"
|
||||
data-path={props.data}
|
||||
data-fill={props.fill}
|
||||
onClick={() => props.onClick?.({ cancelBubble: false })}
|
||||
/>
|
||||
),
|
||||
}));
|
||||
|
||||
vi.mock('use-image', () => ({
|
||||
|
||||
@@ -13,6 +13,8 @@ export function resetStore() {
|
||||
currentFrameIndex: 0,
|
||||
annotations: [],
|
||||
masks: [],
|
||||
maskHistory: [],
|
||||
maskFuture: [],
|
||||
templates: [],
|
||||
activeTemplateId: null,
|
||||
activeClassId: null,
|
||||
|
||||
Reference in New Issue
Block a user