Files
Pre_Seg_Server/backend/tests/test_ai.py
admin 4899c8a08a feat: 完善分割工作区交互与传播去重
功能增加:点击 Canvas mask 后,右侧语义分类树会按 classId/className/label 自动匹配分类,并滚动聚焦到对应分类按钮。

功能增加:工作区新增按起止帧批量清空片段遮罩,复用传播范围输入,范围内已保存标注走 DELETE /api/ai/annotations/{id},本地 draft mask 同步移除。

功能增加:右侧语义分类树上方新增工作区 mask 透明度滑杆,写入 Zustand maskPreviewOpacity,Canvas mask 预览按该值渲染并保留选中加亮反馈。

功能增加:视频处理进度条记录最近自动传播区间,使用不同色系深浅渐变提示最近处理片段。

功能增加:工作区自动传播前会先保存 draft/dirty seed mask,使用稳定后端 source_annotation_id 入队,减少二次传播重复结果。

Bugfix:后端传播任务对旧临时 seed id、不同 SAM 2.1 权重结果做兼容清理;相同 seed 和相同权重才跳过,否则先删旧自动传播标注再重传。

Bugfix:修复 polygon 顶点拖拽结束后触发 Stage 平移导致画布中心偏移的问题,并补充测试环境对 drag target 的模拟。

Bugfix:工具提示会在数秒后自动隐藏,避免创建多边形/矩形等提示长期遮挡画布。

UI 调整:移除右侧面板顶部‘本体论与属性分类管理树’说明栏,减少无效占位。

UI 调整:左侧工具栏和右侧语义面板使用低对比 seg-scrollbar;左侧工具栏外扩滚动条槽位,避免滚动条挤占图标列。

UI 调整:工作区模型状态徽标改为紧凑显示,减少与传播权重选择重复;传播权重下拉改成深色背景和青色文字,避免灰底白字不可读。

UI 调整:缩略图状态框固定优先级,当前帧、人工/AI 标注帧、自动传播帧可用外框/内框组合同时表达。

测试:补充 VideoWorkspace、CanvasArea、FrameTimeline、OntologyInspector、ToolsPalette、useStore 和后端 test_ai 覆盖新增交互、传播去重、批量清空、透明度、滚动条和 UI 状态。

文档:同步更新 README、AGENTS 和 doc/03、doc/04、doc/07、doc/08、doc/09,记录当前功能、接口契约、需求设计冻结和测试覆盖。
2026-05-02 06:45:47 +08:00

793 lines
29 KiB
Python

