- 在传播结果写入前的空间清理中识别同一任务里的其它 seed,避免后续同类 seed 删除前面刚写入的传播结果。 - 保留旧结果替换逻辑:不属于本次其它 seed 的旧传播结果仍可按语义和空间重叠被新 seed 替换。 - 将同类多实例传播回归测试改为重叠输出场景,覆盖此前两个同类方块只保留一个的问题。
1499 lines
56 KiB
Python
1499 lines
56 KiB
Python
import numpy as np
|
|
import cv2
|
|
from pathlib import Path
|
|
from models import Annotation, ProcessingTask
|
|
from services.propagation_task_runner import _seed_signature, run_propagate_project_task
|
|
|
|
|
|
def _create_project_and_frame(client):
|
|
project = client.post("/api/projects", json={"name": "AI Project"}).json()
|
|
frame = client.post(f"/api/projects/{project['id']}/frames", json={
|
|
"project_id": project["id"],
|
|
"frame_index": 0,
|
|
"image_url": "frames/0.jpg",
|
|
"width": 640,
|
|
"height": 360,
|
|
}).json()
|
|
template = client.post("/api/templates", json={
|
|
"name": "Template",
|
|
"color": "#06b6d4",
|
|
"z_index": 0,
|
|
"classes": [],
|
|
"rules": [],
|
|
}).json()
|
|
return project, frame, template
|
|
|
|
|
|
def test_predict_accepts_point_object_with_labels(client, monkeypatch):
|
|
_, frame, _ = _create_project_and_frame(client)
|
|
calls = {}
|
|
|
|
monkeypatch.setattr("routers.ai._load_frame_image", lambda frame: np.zeros((10, 10, 3), dtype=np.uint8))
|
|
|
|
def fake_predict_points(image, points, labels):
|
|
calls["args"] = (points, labels)
|
|
return (
|
|
[[[0.1, 0.1], [0.9, 0.1], [0.9, 0.9]]],
|
|
[0.95],
|
|
)
|
|
|
|
monkeypatch.setattr("routers.ai.sam_registry.predict_points", lambda model, image, points, labels: fake_predict_points(image, points, labels))
|
|
|
|
response = client.post("/api/ai/predict", json={
|
|
"image_id": frame["id"],
|
|
"prompt_type": "point",
|
|
"prompt_data": {"points": [[0.5, 0.5], [0.1, 0.1]], "labels": [1, 0]},
|
|
})
|
|
|
|
assert response.status_code == 200
|
|
assert response.json()["scores"] == [0.95]
|
|
assert calls["args"] == ([[0.5, 0.5], [0.1, 0.1]], [1, 0])
|
|
|
|
|
|
def test_predict_applies_crop_and_background_filter_options(client, monkeypatch):
|
|
_, frame, _ = _create_project_and_frame(client)
|
|
calls = {}
|
|
monkeypatch.setattr("routers.ai._load_frame_image", lambda frame: np.zeros((100, 200, 3), dtype=np.uint8))
|
|
|
|
def fake_predict_points(model, image, points, labels):
|
|
calls["shape"] = image.shape
|
|
calls["points"] = points
|
|
calls["labels"] = labels
|
|
return (
|
|
[
|
|
[[0.0, 0.0], [0.2, 0.0], [0.2, 0.2]],
|
|
[[0.45, 0.45], [0.55, 0.45], [0.55, 0.55]],
|
|
],
|
|
[0.9, 0.01],
|
|
)
|
|
|
|
monkeypatch.setattr("routers.ai.sam_registry.predict_points", fake_predict_points)
|
|
|
|
response = client.post("/api/ai/predict", json={
|
|
"image_id": frame["id"],
|
|
"prompt_type": "point",
|
|
"prompt_data": {"points": [[0.5, 0.5], [0.52, 0.52]], "labels": [1, 0]},
|
|
"options": {
|
|
"crop_to_prompt": True,
|
|
"crop_margin": 0.1,
|
|
"auto_filter_background": True,
|
|
"min_score": 0.05,
|
|
},
|
|
})
|
|
|
|
assert response.status_code == 200
|
|
assert calls["shape"][0] < 100
|
|
assert calls["shape"][1] < 200
|
|
assert calls["labels"] == [1, 0]
|
|
assert response.json()["scores"] == [0.9]
|
|
polygon = response.json()["polygons"][0]
|
|
assert all(0.0 <= coord <= 1.0 for point in polygon for coord in point)
|
|
|
|
|
|
def test_predict_box_and_rejects_semantic_prompt(client, monkeypatch):
|
|
_, frame, _ = _create_project_and_frame(client)
|
|
monkeypatch.setattr("routers.ai._load_frame_image", lambda frame: np.zeros((10, 10, 3), dtype=np.uint8))
|
|
monkeypatch.setattr("routers.ai.sam_registry.predict_box", lambda model, image, box: (
|
|
[[[0.2, 0.2], [0.8, 0.2], [0.8, 0.8]]],
|
|
[0.8],
|
|
))
|
|
|
|
box_response = client.post("/api/ai/predict", json={
|
|
"image_id": frame["id"],
|
|
"prompt_type": "box",
|
|
"prompt_data": [0.2, 0.2, 0.8, 0.8],
|
|
})
|
|
semantic_response = client.post("/api/ai/predict", json={
|
|
"image_id": frame["id"],
|
|
"prompt_type": "semantic",
|
|
"prompt_data": "胆囊",
|
|
"model": "sam3",
|
|
"options": {"min_score": 0.05},
|
|
})
|
|
|
|
assert box_response.status_code == 200
|
|
assert box_response.json()["scores"] == [0.8]
|
|
assert semantic_response.status_code == 400
|
|
assert "Unsupported model: sam3" in semantic_response.json()["detail"]
|
|
|
|
|
|
def test_predict_interactive_combines_box_and_points(client, monkeypatch):
|
|
_, frame, _ = _create_project_and_frame(client)
|
|
calls = {}
|
|
monkeypatch.setattr("routers.ai._load_frame_image", lambda frame: np.zeros((10, 10, 3), dtype=np.uint8))
|
|
|
|
def fake_predict_interactive(model, image, box, points, labels):
|
|
calls["model"] = model
|
|
calls["box"] = box
|
|
calls["points"] = points
|
|
calls["labels"] = labels
|
|
return (
|
|
[[[0.2, 0.2], [0.8, 0.2], [0.8, 0.8]]],
|
|
[0.88],
|
|
)
|
|
|
|
monkeypatch.setattr("routers.ai.sam_registry.predict_interactive", fake_predict_interactive)
|
|
|
|
response = client.post("/api/ai/predict", json={
|
|
"image_id": frame["id"],
|
|
"prompt_type": "interactive",
|
|
"prompt_data": {
|
|
"box": [0.1, 0.1, 0.9, 0.9],
|
|
"points": [[0.5, 0.5], [0.2, 0.2]],
|
|
"labels": [1, 0],
|
|
},
|
|
"model": "sam2.1_hiera_small",
|
|
})
|
|
|
|
assert response.status_code == 200
|
|
assert response.json()["scores"] == [0.88]
|
|
assert calls == {
|
|
"model": "sam2.1_hiera_small",
|
|
"box": [0.1, 0.1, 0.9, 0.9],
|
|
"points": [[0.5, 0.5], [0.2, 0.2]],
|
|
"labels": [1, 0],
|
|
}
|
|
|
|
|
|
def test_model_status_reports_runtime(client, monkeypatch):
|
|
monkeypatch.setattr("routers.ai.sam_registry.runtime_status", lambda selected_model=None: {
|
|
"selected_model": "sam2.1_hiera_tiny",
|
|
"gpu": {
|
|
"available": False,
|
|
"device": "cpu",
|
|
"name": None,
|
|
"torch_available": True,
|
|
"torch_version": "2.x",
|
|
"cuda_version": None,
|
|
},
|
|
"models": [
|
|
{
|
|
"id": "sam2.1_hiera_tiny",
|
|
"label": "SAM 2.1 Tiny",
|
|
"available": True,
|
|
"loaded": False,
|
|
"device": "cpu",
|
|
"supports": ["point", "box", "auto"],
|
|
"message": "ready",
|
|
"package_available": True,
|
|
"checkpoint_exists": True,
|
|
"checkpoint_path": "model.pt",
|
|
"python_ok": True,
|
|
"torch_ok": True,
|
|
"cuda_required": False,
|
|
},
|
|
],
|
|
})
|
|
|
|
response = client.get("/api/ai/models/status")
|
|
|
|
assert response.status_code == 200
|
|
body = response.json()
|
|
assert body["selected_model"] == "sam2.1_hiera_tiny"
|
|
assert len(body["models"]) == 1
|
|
assert body["models"][0]["id"] == "sam2.1_hiera_tiny"
|
|
|
|
|
|
def test_model_status_rejects_disabled_sam3(client):
|
|
response = client.get("/api/ai/models/status?selected_model=sam3")
|
|
|
|
assert response.status_code == 400
|
|
assert "Unsupported model" in response.json()["detail"]
|
|
|
|
|
|
def test_analyze_mask_returns_backend_geometry_properties(client):
|
|
_, frame, _ = _create_project_and_frame(client)
|
|
|
|
response = client.post("/api/ai/analyze-mask", json={
|
|
"frame_id": frame["id"],
|
|
"mask_data": {
|
|
"polygons": [[[0.1, 0.1], [0.3, 0.1], [0.3, 0.3], [0.1, 0.3]]],
|
|
"source": "sam2.1_hiera_tiny",
|
|
"score": 0.87,
|
|
},
|
|
"extract_skeleton": True,
|
|
})
|
|
|
|
assert response.status_code == 200
|
|
body = response.json()
|
|
assert body["confidence"] == 0.87
|
|
assert body["confidence_source"] == "model_score"
|
|
assert body["topology_anchor_count"] == 4
|
|
assert body["area"] > 0
|
|
assert body["message"] == "已从后端重新提取几何拓扑锚点"
|
|
|
|
|
|
def test_analyze_mask_reports_actual_polygon_anchor_count(client):
|
|
_, frame, _ = _create_project_and_frame(client)
|
|
polygon = [[0.1 + index * 0.005, 0.1 + (0.01 if index % 2 else 0)] for index in range(80)]
|
|
|
|
response = client.post("/api/ai/analyze-mask", json={
|
|
"frame_id": frame["id"],
|
|
"mask_data": {
|
|
"polygons": [polygon],
|
|
"label": "AI Mask",
|
|
"color": "#06b6d4",
|
|
},
|
|
"points": [[0.2, 0.2]],
|
|
})
|
|
|
|
assert response.status_code == 200
|
|
body = response.json()
|
|
assert body["topology_anchor_count"] == len(polygon)
|
|
assert len(body["topology_anchors"]) <= 64
|
|
|
|
|
|
def test_smooth_mask_simplifies_noisy_ai_polygon(client):
|
|
_, frame, _ = _create_project_and_frame(client)
|
|
polygon = []
|
|
for index in range(20):
|
|
polygon.append([0.1 + index * 0.02, 0.1 + (0.01 if index % 2 else 0)])
|
|
for index in range(20):
|
|
polygon.append([0.5 + (0.01 if index % 2 else 0), 0.1 + index * 0.02])
|
|
for index in range(20):
|
|
polygon.append([0.5 - index * 0.02, 0.5 + (0.01 if index % 2 else 0)])
|
|
for index in range(20):
|
|
polygon.append([0.1 + (0.01 if index % 2 else 0), 0.5 - index * 0.02])
|
|
|
|
response = client.post("/api/ai/smooth-mask", json={
|
|
"frame_id": frame["id"],
|
|
"mask_data": {
|
|
"polygons": [polygon],
|
|
"label": "AI Mask",
|
|
"color": "#06b6d4",
|
|
},
|
|
"strength": 80,
|
|
})
|
|
|
|
assert response.status_code == 200
|
|
body = response.json()
|
|
assert body["topology_anchor_count"] == len(body["polygons"][0])
|
|
assert len(body["polygons"][0]) < len(polygon)
|
|
|
|
|
|
def test_smooth_mask_uses_eased_strength_curve(client):
|
|
_, frame, _ = _create_project_and_frame(client)
|
|
polygon = []
|
|
for index in range(20):
|
|
polygon.append([0.1 + index * 0.02, 0.1 + (0.01 if index % 2 else 0)])
|
|
for index in range(20):
|
|
polygon.append([0.5 + (0.01 if index % 2 else 0), 0.1 + index * 0.02])
|
|
for index in range(20):
|
|
polygon.append([0.5 - index * 0.02, 0.5 + (0.01 if index % 2 else 0)])
|
|
for index in range(20):
|
|
polygon.append([0.1 + (0.01 if index % 2 else 0), 0.5 - index * 0.02])
|
|
|
|
def smoothed_count(strength: int) -> int:
|
|
response = client.post("/api/ai/smooth-mask", json={
|
|
"frame_id": frame["id"],
|
|
"mask_data": {
|
|
"polygons": [polygon],
|
|
"label": "AI Mask",
|
|
"color": "#06b6d4",
|
|
},
|
|
"strength": strength,
|
|
})
|
|
assert response.status_code == 200
|
|
return len(response.json()["polygons"][0])
|
|
|
|
low_count = smoothed_count(20)
|
|
mid_count = smoothed_count(70)
|
|
high_count = smoothed_count(95)
|
|
|
|
assert low_count <= len(polygon)
|
|
assert mid_count < low_count
|
|
assert high_count < mid_count
|
|
|
|
|
|
def test_smooth_mask_returns_backend_smoothed_geometry(client):
|
|
_, frame, _ = _create_project_and_frame(client)
|
|
|
|
response = client.post("/api/ai/smooth-mask", json={
|
|
"frame_id": frame["id"],
|
|
"mask_data": {
|
|
"polygons": [[[0.1, 0.1], [0.3, 0.1], [0.3, 0.3], [0.1, 0.3]]],
|
|
"label": "胆囊",
|
|
"color": "#ff0000",
|
|
},
|
|
"strength": 45,
|
|
})
|
|
|
|
assert response.status_code == 200
|
|
body = response.json()
|
|
assert body["smoothing"] == {"strength": 45.0, "method": "chaikin"}
|
|
assert len(body["polygons"]) == 1
|
|
assert len(body["polygons"][0]) > 4
|
|
assert body["topology_anchor_count"] > 0
|
|
assert body["message"] == "已应用边缘平滑强度 45"
|
|
|
|
|
|
def test_seed_signature_includes_smoothing_parameters():
|
|
seed = {
|
|
"polygons": [[[0.1, 0.1], [0.3, 0.1], [0.3, 0.3]]],
|
|
"label": "胆囊",
|
|
"color": "#ff0000",
|
|
}
|
|
|
|
assert _seed_signature({**seed, "smoothing": {"strength": 20, "method": "chaikin"}}) != _seed_signature({
|
|
**seed,
|
|
"smoothing": {"strength": 40, "method": "chaikin"},
|
|
})
|
|
|
|
|
|
def test_propagate_saves_tracked_annotations(client, monkeypatch):
|
|
project = client.post("/api/projects", json={"name": "Video Project"}).json()
|
|
frames = [
|
|
client.post(f"/api/projects/{project['id']}/frames", json={
|
|
"project_id": project["id"],
|
|
"frame_index": idx,
|
|
"image_url": f"frames/{idx}.jpg",
|
|
"width": 640,
|
|
"height": 360,
|
|
}).json()
|
|
for idx in range(3)
|
|
]
|
|
calls = {}
|
|
monkeypatch.setattr("routers.ai.download_file", lambda object_name: b"jpeg")
|
|
|
|
def fake_propagate_video(model, frame_paths, source_frame_index, seed, direction, max_frames):
|
|
calls["model"] = model
|
|
calls["source_frame_index"] = source_frame_index
|
|
calls["seed"] = seed
|
|
calls["direction"] = direction
|
|
calls["max_frames"] = max_frames
|
|
calls["frame_count"] = len(frame_paths)
|
|
return [
|
|
{
|
|
"frame_index": 0,
|
|
"polygons": [[[0.1, 0.1], [0.2, 0.1], [0.2, 0.2]]],
|
|
"scores": [0.9],
|
|
"object_ids": [1],
|
|
},
|
|
{
|
|
"frame_index": 1,
|
|
"polygons": [[[0.15, 0.15], [0.25, 0.15], [0.25, 0.25]]],
|
|
"scores": [0.8],
|
|
"object_ids": [1],
|
|
},
|
|
]
|
|
|
|
monkeypatch.setattr("routers.ai.sam_registry.propagate_video", fake_propagate_video)
|
|
|
|
response = client.post("/api/ai/propagate", json={
|
|
"project_id": project["id"],
|
|
"frame_id": frames[0]["id"],
|
|
"model": "sam2.1_hiera_tiny",
|
|
"direction": "forward",
|
|
"max_frames": 2,
|
|
"include_source": False,
|
|
"seed": {
|
|
"polygons": [[[0.1, 0.1], [0.2, 0.1], [0.2, 0.2]]],
|
|
"bbox": [0.1, 0.1, 0.1, 0.1],
|
|
"label": "胆囊",
|
|
"color": "#ff0000",
|
|
"class_metadata": {"id": "c1", "name": "胆囊", "color": "#ff0000", "zIndex": 20},
|
|
"template_id": None,
|
|
"smoothing": {"strength": 45, "method": "chaikin"},
|
|
},
|
|
})
|
|
|
|
assert response.status_code == 200
|
|
body = response.json()
|
|
assert body["created_annotation_count"] == 1
|
|
assert body["processed_frame_count"] == 2
|
|
assert calls["model"] == "sam2.1_hiera_tiny"
|
|
assert calls["source_frame_index"] == 0
|
|
assert calls["direction"] == "forward"
|
|
assert calls["frame_count"] == 2
|
|
saved = body["annotations"][0]
|
|
assert saved["frame_id"] == frames[1]["id"]
|
|
assert saved["mask_data"]["source"] == "sam2.1_hiera_tiny_propagation"
|
|
assert saved["mask_data"]["class"]["name"] == "胆囊"
|
|
assert saved["mask_data"]["score"] == 0.8
|
|
assert saved["mask_data"]["geometry_smoothing"] == {"strength": 45.0, "method": "chaikin"}
|
|
assert saved["mask_data"]["polygons"][0] != [[0.15, 0.15], [0.25, 0.15], [0.25, 0.25]]
|
|
assert len(saved["mask_data"]["polygons"][0]) > 3
|
|
|
|
listing = client.get(f"/api/ai/annotations?project_id={project['id']}")
|
|
assert len(listing.json()) == 1
|
|
|
|
|
|
def test_queue_propagation_task_creates_processing_task(client, monkeypatch):
|
|
project = client.post("/api/projects", json={"name": "Queued Propagation"}).json()
|
|
frame = client.post(f"/api/projects/{project['id']}/frames", json={
|
|
"project_id": project["id"],
|
|
"frame_index": 0,
|
|
"image_url": "frames/0.jpg",
|
|
"width": 640,
|
|
"height": 360,
|
|
}).json()
|
|
|
|
class FakeAsyncResult:
|
|
id = "celery-propagate-1"
|
|
|
|
queued = []
|
|
monkeypatch.setattr("routers.ai.propagate_project_masks.delay", lambda task_id: queued.append(task_id) or FakeAsyncResult())
|
|
monkeypatch.setattr("routers.ai.publish_task_progress_event", lambda task: None)
|
|
|
|
response = client.post("/api/ai/propagate/task", json={
|
|
"project_id": project["id"],
|
|
"frame_id": frame["id"],
|
|
"model": "sam2.1_hiera_tiny",
|
|
"steps": [{
|
|
"direction": "forward",
|
|
"max_frames": 2,
|
|
"seed": {
|
|
"polygons": [[[0.1, 0.1], [0.2, 0.1], [0.2, 0.2]]],
|
|
"label": "胆囊",
|
|
"smoothing": {"strength": 35, "method": "chaikin"},
|
|
},
|
|
}],
|
|
})
|
|
|
|
assert response.status_code == 202
|
|
body = response.json()
|
|
assert body["task_type"] == "propagate_masks"
|
|
assert body["status"] == "queued"
|
|
assert body["celery_task_id"] == "celery-propagate-1"
|
|
assert body["payload"]["model"] == "sam2.1_hiera_tiny"
|
|
assert body["payload"]["steps"][0]["seed"]["label"] == "胆囊"
|
|
assert body["payload"]["steps"][0]["seed"]["smoothing"] == {"strength": 35, "method": "chaikin"}
|
|
assert queued == [body["id"]]
|
|
|
|
|
|
def test_queue_propagation_task_normalizes_model_and_rejects_unsupported(client, monkeypatch):
|
|
project = client.post("/api/projects", json={"name": "Propagation Model"}).json()
|
|
frame = client.post(f"/api/projects/{project['id']}/frames", json={
|
|
"project_id": project["id"],
|
|
"frame_index": 0,
|
|
"image_url": "frames/0.jpg",
|
|
"width": 640,
|
|
"height": 360,
|
|
}).json()
|
|
|
|
class FakeAsyncResult:
|
|
id = "celery-propagate-model"
|
|
|
|
monkeypatch.setattr("routers.ai.propagate_project_masks.delay", lambda task_id: FakeAsyncResult())
|
|
monkeypatch.setattr("routers.ai.publish_task_progress_event", lambda task: None)
|
|
|
|
response = client.post("/api/ai/propagate/task", json={
|
|
"project_id": project["id"],
|
|
"frame_id": frame["id"],
|
|
"model": "sam2",
|
|
"steps": [{
|
|
"direction": "forward",
|
|
"max_frames": 2,
|
|
"seed": {
|
|
"polygons": [[[0.1, 0.1], [0.2, 0.1], [0.2, 0.2]]],
|
|
},
|
|
}],
|
|
})
|
|
|
|
assert response.status_code == 202
|
|
assert response.json()["payload"]["model"] == "sam2.1_hiera_tiny"
|
|
|
|
unsupported = client.post("/api/ai/propagate/task", json={
|
|
"project_id": project["id"],
|
|
"frame_id": frame["id"],
|
|
"model": "sam3",
|
|
"steps": [{
|
|
"direction": "forward",
|
|
"max_frames": 2,
|
|
"seed": {
|
|
"polygons": [[[0.1, 0.1], [0.2, 0.1], [0.2, 0.2]]],
|
|
},
|
|
}],
|
|
})
|
|
|
|
assert unsupported.status_code == 400
|
|
assert "Unsupported model" in unsupported.json()["detail"]
|
|
|
|
|
|
def test_propagation_task_runner_saves_annotations_and_progress(client, db_session, monkeypatch):
|
|
project = client.post("/api/projects", json={"name": "Propagation Worker"}).json()
|
|
frames = [
|
|
client.post(f"/api/projects/{project['id']}/frames", json={
|
|
"project_id": project["id"],
|
|
"frame_index": idx,
|
|
"image_url": f"frames/{idx}.jpg",
|
|
"width": 640,
|
|
"height": 360,
|
|
}).json()
|
|
for idx in range(2)
|
|
]
|
|
task = ProcessingTask(
|
|
task_type="propagate_masks",
|
|
status="queued",
|
|
progress=0,
|
|
project_id=project["id"],
|
|
payload={
|
|
"project_id": project["id"],
|
|
"frame_id": frames[0]["id"],
|
|
"model": "sam2.1_hiera_tiny",
|
|
"include_source": False,
|
|
"save_annotations": True,
|
|
"steps": [{
|
|
"direction": "forward",
|
|
"max_frames": 2,
|
|
"seed": {
|
|
"polygons": [[[0.1, 0.1], [0.2, 0.1], [0.2, 0.2]]],
|
|
"label": "胆囊",
|
|
"color": "#ff0000",
|
|
"class_metadata": {"id": "c1", "name": "胆囊"},
|
|
"smoothing": {"strength": 40, "method": "chaikin"},
|
|
},
|
|
}],
|
|
},
|
|
)
|
|
db_session.add(task)
|
|
db_session.commit()
|
|
db_session.refresh(task)
|
|
|
|
published = []
|
|
monkeypatch.setattr("services.propagation_task_runner.download_file", lambda object_name: b"jpeg")
|
|
monkeypatch.setattr("services.propagation_task_runner.publish_task_progress_event", lambda event_task: published.append((event_task.status, event_task.progress)))
|
|
def fake_propagate_video(model, frame_paths, source_frame_index, seed, direction, max_frames):
|
|
assert [Path(path).name for path in frame_paths] == ["000000.jpg", "000001.jpg"]
|
|
return [
|
|
{"frame_index": 0, "polygons": [[[0.1, 0.1], [0.2, 0.1], [0.2, 0.2]]], "scores": [0.9]},
|
|
{"frame_index": 1, "polygons": [[[0.15, 0.15], [0.25, 0.15], [0.25, 0.25]]], "scores": [0.8]},
|
|
]
|
|
|
|
monkeypatch.setattr("services.propagation_task_runner.sam_registry.propagate_video", fake_propagate_video)
|
|
|
|
result = run_propagate_project_task(db_session, task.id)
|
|
|
|
db_session.refresh(task)
|
|
assert task.status == "success"
|
|
assert task.progress == 100
|
|
assert task.result["model"] == "sam2.1_hiera_tiny"
|
|
assert task.result["steps"][0]["model"] == "sam2.1_hiera_tiny"
|
|
assert result["created_annotation_count"] == 1
|
|
assert result["processed_frame_count"] == 2
|
|
assert published[0][0] == "running"
|
|
assert published[-1] == ("success", 100)
|
|
listing = client.get(f"/api/ai/annotations?project_id={project['id']}")
|
|
assert listing.json()[0]["frame_id"] == frames[1]["id"]
|
|
assert listing.json()[0]["mask_data"]["source"] == "sam2.1_hiera_tiny_propagation"
|
|
stored_polygon = listing.json()[0]["mask_data"]["polygons"][0]
|
|
assert listing.json()[0]["mask_data"]["geometry_smoothing"] == {"strength": 40.0, "method": "chaikin"}
|
|
assert stored_polygon != [[0.15, 0.15], [0.25, 0.15], [0.25, 0.25]]
|
|
assert len(stored_polygon) > 3
|
|
|
|
|
|
def test_propagation_task_runner_skips_unchanged_seed_and_replaces_changed_seed(client, db_session, monkeypatch):
|
|
project = client.post("/api/projects", json={"name": "Propagation Dedupe"}).json()
|
|
frames = [
|
|
client.post(f"/api/projects/{project['id']}/frames", json={
|
|
"project_id": project["id"],
|
|
"frame_index": idx,
|
|
"image_url": f"frames/{idx}.jpg",
|
|
"width": 640,
|
|
"height": 360,
|
|
}).json()
|
|
for idx in range(2)
|
|
]
|
|
|
|
def make_task(seed_polygon):
|
|
task = ProcessingTask(
|
|
task_type="propagate_masks",
|
|
status="queued",
|
|
progress=0,
|
|
project_id=project["id"],
|
|
payload={
|
|
"project_id": project["id"],
|
|
"frame_id": frames[0]["id"],
|
|
"model": "sam2.1_hiera_tiny",
|
|
"include_source": False,
|
|
"save_annotations": True,
|
|
"steps": [{
|
|
"direction": "forward",
|
|
"max_frames": 2,
|
|
"seed": {
|
|
"polygons": [seed_polygon],
|
|
"label": "胆囊",
|
|
"color": "#ff0000",
|
|
"source_annotation_id": 7,
|
|
"source_mask_id": "annotation-7",
|
|
},
|
|
}],
|
|
},
|
|
)
|
|
db_session.add(task)
|
|
db_session.commit()
|
|
db_session.refresh(task)
|
|
return task
|
|
|
|
seed_polygon = [[0.1, 0.1], [0.2, 0.1], [0.2, 0.2]]
|
|
first_output_polygon = [[0.15, 0.15], [0.25, 0.15], [0.25, 0.25]]
|
|
changed_seed_polygon = [[0.2, 0.2], [0.3, 0.2], [0.3, 0.3]]
|
|
replacement_output_polygon = [[0.22, 0.22], [0.32, 0.22], [0.32, 0.32]]
|
|
|
|
monkeypatch.setattr("services.propagation_task_runner.download_file", lambda object_name: b"jpeg")
|
|
monkeypatch.setattr("services.propagation_task_runner.publish_task_progress_event", lambda event_task: None)
|
|
propagate_calls = []
|
|
|
|
def fake_propagate_video(model, frame_paths, source_frame_index, seed, direction, max_frames):
|
|
propagate_calls.append(seed["polygons"][0])
|
|
output_polygon = replacement_output_polygon if seed["polygons"][0] == changed_seed_polygon else first_output_polygon
|
|
return [
|
|
{"frame_index": 0, "polygons": [seed["polygons"][0]], "scores": [0.9]},
|
|
{"frame_index": 1, "polygons": [output_polygon], "scores": [0.8]},
|
|
]
|
|
|
|
monkeypatch.setattr("services.propagation_task_runner.sam_registry.propagate_video", fake_propagate_video)
|
|
|
|
first_result = run_propagate_project_task(db_session, make_task(seed_polygon).id)
|
|
assert first_result["created_annotation_count"] == 1
|
|
assert len(propagate_calls) == 1
|
|
|
|
unchanged_result = run_propagate_project_task(db_session, make_task(seed_polygon).id)
|
|
assert unchanged_result["created_annotation_count"] == 0
|
|
assert unchanged_result["skipped_seed_count"] == 1
|
|
assert len(propagate_calls) == 1
|
|
assert db_session.query(Annotation).filter(Annotation.project_id == project["id"]).count() == 1
|
|
|
|
changed_result = run_propagate_project_task(db_session, make_task(changed_seed_polygon).id)
|
|
assert changed_result["created_annotation_count"] == 1
|
|
assert changed_result["deleted_annotation_count"] == 1
|
|
assert len(propagate_calls) == 2
|
|
annotations = db_session.query(Annotation).filter(Annotation.project_id == project["id"]).all()
|
|
assert len(annotations) == 1
|
|
assert annotations[0].mask_data["polygons"] == [replacement_output_polygon]
|
|
assert annotations[0].mask_data["source_annotation_id"] == 7
|
|
|
|
|
|
def test_propagation_task_runner_replaces_legacy_or_different_weight_results(client, db_session, monkeypatch):
|
|
project = client.post("/api/projects", json={"name": "Propagation Legacy Cleanup"}).json()
|
|
frames = [
|
|
client.post(f"/api/projects/{project['id']}/frames", json={
|
|
"project_id": project["id"],
|
|
"frame_index": idx,
|
|
"image_url": f"frames/{idx}.jpg",
|
|
"width": 640,
|
|
"height": 360,
|
|
}).json()
|
|
for idx in range(2)
|
|
]
|
|
|
|
seed_polygon = [[0.1, 0.1], [0.2, 0.1], [0.2, 0.2]]
|
|
output_polygon = [[0.15, 0.15], [0.25, 0.15], [0.25, 0.25]]
|
|
|
|
db_session.add(Annotation(
|
|
project_id=project["id"],
|
|
frame_id=frames[1]["id"],
|
|
mask_data={
|
|
"polygons": [[[0.12, 0.12], [0.22, 0.12], [0.22, 0.22]]],
|
|
"label": "胆囊",
|
|
"color": "#ff0000",
|
|
"source": "sam2.1_hiera_tiny_propagation",
|
|
"propagated_from_frame_id": frames[0]["id"],
|
|
"propagation_seed_key": "mask:temporary-front-end-id",
|
|
"propagation_direction": "forward",
|
|
},
|
|
bbox=[0.12, 0.12, 0.1, 0.1],
|
|
))
|
|
db_session.commit()
|
|
|
|
task = ProcessingTask(
|
|
task_type="propagate_masks",
|
|
status="queued",
|
|
progress=0,
|
|
project_id=project["id"],
|
|
payload={
|
|
"project_id": project["id"],
|
|
"frame_id": frames[0]["id"],
|
|
"model": "sam2.1_hiera_small",
|
|
"include_source": False,
|
|
"save_annotations": True,
|
|
"steps": [{
|
|
"direction": "forward",
|
|
"max_frames": 2,
|
|
"seed": {
|
|
"polygons": [seed_polygon],
|
|
"label": "胆囊",
|
|
"color": "#ff0000",
|
|
"source_annotation_id": 7,
|
|
"source_mask_id": "annotation-7",
|
|
},
|
|
}],
|
|
},
|
|
)
|
|
db_session.add(task)
|
|
db_session.commit()
|
|
db_session.refresh(task)
|
|
|
|
monkeypatch.setattr("services.propagation_task_runner.download_file", lambda object_name: b"jpeg")
|
|
monkeypatch.setattr("services.propagation_task_runner.publish_task_progress_event", lambda event_task: None)
|
|
monkeypatch.setattr("services.propagation_task_runner.sam_registry.propagate_video", lambda model, frame_paths, source_frame_index, seed, direction, max_frames: [
|
|
{"frame_index": 0, "polygons": [seed["polygons"][0]], "scores": [0.9]},
|
|
{"frame_index": 1, "polygons": [output_polygon], "scores": [0.8]},
|
|
])
|
|
|
|
result = run_propagate_project_task(db_session, task.id)
|
|
|
|
assert result["created_annotation_count"] == 1
|
|
assert result["deleted_annotation_count"] == 1
|
|
annotations = db_session.query(Annotation).filter(Annotation.project_id == project["id"]).all()
|
|
assert len(annotations) == 1
|
|
assert annotations[0].mask_data["source"] == "sam2.1_hiera_small_propagation"
|
|
assert annotations[0].mask_data["source_annotation_id"] == 7
|
|
assert annotations[0].mask_data["polygons"] == [output_polygon]
|
|
|
|
|
|
def test_propagation_task_runner_keeps_same_class_seeds_separate(client, db_session, monkeypatch):
|
|
project = client.post("/api/projects", json={"name": "Propagation Multi Instance"}).json()
|
|
frames = [
|
|
client.post(f"/api/projects/{project['id']}/frames", json={
|
|
"project_id": project["id"],
|
|
"frame_index": idx,
|
|
"image_url": f"frames/{idx}.jpg",
|
|
"width": 640,
|
|
"height": 360,
|
|
}).json()
|
|
for idx in range(2)
|
|
]
|
|
|
|
output_by_source = {
|
|
7: [[0.10, 0.10], [0.30, 0.10], [0.30, 0.30], [0.10, 0.30]],
|
|
8: [[0.12, 0.12], [0.32, 0.12], [0.32, 0.32], [0.12, 0.32]],
|
|
}
|
|
task = ProcessingTask(
|
|
task_type="propagate_masks",
|
|
status="queued",
|
|
progress=0,
|
|
project_id=project["id"],
|
|
payload={
|
|
"project_id": project["id"],
|
|
"frame_id": frames[0]["id"],
|
|
"model": "sam2.1_hiera_tiny",
|
|
"include_source": False,
|
|
"save_annotations": True,
|
|
"steps": [
|
|
{
|
|
"direction": "forward",
|
|
"max_frames": 2,
|
|
"seed": {
|
|
"polygons": [[[0.05, 0.05], [0.15, 0.05], [0.15, 0.15]]],
|
|
"label": "胆囊",
|
|
"color": "#ff0000",
|
|
"source_annotation_id": 7,
|
|
"source_mask_id": "annotation-7",
|
|
"class_metadata": {"id": "gallbladder", "name": "胆囊"},
|
|
},
|
|
},
|
|
{
|
|
"direction": "forward",
|
|
"max_frames": 2,
|
|
"seed": {
|
|
"polygons": [[[0.65, 0.65], [0.75, 0.65], [0.75, 0.75]]],
|
|
"label": "胆囊",
|
|
"color": "#ff0000",
|
|
"source_annotation_id": 8,
|
|
"source_mask_id": "annotation-8",
|
|
"class_metadata": {"id": "gallbladder", "name": "胆囊"},
|
|
},
|
|
},
|
|
],
|
|
},
|
|
)
|
|
db_session.add(task)
|
|
db_session.commit()
|
|
db_session.refresh(task)
|
|
|
|
monkeypatch.setattr("services.propagation_task_runner.download_file", lambda object_name: b"jpeg")
|
|
monkeypatch.setattr("services.propagation_task_runner.publish_task_progress_event", lambda event_task: None)
|
|
|
|
def fake_propagate_video(model, frame_paths, source_frame_index, seed, direction, max_frames):
|
|
output_polygon = output_by_source[seed["source_annotation_id"]]
|
|
return [
|
|
{"frame_index": 0, "polygons": [seed["polygons"][0]], "scores": [0.9]},
|
|
{"frame_index": 1, "polygons": [output_polygon], "scores": [0.8]},
|
|
]
|
|
|
|
monkeypatch.setattr("services.propagation_task_runner.sam_registry.propagate_video", fake_propagate_video)
|
|
|
|
result = run_propagate_project_task(db_session, task.id)
|
|
|
|
assert result["created_annotation_count"] == 2
|
|
assert result["deleted_annotation_count"] == 0
|
|
annotations = db_session.query(Annotation).filter(Annotation.project_id == project["id"]).order_by(Annotation.id).all()
|
|
assert [annotation.mask_data["source_annotation_id"] for annotation in annotations] == [7, 8]
|
|
assert [annotation.mask_data["polygons"][0] for annotation in annotations] == [output_by_source[7], output_by_source[8]]
|
|
|
|
|
|
def test_propagation_task_runner_replaces_downstream_result_from_middle_frame_manual_seed(client, db_session, monkeypatch):
|
|
project = client.post("/api/projects", json={"name": "Propagation Middle Frame Replacement"}).json()
|
|
frames = [
|
|
client.post(f"/api/projects/{project['id']}/frames", json={
|
|
"project_id": project["id"],
|
|
"frame_index": idx,
|
|
"image_url": f"frames/{idx}.jpg",
|
|
"width": 640,
|
|
"height": 360,
|
|
}).json()
|
|
for idx in range(3)
|
|
]
|
|
|
|
old_downstream_polygon = [[0.18, 0.18], [0.28, 0.18], [0.28, 0.28]]
|
|
replacement_seed_polygon = [[0.16, 0.16], [0.26, 0.16], [0.26, 0.26]]
|
|
replacement_downstream_polygon = [[0.19, 0.19], [0.29, 0.19], [0.29, 0.29]]
|
|
db_session.add(Annotation(
|
|
project_id=project["id"],
|
|
frame_id=frames[2]["id"],
|
|
template_id=3,
|
|
mask_data={
|
|
"polygons": [old_downstream_polygon],
|
|
"label": "胆囊",
|
|
"color": "#ff0000",
|
|
"class": {"id": "c1", "name": "胆囊", "color": "#ff0000"},
|
|
"source": "sam2.1_hiera_tiny_propagation",
|
|
"propagated_from_frame_id": frames[0]["id"],
|
|
"propagation_seed_key": "annotation:7",
|
|
"propagation_seed_signature": "old-signature",
|
|
"propagation_direction": "forward",
|
|
"source_annotation_id": 7,
|
|
"source_mask_id": "annotation-7",
|
|
},
|
|
bbox=[0.18, 0.18, 0.1, 0.1],
|
|
))
|
|
db_session.commit()
|
|
|
|
task = ProcessingTask(
|
|
task_type="propagate_masks",
|
|
status="queued",
|
|
progress=0,
|
|
project_id=project["id"],
|
|
payload={
|
|
"project_id": project["id"],
|
|
"frame_id": frames[1]["id"],
|
|
"model": "sam2.1_hiera_tiny",
|
|
"include_source": False,
|
|
"save_annotations": True,
|
|
"steps": [{
|
|
"direction": "forward",
|
|
"max_frames": 2,
|
|
"seed": {
|
|
"polygons": [replacement_seed_polygon],
|
|
"label": "胆囊",
|
|
"color": "#ff0000",
|
|
"source_annotation_id": 20,
|
|
"source_mask_id": "annotation-20",
|
|
},
|
|
}],
|
|
},
|
|
)
|
|
db_session.add(task)
|
|
db_session.commit()
|
|
db_session.refresh(task)
|
|
|
|
monkeypatch.setattr("services.propagation_task_runner.download_file", lambda object_name: b"jpeg")
|
|
monkeypatch.setattr("services.propagation_task_runner.publish_task_progress_event", lambda event_task: None)
|
|
monkeypatch.setattr("services.propagation_task_runner.sam_registry.propagate_video", lambda model, frame_paths, source_frame_index, seed, direction, max_frames: [
|
|
{"frame_index": 0, "polygons": [seed["polygons"][0]], "scores": [0.9]},
|
|
{"frame_index": 1, "polygons": [replacement_downstream_polygon], "scores": [0.8]},
|
|
])
|
|
|
|
result = run_propagate_project_task(db_session, task.id)
|
|
|
|
assert result["created_annotation_count"] == 1
|
|
assert result["deleted_annotation_count"] == 1
|
|
annotations = db_session.query(Annotation).filter(Annotation.project_id == project["id"]).all()
|
|
assert len(annotations) == 1
|
|
assert annotations[0].frame_id == frames[2]["id"]
|
|
assert annotations[0].mask_data["polygons"] == [replacement_downstream_polygon]
|
|
assert annotations[0].mask_data["source_annotation_id"] == 20
|
|
assert annotations[0].mask_data["source_mask_id"] == "annotation-20"
|
|
|
|
|
|
def test_propagation_task_runner_replaces_forward_result_when_middle_frame_propagates_backward(client, db_session, monkeypatch):
|
|
project = client.post("/api/projects", json={"name": "Propagation Backward Middle Replacement"}).json()
|
|
frames = [
|
|
client.post(f"/api/projects/{project['id']}/frames", json={
|
|
"project_id": project["id"],
|
|
"frame_index": idx,
|
|
"image_url": f"frames/{idx}.jpg",
|
|
"width": 640,
|
|
"height": 360,
|
|
}).json()
|
|
for idx in range(3)
|
|
]
|
|
|
|
old_upstream_polygon = [[0.12, 0.12], [0.22, 0.12], [0.22, 0.22]]
|
|
replacement_seed_polygon = [[0.16, 0.16], [0.26, 0.16], [0.26, 0.26]]
|
|
replacement_upstream_polygon = [[0.13, 0.13], [0.23, 0.13], [0.23, 0.23]]
|
|
db_session.add(Annotation(
|
|
project_id=project["id"],
|
|
frame_id=frames[0]["id"],
|
|
mask_data={
|
|
"polygons": [old_upstream_polygon],
|
|
"label": "胆囊",
|
|
"color": "#ff0000",
|
|
"source": "sam2.1_hiera_tiny_propagation",
|
|
"propagated_from_frame_id": frames[0]["id"],
|
|
"propagation_seed_key": "annotation:7",
|
|
"propagation_seed_signature": "old-signature",
|
|
"propagation_direction": "forward",
|
|
"source_annotation_id": 7,
|
|
"source_mask_id": "annotation-7",
|
|
},
|
|
bbox=[0.12, 0.12, 0.1, 0.1],
|
|
))
|
|
db_session.commit()
|
|
|
|
task = ProcessingTask(
|
|
task_type="propagate_masks",
|
|
status="queued",
|
|
progress=0,
|
|
project_id=project["id"],
|
|
payload={
|
|
"project_id": project["id"],
|
|
"frame_id": frames[1]["id"],
|
|
"model": "sam2.1_hiera_tiny",
|
|
"include_source": False,
|
|
"save_annotations": True,
|
|
"steps": [{
|
|
"direction": "backward",
|
|
"max_frames": 2,
|
|
"seed": {
|
|
"polygons": [replacement_seed_polygon],
|
|
"label": "胆囊",
|
|
"color": "#ff0000",
|
|
"source_annotation_id": 20,
|
|
"source_mask_id": "annotation-20",
|
|
},
|
|
}],
|
|
},
|
|
)
|
|
db_session.add(task)
|
|
db_session.commit()
|
|
db_session.refresh(task)
|
|
|
|
monkeypatch.setattr("services.propagation_task_runner.download_file", lambda object_name: b"jpeg")
|
|
monkeypatch.setattr("services.propagation_task_runner.publish_task_progress_event", lambda event_task: None)
|
|
monkeypatch.setattr("services.propagation_task_runner.sam_registry.propagate_video", lambda model, frame_paths, source_frame_index, seed, direction, max_frames: [
|
|
{"frame_index": 0, "polygons": [replacement_upstream_polygon], "scores": [0.8]},
|
|
{"frame_index": 1, "polygons": [seed["polygons"][0]], "scores": [0.9]},
|
|
])
|
|
|
|
result = run_propagate_project_task(db_session, task.id)
|
|
|
|
assert result["created_annotation_count"] == 1
|
|
assert result["deleted_annotation_count"] == 1
|
|
annotations = db_session.query(Annotation).filter(Annotation.project_id == project["id"]).all()
|
|
assert len(annotations) == 1
|
|
assert annotations[0].frame_id == frames[0]["id"]
|
|
assert annotations[0].mask_data["polygons"] == [replacement_upstream_polygon]
|
|
assert annotations[0].mask_data["propagation_direction"] == "backward"
|
|
assert annotations[0].mask_data["source_annotation_id"] == 20
|
|
|
|
|
|
def test_propagation_task_runner_skips_unmodified_propagated_seed_on_overlapping_frames(client, db_session, monkeypatch):
|
|
project = client.post("/api/projects", json={"name": "Propagation Overlap Skip"}).json()
|
|
frames = [
|
|
client.post(f"/api/projects/{project['id']}/frames", json={
|
|
"project_id": project["id"],
|
|
"frame_index": idx,
|
|
"image_url": f"frames/{idx}.jpg",
|
|
"width": 640,
|
|
"height": 360,
|
|
}).json()
|
|
for idx in range(3)
|
|
]
|
|
|
|
original_seed_polygon = [[0.1, 0.1], [0.2, 0.1], [0.2, 0.2]]
|
|
propagated_seed_polygon = [[0.14, 0.14], [0.24, 0.14], [0.24, 0.24]]
|
|
downstream_polygon = [[0.18, 0.18], [0.28, 0.18], [0.28, 0.28]]
|
|
inherited_signature = _seed_signature({
|
|
"polygons": [original_seed_polygon],
|
|
"label": "胆囊",
|
|
"color": "#ff0000",
|
|
"source_annotation_id": 7,
|
|
"source_mask_id": "annotation-7",
|
|
})
|
|
|
|
db_session.add(Annotation(
|
|
project_id=project["id"],
|
|
frame_id=frames[2]["id"],
|
|
mask_data={
|
|
"polygons": [downstream_polygon],
|
|
"label": "胆囊",
|
|
"color": "#ff0000",
|
|
"source": "sam2.1_hiera_tiny_propagation",
|
|
"propagated_from_frame_id": frames[0]["id"],
|
|
"propagation_seed_key": "annotation:7",
|
|
"propagation_seed_signature": inherited_signature,
|
|
"propagation_direction": "forward",
|
|
"source_annotation_id": 7,
|
|
"source_mask_id": "annotation-7",
|
|
},
|
|
bbox=[0.18, 0.18, 0.1, 0.1],
|
|
))
|
|
db_session.commit()
|
|
|
|
task = ProcessingTask(
|
|
task_type="propagate_masks",
|
|
status="queued",
|
|
progress=0,
|
|
project_id=project["id"],
|
|
payload={
|
|
"project_id": project["id"],
|
|
"frame_id": frames[1]["id"],
|
|
"model": "sam2.1_hiera_tiny",
|
|
"include_source": False,
|
|
"save_annotations": True,
|
|
"steps": [{
|
|
"direction": "forward",
|
|
"max_frames": 2,
|
|
"seed": {
|
|
"polygons": [propagated_seed_polygon],
|
|
"label": "胆囊",
|
|
"color": "#ff0000",
|
|
"source_annotation_id": 7,
|
|
"source_mask_id": "annotation-7",
|
|
"propagation_seed_signature": inherited_signature,
|
|
},
|
|
}],
|
|
},
|
|
)
|
|
db_session.add(task)
|
|
db_session.commit()
|
|
db_session.refresh(task)
|
|
|
|
propagate_calls = []
|
|
monkeypatch.setattr("services.propagation_task_runner.download_file", lambda object_name: b"jpeg")
|
|
monkeypatch.setattr("services.propagation_task_runner.publish_task_progress_event", lambda event_task: None)
|
|
monkeypatch.setattr("services.propagation_task_runner.sam_registry.propagate_video", lambda *args, **kwargs: propagate_calls.append(args) or [])
|
|
|
|
result = run_propagate_project_task(db_session, task.id)
|
|
|
|
assert result["created_annotation_count"] == 0
|
|
assert result["deleted_annotation_count"] == 0
|
|
assert result["skipped_seed_count"] == 1
|
|
assert propagate_calls == []
|
|
annotations = db_session.query(Annotation).filter(Annotation.project_id == project["id"]).all()
|
|
assert len(annotations) == 1
|
|
assert annotations[0].mask_data["polygons"] == [downstream_polygon]
|
|
|
|
|
|
def test_predict_validation_errors(client, monkeypatch):
|
|
project, _, _ = _create_project_and_frame(client)
|
|
|
|
assert client.post("/api/ai/predict", json={
|
|
"image_id": 999,
|
|
"prompt_type": "point",
|
|
"prompt_data": [[0.5, 0.5]],
|
|
}).status_code == 404
|
|
|
|
frame = client.post(f"/api/projects/{project['id']}/frames", json={
|
|
"project_id": project["id"],
|
|
"frame_index": 1,
|
|
"image_url": "frames/1.jpg",
|
|
}).json()
|
|
monkeypatch.setattr("routers.ai._load_frame_image", lambda frame: np.zeros((10, 10, 3), dtype=np.uint8))
|
|
assert client.post("/api/ai/predict", json={
|
|
"image_id": frame["id"],
|
|
"prompt_type": "box",
|
|
"prompt_data": [0.1, 0.2],
|
|
}).status_code == 400
|
|
|
|
|
|
def test_save_annotation_validates_project_and_frame(client):
|
|
project, frame, template = _create_project_and_frame(client)
|
|
|
|
saved = client.post("/api/ai/annotate", json={
|
|
"project_id": project["id"],
|
|
"frame_id": frame["id"],
|
|
"template_id": template["id"],
|
|
"mask_data": {"polygons": [[[0.1, 0.1], [0.9, 0.1], [0.9, 0.9]]]},
|
|
"points": [[0.5, 0.5]],
|
|
"bbox": [0.1, 0.1, 0.8, 0.8],
|
|
})
|
|
assert saved.status_code == 201
|
|
assert saved.json()["project_id"] == project["id"]
|
|
|
|
listing = client.get(f"/api/ai/annotations?project_id={project['id']}")
|
|
assert listing.status_code == 200
|
|
assert listing.json()[0]["id"] == saved.json()["id"]
|
|
|
|
frame_listing = client.get(f"/api/ai/annotations?project_id={project['id']}&frame_id={frame['id']}")
|
|
assert frame_listing.status_code == 200
|
|
assert len(frame_listing.json()) == 1
|
|
|
|
missing_project = client.post("/api/ai/annotate", json={"project_id": 999})
|
|
assert missing_project.status_code == 404
|
|
|
|
missing_frame = client.post("/api/ai/annotate", json={
|
|
"project_id": project["id"],
|
|
"frame_id": 999,
|
|
})
|
|
assert missing_frame.status_code == 404
|
|
|
|
missing_project_list = client.get("/api/ai/annotations?project_id=999")
|
|
assert missing_project_list.status_code == 404
|
|
|
|
|
|
def test_update_and_delete_annotation(client):
|
|
project, frame, template = _create_project_and_frame(client)
|
|
saved = client.post("/api/ai/annotate", json={
|
|
"project_id": project["id"],
|
|
"frame_id": frame["id"],
|
|
"template_id": template["id"],
|
|
"mask_data": {
|
|
"polygons": [[[0.1, 0.1], [0.9, 0.1], [0.9, 0.9]]],
|
|
"label": "AI Mask",
|
|
"color": "#06b6d4",
|
|
},
|
|
"points": [[0.5, 0.5]],
|
|
"bbox": [0.1, 0.1, 0.8, 0.8],
|
|
}).json()
|
|
|
|
updated = client.patch(f"/api/ai/annotations/{saved['id']}", json={
|
|
"template_id": template["id"],
|
|
"mask_data": {
|
|
"polygons": [[[0.2, 0.2], [0.8, 0.2], [0.8, 0.8]]],
|
|
"label": "胆囊",
|
|
"color": "#ff0000",
|
|
"class": {"id": "c1", "name": "胆囊", "color": "#ff0000", "zIndex": 20},
|
|
},
|
|
"points": [[0.4, 0.4]],
|
|
"bbox": [0.2, 0.2, 0.6, 0.6],
|
|
})
|
|
|
|
assert updated.status_code == 200
|
|
body = updated.json()
|
|
assert body["mask_data"]["label"] == "胆囊"
|
|
assert body["mask_data"]["class"]["id"] == "c1"
|
|
assert body["points"] == [[0.4, 0.4]]
|
|
assert body["bbox"] == [0.2, 0.2, 0.6, 0.6]
|
|
|
|
listing = client.get(f"/api/ai/annotations?project_id={project['id']}")
|
|
assert listing.status_code == 200
|
|
assert listing.json()[0]["mask_data"]["class"]["name"] == "胆囊"
|
|
|
|
deleted = client.delete(f"/api/ai/annotations/{saved['id']}")
|
|
assert deleted.status_code == 204
|
|
|
|
empty_listing = client.get(f"/api/ai/annotations?project_id={project['id']}")
|
|
assert empty_listing.status_code == 200
|
|
assert empty_listing.json() == []
|
|
|
|
|
|
def test_update_and_delete_annotation_validation(client):
|
|
project, frame, template = _create_project_and_frame(client)
|
|
saved = client.post("/api/ai/annotate", json={
|
|
"project_id": project["id"],
|
|
"frame_id": frame["id"],
|
|
"template_id": template["id"],
|
|
}).json()
|
|
|
|
assert client.patch("/api/ai/annotations/999", json={"bbox": [0, 0, 1, 1]}).status_code == 404
|
|
assert client.delete("/api/ai/annotations/999").status_code == 404
|
|
assert client.patch(
|
|
f"/api/ai/annotations/{saved['id']}",
|
|
json={"template_id": 999},
|
|
).status_code == 404
|
|
|
|
|
|
def test_import_gt_mask_creates_annotations_with_seed_points(client):
|
|
project, frame, template = _create_project_and_frame(client)
|
|
mask = np.zeros((360, 640), dtype=np.uint8)
|
|
cv2.rectangle(mask, (100, 80), (260, 220), 255, thickness=-1)
|
|
ok, encoded = cv2.imencode(".png", mask)
|
|
assert ok
|
|
|
|
response = client.post(
|
|
"/api/ai/import-gt-mask",
|
|
data={
|
|
"project_id": str(project["id"]),
|
|
"frame_id": str(frame["id"]),
|
|
"template_id": str(template["id"]),
|
|
"label": "Imported GT",
|
|
"color": "#22c55e",
|
|
},
|
|
files={"file": ("mask.png", encoded.tobytes(), "image/png")},
|
|
)
|
|
|
|
assert response.status_code == 201
|
|
body = response.json()
|
|
assert len(body) == 1
|
|
assert body[0]["project_id"] == project["id"]
|
|
assert body[0]["frame_id"] == frame["id"]
|
|
assert body[0]["template_id"] == template["id"]
|
|
assert body[0]["mask_data"]["label"] == "Imported GT"
|
|
assert body[0]["mask_data"]["source"] == "gt_mask"
|
|
assert body[0]["mask_data"]["gt_label_value"] == 255
|
|
assert len(body[0]["mask_data"]["polygons"][0]) >= 3
|
|
assert len(body[0]["points"]) == 1
|
|
assert 0.0 <= body[0]["points"][0][0] <= 1.0
|
|
assert 0.0 <= body[0]["points"][0][1] <= 1.0
|
|
|
|
|
|
def test_import_gt_mask_polygons_work_with_analysis_and_smoothing(client):
|
|
project, frame, _ = _create_project_and_frame(client)
|
|
mask = np.zeros((360, 640), dtype=np.uint8)
|
|
cv2.ellipse(mask, (260, 160), (130, 70), 20, 0, 360, 1, thickness=-1)
|
|
ok, encoded = cv2.imencode(".png", mask)
|
|
assert ok
|
|
|
|
response = client.post(
|
|
"/api/ai/import-gt-mask",
|
|
data={
|
|
"project_id": str(project["id"]),
|
|
"frame_id": str(frame["id"]),
|
|
"label": "Imported GT",
|
|
"color": "#22c55e",
|
|
},
|
|
files={"file": ("mask.png", encoded.tobytes(), "image/png")},
|
|
)
|
|
|
|
assert response.status_code == 201
|
|
annotation = response.json()[0]
|
|
assert annotation["mask_data"]["source"] == "gt_mask"
|
|
|
|
analysis = client.post("/api/ai/analyze-mask", json={
|
|
"frame_id": frame["id"],
|
|
"mask_data": annotation["mask_data"],
|
|
"points": annotation["points"],
|
|
"bbox": annotation["bbox"],
|
|
})
|
|
assert analysis.status_code == 200
|
|
assert analysis.json()["topology_anchor_count"] == len(annotation["mask_data"]["polygons"][0])
|
|
|
|
smoothing = client.post("/api/ai/smooth-mask", json={
|
|
"frame_id": frame["id"],
|
|
"mask_data": annotation["mask_data"],
|
|
"points": annotation["points"],
|
|
"bbox": annotation["bbox"],
|
|
"strength": 35,
|
|
})
|
|
assert smoothing.status_code == 200
|
|
assert smoothing.json()["topology_anchor_count"] == len(smoothing.json()["polygons"][0])
|
|
|
|
|
|
def test_import_gt_mask_preserves_detailed_contours(client):
|
|
project, frame, _ = _create_project_and_frame(client)
|
|
mask = np.zeros((360, 640), dtype=np.uint8)
|
|
center = np.array([320, 180])
|
|
vertices = []
|
|
for index in range(96):
|
|
angle = 2 * np.pi * index / 96
|
|
radius = 120 if index % 2 == 0 else 88
|
|
vertices.append([
|
|
int(center[0] + np.cos(angle) * radius),
|
|
int(center[1] + np.sin(angle) * radius),
|
|
])
|
|
cv2.fillPoly(mask, [np.array(vertices, dtype=np.int32)], 1)
|
|
ok, encoded = cv2.imencode(".png", mask)
|
|
assert ok
|
|
|
|
response = client.post(
|
|
"/api/ai/import-gt-mask",
|
|
data={
|
|
"project_id": str(project["id"]),
|
|
"frame_id": str(frame["id"]),
|
|
"label": "Detailed GT",
|
|
"color": "#22c55e",
|
|
},
|
|
files={"file": ("mask.png", encoded.tobytes(), "image/png")},
|
|
)
|
|
|
|
assert response.status_code == 201
|
|
polygon = response.json()[0]["mask_data"]["polygons"][0]
|
|
assert len(polygon) > 80
|
|
assert len(polygon) <= 2048
|
|
|
|
|
|
def test_import_gt_mask_splits_label_values(client):
|
|
project, frame, _ = _create_project_and_frame(client)
|
|
mask = np.zeros((360, 640), dtype=np.uint8)
|
|
cv2.rectangle(mask, (20, 20), (120, 120), 1, thickness=-1)
|
|
cv2.rectangle(mask, (220, 80), (320, 180), 2, thickness=-1)
|
|
ok, encoded = cv2.imencode(".png", mask)
|
|
assert ok
|
|
|
|
response = client.post(
|
|
"/api/ai/import-gt-mask",
|
|
data={
|
|
"project_id": str(project["id"]),
|
|
"frame_id": str(frame["id"]),
|
|
"label": "GT Class",
|
|
},
|
|
files={"file": ("labels.png", encoded.tobytes(), "image/png")},
|
|
)
|
|
|
|
assert response.status_code == 201
|
|
body = sorted(response.json(), key=lambda item: item["mask_data"]["gt_label_value"])
|
|
assert [item["mask_data"]["gt_label_value"] for item in body] == [1, 2]
|
|
assert [item["mask_data"]["label"] for item in body] == ["GT Class 1", "GT Class 2"]
|
|
assert all(len(item["points"]) == 1 for item in body)
|
|
|
|
|
|
def test_import_gt_mask_rejects_background_only_label_image(client):
|
|
project, frame, _ = _create_project_and_frame(client)
|
|
mask = np.zeros((360, 640), dtype=np.uint8)
|
|
ok, encoded = cv2.imencode(".png", mask)
|
|
assert ok
|
|
|
|
response = client.post(
|
|
"/api/ai/import-gt-mask",
|
|
data={
|
|
"project_id": str(project["id"]),
|
|
"frame_id": str(frame["id"]),
|
|
"label": "GT Class",
|
|
},
|
|
files={"file": ("empty-gt-label.png", encoded.tobytes(), "image/png")},
|
|
)
|
|
|
|
assert response.status_code == 400
|
|
assert response.json()["detail"] == "GT Mask 图片中没有非背景 maskid 区域。"
|
|
|
|
|
|
def test_import_gt_mask_accepts_uint8_low_value_gtlabel_png(client):
|
|
project, frame, _ = _create_project_and_frame(client)
|
|
template = client.post("/api/templates", json={
|
|
"name": "GTLabel Template",
|
|
"color": "#06b6d4",
|
|
"z_index": 0,
|
|
"classes": [
|
|
{"id": "tumor", "name": "肿瘤", "color": "#ff0000", "zIndex": 10, "maskId": 1},
|
|
],
|
|
"rules": [],
|
|
}).json()
|
|
mask = np.zeros((360, 640), dtype=np.uint8)
|
|
cv2.rectangle(mask, (40, 40), (140, 140), 1, thickness=-1)
|
|
ok, encoded = cv2.imencode(".png", mask)
|
|
assert ok
|
|
|
|
response = client.post(
|
|
"/api/ai/import-gt-mask",
|
|
data={
|
|
"project_id": str(project["id"]),
|
|
"frame_id": str(frame["id"]),
|
|
"template_id": str(template["id"]),
|
|
"unknown_color_policy": "discard",
|
|
},
|
|
files={"file": ("GT_label.png", encoded.tobytes(), "image/png")},
|
|
)
|
|
|
|
assert response.status_code == 201
|
|
body = response.json()
|
|
assert len(body) == 1
|
|
assert body[0]["mask_data"]["gt_label_value"] == 1
|
|
assert body[0]["mask_data"]["class"]["name"] == "肿瘤"
|
|
assert body[0]["mask_data"]["class"]["maskId"] == 1
|
|
|
|
|
|
def test_import_gt_mask_rejects_rgb_color_masks(client):
|
|
project, frame, _ = _create_project_and_frame(client)
|
|
template = client.post("/api/templates", json={
|
|
"name": "Color Template",
|
|
"color": "#06b6d4",
|
|
"z_index": 0,
|
|
"classes": [
|
|
{"id": "known", "name": "已知类别", "color": "#ff0000", "zIndex": 10, "maskId": 1},
|
|
],
|
|
"rules": [],
|
|
}).json()
|
|
mask = np.zeros((80, 120, 3), dtype=np.uint8)
|
|
mask[10:40, 10:40] = [0, 0, 255] # BGR red -> #ff0000
|
|
mask[40:70, 70:110] = [0, 255, 0] # BGR green -> unknown #00ff00
|
|
ok, encoded = cv2.imencode(".png", mask)
|
|
assert ok
|
|
|
|
response = client.post(
|
|
"/api/ai/import-gt-mask",
|
|
data={
|
|
"project_id": str(project["id"]),
|
|
"frame_id": str(frame["id"]),
|
|
"template_id": str(template["id"]),
|
|
"unknown_color_policy": "discard",
|
|
},
|
|
files={"file": ("color-mask.png", encoded.tobytes(), "image/png")},
|
|
)
|
|
|
|
assert response.status_code == 400
|
|
assert "RGB 三通道完全相同" in response.json()["detail"]
|
|
|
|
|
|
def test_import_gt_mask_rejects_uint16_gt_label(client):
|
|
project, frame, _ = _create_project_and_frame(client)
|
|
template = client.post("/api/templates", json={
|
|
"name": "Label Template",
|
|
"color": "#06b6d4",
|
|
"z_index": 0,
|
|
"classes": [{"id": "tumor", "name": "肿瘤", "color": "#ff0000", "zIndex": 10, "maskId": 1}],
|
|
"rules": [],
|
|
}).json()
|
|
mask = np.zeros((360, 640), dtype=np.uint16)
|
|
cv2.rectangle(mask, (20, 20), (120, 120), 1, thickness=-1)
|
|
ok, encoded = cv2.imencode(".png", mask)
|
|
assert ok
|
|
|
|
response = client.post(
|
|
"/api/ai/import-gt-mask",
|
|
data={
|
|
"project_id": str(project["id"]),
|
|
"frame_id": str(frame["id"]),
|
|
"template_id": str(template["id"]),
|
|
"unknown_color_policy": "discard",
|
|
},
|
|
files={"file": ("gt_label.png", encoded.tobytes(), "image/png")},
|
|
)
|
|
|
|
assert response.status_code == 400
|
|
assert "仅支持 8-bit" in response.json()["detail"]
|
|
|
|
|
|
def test_import_gt_mask_handles_unknown_maskid_policy_and_resizes_to_frame(client):
|
|
project, frame, _ = _create_project_and_frame(client)
|
|
template = client.post("/api/templates", json={
|
|
"name": "Color Template",
|
|
"color": "#06b6d4",
|
|
"z_index": 0,
|
|
"classes": [{"id": "known", "name": "已定义", "color": "#ff0000", "zIndex": 10, "maskId": 1}],
|
|
"rules": [],
|
|
}).json()
|
|
mask = np.zeros((90, 160, 3), dtype=np.uint8)
|
|
cv2.rectangle(mask, (5, 5), (40, 40), (1, 1, 1), thickness=-1)
|
|
cv2.rectangle(mask, (80, 5), (120, 40), (2, 2, 2), thickness=-1)
|
|
ok, encoded = cv2.imencode(".png", mask)
|
|
assert ok
|
|
|
|
discard_response = client.post(
|
|
"/api/ai/import-gt-mask",
|
|
data={
|
|
"project_id": str(project["id"]),
|
|
"frame_id": str(frame["id"]),
|
|
"template_id": str(template["id"]),
|
|
"unknown_color_policy": "discard",
|
|
},
|
|
files={"file": ("colors.png", encoded.tobytes(), "image/png")},
|
|
)
|
|
|
|
assert discard_response.status_code == 201
|
|
assert [item["mask_data"]["label"] for item in discard_response.json()] == ["已定义"]
|
|
assert discard_response.json()[0]["mask_data"]["gt_original_size"] == {"width": 160, "height": 90}
|
|
assert discard_response.json()[0]["mask_data"]["gt_resized_to_frame"] is True
|
|
assert discard_response.json()[0]["mask_data"]["image_size"] == {"width": 640, "height": 360}
|
|
|
|
undefined_response = client.post(
|
|
"/api/ai/import-gt-mask",
|
|
data={
|
|
"project_id": str(project["id"]),
|
|
"frame_id": str(frame["id"]),
|
|
"template_id": str(template["id"]),
|
|
"unknown_color_policy": "undefined",
|
|
},
|
|
files={"file": ("colors.png", encoded.tobytes(), "image/png")},
|
|
)
|
|
|
|
assert undefined_response.status_code == 201
|
|
labels = {item["mask_data"]["label"] for item in undefined_response.json()}
|
|
assert labels == {"已定义", "未定义类别 2"}
|
|
unknown = next(item for item in undefined_response.json() if item["mask_data"]["label"].startswith("未定义"))
|
|
assert unknown["mask_data"]["gt_unknown_class"] is True
|
|
assert unknown["mask_data"]["gt_label_value"] == 2
|
|
assert unknown["mask_data"]["gt_resized_to_frame"] is True
|