feat: 完善分割工作区导入导出与管理流程
- 新增基于 JWT 当前用户的登录恢复、角色权限、用户管理、审计日志和演示出厂重置后台接口与前端管理页。 - 重串 GT_label 导出和 GT Mask 导入逻辑:导出保留类别真实 maskid,导入仅接受灰度或 RGB 等通道 maskid 图,支持未知 maskid 策略、尺寸最近邻拉伸和导入预览。 - 统一分割结果导出体验:默认当前帧,按项目抽帧顺序和 XhXXmXXsXXXms 时间戳命名 ZIP 与图片,补齐 GT/Pro/Mix/分开 Mask 输出和映射 JSON。 - 调整工作区左侧工具栏:移除创建点/线段入口,新增画笔、橡皮擦及尺寸控制,并按绘制、布尔、导入/AI 工具分组分隔。 - 扩展 Canvas 编辑能力:画笔按语义分类绘制并可自动并入连通选中 mask,橡皮擦对选中区域扣除,优化布尔操作、选区、撤销重做和保存状态联动。 - 优化自动传播时间轴显示:同一蓝色系按传播新旧递进变暗,老传播记录达到阈值后统一旧记录色,并维护范围选择与清空后的历史显示。 - 将 AI 智能分割入口替换为更明确的 AI 元素图标,并同步侧栏、工作区和 AI 页面入口表现。 - 完善模板分类、maskid 工具函数、分类树联动、遮罩透明度、边缘平滑和传播链同步相关前端状态。 - 扩展后端项目、媒体、任务、Dashboard、模板和传播 runner 的用户隔离、任务控制、进度事件与兼容处理。 - 补充前后端测试,覆盖用户管理、GT_label 往返导入导出、GT Mask 校验和预览、画笔/橡皮擦、时间轴传播历史、导出范围、WebSocket 与 API 封装。 - 更新 AGENTS、README 和 doc 文档,记录当前接口契约、实现状态、测试计划、安装说明和 maskid/GT_label 规则。
This commit is contained in:
@@ -19,7 +19,7 @@ if str(BACKEND_DIR) not in sys.path:
|
||||
|
||||
from database import Base, get_db # noqa: E402
|
||||
from main import websocket_progress # noqa: E402
|
||||
from routers import ai, auth, dashboard, export, media, projects, tasks, templates # noqa: E402
|
||||
from routers import admin, ai, auth, dashboard, export, media, projects, tasks, templates # noqa: E402
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
@@ -32,6 +32,7 @@ def db_session() -> Iterator[Session]:
|
||||
TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
Base.metadata.create_all(bind=engine)
|
||||
session = TestingSessionLocal()
|
||||
auth.ensure_default_admin(session)
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
@@ -56,6 +57,7 @@ def app(db_session: Session) -> FastAPI:
|
||||
test_app.include_router(export.router)
|
||||
test_app.include_router(dashboard.router)
|
||||
test_app.include_router(tasks.router)
|
||||
test_app.include_router(admin.router)
|
||||
|
||||
@test_app.get("/health")
|
||||
def health_check() -> dict[str, str]:
|
||||
@@ -67,6 +69,10 @@ def app(db_session: Session) -> FastAPI:
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def client(app: FastAPI) -> Iterator[TestClient]:
|
||||
def client(app: FastAPI, db_session: Session) -> Iterator[TestClient]:
|
||||
with TestClient(app) as test_client:
|
||||
admin = auth.ensure_default_admin(db_session)
|
||||
test_client.headers.update({
|
||||
"Authorization": f"Bearer {auth.create_access_token(admin)}"
|
||||
})
|
||||
yield test_client
|
||||
|
||||
158
backend/tests/test_admin.py
Normal file
158
backend/tests/test_admin.py
Normal file
@@ -0,0 +1,158 @@
|
||||
from models import Annotation, AuditLog, Frame, Mask, ProcessingTask, Project, Template, User
|
||||
from routers.auth import create_access_token, hash_password
|
||||
from statuses import PROJECT_STATUS_PENDING
|
||||
|
||||
|
||||
def test_admin_user_management_and_audit_logs(client, db_session):
|
||||
created = client.post("/api/admin/users", json={
|
||||
"username": "doctor",
|
||||
"password": "secret123",
|
||||
"role": "annotator",
|
||||
"is_active": True,
|
||||
})
|
||||
assert created.status_code == 201
|
||||
user_id = created.json()["id"]
|
||||
|
||||
updated = client.patch(f"/api/admin/users/{user_id}", json={
|
||||
"role": "viewer",
|
||||
"password": "newsecret",
|
||||
"is_active": False,
|
||||
})
|
||||
assert updated.status_code == 200
|
||||
assert updated.json()["role"] == "viewer"
|
||||
assert updated.json()["is_active"] == 0
|
||||
|
||||
users = client.get("/api/admin/users")
|
||||
assert users.status_code == 200
|
||||
assert any(user["username"] == "doctor" for user in users.json())
|
||||
|
||||
deleted = client.delete(f"/api/admin/users/{user_id}")
|
||||
assert deleted.status_code == 204
|
||||
|
||||
logs = client.get("/api/admin/audit-logs")
|
||||
assert logs.status_code == 200
|
||||
actions = [log["action"] for log in logs.json()]
|
||||
assert "admin.user_created" in actions
|
||||
assert "admin.user_updated" in actions
|
||||
assert "admin.user_deleted" in actions
|
||||
|
||||
|
||||
def test_admin_routes_require_admin_role(client, db_session):
|
||||
user = User(username="viewer", password_hash=hash_password("secret123"), role="viewer", is_active=1)
|
||||
db_session.add(user)
|
||||
db_session.commit()
|
||||
db_session.refresh(user)
|
||||
original_auth = client.headers["Authorization"]
|
||||
client.headers.update({"Authorization": f"Bearer {create_access_token(user)}"})
|
||||
try:
|
||||
response = client.get("/api/admin/users")
|
||||
assert response.status_code == 403
|
||||
finally:
|
||||
client.headers.update({"Authorization": original_auth})
|
||||
|
||||
|
||||
def test_viewer_role_is_read_only_for_business_mutations(client, db_session):
|
||||
project = client.post("/api/projects", json={"name": "Readonly Check"}).json()
|
||||
user = User(username="readonly", password_hash=hash_password("secret123"), role="viewer", is_active=1)
|
||||
db_session.add(user)
|
||||
db_session.commit()
|
||||
db_session.refresh(user)
|
||||
original_auth = client.headers["Authorization"]
|
||||
client.headers.update({"Authorization": f"Bearer {create_access_token(user)}"})
|
||||
try:
|
||||
assert client.get("/api/projects").status_code == 200
|
||||
assert client.post("/api/projects", json={"name": "Nope"}).status_code == 403
|
||||
assert client.patch(f"/api/projects/{project['id']}", json={"name": "Nope"}).status_code == 403
|
||||
assert client.post("/api/ai/annotate", json={"project_id": project["id"]}).status_code == 403
|
||||
finally:
|
||||
client.headers.update({"Authorization": original_auth})
|
||||
|
||||
|
||||
def test_admin_cannot_delete_self_or_user_with_projects(client, db_session):
|
||||
me = client.get("/api/auth/me").json()
|
||||
assert client.delete(f"/api/admin/users/{me['id']}").status_code == 400
|
||||
|
||||
user = User(username="owner", password_hash=hash_password("secret123"), role="annotator", is_active=1)
|
||||
db_session.add(user)
|
||||
db_session.commit()
|
||||
db_session.refresh(user)
|
||||
db_session.add(Project(name="Owned", owner_user_id=user.id))
|
||||
db_session.commit()
|
||||
|
||||
response = client.delete(f"/api/admin/users/{user.id}")
|
||||
assert response.status_code == 409
|
||||
|
||||
|
||||
def test_demo_factory_reset_leaves_admin_and_unparsed_demo_video(client, db_session, monkeypatch, tmp_path):
|
||||
video_path = tmp_path / "Data_MyVideo_1.mp4"
|
||||
video_path.write_bytes(b"demo-video")
|
||||
monkeypatch.setattr("routers.admin.settings.demo_video_path", str(video_path))
|
||||
uploaded = []
|
||||
monkeypatch.setattr("routers.admin.upload_file", lambda object_name, data, content_type, length: uploaded.append({
|
||||
"object_name": object_name,
|
||||
"data": data,
|
||||
"content_type": content_type,
|
||||
"length": length,
|
||||
}))
|
||||
|
||||
extra_user = User(username="doctor", password_hash=hash_password("secret123"), role="annotator", is_active=1)
|
||||
db_session.add(extra_user)
|
||||
db_session.commit()
|
||||
db_session.refresh(extra_user)
|
||||
old_project = Project(name="Old", owner_user_id=extra_user.id, video_path="uploads/old.mp4")
|
||||
db_session.add(old_project)
|
||||
db_session.commit()
|
||||
db_session.refresh(old_project)
|
||||
frame = Frame(project_id=old_project.id, frame_index=0, image_url="frames/old.jpg")
|
||||
db_session.add(frame)
|
||||
task = ProcessingTask(task_type="parse_video", project_id=old_project.id)
|
||||
private_template = Template(
|
||||
name="Private",
|
||||
description="private",
|
||||
color="#fff",
|
||||
z_index=1,
|
||||
owner_user_id=extra_user.id,
|
||||
)
|
||||
db_session.add_all([task, private_template])
|
||||
db_session.commit()
|
||||
db_session.refresh(frame)
|
||||
annotation = Annotation(project_id=old_project.id, frame_id=frame.id, mask_data={"label": "old"})
|
||||
db_session.add(annotation)
|
||||
db_session.commit()
|
||||
db_session.refresh(annotation)
|
||||
db_session.add(Mask(annotation_id=annotation.id, mask_url="masks/old.png"))
|
||||
db_session.add(AuditLog(actor_user_id=extra_user.id, action="old.audit"))
|
||||
db_session.commit()
|
||||
|
||||
response = client.post("/api/admin/demo-factory-reset", json={"confirmation": "RESET_DEMO_FACTORY"})
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["message"] == "演示环境已恢复出厂设置"
|
||||
assert data["admin_user"]["username"] == "admin"
|
||||
assert data["project"]["name"] == "Data_MyVideo_1"
|
||||
assert data["project"]["status"] == PROJECT_STATUS_PENDING
|
||||
assert data["project"]["frame_count"] == 0
|
||||
assert data["project"]["video_path"] == f"uploads/{data['project']['id']}/Data_MyVideo_1.mp4"
|
||||
assert uploaded == [{
|
||||
"object_name": data["project"]["video_path"],
|
||||
"data": b"demo-video",
|
||||
"content_type": "video/mp4",
|
||||
"length": len(b"demo-video"),
|
||||
}]
|
||||
|
||||
assert [user.username for user in db_session.query(User).all()] == ["admin"]
|
||||
assert db_session.query(Project).count() == 1
|
||||
assert db_session.query(Frame).count() == 0
|
||||
assert db_session.query(Annotation).count() == 0
|
||||
assert db_session.query(Mask).count() == 0
|
||||
assert db_session.query(ProcessingTask).count() == 0
|
||||
assert db_session.query(Template).filter(Template.owner_user_id.is_not(None)).count() == 0
|
||||
assert db_session.query(AuditLog).count() == 1
|
||||
assert db_session.query(AuditLog).first().action == "admin.demo_factory_reset"
|
||||
|
||||
|
||||
def test_demo_factory_reset_requires_exact_confirmation(client):
|
||||
response = client.post("/api/admin/demo-factory-reset", json={"confirmation": "reset"})
|
||||
|
||||
assert response.status_code == 400
|
||||
@@ -223,6 +223,88 @@ def test_analyze_mask_returns_backend_geometry_properties(client):
|
||||
assert body["message"] == "已从后端重新提取几何拓扑锚点"
|
||||
|
||||
|
||||
def test_analyze_mask_reports_actual_polygon_anchor_count(client):
|
||||
_, frame, _ = _create_project_and_frame(client)
|
||||
polygon = [[0.1 + index * 0.005, 0.1 + (0.01 if index % 2 else 0)] for index in range(80)]
|
||||
|
||||
response = client.post("/api/ai/analyze-mask", json={
|
||||
"frame_id": frame["id"],
|
||||
"mask_data": {
|
||||
"polygons": [polygon],
|
||||
"label": "AI Mask",
|
||||
"color": "#06b6d4",
|
||||
},
|
||||
"points": [[0.2, 0.2]],
|
||||
})
|
||||
|
||||
assert response.status_code == 200
|
||||
body = response.json()
|
||||
assert body["topology_anchor_count"] == len(polygon)
|
||||
assert len(body["topology_anchors"]) <= 64
|
||||
|
||||
|
||||
def test_smooth_mask_simplifies_noisy_ai_polygon(client):
|
||||
_, frame, _ = _create_project_and_frame(client)
|
||||
polygon = []
|
||||
for index in range(20):
|
||||
polygon.append([0.1 + index * 0.02, 0.1 + (0.01 if index % 2 else 0)])
|
||||
for index in range(20):
|
||||
polygon.append([0.5 + (0.01 if index % 2 else 0), 0.1 + index * 0.02])
|
||||
for index in range(20):
|
||||
polygon.append([0.5 - index * 0.02, 0.5 + (0.01 if index % 2 else 0)])
|
||||
for index in range(20):
|
||||
polygon.append([0.1 + (0.01 if index % 2 else 0), 0.5 - index * 0.02])
|
||||
|
||||
response = client.post("/api/ai/smooth-mask", json={
|
||||
"frame_id": frame["id"],
|
||||
"mask_data": {
|
||||
"polygons": [polygon],
|
||||
"label": "AI Mask",
|
||||
"color": "#06b6d4",
|
||||
},
|
||||
"strength": 80,
|
||||
})
|
||||
|
||||
assert response.status_code == 200
|
||||
body = response.json()
|
||||
assert body["topology_anchor_count"] == len(body["polygons"][0])
|
||||
assert len(body["polygons"][0]) < len(polygon)
|
||||
|
||||
|
||||
def test_smooth_mask_uses_eased_strength_curve(client):
|
||||
_, frame, _ = _create_project_and_frame(client)
|
||||
polygon = []
|
||||
for index in range(20):
|
||||
polygon.append([0.1 + index * 0.02, 0.1 + (0.01 if index % 2 else 0)])
|
||||
for index in range(20):
|
||||
polygon.append([0.5 + (0.01 if index % 2 else 0), 0.1 + index * 0.02])
|
||||
for index in range(20):
|
||||
polygon.append([0.5 - index * 0.02, 0.5 + (0.01 if index % 2 else 0)])
|
||||
for index in range(20):
|
||||
polygon.append([0.1 + (0.01 if index % 2 else 0), 0.5 - index * 0.02])
|
||||
|
||||
def smoothed_count(strength: int) -> int:
|
||||
response = client.post("/api/ai/smooth-mask", json={
|
||||
"frame_id": frame["id"],
|
||||
"mask_data": {
|
||||
"polygons": [polygon],
|
||||
"label": "AI Mask",
|
||||
"color": "#06b6d4",
|
||||
},
|
||||
"strength": strength,
|
||||
})
|
||||
assert response.status_code == 200
|
||||
return len(response.json()["polygons"][0])
|
||||
|
||||
low_count = smoothed_count(20)
|
||||
mid_count = smoothed_count(70)
|
||||
high_count = smoothed_count(95)
|
||||
|
||||
assert low_count <= len(polygon)
|
||||
assert mid_count < low_count
|
||||
assert high_count < mid_count
|
||||
|
||||
|
||||
def test_smooth_mask_returns_backend_smoothed_geometry(client):
|
||||
_, frame, _ = _create_project_and_frame(client)
|
||||
|
||||
@@ -311,6 +393,7 @@ def test_propagate_saves_tracked_annotations(client, monkeypatch):
|
||||
"color": "#ff0000",
|
||||
"class_metadata": {"id": "c1", "name": "胆囊", "color": "#ff0000", "zIndex": 20},
|
||||
"template_id": None,
|
||||
"smoothing": {"strength": 45, "method": "chaikin"},
|
||||
},
|
||||
})
|
||||
|
||||
@@ -327,6 +410,9 @@ def test_propagate_saves_tracked_annotations(client, monkeypatch):
|
||||
assert saved["mask_data"]["source"] == "sam2.1_hiera_tiny_propagation"
|
||||
assert saved["mask_data"]["class"]["name"] == "胆囊"
|
||||
assert saved["mask_data"]["score"] == 0.8
|
||||
assert saved["mask_data"]["geometry_smoothing"] == {"strength": 45.0, "method": "chaikin"}
|
||||
assert saved["mask_data"]["polygons"][0] != [[0.15, 0.15], [0.25, 0.15], [0.25, 0.25]]
|
||||
assert len(saved["mask_data"]["polygons"][0]) > 3
|
||||
|
||||
listing = client.get(f"/api/ai/annotations?project_id={project['id']}")
|
||||
assert len(listing.json()) == 1
|
||||
@@ -490,8 +576,10 @@ def test_propagation_task_runner_saves_annotations_and_progress(client, db_sessi
|
||||
listing = client.get(f"/api/ai/annotations?project_id={project['id']}")
|
||||
assert listing.json()[0]["frame_id"] == frames[1]["id"]
|
||||
assert listing.json()[0]["mask_data"]["source"] == "sam2.1_hiera_tiny_propagation"
|
||||
stored_polygon = listing.json()[0]["mask_data"]["polygons"][0]
|
||||
assert listing.json()[0]["mask_data"]["geometry_smoothing"] == {"strength": 40.0, "method": "chaikin"}
|
||||
assert len(listing.json()[0]["mask_data"]["polygons"][0]) > 3
|
||||
assert stored_polygon != [[0.15, 0.15], [0.25, 0.15], [0.25, 0.25]]
|
||||
assert len(stored_polygon) > 3
|
||||
|
||||
|
||||
def test_propagation_task_runner_skips_unchanged_seed_and_replaces_changed_seed(client, db_session, monkeypatch):
|
||||
@@ -1084,3 +1172,156 @@ def test_import_gt_mask_splits_label_values(client):
|
||||
assert [item["mask_data"]["gt_label_value"] for item in body] == [1, 2]
|
||||
assert [item["mask_data"]["label"] for item in body] == ["GT Class 1", "GT Class 2"]
|
||||
assert all(len(item["points"]) == 1 for item in body)
|
||||
|
||||
|
||||
def test_import_gt_mask_preserves_low_value_gtlabel_png(client):
|
||||
project, frame, _ = _create_project_and_frame(client)
|
||||
template = client.post("/api/templates", json={
|
||||
"name": "GTLabel Template",
|
||||
"color": "#06b6d4",
|
||||
"z_index": 0,
|
||||
"classes": [
|
||||
{"id": "tumor", "name": "肿瘤", "color": "#ff0000", "zIndex": 10, "maskId": 1},
|
||||
],
|
||||
"rules": [],
|
||||
}).json()
|
||||
mask = np.zeros((360, 640), dtype=np.uint16)
|
||||
cv2.rectangle(mask, (40, 40), (140, 140), 1, thickness=-1)
|
||||
ok, encoded = cv2.imencode(".png", mask)
|
||||
assert ok
|
||||
|
||||
response = client.post(
|
||||
"/api/ai/import-gt-mask",
|
||||
data={
|
||||
"project_id": str(project["id"]),
|
||||
"frame_id": str(frame["id"]),
|
||||
"template_id": str(template["id"]),
|
||||
"unknown_color_policy": "discard",
|
||||
},
|
||||
files={"file": ("GT_label.png", encoded.tobytes(), "image/png")},
|
||||
)
|
||||
|
||||
assert response.status_code == 201
|
||||
body = response.json()
|
||||
assert len(body) == 1
|
||||
assert body[0]["mask_data"]["gt_label_value"] == 1
|
||||
assert body[0]["mask_data"]["class"]["name"] == "肿瘤"
|
||||
assert body[0]["mask_data"]["class"]["maskId"] == 1
|
||||
|
||||
|
||||
def test_import_gt_mask_rejects_rgb_color_masks(client):
|
||||
project, frame, _ = _create_project_and_frame(client)
|
||||
template = client.post("/api/templates", json={
|
||||
"name": "Color Template",
|
||||
"color": "#06b6d4",
|
||||
"z_index": 0,
|
||||
"classes": [
|
||||
{"id": "known", "name": "已知类别", "color": "#ff0000", "zIndex": 10, "maskId": 1},
|
||||
],
|
||||
"rules": [],
|
||||
}).json()
|
||||
mask = np.zeros((80, 120, 3), dtype=np.uint8)
|
||||
mask[10:40, 10:40] = [0, 0, 255] # BGR red -> #ff0000
|
||||
mask[40:70, 70:110] = [0, 255, 0] # BGR green -> unknown #00ff00
|
||||
ok, encoded = cv2.imencode(".png", mask)
|
||||
assert ok
|
||||
|
||||
response = client.post(
|
||||
"/api/ai/import-gt-mask",
|
||||
data={
|
||||
"project_id": str(project["id"]),
|
||||
"frame_id": str(frame["id"]),
|
||||
"template_id": str(template["id"]),
|
||||
"unknown_color_policy": "discard",
|
||||
},
|
||||
files={"file": ("color-mask.png", encoded.tobytes(), "image/png")},
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert "RGB 三通道完全相同" in response.json()["detail"]
|
||||
|
||||
|
||||
def test_import_gt_mask_reads_uint16_gt_label_and_maps_maskid_class(client):
|
||||
project, frame, _ = _create_project_and_frame(client)
|
||||
template = client.post("/api/templates", json={
|
||||
"name": "Label Template",
|
||||
"color": "#06b6d4",
|
||||
"z_index": 0,
|
||||
"classes": [{"id": "tumor", "name": "肿瘤", "color": "#ff0000", "zIndex": 10, "maskId": 1}],
|
||||
"rules": [],
|
||||
}).json()
|
||||
mask = np.zeros((360, 640), dtype=np.uint16)
|
||||
cv2.rectangle(mask, (20, 20), (120, 120), 1, thickness=-1)
|
||||
ok, encoded = cv2.imencode(".png", mask)
|
||||
assert ok
|
||||
|
||||
response = client.post(
|
||||
"/api/ai/import-gt-mask",
|
||||
data={
|
||||
"project_id": str(project["id"]),
|
||||
"frame_id": str(frame["id"]),
|
||||
"template_id": str(template["id"]),
|
||||
"unknown_color_policy": "discard",
|
||||
},
|
||||
files={"file": ("gt_label.png", encoded.tobytes(), "image/png")},
|
||||
)
|
||||
|
||||
assert response.status_code == 201
|
||||
body = response.json()
|
||||
assert len(body) == 1
|
||||
assert body[0]["mask_data"]["gt_label_value"] == 1
|
||||
assert body[0]["mask_data"]["label"] == "肿瘤"
|
||||
assert body[0]["mask_data"]["class"]["maskId"] == 1
|
||||
assert body[0]["mask_data"]["class"]["color"] == "#ff0000"
|
||||
|
||||
|
||||
def test_import_gt_mask_handles_unknown_maskid_policy_and_resizes_to_frame(client):
|
||||
project, frame, _ = _create_project_and_frame(client)
|
||||
template = client.post("/api/templates", json={
|
||||
"name": "Color Template",
|
||||
"color": "#06b6d4",
|
||||
"z_index": 0,
|
||||
"classes": [{"id": "known", "name": "已定义", "color": "#ff0000", "zIndex": 10, "maskId": 1}],
|
||||
"rules": [],
|
||||
}).json()
|
||||
mask = np.zeros((90, 160, 3), dtype=np.uint8)
|
||||
cv2.rectangle(mask, (5, 5), (40, 40), (1, 1, 1), thickness=-1)
|
||||
cv2.rectangle(mask, (80, 5), (120, 40), (2, 2, 2), thickness=-1)
|
||||
ok, encoded = cv2.imencode(".png", mask)
|
||||
assert ok
|
||||
|
||||
discard_response = client.post(
|
||||
"/api/ai/import-gt-mask",
|
||||
data={
|
||||
"project_id": str(project["id"]),
|
||||
"frame_id": str(frame["id"]),
|
||||
"template_id": str(template["id"]),
|
||||
"unknown_color_policy": "discard",
|
||||
},
|
||||
files={"file": ("colors.png", encoded.tobytes(), "image/png")},
|
||||
)
|
||||
|
||||
assert discard_response.status_code == 201
|
||||
assert [item["mask_data"]["label"] for item in discard_response.json()] == ["已定义"]
|
||||
assert discard_response.json()[0]["mask_data"]["gt_original_size"] == {"width": 160, "height": 90}
|
||||
assert discard_response.json()[0]["mask_data"]["gt_resized_to_frame"] is True
|
||||
assert discard_response.json()[0]["mask_data"]["image_size"] == {"width": 640, "height": 360}
|
||||
|
||||
undefined_response = client.post(
|
||||
"/api/ai/import-gt-mask",
|
||||
data={
|
||||
"project_id": str(project["id"]),
|
||||
"frame_id": str(frame["id"]),
|
||||
"template_id": str(template["id"]),
|
||||
"unknown_color_policy": "undefined",
|
||||
},
|
||||
files={"file": ("colors.png", encoded.tobytes(), "image/png")},
|
||||
)
|
||||
|
||||
assert undefined_response.status_code == 201
|
||||
labels = {item["mask_data"]["label"] for item in undefined_response.json()}
|
||||
assert labels == {"已定义", "未定义类别 2"}
|
||||
unknown = next(item for item in undefined_response.json() if item["mask_data"]["label"].startswith("未定义"))
|
||||
assert unknown["mask_data"]["gt_unknown_class"] is True
|
||||
assert unknown["mask_data"]["gt_label_value"] == 2
|
||||
assert unknown["mask_data"]["gt_resized_to_frame"] is True
|
||||
|
||||
@@ -2,10 +2,11 @@ def test_login_success(client):
|
||||
response = client.post("/api/auth/login", json={"username": "admin", "password": "123456"})
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {
|
||||
"token": "fake-jwt-token-for-admin",
|
||||
"username": "admin",
|
||||
}
|
||||
body = response.json()
|
||||
assert body["token"]
|
||||
assert body["token_type"] == "bearer"
|
||||
assert body["username"] == "admin"
|
||||
assert body["user"]["username"] == "admin"
|
||||
|
||||
|
||||
def test_login_rejects_invalid_credentials(client):
|
||||
@@ -13,3 +14,19 @@ def test_login_rejects_invalid_credentials(client):
|
||||
|
||||
assert response.status_code == 401
|
||||
assert response.json()["detail"] == "Invalid credentials"
|
||||
|
||||
|
||||
def test_me_returns_current_user(client):
|
||||
response = client.get("/api/auth/me")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["username"] == "admin"
|
||||
|
||||
|
||||
def test_business_routes_require_auth(app):
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
with TestClient(app) as unauthenticated:
|
||||
response = unauthenticated.get("/api/projects")
|
||||
|
||||
assert response.status_code == 401
|
||||
|
||||
@@ -1,19 +1,31 @@
|
||||
import zipfile
|
||||
import json
|
||||
from io import BytesIO
|
||||
from urllib.parse import unquote
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
|
||||
def _fake_image_bytes(width=100, height=50, color=(255, 255, 255)):
|
||||
image = np.full((height, width, 3), color, dtype=np.uint8)
|
||||
_, encoded = cv2.imencode(".jpg", image)
|
||||
return encoded.tobytes()
|
||||
|
||||
|
||||
def _seed_export_data(client):
|
||||
project = client.post("/api/projects", json={"name": "Export Project"}).json()
|
||||
project = client.post("/api/projects", json={
|
||||
"name": "Export Project",
|
||||
"video_path": "uploads/1/clip.mp4",
|
||||
}).json()
|
||||
frame = client.post(f"/api/projects/{project['id']}/frames", json={
|
||||
"project_id": project["id"],
|
||||
"frame_index": 0,
|
||||
"image_url": "frames/0.jpg",
|
||||
"width": 100,
|
||||
"height": 50,
|
||||
"timestamp_ms": 1250.0,
|
||||
"source_frame_number": 37,
|
||||
}).json()
|
||||
template = client.post("/api/templates", json={
|
||||
"name": "Category",
|
||||
@@ -113,6 +125,328 @@ def test_export_masks_uses_z_index_for_semantic_fusion(client):
|
||||
assert semantic[10, 10] == high_value
|
||||
|
||||
|
||||
def test_export_results_zip_contains_coco_original_images_and_selected_mask_outputs(client, monkeypatch):
|
||||
project, _, _, annotation = _seed_export_data(client)
|
||||
monkeypatch.setattr("routers.export.download_file", lambda object_name: _fake_image_bytes())
|
||||
|
||||
response = client.get(f"/api/export/{project['id']}/results?scope=all&mask_type=both")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.headers["content-type"].startswith("application/zip")
|
||||
with zipfile.ZipFile(BytesIO(response.content)) as archive:
|
||||
names = archive.namelist()
|
||||
frame_stem = "clip_0h00m01s250ms_frame000001"
|
||||
assert "annotations_coco.json" in names
|
||||
assert "maskid_GT像素值_类别映射.json" in names
|
||||
assert f"原始图片/{frame_stem}.jpg" in names
|
||||
assert f"分开Mask分割结果/{frame_stem}_分别导出/{frame_stem}_Category_maskid1.png" in names
|
||||
assert f"GT_label图/{frame_stem}.png" in names
|
||||
assert f"Pro_label彩色分割结果/{frame_stem}.png" in names
|
||||
assert f"Mix_label重叠覆盖彩色分割结果/{frame_stem}.png" in names
|
||||
coco = json.loads(archive.read("annotations_coco.json"))
|
||||
mapping = json.loads(archive.read("maskid_GT像素值_类别映射.json"))
|
||||
label_bytes = np.frombuffer(archive.read(f"GT_label图/{frame_stem}.png"), dtype=np.uint8)
|
||||
gt_label = cv2.imdecode(label_bytes, cv2.IMREAD_UNCHANGED)
|
||||
pro_label = cv2.imdecode(
|
||||
np.frombuffer(archive.read(f"Pro_label彩色分割结果/{frame_stem}.png"), dtype=np.uint8),
|
||||
cv2.IMREAD_COLOR,
|
||||
)
|
||||
mix_label = cv2.imdecode(
|
||||
np.frombuffer(archive.read(f"Mix_label重叠覆盖彩色分割结果/{frame_stem}.png"), dtype=np.uint8),
|
||||
cv2.IMREAD_COLOR,
|
||||
)
|
||||
|
||||
assert coco["images"][0]["frame_index"] == 0
|
||||
assert coco["annotations"][0]["image_id"] == annotation["frame_id"]
|
||||
assert mapping["classes"] == [{
|
||||
"gt_pixel_value": 1,
|
||||
"maskid": 1,
|
||||
"chineseName": "Category",
|
||||
"className": "Category",
|
||||
"categoryName": "Category",
|
||||
"rgb": [6, 182, 212],
|
||||
"color": "#06b6d4",
|
||||
"key": f"template:{annotation['template_id']}",
|
||||
"template_id": annotation["template_id"],
|
||||
}]
|
||||
assert gt_label[0, 0] == 0
|
||||
assert gt_label[20, 50] == 1
|
||||
assert pro_label[20, 50].tolist() == [212, 182, 6]
|
||||
assert pro_label[0, 0].tolist() == [0, 0, 0]
|
||||
assert mix_label[20, 50].tolist() != [255, 255, 255]
|
||||
|
||||
|
||||
def test_export_results_uses_internal_layer_order_for_gt_pro_and_mix_outputs(client, monkeypatch):
|
||||
monkeypatch.setattr("routers.export.download_file", lambda object_name: _fake_image_bytes(20, 20))
|
||||
project = client.post("/api/projects", json={
|
||||
"name": "Layered Export Project",
|
||||
"video_path": "uploads/2/layered.mp4",
|
||||
}).json()
|
||||
frame = client.post(f"/api/projects/{project['id']}/frames", json={
|
||||
"project_id": project["id"],
|
||||
"frame_index": 0,
|
||||
"image_url": "frames/layered.jpg",
|
||||
"width": 20,
|
||||
"height": 20,
|
||||
"timestamp_ms": 0,
|
||||
"source_frame_number": 0,
|
||||
}).json()
|
||||
client.post("/api/ai/annotate", json={
|
||||
"project_id": project["id"],
|
||||
"frame_id": frame["id"],
|
||||
"mask_data": {
|
||||
"polygons": [[[0.1, 0.1], [0.8, 0.1], [0.8, 0.8], [0.1, 0.8]]],
|
||||
"label": "Low",
|
||||
"color": "#00ff00",
|
||||
"class": {"id": "low", "name": "Low", "color": "#00ff00", "zIndex": 10},
|
||||
},
|
||||
})
|
||||
client.post("/api/ai/annotate", json={
|
||||
"project_id": project["id"],
|
||||
"frame_id": frame["id"],
|
||||
"mask_data": {
|
||||
"polygons": [[[0.4, 0.4], [0.9, 0.4], [0.9, 0.9], [0.4, 0.9]]],
|
||||
"label": "High",
|
||||
"color": "#ff0000",
|
||||
"class": {"id": "high", "name": "High", "color": "#ff0000", "zIndex": 20},
|
||||
},
|
||||
})
|
||||
|
||||
response = client.get(
|
||||
f"/api/export/{project['id']}/results?scope=all&outputs=gt_label,pro_label,mix_label&mix_opacity=0.5",
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
with zipfile.ZipFile(BytesIO(response.content)) as archive:
|
||||
mapping = json.loads(archive.read("maskid_GT像素值_类别映射.json"))
|
||||
high_value = next(item["maskid"] for item in mapping["classes"] if item["key"] == "class:high")
|
||||
stem = "layered_0h00m00s000ms_frame000001"
|
||||
gt_label = cv2.imdecode(
|
||||
np.frombuffer(archive.read(f"GT_label图/{stem}.png"), dtype=np.uint8),
|
||||
cv2.IMREAD_UNCHANGED,
|
||||
)
|
||||
pro_label = cv2.imdecode(
|
||||
np.frombuffer(archive.read(f"Pro_label彩色分割结果/{stem}.png"), dtype=np.uint8),
|
||||
cv2.IMREAD_COLOR,
|
||||
)
|
||||
mix_label = cv2.imdecode(
|
||||
np.frombuffer(archive.read(f"Mix_label重叠覆盖彩色分割结果/{stem}.png"), dtype=np.uint8),
|
||||
cv2.IMREAD_COLOR,
|
||||
)
|
||||
|
||||
assert gt_label[10, 10] == high_value
|
||||
assert pro_label[10, 10].tolist() == [0, 0, 255]
|
||||
assert mix_label[10, 10].tolist() == [127, 127, 255]
|
||||
|
||||
|
||||
def test_export_results_supports_range_and_current_scope(client, monkeypatch):
|
||||
monkeypatch.setattr("routers.export.download_file", lambda object_name: _fake_image_bytes(20, 20))
|
||||
project = client.post("/api/projects", json={
|
||||
"name": "Scoped Export Project",
|
||||
"video_path": "uploads/9/scope.mp4",
|
||||
"parse_fps": 2,
|
||||
}).json()
|
||||
template = client.post("/api/templates", json={
|
||||
"name": "Scoped Category",
|
||||
"color": "#06b6d4",
|
||||
"z_index": 0,
|
||||
"classes": [],
|
||||
"rules": [],
|
||||
}).json()
|
||||
frames = []
|
||||
annotations = []
|
||||
for idx in range(3):
|
||||
frame = client.post(f"/api/projects/{project['id']}/frames", json={
|
||||
"project_id": project["id"],
|
||||
"frame_index": idx,
|
||||
"image_url": f"frames/{idx}.jpg",
|
||||
"width": 20,
|
||||
"height": 20,
|
||||
"timestamp_ms": idx * 500.0,
|
||||
"source_frame_number": idx * 10,
|
||||
}).json()
|
||||
frames.append(frame)
|
||||
annotations.append(client.post("/api/ai/annotate", json={
|
||||
"project_id": project["id"],
|
||||
"frame_id": frame["id"],
|
||||
"template_id": template["id"],
|
||||
"mask_data": {"polygons": [[[0.1, 0.1], [0.8, 0.1], [0.8, 0.8], [0.1, 0.8]]]},
|
||||
}).json())
|
||||
|
||||
range_response = client.get(
|
||||
f"/api/export/{project['id']}/results?scope=range&start_frame=2&end_frame=3&mask_type=gt_label",
|
||||
)
|
||||
current_response = client.get(
|
||||
f"/api/export/{project['id']}/results?scope=current&frame_id={frames[1]['id']}&mask_type=separate",
|
||||
)
|
||||
|
||||
assert range_response.status_code == 200
|
||||
assert "Scoped_Export_Project_seg_T_0h00m00s500ms-0h00m01s000ms_P_2-3.zip" in unquote(
|
||||
range_response.headers["content-disposition"],
|
||||
)
|
||||
with zipfile.ZipFile(BytesIO(range_response.content)) as archive:
|
||||
names = archive.namelist()
|
||||
coco = json.loads(archive.read("annotations_coco.json"))
|
||||
assert "原始图片/scope_0h00m00s500ms_frame000002.jpg" in names
|
||||
assert "原始图片/scope_0h00m01s000ms_frame000003.jpg" in names
|
||||
assert "原始图片/scope_0h00m00s000ms_frame000001.jpg" not in names
|
||||
assert "GT_label图/scope_0h00m00s500ms_frame000002.png" in names
|
||||
assert "GT_label图/scope_0h00m01s000ms_frame000003.png" in names
|
||||
assert "GT_label图/scope_0h00m00s000ms_frame000001.png" not in names
|
||||
assert not any(name.startswith("分开Mask分割结果/") for name in names)
|
||||
assert not any(name.startswith("Pro_label彩色分割结果/") for name in names)
|
||||
assert not any(name.startswith("Mix_label重叠覆盖彩色分割结果/") for name in names)
|
||||
assert [image["frame_index"] for image in coco["images"]] == [1, 2]
|
||||
|
||||
assert current_response.status_code == 200
|
||||
with zipfile.ZipFile(BytesIO(current_response.content)) as archive:
|
||||
names = archive.namelist()
|
||||
coco = json.loads(archive.read("annotations_coco.json"))
|
||||
current_stem = "scope_0h00m00s500ms_frame000002"
|
||||
assert f"原始图片/{current_stem}.jpg" in names
|
||||
assert f"分开Mask分割结果/{current_stem}_分别导出/{current_stem}_Scoped_Category_maskid1.png" in names
|
||||
assert f"分开Mask分割结果/scope_0h00m00s000ms_frame000001_分别导出/scope_0h00m00s000ms_frame000001_Scoped_Category_maskid1.png" not in names
|
||||
assert not any(name.startswith("GT_label图/") for name in names)
|
||||
assert not any(name.startswith("Pro_label彩色分割结果/") for name in names)
|
||||
assert not any(name.startswith("Mix_label重叠覆盖彩色分割结果/") for name in names)
|
||||
assert [image["id"] for image in coco["images"]] == [frames[1]["id"]]
|
||||
|
||||
|
||||
def test_export_results_preserves_template_maskid_consistently_across_frames(client, monkeypatch):
|
||||
monkeypatch.setattr("routers.export.download_file", lambda object_name: _fake_image_bytes(20, 20))
|
||||
project = client.post("/api/projects", json={
|
||||
"name": "MaskId Export Project",
|
||||
"video_path": "uploads/8/maskid-demo.mp4",
|
||||
"parse_fps": 1,
|
||||
}).json()
|
||||
frames = []
|
||||
for idx in range(2):
|
||||
frames.append(client.post(f"/api/projects/{project['id']}/frames", json={
|
||||
"project_id": project["id"],
|
||||
"frame_index": idx,
|
||||
"image_url": f"frames/{idx}.jpg",
|
||||
"width": 20,
|
||||
"height": 20,
|
||||
"timestamp_ms": idx * 1000.0,
|
||||
"source_frame_number": idx,
|
||||
}).json())
|
||||
client.post("/api/ai/annotate", json={
|
||||
"project_id": project["id"],
|
||||
"frame_id": frames[-1]["id"],
|
||||
"mask_data": {
|
||||
"polygons": [[[0.1, 0.1], [0.8, 0.1], [0.8, 0.8], [0.1, 0.8]]],
|
||||
"label": "Tumor",
|
||||
"color": "#ff0000",
|
||||
"class": {"id": "tumor", "name": "Tumor", "color": "#ff0000", "maskId": 7, "zIndex": 30},
|
||||
},
|
||||
})
|
||||
|
||||
response = client.get(f"/api/export/{project['id']}/results?scope=all&mask_type=both")
|
||||
|
||||
assert response.status_code == 200
|
||||
with zipfile.ZipFile(BytesIO(response.content)) as archive:
|
||||
names = archive.namelist()
|
||||
mapping = json.loads(archive.read("maskid_GT像素值_类别映射.json"))
|
||||
first_stem = "maskid-demo_0h00m00s000ms_frame000001"
|
||||
second_stem = "maskid-demo_0h00m01s000ms_frame000002"
|
||||
assert f"分开Mask分割结果/{first_stem}_分别导出/{first_stem}_Tumor_maskid7.png" in names
|
||||
assert f"分开Mask分割结果/{second_stem}_分别导出/{second_stem}_Tumor_maskid7.png" in names
|
||||
first_label = cv2.imdecode(np.frombuffer(archive.read(f"GT_label图/{first_stem}.png"), dtype=np.uint8), cv2.IMREAD_UNCHANGED)
|
||||
second_label = cv2.imdecode(np.frombuffer(archive.read(f"GT_label图/{second_stem}.png"), dtype=np.uint8), cv2.IMREAD_UNCHANGED)
|
||||
|
||||
assert mapping["classes"] == [{
|
||||
"gt_pixel_value": 7,
|
||||
"maskid": 7,
|
||||
"chineseName": "Tumor",
|
||||
"className": "Tumor",
|
||||
"categoryName": "",
|
||||
"rgb": [255, 0, 0],
|
||||
"color": "#ff0000",
|
||||
"key": "class:tumor",
|
||||
"template_id": None,
|
||||
}]
|
||||
assert first_label[5, 5] == 7
|
||||
assert second_label[5, 5] == 7
|
||||
|
||||
|
||||
def test_exported_gtlabel_round_trips_through_gt_mask_import_with_template_maskid(client, monkeypatch):
|
||||
monkeypatch.setattr("routers.export.download_file", lambda object_name: _fake_image_bytes(20, 20))
|
||||
project = client.post("/api/projects", json={
|
||||
"name": "GT Roundtrip Project",
|
||||
"video_path": "uploads/8/roundtrip.mp4",
|
||||
}).json()
|
||||
template = client.post("/api/templates", json={
|
||||
"name": "Roundtrip Template",
|
||||
"color": "#06b6d4",
|
||||
"z_index": 0,
|
||||
"classes": [
|
||||
{"id": "tumor", "name": "Tumor", "color": "#ff0000", "zIndex": 30, "maskId": 7},
|
||||
],
|
||||
"rules": [],
|
||||
}).json()
|
||||
source_frame = client.post(f"/api/projects/{project['id']}/frames", json={
|
||||
"project_id": project["id"],
|
||||
"frame_index": 0,
|
||||
"image_url": "frames/source.jpg",
|
||||
"width": 20,
|
||||
"height": 20,
|
||||
"timestamp_ms": 0,
|
||||
}).json()
|
||||
target_frame = client.post(f"/api/projects/{project['id']}/frames", json={
|
||||
"project_id": project["id"],
|
||||
"frame_index": 1,
|
||||
"image_url": "frames/target.jpg",
|
||||
"width": 20,
|
||||
"height": 20,
|
||||
"timestamp_ms": 1000,
|
||||
}).json()
|
||||
client.post("/api/ai/annotate", json={
|
||||
"project_id": project["id"],
|
||||
"frame_id": source_frame["id"],
|
||||
"template_id": template["id"],
|
||||
"mask_data": {
|
||||
"polygons": [[[0.1, 0.1], [0.8, 0.1], [0.8, 0.8], [0.1, 0.8]]],
|
||||
"label": "Tumor",
|
||||
"color": "#ff0000",
|
||||
"class": {"id": "tumor", "name": "Tumor", "color": "#ff0000", "maskId": 7, "zIndex": 30},
|
||||
},
|
||||
})
|
||||
|
||||
export_response = client.get(
|
||||
f"/api/export/{project['id']}/results?scope=current&frame_id={source_frame['id']}&outputs=gt_label",
|
||||
)
|
||||
|
||||
assert export_response.status_code == 200
|
||||
with zipfile.ZipFile(BytesIO(export_response.content)) as archive:
|
||||
stem = "roundtrip_0h00m00s000ms_frame000001"
|
||||
exported_gt_label = archive.read(f"GT_label图/{stem}.png")
|
||||
gt_label = cv2.imdecode(np.frombuffer(exported_gt_label, dtype=np.uint8), cv2.IMREAD_UNCHANGED)
|
||||
mapping = json.loads(archive.read("maskid_GT像素值_类别映射.json"))
|
||||
|
||||
assert gt_label[5, 5] == 7
|
||||
assert mapping["classes"][0]["maskid"] == 7
|
||||
|
||||
import_response = client.post(
|
||||
"/api/ai/import-gt-mask",
|
||||
data={
|
||||
"project_id": str(project["id"]),
|
||||
"frame_id": str(target_frame["id"]),
|
||||
"template_id": str(template["id"]),
|
||||
"unknown_color_policy": "discard",
|
||||
},
|
||||
files={"file": ("exported_gt_label.png", exported_gt_label, "image/png")},
|
||||
)
|
||||
|
||||
assert import_response.status_code == 201
|
||||
imported = import_response.json()
|
||||
assert len(imported) == 1
|
||||
assert imported[0]["frame_id"] == target_frame["id"]
|
||||
assert imported[0]["mask_data"]["gt_label_value"] == 7
|
||||
assert imported[0]["mask_data"]["label"] == "Tumor"
|
||||
assert imported[0]["mask_data"]["class"]["maskId"] == 7
|
||||
|
||||
|
||||
def test_export_missing_project_returns_404(client):
|
||||
assert client.get("/api/export/999/coco").status_code == 404
|
||||
assert client.get("/api/export/999/masks").status_code == 404
|
||||
assert client.get("/api/export/999/results").status_code == 404
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from models import Annotation, Frame, Mask, ProcessingTask, Project
|
||||
from models import Annotation, Frame, Mask, ProcessingTask, Project, User
|
||||
from routers.auth import create_access_token, hash_password
|
||||
|
||||
|
||||
def test_project_crud_and_frames(client, monkeypatch):
|
||||
@@ -93,3 +94,33 @@ def test_project_and_frame_404s(client):
|
||||
}).status_code == 404
|
||||
assert client.get("/api/projects/999/frames").status_code == 404
|
||||
assert client.get("/api/projects/999/frames/1").status_code == 404
|
||||
|
||||
|
||||
def test_projects_are_scoped_to_authenticated_owner(client, db_session):
|
||||
owner_project = client.post("/api/projects", json={"name": "Owner Project"}).json()
|
||||
other_user = User(
|
||||
username="other",
|
||||
password_hash=hash_password("pass"),
|
||||
role="annotator",
|
||||
is_active=1,
|
||||
)
|
||||
db_session.add(other_user)
|
||||
db_session.commit()
|
||||
db_session.refresh(other_user)
|
||||
other_project = Project(name="Other Project", owner_user_id=other_user.id)
|
||||
db_session.add(other_project)
|
||||
db_session.commit()
|
||||
db_session.refresh(other_project)
|
||||
|
||||
listing = client.get("/api/projects")
|
||||
assert [project["id"] for project in listing.json()] == [owner_project["id"]]
|
||||
assert client.get(f"/api/projects/{other_project.id}").status_code == 404
|
||||
|
||||
original_auth = client.headers["Authorization"]
|
||||
client.headers.update({"Authorization": f"Bearer {create_access_token(other_user)}"})
|
||||
try:
|
||||
other_listing = client.get("/api/projects")
|
||||
assert [project["id"] for project in other_listing.json()] == [other_project.id]
|
||||
assert client.get(f"/api/projects/{owner_project['id']}").status_code == 404
|
||||
finally:
|
||||
client.headers.update({"Authorization": original_auth})
|
||||
|
||||
Reference in New Issue
Block a user