完善项目导入、模板与分割工作区交互
- 增强 DICOM/视频项目导入与演示数据:DICOM 按文件名自然顺序处理,导入后展示上传与解析任务进度,恢复演示出厂设置保留演示视频和演示 DICOM 项目,并补充 demo media seed 逻辑。 - 完善项目管理:项目支持重命名、删除、复制,删除使用站内确认弹窗,复制支持新项目重置和全内容复制,DICOM 项目不显示生成帧入口。 - 完善 GT Mask 与导出链路:只支持 8-bit maskid 图导入,非法/全背景图明确拒绝,尺寸自动适配,高精度 polygon 回显;统一导出默认当前帧,GT_label 使用 uint8 和真实 maskid,待分类 maskid 0 与背景一致。 - 完善分割工作区交互:新增画笔和橡皮擦并支持尺寸控制,移除创建点/线段入口,工具栏按类别分隔,AI 智能分割使用明确 AI 图标,取消黄色 seed point,清空/删除传播 mask 后同步清理空帧时间轴状态。 - 完善传播与时间轴:自动传播使用 SAM 2.1 权重任务,参考帧无遮罩时提示,传播历史按同一蓝色系递进变暗,删除/清空传播链时保留人工或独立 AI 标注来源。 - 完善模板库:新增头颈部 CT 分割默认模板,所有模板保留 maskid 0 待分类,支持鼠标复制模板、拖拽层级、JSON 批量导入预览、删除 label 和站内删除确认。 - 完善用户与高风险确认:用户改密码、删除用户、恢复演示出厂设置和清空人工/AI 标注帧均改为站内确认交互,避免浏览器原生 prompt/confirm。 - 补充前后端测试与文档:更新项目、模板、GT 导入、导出、传播、DICOM、用户管理等测试,并同步 README、AGENTS 和 doc 下实现/契约/测试计划文档。
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
from models import Annotation, AuditLog, Frame, Mask, ProcessingTask, Project, Template, User
|
||||
from routers.auth import create_access_token, hash_password
|
||||
from statuses import PROJECT_STATUS_PENDING
|
||||
from statuses import PROJECT_STATUS_READY
|
||||
|
||||
|
||||
def test_admin_user_management_and_audit_logs(client, db_session):
|
||||
@@ -83,17 +83,34 @@ def test_admin_cannot_delete_self_or_user_with_projects(client, db_session):
|
||||
assert response.status_code == 409
|
||||
|
||||
|
||||
def test_demo_factory_reset_leaves_admin_and_unparsed_demo_video(client, db_session, monkeypatch, tmp_path):
|
||||
def test_demo_factory_reset_leaves_admin_and_parsed_demo_dicom(client, db_session, monkeypatch, tmp_path):
|
||||
video_path = tmp_path / "Data_MyVideo_1.mp4"
|
||||
video_path.write_bytes(b"demo-video")
|
||||
monkeypatch.setattr("routers.admin.settings.demo_video_path", str(video_path))
|
||||
dicom_dir = tmp_path / "dicom"
|
||||
dicom_dir.mkdir()
|
||||
for name in ["10.dcm", "2.dcm", "1.dcm"]:
|
||||
(dicom_dir / name).write_bytes(name.encode())
|
||||
monkeypatch.setattr("routers.admin.settings.demo_dicom_dir", str(dicom_dir))
|
||||
|
||||
parsed_frame_paths = []
|
||||
for idx in range(3):
|
||||
frame_path = tmp_path / f"frame_{idx:06d}.jpg"
|
||||
frame_path.write_bytes(b"frame")
|
||||
parsed_frame_paths.append(str(frame_path))
|
||||
|
||||
uploaded = []
|
||||
monkeypatch.setattr("routers.admin.upload_file", lambda object_name, data, content_type, length: uploaded.append({
|
||||
monkeypatch.setattr("services.demo_media.upload_file", lambda object_name, data, content_type, length: uploaded.append({
|
||||
"object_name": object_name,
|
||||
"data": data,
|
||||
"content_type": content_type,
|
||||
"length": length,
|
||||
}))
|
||||
monkeypatch.setattr("services.demo_media.parse_dicom", lambda dicom_dir_arg, output_dir: parsed_frame_paths)
|
||||
monkeypatch.setattr(
|
||||
"services.demo_media.upload_frames_to_minio",
|
||||
lambda frame_files, project_id: [f"projects/{project_id}/frames/frame_{idx:06d}.jpg" for idx, _ in enumerate(frame_files)],
|
||||
)
|
||||
|
||||
extra_user = User(username="doctor", password_hash=hash_password("secret123"), role="annotator", is_active=1)
|
||||
db_session.add(extra_user)
|
||||
@@ -113,7 +130,15 @@ def test_demo_factory_reset_leaves_admin_and_unparsed_demo_video(client, db_sess
|
||||
z_index=1,
|
||||
owner_user_id=extra_user.id,
|
||||
)
|
||||
db_session.add_all([task, private_template])
|
||||
system_template = Template(
|
||||
name="头颈部CT分割",
|
||||
description="头颈部CT分割",
|
||||
color="#ef4444",
|
||||
z_index=10,
|
||||
owner_user_id=None,
|
||||
mapping_rules={"classes": [{"name": "肿瘤/结节 (Tumor/Nodule)", "color": "#ff0000", "maskId": 1}], "rules": []},
|
||||
)
|
||||
db_session.add_all([task, private_template, system_template])
|
||||
db_session.commit()
|
||||
db_session.refresh(frame)
|
||||
annotation = Annotation(project_id=old_project.id, frame_id=frame.id, mask_data={"label": "old"})
|
||||
@@ -130,24 +155,36 @@ def test_demo_factory_reset_leaves_admin_and_unparsed_demo_video(client, db_sess
|
||||
data = response.json()
|
||||
assert data["message"] == "演示环境已恢复出厂设置"
|
||||
assert data["admin_user"]["username"] == "admin"
|
||||
assert data["project"]["name"] == "Data_MyVideo_1"
|
||||
assert data["project"]["status"] == PROJECT_STATUS_PENDING
|
||||
assert data["project"]["frame_count"] == 0
|
||||
assert data["project"]["video_path"] == f"uploads/{data['project']['id']}/Data_MyVideo_1.mp4"
|
||||
assert uploaded == [{
|
||||
"object_name": data["project"]["video_path"],
|
||||
"data": b"demo-video",
|
||||
"content_type": "video/mp4",
|
||||
"length": len(b"demo-video"),
|
||||
}]
|
||||
assert data["project"]["name"] == "演示DICOM序列"
|
||||
assert data["project"]["status"] == PROJECT_STATUS_READY
|
||||
assert data["project"]["source_type"] == "dicom"
|
||||
assert data["project"]["frame_count"] == 3
|
||||
assert data["project"]["video_path"] == f"uploads/{data['project']['id']}/dicom"
|
||||
assert [project["name"] for project in data["projects"]] == ["Data_MyVideo_1", "演示DICOM序列"]
|
||||
assert data["projects"][0]["status"] == "pending"
|
||||
assert data["projects"][0]["source_type"] == "video"
|
||||
assert data["projects"][0]["frame_count"] == 0
|
||||
assert data["projects"][1]["status"] == PROJECT_STATUS_READY
|
||||
assert data["projects"][1]["source_type"] == "dicom"
|
||||
assert data["projects"][1]["frame_count"] == 3
|
||||
assert [item["object_name"] for item in uploaded] == [
|
||||
f"uploads/{data['projects'][0]['id']}/Data_MyVideo_1.mp4",
|
||||
f"uploads/{data['project']['id']}/dicom/1.dcm",
|
||||
f"uploads/{data['project']['id']}/dicom/2.dcm",
|
||||
f"uploads/{data['project']['id']}/dicom/10.dcm",
|
||||
]
|
||||
assert [item["content_type"] for item in uploaded] == ["video/mp4", "application/dicom", "application/dicom", "application/dicom"]
|
||||
|
||||
assert [user.username for user in db_session.query(User).all()] == ["admin"]
|
||||
assert db_session.query(Project).count() == 1
|
||||
assert db_session.query(Frame).count() == 0
|
||||
assert db_session.query(Project).count() == 2
|
||||
assert db_session.query(Frame).count() == 3
|
||||
assert [frame.source_frame_number for frame in db_session.query(Frame).order_by(Frame.frame_index).all()] == [0, 1, 2]
|
||||
assert db_session.query(Annotation).count() == 0
|
||||
assert db_session.query(Mask).count() == 0
|
||||
assert db_session.query(ProcessingTask).count() == 0
|
||||
assert db_session.query(Template).filter(Template.owner_user_id.is_not(None)).count() == 0
|
||||
preserved_templates = db_session.query(Template).filter(Template.owner_user_id.is_(None)).all()
|
||||
assert [template.name for template in preserved_templates] == ["头颈部CT分割"]
|
||||
assert db_session.query(AuditLog).count() == 1
|
||||
assert db_session.query(AuditLog).first().action == "admin.demo_factory_reset"
|
||||
|
||||
|
||||
@@ -1149,6 +1149,81 @@ def test_import_gt_mask_creates_annotations_with_seed_points(client):
|
||||
assert 0.0 <= body[0]["points"][0][1] <= 1.0
|
||||
|
||||
|
||||
def test_import_gt_mask_polygons_work_with_analysis_and_smoothing(client):
|
||||
project, frame, _ = _create_project_and_frame(client)
|
||||
mask = np.zeros((360, 640), dtype=np.uint8)
|
||||
cv2.ellipse(mask, (260, 160), (130, 70), 20, 0, 360, 1, thickness=-1)
|
||||
ok, encoded = cv2.imencode(".png", mask)
|
||||
assert ok
|
||||
|
||||
response = client.post(
|
||||
"/api/ai/import-gt-mask",
|
||||
data={
|
||||
"project_id": str(project["id"]),
|
||||
"frame_id": str(frame["id"]),
|
||||
"label": "Imported GT",
|
||||
"color": "#22c55e",
|
||||
},
|
||||
files={"file": ("mask.png", encoded.tobytes(), "image/png")},
|
||||
)
|
||||
|
||||
assert response.status_code == 201
|
||||
annotation = response.json()[0]
|
||||
assert annotation["mask_data"]["source"] == "gt_mask"
|
||||
|
||||
analysis = client.post("/api/ai/analyze-mask", json={
|
||||
"frame_id": frame["id"],
|
||||
"mask_data": annotation["mask_data"],
|
||||
"points": annotation["points"],
|
||||
"bbox": annotation["bbox"],
|
||||
})
|
||||
assert analysis.status_code == 200
|
||||
assert analysis.json()["topology_anchor_count"] == len(annotation["mask_data"]["polygons"][0])
|
||||
|
||||
smoothing = client.post("/api/ai/smooth-mask", json={
|
||||
"frame_id": frame["id"],
|
||||
"mask_data": annotation["mask_data"],
|
||||
"points": annotation["points"],
|
||||
"bbox": annotation["bbox"],
|
||||
"strength": 35,
|
||||
})
|
||||
assert smoothing.status_code == 200
|
||||
assert smoothing.json()["topology_anchor_count"] == len(smoothing.json()["polygons"][0])
|
||||
|
||||
|
||||
def test_import_gt_mask_preserves_detailed_contours(client):
|
||||
project, frame, _ = _create_project_and_frame(client)
|
||||
mask = np.zeros((360, 640), dtype=np.uint8)
|
||||
center = np.array([320, 180])
|
||||
vertices = []
|
||||
for index in range(96):
|
||||
angle = 2 * np.pi * index / 96
|
||||
radius = 120 if index % 2 == 0 else 88
|
||||
vertices.append([
|
||||
int(center[0] + np.cos(angle) * radius),
|
||||
int(center[1] + np.sin(angle) * radius),
|
||||
])
|
||||
cv2.fillPoly(mask, [np.array(vertices, dtype=np.int32)], 1)
|
||||
ok, encoded = cv2.imencode(".png", mask)
|
||||
assert ok
|
||||
|
||||
response = client.post(
|
||||
"/api/ai/import-gt-mask",
|
||||
data={
|
||||
"project_id": str(project["id"]),
|
||||
"frame_id": str(frame["id"]),
|
||||
"label": "Detailed GT",
|
||||
"color": "#22c55e",
|
||||
},
|
||||
files={"file": ("mask.png", encoded.tobytes(), "image/png")},
|
||||
)
|
||||
|
||||
assert response.status_code == 201
|
||||
polygon = response.json()[0]["mask_data"]["polygons"][0]
|
||||
assert len(polygon) > 80
|
||||
assert len(polygon) <= 2048
|
||||
|
||||
|
||||
def test_import_gt_mask_splits_label_values(client):
|
||||
project, frame, _ = _create_project_and_frame(client)
|
||||
mask = np.zeros((360, 640), dtype=np.uint8)
|
||||
@@ -1174,7 +1249,27 @@ def test_import_gt_mask_splits_label_values(client):
|
||||
assert all(len(item["points"]) == 1 for item in body)
|
||||
|
||||
|
||||
def test_import_gt_mask_preserves_low_value_gtlabel_png(client):
|
||||
def test_import_gt_mask_rejects_background_only_label_image(client):
|
||||
project, frame, _ = _create_project_and_frame(client)
|
||||
mask = np.zeros((360, 640), dtype=np.uint8)
|
||||
ok, encoded = cv2.imencode(".png", mask)
|
||||
assert ok
|
||||
|
||||
response = client.post(
|
||||
"/api/ai/import-gt-mask",
|
||||
data={
|
||||
"project_id": str(project["id"]),
|
||||
"frame_id": str(frame["id"]),
|
||||
"label": "GT Class",
|
||||
},
|
||||
files={"file": ("empty-gt-label.png", encoded.tobytes(), "image/png")},
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert response.json()["detail"] == "GT Mask 图片中没有非背景 maskid 区域。"
|
||||
|
||||
|
||||
def test_import_gt_mask_accepts_uint8_low_value_gtlabel_png(client):
|
||||
project, frame, _ = _create_project_and_frame(client)
|
||||
template = client.post("/api/templates", json={
|
||||
"name": "GTLabel Template",
|
||||
@@ -1185,7 +1280,7 @@ def test_import_gt_mask_preserves_low_value_gtlabel_png(client):
|
||||
],
|
||||
"rules": [],
|
||||
}).json()
|
||||
mask = np.zeros((360, 640), dtype=np.uint16)
|
||||
mask = np.zeros((360, 640), dtype=np.uint8)
|
||||
cv2.rectangle(mask, (40, 40), (140, 140), 1, thickness=-1)
|
||||
ok, encoded = cv2.imencode(".png", mask)
|
||||
assert ok
|
||||
@@ -1241,7 +1336,7 @@ def test_import_gt_mask_rejects_rgb_color_masks(client):
|
||||
assert "RGB 三通道完全相同" in response.json()["detail"]
|
||||
|
||||
|
||||
def test_import_gt_mask_reads_uint16_gt_label_and_maps_maskid_class(client):
|
||||
def test_import_gt_mask_rejects_uint16_gt_label(client):
|
||||
project, frame, _ = _create_project_and_frame(client)
|
||||
template = client.post("/api/templates", json={
|
||||
"name": "Label Template",
|
||||
@@ -1266,13 +1361,8 @@ def test_import_gt_mask_reads_uint16_gt_label_and_maps_maskid_class(client):
|
||||
files={"file": ("gt_label.png", encoded.tobytes(), "image/png")},
|
||||
)
|
||||
|
||||
assert response.status_code == 201
|
||||
body = response.json()
|
||||
assert len(body) == 1
|
||||
assert body[0]["mask_data"]["gt_label_value"] == 1
|
||||
assert body[0]["mask_data"]["label"] == "肿瘤"
|
||||
assert body[0]["mask_data"]["class"]["maskId"] == 1
|
||||
assert body[0]["mask_data"]["class"]["color"] == "#ff0000"
|
||||
assert response.status_code == 400
|
||||
assert "仅支持 8-bit" in response.json()["detail"]
|
||||
|
||||
|
||||
def test_import_gt_mask_handles_unknown_maskid_policy_and_resizes_to_frame(client):
|
||||
|
||||
@@ -169,6 +169,7 @@ def test_export_results_zip_contains_coco_original_images_and_selected_mask_outp
|
||||
"key": f"template:{annotation['template_id']}",
|
||||
"template_id": annotation["template_id"],
|
||||
}]
|
||||
assert gt_label.dtype == np.uint8
|
||||
assert gt_label[0, 0] == 0
|
||||
assert gt_label[20, 50] == 1
|
||||
assert pro_label[20, 50].tolist() == [212, 182, 6]
|
||||
@@ -234,6 +235,7 @@ def test_export_results_uses_internal_layer_order_for_gt_pro_and_mix_outputs(cli
|
||||
cv2.IMREAD_COLOR,
|
||||
)
|
||||
|
||||
assert gt_label.dtype == np.uint8
|
||||
assert gt_label[10, 10] == high_value
|
||||
assert pro_label[10, 10].tolist() == [0, 0, 255]
|
||||
assert mix_label[10, 10].tolist() == [127, 127, 255]
|
||||
@@ -365,10 +367,74 @@ def test_export_results_preserves_template_maskid_consistently_across_frames(cli
|
||||
"key": "class:tumor",
|
||||
"template_id": None,
|
||||
}]
|
||||
assert first_label.dtype == np.uint8
|
||||
assert second_label.dtype == np.uint8
|
||||
assert first_label[5, 5] == 7
|
||||
assert second_label[5, 5] == 7
|
||||
|
||||
|
||||
def test_export_results_keeps_unclassified_maskid_zero_black_in_gt_and_pro(client, monkeypatch):
|
||||
monkeypatch.setattr("routers.export.download_file", lambda object_name: _fake_image_bytes(20, 20))
|
||||
project = client.post("/api/projects", json={
|
||||
"name": "Unclassified Export Project",
|
||||
"video_path": "uploads/8/unclassified.mp4",
|
||||
}).json()
|
||||
frame = client.post(f"/api/projects/{project['id']}/frames", json={
|
||||
"project_id": project["id"],
|
||||
"frame_index": 0,
|
||||
"image_url": "frames/source.jpg",
|
||||
"width": 20,
|
||||
"height": 20,
|
||||
"timestamp_ms": 0,
|
||||
}).json()
|
||||
client.post("/api/ai/annotate", json={
|
||||
"project_id": project["id"],
|
||||
"frame_id": frame["id"],
|
||||
"mask_data": {
|
||||
"polygons": [[[0.1, 0.1], [0.8, 0.1], [0.8, 0.8], [0.1, 0.8]]],
|
||||
"label": "待分类",
|
||||
"color": "#000000",
|
||||
"class": {
|
||||
"id": "reserved-unclassified",
|
||||
"name": "待分类",
|
||||
"color": "#000000",
|
||||
"maskId": 0,
|
||||
"zIndex": 0,
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
response = client.get(f"/api/export/{project['id']}/results?scope=all&outputs=gt_label,pro_label")
|
||||
|
||||
assert response.status_code == 200
|
||||
with zipfile.ZipFile(BytesIO(response.content)) as archive:
|
||||
mapping = json.loads(archive.read("maskid_GT像素值_类别映射.json"))
|
||||
stem = "unclassified_0h00m00s000ms_frame000001"
|
||||
gt_label = cv2.imdecode(
|
||||
np.frombuffer(archive.read(f"GT_label图/{stem}.png"), dtype=np.uint8),
|
||||
cv2.IMREAD_UNCHANGED,
|
||||
)
|
||||
pro_label = cv2.imdecode(
|
||||
np.frombuffer(archive.read(f"Pro_label彩色分割结果/{stem}.png"), dtype=np.uint8),
|
||||
cv2.IMREAD_COLOR,
|
||||
)
|
||||
|
||||
assert mapping["classes"] == [{
|
||||
"gt_pixel_value": 0,
|
||||
"maskid": 0,
|
||||
"chineseName": "待分类",
|
||||
"className": "待分类",
|
||||
"categoryName": "",
|
||||
"rgb": [0, 0, 0],
|
||||
"color": "#000000",
|
||||
"key": "class:reserved-unclassified",
|
||||
"template_id": None,
|
||||
}]
|
||||
assert gt_label.dtype == np.uint8
|
||||
assert gt_label[5, 5] == 0
|
||||
assert pro_label[5, 5].tolist() == [0, 0, 0]
|
||||
|
||||
|
||||
def test_exported_gtlabel_round_trips_through_gt_mask_import_with_template_maskid(client, monkeypatch):
|
||||
monkeypatch.setattr("routers.export.download_file", lambda object_name: _fake_image_bytes(20, 20))
|
||||
project = client.post("/api/projects", json={
|
||||
@@ -423,6 +489,7 @@ def test_exported_gtlabel_round_trips_through_gt_mask_import_with_template_maski
|
||||
gt_label = cv2.imdecode(np.frombuffer(exported_gt_label, dtype=np.uint8), cv2.IMREAD_UNCHANGED)
|
||||
mapping = json.loads(archive.read("maskid_GT像素值_类别映射.json"))
|
||||
|
||||
assert gt_label.dtype == np.uint8
|
||||
assert gt_label[5, 5] == 7
|
||||
assert mapping["classes"][0]["maskid"] == 7
|
||||
|
||||
@@ -446,6 +513,36 @@ def test_exported_gtlabel_round_trips_through_gt_mask_import_with_template_maski
|
||||
assert imported[0]["mask_data"]["class"]["maskId"] == 7
|
||||
|
||||
|
||||
def test_export_results_rejects_gtlabel_maskid_outside_uint8_range(client, monkeypatch):
|
||||
monkeypatch.setattr("routers.export.download_file", lambda object_name: _fake_image_bytes(20, 20))
|
||||
project = client.post("/api/projects", json={
|
||||
"name": "Large MaskId Project",
|
||||
"video_path": "uploads/8/large-maskid.mp4",
|
||||
}).json()
|
||||
frame = client.post(f"/api/projects/{project['id']}/frames", json={
|
||||
"project_id": project["id"],
|
||||
"frame_index": 0,
|
||||
"image_url": "frames/source.jpg",
|
||||
"width": 20,
|
||||
"height": 20,
|
||||
}).json()
|
||||
client.post("/api/ai/annotate", json={
|
||||
"project_id": project["id"],
|
||||
"frame_id": frame["id"],
|
||||
"mask_data": {
|
||||
"polygons": [[[0.1, 0.1], [0.8, 0.1], [0.8, 0.8], [0.1, 0.8]]],
|
||||
"label": "TooLarge",
|
||||
"color": "#ff0000",
|
||||
"class": {"id": "too-large", "name": "TooLarge", "color": "#ff0000", "maskId": 300, "zIndex": 30},
|
||||
},
|
||||
})
|
||||
|
||||
response = client.get(f"/api/export/{project['id']}/results?scope=all&outputs=gt_label")
|
||||
|
||||
assert response.status_code == 400
|
||||
assert "8-bit maskid" in response.json()["detail"]
|
||||
|
||||
|
||||
def test_export_missing_project_returns_404(client):
|
||||
assert client.get("/api/export/999/coco").status_code == 404
|
||||
assert client.get("/api/export/999/masks").status_code == 404
|
||||
|
||||
@@ -48,15 +48,37 @@ def test_upload_dicom_batch_filters_files_and_creates_project(client, monkeypatc
|
||||
response = client.post(
|
||||
"/api/media/upload/dicom",
|
||||
files=[
|
||||
("files", ("a.dcm", b"dcm", "application/dicom")),
|
||||
("files", ("10.dcm", b"dcm10", "application/dicom")),
|
||||
("files", ("skip.txt", b"text", "text/plain")),
|
||||
("files", ("2.dcm", b"dcm2", "application/dicom")),
|
||||
("files", ("1.dcm", b"dcm1", "application/dicom")),
|
||||
],
|
||||
)
|
||||
|
||||
assert response.status_code == 201
|
||||
data = response.json()
|
||||
assert data["uploaded_count"] == 1
|
||||
assert uploaded == [f"uploads/{data['project_id']}/dicom/a.dcm"]
|
||||
assert data["uploaded_count"] == 3
|
||||
assert uploaded == [
|
||||
f"uploads/{data['project_id']}/dicom/1.dcm",
|
||||
f"uploads/{data['project_id']}/dicom/2.dcm",
|
||||
f"uploads/{data['project_id']}/dicom/10.dcm",
|
||||
]
|
||||
project_detail = client.get(f"/api/projects/{data['project_id']}").json()
|
||||
assert project_detail["name"] == "1.dcm"
|
||||
|
||||
|
||||
def test_upload_dicom_batch_rejects_when_no_valid_dicom(client, monkeypatch):
|
||||
monkeypatch.setattr("routers.media.upload_file", lambda *args, **kwargs: None)
|
||||
|
||||
response = client.post(
|
||||
"/api/media/upload/dicom",
|
||||
files=[
|
||||
("files", ("notes.txt", b"text", "text/plain")),
|
||||
],
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert response.json()["detail"] == "No valid DICOM files uploaded"
|
||||
|
||||
|
||||
def test_parse_media_queues_background_task(client, monkeypatch):
|
||||
@@ -194,6 +216,101 @@ def test_parse_task_runner_registers_frames(client, db_session, monkeypatch, tmp
|
||||
assert frames[0]["source_frame_number"] == 0
|
||||
|
||||
|
||||
def test_parse_dicom_reads_files_in_natural_filename_order(monkeypatch, tmp_path):
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
|
||||
from services.frame_parser import parse_dicom
|
||||
|
||||
dcm_dir = tmp_path / "dcm"
|
||||
output_dir = tmp_path / "frames"
|
||||
dcm_dir.mkdir()
|
||||
for name in ["10.dcm", "2.dcm", "1.dcm"]:
|
||||
(dcm_dir / name).write_bytes(b"dcm")
|
||||
|
||||
read_order = []
|
||||
|
||||
class FakeDicom:
|
||||
pixel_array = np.ones((2, 2), dtype=np.uint8)
|
||||
|
||||
def fake_dcmread(path):
|
||||
read_order.append(Path(path).name)
|
||||
return FakeDicom()
|
||||
|
||||
def fake_imwrite(path, image, params=None):
|
||||
Path(path).write_bytes(image.tobytes())
|
||||
return True
|
||||
|
||||
monkeypatch.setattr("services.frame_parser.dcmread", fake_dcmread)
|
||||
monkeypatch.setattr("services.frame_parser.cv2.imwrite", fake_imwrite)
|
||||
|
||||
frame_files = parse_dicom(str(dcm_dir), str(output_dir))
|
||||
|
||||
assert read_order == ["1.dcm", "2.dcm", "10.dcm"]
|
||||
assert [Path(path).name for path in frame_files] == ["frame_000000.jpg", "frame_000001.jpg", "frame_000002.jpg"]
|
||||
|
||||
|
||||
def test_parse_task_runner_downloads_dicom_objects_in_natural_filename_order(client, db_session, monkeypatch, tmp_path):
|
||||
from types import SimpleNamespace
|
||||
|
||||
from models import ProcessingTask
|
||||
from services.media_task_runner import run_parse_media_task
|
||||
|
||||
project = client.post("/api/projects", json={
|
||||
"name": "DICOM",
|
||||
"video_path": "uploads/1/dicom",
|
||||
"source_type": "dicom",
|
||||
"parse_fps": 30,
|
||||
}).json()
|
||||
task = ProcessingTask(
|
||||
task_type="parse_dicom",
|
||||
status="queued",
|
||||
progress=0,
|
||||
project_id=project["id"],
|
||||
payload={"source_type": "dicom"},
|
||||
)
|
||||
db_session.add(task)
|
||||
db_session.commit()
|
||||
db_session.refresh(task)
|
||||
|
||||
class FakeClient:
|
||||
def list_objects(self, bucket, prefix, recursive=True):
|
||||
return [
|
||||
SimpleNamespace(object_name=f"{prefix}/10.dcm"),
|
||||
SimpleNamespace(object_name=f"{prefix}/2.dcm"),
|
||||
SimpleNamespace(object_name=f"{prefix}/1.dcm"),
|
||||
]
|
||||
|
||||
downloaded = []
|
||||
frame_files = []
|
||||
for idx in range(3):
|
||||
frame_file = tmp_path / f"frame_{idx:06d}.jpg"
|
||||
frame_file.write_bytes(b"fake image")
|
||||
frame_files.append(str(frame_file))
|
||||
|
||||
monkeypatch.setattr("services.media_task_runner.get_minio_client", lambda: FakeClient())
|
||||
monkeypatch.setattr(
|
||||
"services.media_task_runner.download_file",
|
||||
lambda object_name: downloaded.append(object_name) or b"dcm",
|
||||
)
|
||||
monkeypatch.setattr("services.media_task_runner.parse_dicom", lambda *args, **kwargs: frame_files)
|
||||
monkeypatch.setattr(
|
||||
"services.media_task_runner.upload_frames_to_minio",
|
||||
lambda frames, project_id: [f"projects/{project_id}/frames/{idx}.jpg" for idx, _ in enumerate(frames)],
|
||||
)
|
||||
monkeypatch.setattr("services.media_task_runner.publish_task_progress_event", lambda task: None)
|
||||
|
||||
result = run_parse_media_task(db_session, task.id)
|
||||
|
||||
assert result["frames_extracted"] == 3
|
||||
assert downloaded == [
|
||||
"uploads/1/dicom/1.dcm",
|
||||
"uploads/1/dicom/2.dcm",
|
||||
"uploads/1/dicom/10.dcm",
|
||||
]
|
||||
|
||||
|
||||
def test_parse_task_runner_skips_already_cancelled_task(db_session):
|
||||
from models import ProcessingTask
|
||||
from services.media_task_runner import run_parse_media_task
|
||||
|
||||
@@ -42,6 +42,9 @@ def test_project_crud_and_frames(client, monkeypatch):
|
||||
assert updated.json()["name"] == "Renamed"
|
||||
assert updated.json()["status"] == "ready"
|
||||
|
||||
empty_name = client.patch(f"/api/projects/{project_id}", json={"name": " "})
|
||||
assert empty_name.status_code == 400
|
||||
|
||||
deleted = client.delete(f"/api/projects/{project_id}")
|
||||
assert deleted.status_code == 204
|
||||
assert client.get(f"/api/projects/{project_id}").status_code == 404
|
||||
@@ -83,10 +86,97 @@ def test_delete_project_cascades_related_records(client, db_session):
|
||||
assert db_session.query(ProcessingTask).filter(ProcessingTask.project_id == project_id).count() == 0
|
||||
|
||||
|
||||
def test_copy_project_reset_copies_frame_sequence_without_annotations(client, db_session):
|
||||
created = client.post("/api/projects", json={
|
||||
"name": "Reset Source",
|
||||
"description": "desc",
|
||||
"video_path": "uploads/source.mp4",
|
||||
"thumbnail_url": "thumb.jpg",
|
||||
"status": "ready",
|
||||
"parse_fps": 12,
|
||||
})
|
||||
assert created.status_code == 201
|
||||
project_id = created.json()["id"]
|
||||
frame = client.post(f"/api/projects/{project_id}/frames", json={
|
||||
"project_id": project_id,
|
||||
"frame_index": 0,
|
||||
"image_url": "frames/source/frame_000000.jpg",
|
||||
"width": 640,
|
||||
"height": 360,
|
||||
"timestamp_ms": 0,
|
||||
"source_frame_number": 0,
|
||||
})
|
||||
assert frame.status_code == 201
|
||||
annotation = client.post("/api/ai/annotate", json={
|
||||
"project_id": project_id,
|
||||
"frame_id": frame.json()["id"],
|
||||
"mask_data": {"label": "Tumor", "polygons": [[[0.1, 0.1], [0.2, 0.1], [0.2, 0.2]]]},
|
||||
})
|
||||
assert annotation.status_code == 201
|
||||
|
||||
copied = client.post(f"/api/projects/{project_id}/copy", json={"mode": "reset"})
|
||||
assert copied.status_code == 201
|
||||
copied_body = copied.json()
|
||||
assert copied_body["name"] == "Reset Source 副本"
|
||||
assert copied_body["frame_count"] == 1
|
||||
assert copied_body["video_path"] == "uploads/source.mp4"
|
||||
assert copied_body["parse_fps"] == 12
|
||||
|
||||
copied_frames = db_session.query(Frame).filter(Frame.project_id == copied_body["id"]).all()
|
||||
assert len(copied_frames) == 1
|
||||
assert copied_frames[0].image_url == "frames/source/frame_000000.jpg"
|
||||
assert db_session.query(Annotation).filter(Annotation.project_id == copied_body["id"]).count() == 0
|
||||
|
||||
|
||||
def test_copy_project_full_copies_annotations_and_mask_metadata(client, db_session):
|
||||
created = client.post("/api/projects", json={
|
||||
"name": "Full Source",
|
||||
"status": "ready",
|
||||
})
|
||||
assert created.status_code == 201
|
||||
project_id = created.json()["id"]
|
||||
frame = client.post(f"/api/projects/{project_id}/frames", json={
|
||||
"project_id": project_id,
|
||||
"frame_index": 0,
|
||||
"image_url": "frames/source/frame_000000.jpg",
|
||||
"width": 640,
|
||||
"height": 360,
|
||||
})
|
||||
assert frame.status_code == 201
|
||||
frame_id = frame.json()["id"]
|
||||
annotation = client.post("/api/ai/annotate", json={
|
||||
"project_id": project_id,
|
||||
"frame_id": frame_id,
|
||||
"mask_data": {"label": "Tumor", "polygons": [[[0.1, 0.1], [0.2, 0.1], [0.2, 0.2]]]},
|
||||
"points": [[0.1, 0.1]],
|
||||
"bbox": [0.1, 0.1, 0.1, 0.1],
|
||||
})
|
||||
assert annotation.status_code == 201
|
||||
annotation_id = annotation.json()["id"]
|
||||
db_session.add(Mask(annotation_id=annotation_id, mask_url="masks/source.png", format="png"))
|
||||
db_session.commit()
|
||||
|
||||
copied = client.post(f"/api/projects/{project_id}/copy", json={"mode": "full"})
|
||||
assert copied.status_code == 201
|
||||
copied_body = copied.json()
|
||||
copied_frames = db_session.query(Frame).filter(Frame.project_id == copied_body["id"]).all()
|
||||
copied_annotations = db_session.query(Annotation).filter(Annotation.project_id == copied_body["id"]).all()
|
||||
|
||||
assert copied_body["name"] == "Full Source 副本"
|
||||
assert len(copied_frames) == 1
|
||||
assert len(copied_annotations) == 1
|
||||
assert copied_annotations[0].id != annotation_id
|
||||
assert copied_annotations[0].frame_id == copied_frames[0].id
|
||||
assert copied_annotations[0].mask_data["label"] == "Tumor"
|
||||
assert copied_annotations[0].bbox == [0.1, 0.1, 0.1, 0.1]
|
||||
assert copied_annotations[0].masks[0].mask_url == "masks/source.png"
|
||||
|
||||
|
||||
def test_project_and_frame_404s(client):
|
||||
assert client.get("/api/projects/999").status_code == 404
|
||||
assert client.patch("/api/projects/999", json={"name": "x"}).status_code == 404
|
||||
assert client.delete("/api/projects/999").status_code == 404
|
||||
assert client.post("/api/projects/999/copy", json={"mode": "reset"}).status_code == 404
|
||||
assert client.post("/api/projects/999/frames", json={
|
||||
"project_id": 999,
|
||||
"frame_index": 0,
|
||||
|
||||
@@ -37,3 +37,55 @@ def test_template_404s(client):
|
||||
assert client.get("/api/templates/999").status_code == 404
|
||||
assert client.patch("/api/templates/999", json={"name": "x"}).status_code == 404
|
||||
assert client.delete("/api/templates/999").status_code == 404
|
||||
|
||||
|
||||
def test_default_head_neck_ct_template_is_seeded_and_visible(client, db_session):
|
||||
from main import ensure_default_templates
|
||||
from models import Template
|
||||
|
||||
ensure_default_templates(db_session)
|
||||
ensure_default_templates(db_session)
|
||||
|
||||
templates = db_session.query(Template).filter(Template.owner_user_id.is_(None)).all()
|
||||
names = [template.name for template in templates]
|
||||
assert names.count("头颈部CT分割") == 1
|
||||
|
||||
listing = client.get("/api/templates")
|
||||
assert listing.status_code == 200
|
||||
head_neck = next(template for template in listing.json() if template["name"] == "头颈部CT分割")
|
||||
assert head_neck["description"] == "头颈部CT分割"
|
||||
expected_names = [
|
||||
"肿瘤/结节 (Tumor/Nodule)",
|
||||
"下颌骨 (Mandible)",
|
||||
"甲状腺 (Thyroid)",
|
||||
"气管 (Trachea)",
|
||||
"颈椎 (Cervical Spine)",
|
||||
"颈动脉 (Carotid Artery)",
|
||||
"颈静脉 (Jugular Vein)",
|
||||
"腮腺 (Parotid Gland)",
|
||||
"下颌下腺 (Submandibular Gland)",
|
||||
"舌骨 (Hyoid Bone)",
|
||||
"待分类",
|
||||
]
|
||||
expected_colors = [
|
||||
"#ff0000",
|
||||
"#00ff00",
|
||||
"#0000ff",
|
||||
"#ffff00",
|
||||
"#ff00ff",
|
||||
"#00ffff",
|
||||
"#ff8000",
|
||||
"#800080",
|
||||
"#008080",
|
||||
"#808000",
|
||||
"#000000",
|
||||
]
|
||||
actual_names = [item["name"] for item in head_neck["classes"]]
|
||||
actual_colors = [item["color"] for item in head_neck["classes"]]
|
||||
actual_mask_ids = [item["maskId"] for item in head_neck["classes"]]
|
||||
if actual_names != expected_names:
|
||||
raise AssertionError(f"Unexpected head-neck classes: {actual_names}")
|
||||
if actual_colors != expected_colors:
|
||||
raise AssertionError(f"Unexpected head-neck colors: {actual_colors}")
|
||||
if actual_mask_ids != [*list(range(1, 11)), 0]:
|
||||
raise AssertionError(f"Unexpected head-neck mask IDs: {actual_mask_ids}")
|
||||
|
||||
Reference in New Issue
Block a user