- 接入 SAM2 视频传播能力:新增 /api/ai/propagate,支持用当前帧 mask/polygon/bbox 作为 seed,通过 SAM2 video predictor 向前、向后或双向传播,并可保存为真实 annotation。 - 接入 SAM3 video tracker:通过独立 Python 3.12 external worker 调用 SAM3 video predictor/tracker,使用本地 checkpoint 与 bbox seed 执行视频级跟踪,并在模型状态中标记 video_track 能力。 - 完善 SAM 模型分发:sam_registry 按 model_id 明确区分 sam2 propagation 与 sam3 video_track,避免两个模型链路混用。 - 打通前端“传播片段”:VideoWorkspace 使用当前选中 mask 和当前 AI 模型调用后端传播接口,传播结果回写并刷新工作区已保存标注。 - 增强 SAM3 本地 checkpoint 配置:新增 sam3_checkpoint_path 配置和 .env.example 示例,状态检查改为基于本地 checkpoint/独立环境/模型包可用性。 - 完善视频拆帧参数:/api/media/parse 支持 parse_fps、max_frames、target_width,后端任务保存帧时间戳、源帧号和 frame_sequence 元数据。 - 增加运行时 schema 兼容处理:启动时为旧 frames 表补充 timestamp_ms 和 source_frame_number 列,避免旧库升级后缺字段。 - 强化 Canvas 标注编辑:补齐多边形闭合、点工具、顶点拖拽、边中点插入、Delete/Backspace 删除、区域合并和重叠去除等交互。 - 增强语义分类联动:选中 mask 后可通过右侧语义分类树更新标签、颜色和 class metadata,并同步到保存/导出链路。 - 增加关键帧时间轴体验:FrameTimeline 显示具体时间信息,并支持键盘左右方向键切换关键帧。 - 完善 AI 交互分割参数:前端保留正向点、反向点、框选和 interactive prompt 的调用状态,支持 SAM2 细化候选区域与 SAM3 bbox 入口。 - 扩展后端/前端 API 类型:新增 propagateMasks、传播请求/响应 schema,并补齐 annotation、导出、模型状态和任务接口的测试覆盖。 - 更新项目文档:同步 README、AGENTS、接口契约、需求冻结、设计冻结、前端元素审计、实施计划和测试计划,标明真实功能边界与剩余风险。 - 增加测试覆盖:补充 SAM2/SAM3 传播、SAM3 状态、媒体拆帧参数、Canvas 编辑、语义标签切换、时间轴、工作区传播和 API 合约测试。 - 加强仓库安全边界:将 sam3权重/ 加入 .gitignore,避免本地模型权重被误提交。 验证:npm run test:run;pytest backend/tests;npm run lint;npm run build;python -m py_compile;git diff --check。
461 lines
16 KiB
Python
461 lines
16 KiB
Python
import numpy as np
|
|
import cv2
|
|
|
|
|
|
def _create_project_and_frame(client):
|
|
project = client.post("/api/projects", json={"name": "AI Project"}).json()
|
|
frame = client.post(f"/api/projects/{project['id']}/frames", json={
|
|
"project_id": project["id"],
|
|
"frame_index": 0,
|
|
"image_url": "frames/0.jpg",
|
|
"width": 640,
|
|
"height": 360,
|
|
}).json()
|
|
template = client.post("/api/templates", json={
|
|
"name": "Template",
|
|
"color": "#06b6d4",
|
|
"z_index": 0,
|
|
"classes": [],
|
|
"rules": [],
|
|
}).json()
|
|
return project, frame, template
|
|
|
|
|
|
def test_predict_accepts_point_object_with_labels(client, monkeypatch):
|
|
_, frame, _ = _create_project_and_frame(client)
|
|
calls = {}
|
|
|
|
monkeypatch.setattr("routers.ai._load_frame_image", lambda frame: np.zeros((10, 10, 3), dtype=np.uint8))
|
|
|
|
def fake_predict_points(image, points, labels):
|
|
calls["args"] = (points, labels)
|
|
return (
|
|
[[[0.1, 0.1], [0.9, 0.1], [0.9, 0.9]]],
|
|
[0.95],
|
|
)
|
|
|
|
monkeypatch.setattr("routers.ai.sam_registry.predict_points", lambda model, image, points, labels: fake_predict_points(image, points, labels))
|
|
|
|
response = client.post("/api/ai/predict", json={
|
|
"image_id": frame["id"],
|
|
"prompt_type": "point",
|
|
"prompt_data": {"points": [[0.5, 0.5], [0.1, 0.1]], "labels": [1, 0]},
|
|
})
|
|
|
|
assert response.status_code == 200
|
|
assert response.json()["scores"] == [0.95]
|
|
assert calls["args"] == ([[0.5, 0.5], [0.1, 0.1]], [1, 0])
|
|
|
|
|
|
def test_predict_applies_crop_and_background_filter_options(client, monkeypatch):
|
|
_, frame, _ = _create_project_and_frame(client)
|
|
calls = {}
|
|
monkeypatch.setattr("routers.ai._load_frame_image", lambda frame: np.zeros((100, 200, 3), dtype=np.uint8))
|
|
|
|
def fake_predict_points(model, image, points, labels):
|
|
calls["shape"] = image.shape
|
|
calls["points"] = points
|
|
calls["labels"] = labels
|
|
return (
|
|
[
|
|
[[0.0, 0.0], [0.2, 0.0], [0.2, 0.2]],
|
|
[[0.45, 0.45], [0.55, 0.45], [0.55, 0.55]],
|
|
],
|
|
[0.9, 0.01],
|
|
)
|
|
|
|
monkeypatch.setattr("routers.ai.sam_registry.predict_points", fake_predict_points)
|
|
|
|
response = client.post("/api/ai/predict", json={
|
|
"image_id": frame["id"],
|
|
"prompt_type": "point",
|
|
"prompt_data": {"points": [[0.5, 0.5], [0.52, 0.52]], "labels": [1, 0]},
|
|
"options": {
|
|
"crop_to_prompt": True,
|
|
"crop_margin": 0.1,
|
|
"auto_filter_background": True,
|
|
"min_score": 0.05,
|
|
},
|
|
})
|
|
|
|
assert response.status_code == 200
|
|
assert calls["shape"][0] < 100
|
|
assert calls["shape"][1] < 200
|
|
assert calls["labels"] == [1, 0]
|
|
assert response.json()["scores"] == [0.9]
|
|
polygon = response.json()["polygons"][0]
|
|
assert all(0.0 <= coord <= 1.0 for point in polygon for coord in point)
|
|
|
|
|
|
def test_predict_box_and_semantic_fallback(client, monkeypatch):
|
|
_, frame, _ = _create_project_and_frame(client)
|
|
monkeypatch.setattr("routers.ai._load_frame_image", lambda frame: np.zeros((10, 10, 3), dtype=np.uint8))
|
|
monkeypatch.setattr("routers.ai.sam_registry.predict_box", lambda model, image, box: (
|
|
[[[0.2, 0.2], [0.8, 0.2], [0.8, 0.8]]],
|
|
[0.8],
|
|
))
|
|
monkeypatch.setattr("routers.ai.sam_registry.predict_semantic", lambda model, image, text: (
|
|
[[[0.0, 0.0], [1.0, 0.0], [1.0, 1.0]]],
|
|
[0.5],
|
|
))
|
|
|
|
box_response = client.post("/api/ai/predict", json={
|
|
"image_id": frame["id"],
|
|
"prompt_type": "box",
|
|
"prompt_data": [0.2, 0.2, 0.8, 0.8],
|
|
})
|
|
semantic_response = client.post("/api/ai/predict", json={
|
|
"image_id": frame["id"],
|
|
"prompt_type": "semantic",
|
|
"prompt_data": "胆囊",
|
|
})
|
|
|
|
assert box_response.status_code == 200
|
|
assert box_response.json()["scores"] == [0.8]
|
|
assert semantic_response.status_code == 200
|
|
assert semantic_response.json()["scores"] == [0.5]
|
|
|
|
|
|
def test_predict_interactive_combines_box_and_points(client, monkeypatch):
|
|
_, frame, _ = _create_project_and_frame(client)
|
|
calls = {}
|
|
monkeypatch.setattr("routers.ai._load_frame_image", lambda frame: np.zeros((10, 10, 3), dtype=np.uint8))
|
|
|
|
def fake_predict_interactive(model, image, box, points, labels):
|
|
calls["model"] = model
|
|
calls["box"] = box
|
|
calls["points"] = points
|
|
calls["labels"] = labels
|
|
return (
|
|
[[[0.2, 0.2], [0.8, 0.2], [0.8, 0.8]]],
|
|
[0.88],
|
|
)
|
|
|
|
monkeypatch.setattr("routers.ai.sam_registry.predict_interactive", fake_predict_interactive)
|
|
|
|
response = client.post("/api/ai/predict", json={
|
|
"image_id": frame["id"],
|
|
"prompt_type": "interactive",
|
|
"prompt_data": {
|
|
"box": [0.1, 0.1, 0.9, 0.9],
|
|
"points": [[0.5, 0.5], [0.2, 0.2]],
|
|
"labels": [1, 0],
|
|
},
|
|
"model": "sam2",
|
|
})
|
|
|
|
assert response.status_code == 200
|
|
assert response.json()["scores"] == [0.88]
|
|
assert calls == {
|
|
"model": "sam2",
|
|
"box": [0.1, 0.1, 0.9, 0.9],
|
|
"points": [[0.5, 0.5], [0.2, 0.2]],
|
|
"labels": [1, 0],
|
|
}
|
|
|
|
|
|
def test_model_status_reports_runtime(client, monkeypatch):
|
|
monkeypatch.setattr("routers.ai.sam_registry.runtime_status", lambda selected_model=None: {
|
|
"selected_model": selected_model or "sam2",
|
|
"gpu": {
|
|
"available": False,
|
|
"device": "cpu",
|
|
"name": None,
|
|
"torch_available": True,
|
|
"torch_version": "2.x",
|
|
"cuda_version": None,
|
|
},
|
|
"models": [
|
|
{
|
|
"id": "sam2",
|
|
"label": "SAM 2",
|
|
"available": True,
|
|
"loaded": False,
|
|
"device": "cpu",
|
|
"supports": ["point", "box", "auto"],
|
|
"message": "ready",
|
|
"package_available": True,
|
|
"checkpoint_exists": True,
|
|
"checkpoint_path": "model.pt",
|
|
"python_ok": True,
|
|
"torch_ok": True,
|
|
"cuda_required": False,
|
|
},
|
|
{
|
|
"id": "sam3",
|
|
"label": "SAM 3",
|
|
"available": False,
|
|
"loaded": False,
|
|
"device": "unavailable",
|
|
"supports": ["semantic"],
|
|
"message": "missing Python 3.12+ runtime",
|
|
"package_available": False,
|
|
"checkpoint_exists": False,
|
|
"checkpoint_path": None,
|
|
"python_ok": False,
|
|
"torch_ok": True,
|
|
"cuda_required": True,
|
|
},
|
|
],
|
|
})
|
|
|
|
response = client.get("/api/ai/models/status?selected_model=sam3")
|
|
|
|
assert response.status_code == 200
|
|
body = response.json()
|
|
assert body["selected_model"] == "sam3"
|
|
assert body["models"][1]["id"] == "sam3"
|
|
assert body["models"][1]["available"] is False
|
|
|
|
|
|
def test_propagate_saves_tracked_annotations(client, monkeypatch):
|
|
project = client.post("/api/projects", json={"name": "Video Project"}).json()
|
|
frames = [
|
|
client.post(f"/api/projects/{project['id']}/frames", json={
|
|
"project_id": project["id"],
|
|
"frame_index": idx,
|
|
"image_url": f"frames/{idx}.jpg",
|
|
"width": 640,
|
|
"height": 360,
|
|
}).json()
|
|
for idx in range(3)
|
|
]
|
|
calls = {}
|
|
monkeypatch.setattr("routers.ai.download_file", lambda object_name: b"jpeg")
|
|
|
|
def fake_propagate_video(model, frame_paths, source_frame_index, seed, direction, max_frames):
|
|
calls["model"] = model
|
|
calls["source_frame_index"] = source_frame_index
|
|
calls["seed"] = seed
|
|
calls["direction"] = direction
|
|
calls["max_frames"] = max_frames
|
|
calls["frame_count"] = len(frame_paths)
|
|
return [
|
|
{
|
|
"frame_index": 0,
|
|
"polygons": [[[0.1, 0.1], [0.2, 0.1], [0.2, 0.2]]],
|
|
"scores": [0.9],
|
|
"object_ids": [1],
|
|
},
|
|
{
|
|
"frame_index": 1,
|
|
"polygons": [[[0.15, 0.15], [0.25, 0.15], [0.25, 0.25]]],
|
|
"scores": [0.8],
|
|
"object_ids": [1],
|
|
},
|
|
]
|
|
|
|
monkeypatch.setattr("routers.ai.sam_registry.propagate_video", fake_propagate_video)
|
|
|
|
response = client.post("/api/ai/propagate", json={
|
|
"project_id": project["id"],
|
|
"frame_id": frames[0]["id"],
|
|
"model": "sam2",
|
|
"direction": "forward",
|
|
"max_frames": 2,
|
|
"include_source": False,
|
|
"seed": {
|
|
"polygons": [[[0.1, 0.1], [0.2, 0.1], [0.2, 0.2]]],
|
|
"bbox": [0.1, 0.1, 0.1, 0.1],
|
|
"label": "胆囊",
|
|
"color": "#ff0000",
|
|
"class_metadata": {"id": "c1", "name": "胆囊", "color": "#ff0000", "zIndex": 20},
|
|
"template_id": None,
|
|
},
|
|
})
|
|
|
|
assert response.status_code == 200
|
|
body = response.json()
|
|
assert body["created_annotation_count"] == 1
|
|
assert body["processed_frame_count"] == 2
|
|
assert calls["model"] == "sam2"
|
|
assert calls["source_frame_index"] == 0
|
|
assert calls["direction"] == "forward"
|
|
assert calls["frame_count"] == 2
|
|
saved = body["annotations"][0]
|
|
assert saved["frame_id"] == frames[1]["id"]
|
|
assert saved["mask_data"]["source"] == "sam2_propagation"
|
|
assert saved["mask_data"]["class"]["name"] == "胆囊"
|
|
assert saved["mask_data"]["score"] == 0.8
|
|
|
|
listing = client.get(f"/api/ai/annotations?project_id={project['id']}")
|
|
assert len(listing.json()) == 1
|
|
|
|
|
|
def test_predict_validation_errors(client, monkeypatch):
|
|
project, _, _ = _create_project_and_frame(client)
|
|
|
|
assert client.post("/api/ai/predict", json={
|
|
"image_id": 999,
|
|
"prompt_type": "point",
|
|
"prompt_data": [[0.5, 0.5]],
|
|
}).status_code == 404
|
|
|
|
frame = client.post(f"/api/projects/{project['id']}/frames", json={
|
|
"project_id": project["id"],
|
|
"frame_index": 1,
|
|
"image_url": "frames/1.jpg",
|
|
}).json()
|
|
monkeypatch.setattr("routers.ai._load_frame_image", lambda frame: np.zeros((10, 10, 3), dtype=np.uint8))
|
|
assert client.post("/api/ai/predict", json={
|
|
"image_id": frame["id"],
|
|
"prompt_type": "box",
|
|
"prompt_data": [0.1, 0.2],
|
|
}).status_code == 400
|
|
|
|
|
|
def test_save_annotation_validates_project_and_frame(client):
|
|
project, frame, template = _create_project_and_frame(client)
|
|
|
|
saved = client.post("/api/ai/annotate", json={
|
|
"project_id": project["id"],
|
|
"frame_id": frame["id"],
|
|
"template_id": template["id"],
|
|
"mask_data": {"polygons": [[[0.1, 0.1], [0.9, 0.1], [0.9, 0.9]]]},
|
|
"points": [[0.5, 0.5]],
|
|
"bbox": [0.1, 0.1, 0.8, 0.8],
|
|
})
|
|
assert saved.status_code == 201
|
|
assert saved.json()["project_id"] == project["id"]
|
|
|
|
listing = client.get(f"/api/ai/annotations?project_id={project['id']}")
|
|
assert listing.status_code == 200
|
|
assert listing.json()[0]["id"] == saved.json()["id"]
|
|
|
|
frame_listing = client.get(f"/api/ai/annotations?project_id={project['id']}&frame_id={frame['id']}")
|
|
assert frame_listing.status_code == 200
|
|
assert len(frame_listing.json()) == 1
|
|
|
|
missing_project = client.post("/api/ai/annotate", json={"project_id": 999})
|
|
assert missing_project.status_code == 404
|
|
|
|
missing_frame = client.post("/api/ai/annotate", json={
|
|
"project_id": project["id"],
|
|
"frame_id": 999,
|
|
})
|
|
assert missing_frame.status_code == 404
|
|
|
|
missing_project_list = client.get("/api/ai/annotations?project_id=999")
|
|
assert missing_project_list.status_code == 404
|
|
|
|
|
|
def test_update_and_delete_annotation(client):
|
|
project, frame, template = _create_project_and_frame(client)
|
|
saved = client.post("/api/ai/annotate", json={
|
|
"project_id": project["id"],
|
|
"frame_id": frame["id"],
|
|
"template_id": template["id"],
|
|
"mask_data": {
|
|
"polygons": [[[0.1, 0.1], [0.9, 0.1], [0.9, 0.9]]],
|
|
"label": "AI Mask",
|
|
"color": "#06b6d4",
|
|
},
|
|
"points": [[0.5, 0.5]],
|
|
"bbox": [0.1, 0.1, 0.8, 0.8],
|
|
}).json()
|
|
|
|
updated = client.patch(f"/api/ai/annotations/{saved['id']}", json={
|
|
"template_id": template["id"],
|
|
"mask_data": {
|
|
"polygons": [[[0.2, 0.2], [0.8, 0.2], [0.8, 0.8]]],
|
|
"label": "胆囊",
|
|
"color": "#ff0000",
|
|
"class": {"id": "c1", "name": "胆囊", "color": "#ff0000", "zIndex": 20},
|
|
},
|
|
"points": [[0.4, 0.4]],
|
|
"bbox": [0.2, 0.2, 0.6, 0.6],
|
|
})
|
|
|
|
assert updated.status_code == 200
|
|
body = updated.json()
|
|
assert body["mask_data"]["label"] == "胆囊"
|
|
assert body["mask_data"]["class"]["id"] == "c1"
|
|
assert body["points"] == [[0.4, 0.4]]
|
|
assert body["bbox"] == [0.2, 0.2, 0.6, 0.6]
|
|
|
|
listing = client.get(f"/api/ai/annotations?project_id={project['id']}")
|
|
assert listing.status_code == 200
|
|
assert listing.json()[0]["mask_data"]["class"]["name"] == "胆囊"
|
|
|
|
deleted = client.delete(f"/api/ai/annotations/{saved['id']}")
|
|
assert deleted.status_code == 204
|
|
|
|
empty_listing = client.get(f"/api/ai/annotations?project_id={project['id']}")
|
|
assert empty_listing.status_code == 200
|
|
assert empty_listing.json() == []
|
|
|
|
|
|
def test_update_and_delete_annotation_validation(client):
|
|
project, frame, template = _create_project_and_frame(client)
|
|
saved = client.post("/api/ai/annotate", json={
|
|
"project_id": project["id"],
|
|
"frame_id": frame["id"],
|
|
"template_id": template["id"],
|
|
}).json()
|
|
|
|
assert client.patch("/api/ai/annotations/999", json={"bbox": [0, 0, 1, 1]}).status_code == 404
|
|
assert client.delete("/api/ai/annotations/999").status_code == 404
|
|
assert client.patch(
|
|
f"/api/ai/annotations/{saved['id']}",
|
|
json={"template_id": 999},
|
|
).status_code == 404
|
|
|
|
|
|
def test_import_gt_mask_creates_annotations_with_seed_points(client):
|
|
project, frame, template = _create_project_and_frame(client)
|
|
mask = np.zeros((360, 640), dtype=np.uint8)
|
|
cv2.rectangle(mask, (100, 80), (260, 220), 255, thickness=-1)
|
|
ok, encoded = cv2.imencode(".png", mask)
|
|
assert ok
|
|
|
|
response = client.post(
|
|
"/api/ai/import-gt-mask",
|
|
data={
|
|
"project_id": str(project["id"]),
|
|
"frame_id": str(frame["id"]),
|
|
"template_id": str(template["id"]),
|
|
"label": "Imported GT",
|
|
"color": "#22c55e",
|
|
},
|
|
files={"file": ("mask.png", encoded.tobytes(), "image/png")},
|
|
)
|
|
|
|
assert response.status_code == 201
|
|
body = response.json()
|
|
assert len(body) == 1
|
|
assert body[0]["project_id"] == project["id"]
|
|
assert body[0]["frame_id"] == frame["id"]
|
|
assert body[0]["template_id"] == template["id"]
|
|
assert body[0]["mask_data"]["label"] == "Imported GT"
|
|
assert body[0]["mask_data"]["source"] == "gt_mask"
|
|
assert body[0]["mask_data"]["gt_label_value"] == 255
|
|
assert len(body[0]["mask_data"]["polygons"][0]) >= 3
|
|
assert len(body[0]["points"]) == 1
|
|
assert 0.0 <= body[0]["points"][0][0] <= 1.0
|
|
assert 0.0 <= body[0]["points"][0][1] <= 1.0
|
|
|
|
|
|
def test_import_gt_mask_splits_label_values(client):
|
|
project, frame, _ = _create_project_and_frame(client)
|
|
mask = np.zeros((360, 640), dtype=np.uint8)
|
|
cv2.rectangle(mask, (20, 20), (120, 120), 1, thickness=-1)
|
|
cv2.rectangle(mask, (220, 80), (320, 180), 2, thickness=-1)
|
|
ok, encoded = cv2.imencode(".png", mask)
|
|
assert ok
|
|
|
|
response = client.post(
|
|
"/api/ai/import-gt-mask",
|
|
data={
|
|
"project_id": str(project["id"]),
|
|
"frame_id": str(frame["id"]),
|
|
"label": "GT Class",
|
|
},
|
|
files={"file": ("labels.png", encoded.tobytes(), "image/png")},
|
|
)
|
|
|
|
assert response.status_code == 201
|
|
body = sorted(response.json(), key=lambda item: item["mask_data"]["gt_label_value"])
|
|
assert [item["mask_data"]["gt_label_value"] for item in body] == [1, 2]
|
|
assert [item["mask_data"]["label"] for item in body] == ["GT Class 1", "GT Class 2"]
|
|
assert all(len(item["points"]) == 1 for item in body)
|