import numpy as np
import cv2
from pathlib import Path
from models import Annotation, ProcessingTask
from services.propagation_task_runner import run_propagate_project_task
def _create_project_and_frame(client):
project = client.post("/api/projects", json={"name": "AI Project"}).json()
frame = client.post(f"/api/projects/{project['id']}/frames", json={
"project_id": project["id"],
"frame_index": 0,
"image_url": "frames/0.jpg",
"width": 640,
"height": 360,
}).json()
template = client.post("/api/templates", json={
"name": "Template",
"color": "#06b6d4",
"z_index": 0,
"classes": [],
"rules": [],
}).json()
return project, frame, template
def test_predict_accepts_point_object_with_labels(client, monkeypatch):
_, frame, _ = _create_project_and_frame(client)
calls = {}
monkeypatch.setattr("routers.ai._load_frame_image", lambda frame: np.zeros((10, 10, 3), dtype=np.uint8))
def fake_predict_points(image, points, labels):
calls["args"] = (points, labels)
return (
[[[0.1, 0.1], [0.9, 0.1], [0.9, 0.9]]],
[0.95],
)
monkeypatch.setattr("routers.ai.sam_registry.predict_points", lambda model, image, points, labels: fake_predict_points(image, points, labels))
response = client.post("/api/ai/predict", json={
"image_id": frame["id"],
"prompt_type": "point",
"prompt_data": {"points": [[0.5, 0.5], [0.1, 0.1]], "labels": [1, 0]},
})
assert response.status_code == 200
assert response.json()["scores"] == [0.95]
assert calls["args"] == ([[0.5, 0.5], [0.1, 0.1]], [1, 0])
def test_predict_applies_crop_and_background_filter_options(client, monkeypatch):
_, frame, _ = _create_project_and_frame(client)
calls = {}
monkeypatch.setattr("routers.ai._load_frame_image", lambda frame: np.zeros((100, 200, 3), dtype=np.uint8))
def fake_predict_points(model, image, points, labels):
calls["shape"] = image.shape
calls["points"] = points
calls["labels"] = labels
return (
[
[[0.0, 0.0], [0.2, 0.0], [0.2, 0.2]],
[[0.45, 0.45], [0.55, 0.45], [0.55, 0.55]],
],
[0.9, 0.01],
)
monkeypatch.setattr("routers.ai.sam_registry.predict_points", fake_predict_points)
response = client.post("/api/ai/predict", json={
"image_id": frame["id"],
"prompt_type": "point",
"prompt_data": {"points": [[0.5, 0.5], [0.52, 0.52]], "labels": [1, 0]},
"options": {
"crop_to_prompt": True,
"crop_margin": 0.1,
"auto_filter_background": True,
"min_score": 0.05,
},
})
assert response.status_code == 200
assert calls["shape"][0] < 100
assert calls["shape"][1] < 200
assert calls["labels"] == [1, 0]
assert response.json()["scores"] == [0.9]
polygon = response.json()["polygons"][0]
assert all(0.0 <= coord <= 1.0 for point in polygon for coord in point)
def test_predict_box_and_rejects_semantic_prompt(client, monkeypatch):
_, frame, _ = _create_project_and_frame(client)
monkeypatch.setattr("routers.ai._load_frame_image", lambda frame: np.zeros((10, 10, 3), dtype=np.uint8))
monkeypatch.setattr("routers.ai.sam_registry.predict_box", lambda model, image, box: (
[[[0.2, 0.2], [0.8, 0.2], [0.8, 0.8]]],
[0.8],
))
box_response = client.post("/api/ai/predict", json={
"image_id": frame["id"],
"prompt_type": "box",
"prompt_data": [0.2, 0.2, 0.8, 0.8],
})
semantic_response = client.post("/api/ai/predict", json={
"image_id": frame["id"],
"prompt_type": "semantic",
"prompt_data": "胆囊",
"model": "sam3",
"options": {"min_score": 0.05},
})
assert box_response.status_code == 200
assert box_response.json()["scores"] == [0.8]
assert semantic_response.status_code == 400
assert "Unsupported model: sam3" in semantic_response.json()["detail"]
def test_predict_interactive_combines_box_and_points(client, monkeypatch):
_, frame, _ = _create_project_and_frame(client)
calls = {}
monkeypatch.setattr("routers.ai._load_frame_image", lambda frame: np.zeros((10, 10, 3), dtype=np.uint8))
def fake_predict_interactive(model, image, box, points, labels):
calls["model"] = model
calls["box"] = box
calls["points"] = points
calls["labels"] = labels
return (
[[[0.2, 0.2], [0.8, 0.2], [0.8, 0.8]]],
[0.88],
)
monkeypatch.setattr("routers.ai.sam_registry.predict_interactive", fake_predict_interactive)
response = client.post("/api/ai/predict", json={
"image_id": frame["id"],
"prompt_type": "interactive",
"prompt_data": {
"box": [0.1, 0.1, 0.9, 0.9],
"points": [[0.5, 0.5], [0.2, 0.2]],
"labels": [1, 0],
},
"model": "sam2.1_hiera_small",
})
assert response.status_code == 200
assert response.json()["scores"] == [0.88]
assert calls == {
"model": "sam2.1_hiera_small",
"box": [0.1, 0.1, 0.9, 0.9],
"points": [[0.5, 0.5], [0.2, 0.2]],
"labels": [1, 0],
}
def test_model_status_reports_runtime(client, monkeypatch):
monkeypatch.setattr("routers.ai.sam_registry.runtime_status", lambda selected_model=None: {
"selected_model": "sam2.1_hiera_tiny",
"gpu": {
"available": False,
"device": "cpu",
"name": None,
"torch_available": True,
"torch_version": "2.x",
"cuda_version": None,
},
"models": [
{
"id": "sam2.1_hiera_tiny",
"label": "SAM 2.1 Tiny",
"available": True,
"loaded": False,
"device": "cpu",
"supports": ["point", "box", "auto"],
"message": "ready",
"package_available": True,
"checkpoint_exists": True,
"checkpoint_path": "model.pt",
"python_ok": True,
"torch_ok": True,
"cuda_required": False,
},
],
})
response = client.get("/api/ai/models/status")
assert response.status_code == 200
body = response.json()
assert body["selected_model"] == "sam2.1_hiera_tiny"
assert len(body["models"]) == 1
assert body["models"][0]["id"] == "sam2.1_hiera_tiny"
def test_model_status_rejects_disabled_sam3(client):
response = client.get("/api/ai/models/status?selected_model=sam3")
assert response.status_code == 400
assert "Unsupported model" in response.json()["detail"]
def test_analyze_mask_returns_backend_geometry_properties(client):
_, frame, _ = _create_project_and_frame(client)
response = client.post("/api/ai/analyze-mask", json={
"frame_id": frame["id"],
"mask_data": {
"polygons": [[[0.1, 0.1], [0.3, 0.1], [0.3, 0.3], [0.1, 0.3]]],
"source": "sam2.1_hiera_tiny",
"score": 0.87,
},
"extract_skeleton": True,
})
assert response.status_code == 200
body = response.json()
assert body["confidence"] == 0.87
assert body["confidence_source"] == "model_score"
assert body["topology_anchor_count"] == 4
assert body["area"] > 0
assert body["message"] == "已从后端重新提取几何拓扑锚点"
def test_propagate_saves_tracked_annotations(client, monkeypatch):
project = client.post("/api/projects", json={"name": "Video Project"}).json()
frames = [
client.post(f"/api/projects/{project['id']}/frames", json={
"project_id": project["id"],
"frame_index": idx,
"image_url": f"frames/{idx}.jpg",
"width": 640,
"height": 360,
}).json()
for idx in range(3)
]
calls = {}
monkeypatch.setattr("routers.ai.download_file", lambda object_name: b"jpeg")
def fake_propagate_video(model, frame_paths, source_frame_index, seed, direction, max_frames):
calls["model"] = model
calls["source_frame_index"] = source_frame_index
calls["seed"] = seed
calls["direction"] = direction
calls["max_frames"] = max_frames
calls["frame_count"] = len(frame_paths)
return [
{
"frame_index": 0,
"polygons": [[[0.1, 0.1], [0.2, 0.1], [0.2, 0.2]]],
"scores": [0.9],
"object_ids": [1],
},
{
"frame_index": 1,
"polygons": [[[0.15, 0.15], [0.25, 0.15], [0.25, 0.25]]],
"scores": [0.8],
"object_ids": [1],
},
]
monkeypatch.setattr("routers.ai.sam_registry.propagate_video", fake_propagate_video)
response = client.post("/api/ai/propagate", json={
"project_id": project["id"],
"frame_id": frames[0]["id"],
"model": "sam2.1_hiera_tiny",
"direction": "forward",
"max_frames": 2,
"include_source": False,
"seed": {
"polygons": [[[0.1, 0.1], [0.2, 0.1], [0.2, 0.2]]],
"bbox": [0.1, 0.1, 0.1, 0.1],
"label": "胆囊",
"color": "#ff0000",
"class_metadata": {"id": "c1", "name": "胆囊", "color": "#ff0000", "zIndex": 20},
"template_id": None,
},
})
assert response.status_code == 200
body = response.json()
assert body["created_annotation_count"] == 1
assert body["processed_frame_count"] == 2
assert calls["model"] == "sam2.1_hiera_tiny"
assert calls["source_frame_index"] == 0
assert calls["direction"] == "forward"
assert calls["frame_count"] == 2
saved = body["annotations"][0]
assert saved["frame_id"] == frames[1]["id"]
assert saved["mask_data"]["source"] == "sam2.1_hiera_tiny_propagation"
assert saved["mask_data"]["class"]["name"] == "胆囊"
assert saved["mask_data"]["score"] == 0.8
listing = client.get(f"/api/ai/annotations?project_id={project['id']}")
assert len(listing.json()) == 1
def test_queue_propagation_task_creates_processing_task(client, monkeypatch):
project = client.post("/api/projects", json={"name": "Queued Propagation"}).json()
frame = client.post(f"/api/projects/{project['id']}/frames", json={
"project_id": project["id"],
"frame_index": 0,
"image_url": "frames/0.jpg",
"width": 640,
"height": 360,
}).json()
class FakeAsyncResult:
id = "celery-propagate-1"
queued = []
monkeypatch.setattr("routers.ai.propagate_project_masks.delay", lambda task_id: queued.append(task_id) or FakeAsyncResult())
monkeypatch.setattr("routers.ai.publish_task_progress_event", lambda task: None)
response = client.post("/api/ai/propagate/task", json={
"project_id": project["id"],
"frame_id": frame["id"],
"model": "sam2.1_hiera_tiny",
"steps": [{
"direction": "forward",
"max_frames": 2,
"seed": {
"polygons": [[[0.1, 0.1], [0.2, 0.1], [0.2, 0.2]]],
"label": "胆囊",
},
}],
})
assert response.status_code == 202
body = response.json()
assert body["task_type"] == "propagate_masks"
assert body["status"] == "queued"
assert body["celery_task_id"] == "celery-propagate-1"
assert body["payload"]["model"] == "sam2.1_hiera_tiny"
assert body["payload"]["steps"][0]["seed"]["label"] == "胆囊"
assert queued == [body["id"]]
def test_queue_propagation_task_normalizes_model_and_rejects_unsupported(client, monkeypatch):
project = client.post("/api/projects", json={"name": "Propagation Model"}).json()
frame = client.post(f"/api/projects/{project['id']}/frames", json={
"project_id": project["id"],
"frame_index": 0,
"image_url": "frames/0.jpg",
"width": 640,
"height": 360,
}).json()
class FakeAsyncResult:
id = "celery-propagate-model"
monkeypatch.setattr("routers.ai.propagate_project_masks.delay", lambda task_id: FakeAsyncResult())
monkeypatch.setattr("routers.ai.publish_task_progress_event", lambda task: None)
response = client.post("/api/ai/propagate/task", json={
"project_id": project["id"],
"frame_id": frame["id"],
"model": "sam2",
"steps": [{
"direction": "forward",
"max_frames": 2,
"seed": {
"polygons": [[[0.1, 0.1], [0.2, 0.1], [0.2, 0.2]]],
},
}],
})
assert response.status_code == 202
assert response.json()["payload"]["model"] == "sam2.1_hiera_tiny"
unsupported = client.post("/api/ai/propagate/task", json={
"project_id": project["id"],
"frame_id": frame["id"],
"model": "sam3",
"steps": [{
"direction": "forward",
"max_frames": 2,
"seed": {
"polygons": [[[0.1, 0.1], [0.2, 0.1], [0.2, 0.2]]],
},
}],
})
assert unsupported.status_code == 400
assert "Unsupported model" in unsupported.json()["detail"]
def test_propagation_task_runner_saves_annotations_and_progress(client, db_session, monkeypatch):
project = client.post("/api/projects", json={"name": "Propagation Worker"}).json()
frames = [
client.post(f"/api/projects/{project['id']}/frames", json={
"project_id": project["id"],
"frame_index": idx,
"image_url": f"frames/{idx}.jpg",
"width": 640,
"height": 360,
}).json()
for idx in range(2)
]
task = ProcessingTask(
task_type="propagate_masks",
status="queued",
progress=0,
project_id=project["id"],
payload={
"project_id": project["id"],
"frame_id": frames[0]["id"],
"model": "sam2.1_hiera_tiny",
"include_source": False,
"save_annotations": True,
"steps": [{
"direction": "forward",
"max_frames": 2,
"seed": {
"polygons": [[[0.1, 0.1], [0.2, 0.1], [0.2, 0.2]]],
"label": "胆囊",
"color": "#ff0000",
"class_metadata": {"id": "c1", "name": "胆囊"},
},
}],
},
)
db_session.add(task)
db_session.commit()
db_session.refresh(task)
published = []
monkeypatch.setattr("services.propagation_task_runner.download_file", lambda object_name: b"jpeg")
monkeypatch.setattr("services.propagation_task_runner.publish_task_progress_event", lambda event_task: published.append((event_task.status, event_task.progress)))
def fake_propagate_video(model, frame_paths, source_frame_index, seed, direction, max_frames):
assert [Path(path).name for path in frame_paths] == ["000000.jpg", "000001.jpg"]
return [
{"frame_index": 0, "polygons": [[[0.1, 0.1], [0.2, 0.1], [0.2, 0.2]]], "scores": [0.9]},
{"frame_index": 1, "polygons": [[[0.15, 0.15], [0.25, 0.15], [0.25, 0.25]]], "scores": [0.8]},
]
monkeypatch.setattr("services.propagation_task_runner.sam_registry.propagate_video", fake_propagate_video)
result = run_propagate_project_task(db_session, task.id)
db_session.refresh(task)
assert task.status == "success"
assert task.progress == 100
assert task.result["model"] == "sam2.1_hiera_tiny"
assert task.result["steps"][0]["model"] == "sam2.1_hiera_tiny"
assert result["created_annotation_count"] == 1
assert result["processed_frame_count"] == 2
assert published[0][0] == "running"
assert published[-1] == ("success", 100)
listing = client.get(f"/api/ai/annotations?project_id={project['id']}")
assert listing.json()[0]["frame_id"] == frames[1]["id"]
assert listing.json()[0]["mask_data"]["source"] == "sam2.1_hiera_tiny_propagation"
def test_propagation_task_runner_skips_unchanged_seed_and_replaces_changed_seed(client, db_session, monkeypatch):
project = client.post("/api/projects", json={"name": "Propagation Dedupe"}).json()
frames = [
client.post(f"/api/projects/{project['id']}/frames", json={
"project_id": project["id"],
"frame_index": idx,
"image_url": f"frames/{idx}.jpg",
"width": 640,
"height": 360,
}).json()
for idx in range(2)
]
def make_task(seed_polygon):
task = ProcessingTask(
task_type="propagate_masks",
status="queued",
progress=0,
project_id=project["id"],
payload={
"project_id": project["id"],
"frame_id": frames[0]["id"],
"model": "sam2.1_hiera_tiny",
"include_source": False,
"save_annotations": True,
"steps": [{
"direction": "forward",
"max_frames": 2,
"seed": {
"polygons": [seed_polygon],
"label": "胆囊",
"color": "#ff0000",
"source_annotation_id": 7,
"source_mask_id": "annotation-7",
},
}],
},
)
db_session.add(task)
db_session.commit()
db_session.refresh(task)
return task
seed_polygon = [[0.1, 0.1], [0.2, 0.1], [0.2, 0.2]]
first_output_polygon = [[0.15, 0.15], [0.25, 0.15], [0.25, 0.25]]
changed_seed_polygon = [[0.2, 0.2], [0.3, 0.2], [0.3, 0.3]]
replacement_output_polygon = [[0.22, 0.22], [0.32, 0.22], [0.32, 0.32]]
monkeypatch.setattr("services.propagation_task_runner.download_file", lambda object_name: b"jpeg")
monkeypatch.setattr("services.propagation_task_runner.publish_task_progress_event", lambda event_task: None)
propagate_calls = []
def fake_propagate_video(model, frame_paths, source_frame_index, seed, direction, max_frames):
propagate_calls.append(seed["polygons"][0])
output_polygon = replacement_output_polygon if seed["polygons"][0] == changed_seed_polygon else first_output_polygon
return [
{"frame_index": 0, "polygons": [seed["polygons"][0]], "scores": [0.9]},
{"frame_index": 1, "polygons": [output_polygon], "scores": [0.8]},
]
monkeypatch.setattr("services.propagation_task_runner.sam_registry.propagate_video", fake_propagate_video)
first_result = run_propagate_project_task(db_session, make_task(seed_polygon).id)
assert first_result["created_annotation_count"] == 1
assert len(propagate_calls) == 1
unchanged_result = run_propagate_project_task(db_session, make_task(seed_polygon).id)
assert unchanged_result["created_annotation_count"] == 0
assert unchanged_result["skipped_seed_count"] == 1
assert len(propagate_calls) == 1
assert db_session.query(Annotation).filter(Annotation.project_id == project["id"]).count() == 1
changed_result = run_propagate_project_task(db_session, make_task(changed_seed_polygon).id)
assert changed_result["created_annotation_count"] == 1
assert changed_result["deleted_annotation_count"] == 1
assert len(propagate_calls) == 2
annotations = db_session.query(Annotation).filter(Annotation.project_id == project["id"]).all()
assert len(annotations) == 1
assert annotations[0].mask_data["polygons"] == [replacement_output_polygon]
assert annotations[0].mask_data["source_annotation_id"] == 7
def test_propagation_task_runner_replaces_legacy_or_different_weight_results(client, db_session, monkeypatch):
project = client.post("/api/projects", json={"name": "Propagation Legacy Cleanup"}).json()
frames = [
client.post(f"/api/projects/{project['id']}/frames", json={
"project_id": project["id"],
"frame_index": idx,
"image_url": f"frames/{idx}.jpg",
"width": 640,
"height": 360,
}).json()
for idx in range(2)
]
seed_polygon = [[0.1, 0.1], [0.2, 0.1], [0.2, 0.2]]
output_polygon = [[0.15, 0.15], [0.25, 0.15], [0.25, 0.25]]
db_session.add(Annotation(
project_id=project["id"],
frame_id=frames[1]["id"],
mask_data={
"polygons": [[[0.12, 0.12], [0.22, 0.12], [0.22, 0.22]]],
"label": "胆囊",
"color": "#ff0000",
"source": "sam2.1_hiera_tiny_propagation",
"propagated_from_frame_id": frames[0]["id"],
"propagation_seed_key": "mask:temporary-front-end-id",
"propagation_direction": "forward",
},
bbox=[0.12, 0.12, 0.1, 0.1],
))
db_session.commit()
task = ProcessingTask(
task_type="propagate_masks",
status="queued",
progress=0,
project_id=project["id"],
payload={
"project_id": project["id"],
"frame_id": frames[0]["id"],
"model": "sam2.1_hiera_small",
"include_source": False,
"save_annotations": True,
"steps": [{
"direction": "forward",
"max_frames": 2,
"seed": {
"polygons": [seed_polygon],
"label": "胆囊",
"color": "#ff0000",
"source_annotation_id": 7,
"source_mask_id": "annotation-7",
},
}],
},
)
db_session.add(task)
db_session.commit()
db_session.refresh(task)
monkeypatch.setattr("services.propagation_task_runner.download_file", lambda object_name: b"jpeg")
monkeypatch.setattr("services.propagation_task_runner.publish_task_progress_event", lambda event_task: None)
monkeypatch.setattr("services.propagation_task_runner.sam_registry.propagate_video", lambda model, frame_paths, source_frame_index, seed, direction, max_frames: [
{"frame_index": 0, "polygons": [seed["polygons"][0]], "scores": [0.9]},
{"frame_index": 1, "polygons": [output_polygon], "scores": [0.8]},
])
result = run_propagate_project_task(db_session, task.id)
assert result["created_annotation_count"] == 1
assert result["deleted_annotation_count"] == 1
annotations = db_session.query(Annotation).filter(Annotation.project_id == project["id"]).all()
assert len(annotations) == 1
assert annotations[0].mask_data["source"] == "sam2.1_hiera_small_propagation"
assert annotations[0].mask_data["source_annotation_id"] == 7
assert annotations[0].mask_data["polygons"] == [output_polygon]
def test_predict_validation_errors(client, monkeypatch):
project, _, _ = _create_project_and_frame(client)
assert client.post("/api/ai/predict", json={
"image_id": 999,
"prompt_type": "point",
"prompt_data": [[0.5, 0.5]],
}).status_code == 404
frame = client.post(f"/api/projects/{project['id']}/frames", json={
"project_id": project["id"],
"frame_index": 1,
"image_url": "frames/1.jpg",
}).json()
monkeypatch.setattr("routers.ai._load_frame_image", lambda frame: np.zeros((10, 10, 3), dtype=np.uint8))
assert client.post("/api/ai/predict", json={
"image_id": frame["id"],
"prompt_type": "box",
"prompt_data": [0.1, 0.2],
}).status_code == 400
def test_save_annotation_validates_project_and_frame(client):
project, frame, template = _create_project_and_frame(client)
saved = client.post("/api/ai/annotate", json={
"project_id": project["id"],
"frame_id": frame["id"],
"template_id": template["id"],
"mask_data": {"polygons": [[[0.1, 0.1], [0.9, 0.1], [0.9, 0.9]]]},
"points": [[0.5, 0.5]],
"bbox": [0.1, 0.1, 0.8, 0.8],
})
assert saved.status_code == 201
assert saved.json()["project_id"] == project["id"]
listing = client.get(f"/api/ai/annotations?project_id={project['id']}")
assert listing.status_code == 200
assert listing.json()[0]["id"] == saved.json()["id"]
frame_listing = client.get(f"/api/ai/annotations?project_id={project['id']}&frame_id={frame['id']}")
assert frame_listing.status_code == 200
assert len(frame_listing.json()) == 1
missing_project = client.post("/api/ai/annotate", json={"project_id": 999})
assert missing_project.status_code == 404
missing_frame = client.post("/api/ai/annotate", json={
"project_id": project["id"],
"frame_id": 999,
})
assert missing_frame.status_code == 404
missing_project_list = client.get("/api/ai/annotations?project_id=999")
assert missing_project_list.status_code == 404
def test_update_and_delete_annotation(client):
project, frame, template = _create_project_and_frame(client)
saved = client.post("/api/ai/annotate", json={
"project_id": project["id"],
"frame_id": frame["id"],
"template_id": template["id"],
"mask_data": {
"polygons": [[[0.1, 0.1], [0.9, 0.1], [0.9, 0.9]]],
"label": "AI Mask",
"color": "#06b6d4",
},
"points": [[0.5, 0.5]],
"bbox": [0.1, 0.1, 0.8, 0.8],
}).json()
updated = client.patch(f"/api/ai/annotations/{saved['id']}", json={
"template_id": template["id"],
"mask_data": {
"polygons": [[[0.2, 0.2], [0.8, 0.2], [0.8, 0.8]]],
"label": "胆囊",
"color": "#ff0000",
"class": {"id": "c1", "name": "胆囊", "color": "#ff0000", "zIndex": 20},
},
"points": [[0.4, 0.4]],
"bbox": [0.2, 0.2, 0.6, 0.6],
})
assert updated.status_code == 200
body = updated.json()
assert body["mask_data"]["label"] == "胆囊"
assert body["mask_data"]["class"]["id"] == "c1"
assert body["points"] == [[0.4, 0.4]]
assert body["bbox"] == [0.2, 0.2, 0.6, 0.6]
listing = client.get(f"/api/ai/annotations?project_id={project['id']}")
assert listing.status_code == 200
assert listing.json()[0]["mask_data"]["class"]["name"] == "胆囊"
deleted = client.delete(f"/api/ai/annotations/{saved['id']}")
assert deleted.status_code == 204
empty_listing = client.get(f"/api/ai/annotations?project_id={project['id']}")
assert empty_listing.status_code == 200
assert empty_listing.json() == []
def test_update_and_delete_annotation_validation(client):
project, frame, template = _create_project_and_frame(client)
saved = client.post("/api/ai/annotate", json={
"project_id": project["id"],
"frame_id": frame["id"],
"template_id": template["id"],
}).json()
assert client.patch("/api/ai/annotations/999", json={"bbox": [0, 0, 1, 1]}).status_code == 404
assert client.delete("/api/ai/annotations/999").status_code == 404
assert client.patch(
f"/api/ai/annotations/{saved['id']}",
json={"template_id": 999},
).status_code == 404
def test_import_gt_mask_creates_annotations_with_seed_points(client):
project, frame, template = _create_project_and_frame(client)
mask = np.zeros((360, 640), dtype=np.uint8)
cv2.rectangle(mask, (100, 80), (260, 220), 255, thickness=-1)
ok, encoded = cv2.imencode(".png", mask)
assert ok
response = client.post(
"/api/ai/import-gt-mask",
data={
"project_id": str(project["id"]),
"frame_id": str(frame["id"]),
"template_id": str(template["id"]),
"label": "Imported GT",
"color": "#22c55e",
},
files={"file": ("mask.png", encoded.tobytes(), "image/png")},
)
assert response.status_code == 201
body = response.json()
assert len(body) == 1
assert body[0]["project_id"] == project["id"]
assert body[0]["frame_id"] == frame["id"]
assert body[0]["template_id"] == template["id"]
assert body[0]["mask_data"]["label"] == "Imported GT"
assert body[0]["mask_data"]["source"] == "gt_mask"
assert body[0]["mask_data"]["gt_label_value"] == 255
assert len(body[0]["mask_data"]["polygons"][0]) >= 3
assert len(body[0]["points"]) == 1
assert 0.0 <= body[0]["points"][0][0] <= 1.0
assert 0.0 <= body[0]["points"][0][1] <= 1.0
def test_import_gt_mask_splits_label_values(client):
project, frame, _ = _create_project_and_frame(client)
mask = np.zeros((360, 640), dtype=np.uint8)
cv2.rectangle(mask, (20, 20), (120, 120), 1, thickness=-1)
cv2.rectangle(mask, (220, 80), (320, 180), 2, thickness=-1)
ok, encoded = cv2.imencode(".png", mask)
assert ok
response = client.post(
"/api/ai/import-gt-mask",
data={
"project_id": str(project["id"]),
"frame_id": str(frame["id"]),
"label": "GT Class",
},
files={"file": ("labels.png", encoded.tobytes(), "image/png")},
)
assert response.status_code == 201
body = sorted(response.json(), key=lambda item: item["mask_data"]["gt_label_value"])
assert [item["mask_data"]["gt_label_value"] for item in body] == [1, 2]
assert [item["mask_data"]["label"] for item in body] == ["GT Class 1", "GT Class 2"]
assert all(len(item["points"]) == 1 for item in body)