feat: 完善分割工作区导入导出与管理流程

- 新增基于 JWT 当前用户的登录恢复、角色权限、用户管理、审计日志和演示出厂重置后台接口与前端管理页。

- 重串 GT_label 导出和 GT Mask 导入逻辑:导出保留类别真实 maskid,导入仅接受灰度或 RGB 等通道 maskid 图,支持未知 maskid 策略、尺寸最近邻拉伸和导入预览。

- 统一分割结果导出体验:默认当前帧,按项目抽帧顺序和 XhXXmXXsXXXms 时间戳命名 ZIP 与图片,补齐 GT/Pro/Mix/分开 Mask 输出和映射 JSON。

- 调整工作区左侧工具栏:移除创建点/线段入口,新增画笔、橡皮擦及尺寸控制,并按绘制、布尔、导入/AI 工具分组分隔。

- 扩展 Canvas 编辑能力:画笔按语义分类绘制并可自动并入连通选中 mask,橡皮擦对选中区域扣除,优化布尔操作、选区、撤销重做和保存状态联动。

- 优化自动传播时间轴显示:同一蓝色系按传播新旧递进变暗,老传播记录达到阈值后统一旧记录色,并维护范围选择与清空后的历史显示。

- 将 AI 智能分割入口替换为更明确的 AI 元素图标,并同步侧栏、工作区和 AI 页面入口表现。

- 完善模板分类、maskid 工具函数、分类树联动、遮罩透明度、边缘平滑和传播链同步相关前端状态。

- 扩展后端项目、媒体、任务、Dashboard、模板和传播 runner 的用户隔离、任务控制、进度事件与兼容处理。

- 补充前后端测试,覆盖用户管理、GT_label 往返导入导出、GT Mask 校验和预览、画笔/橡皮擦、时间轴传播历史、导出范围、WebSocket 与 API 封装。

- 更新 AGENTS、README 和 doc 文档,记录当前接口契约、实现状态、测试计划、安装说明和 maskid/GT_label 规则。
This commit is contained in:
2026-05-03 03:52:32 +08:00
parent 4c1d3dba73
commit afcddfaeb9
62 changed files with 6572 additions and 849 deletions

View File

@@ -19,7 +19,7 @@ if str(BACKEND_DIR) not in sys.path:
from database import Base, get_db # noqa: E402
from main import websocket_progress # noqa: E402
from routers import ai, auth, dashboard, export, media, projects, tasks, templates # noqa: E402
from routers import admin, ai, auth, dashboard, export, media, projects, tasks, templates # noqa: E402
@pytest.fixture()
@@ -32,6 +32,7 @@ def db_session() -> Iterator[Session]:
TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base.metadata.create_all(bind=engine)
session = TestingSessionLocal()
auth.ensure_default_admin(session)
try:
yield session
finally:
@@ -56,6 +57,7 @@ def app(db_session: Session) -> FastAPI:
test_app.include_router(export.router)
test_app.include_router(dashboard.router)
test_app.include_router(tasks.router)
test_app.include_router(admin.router)
@test_app.get("/health")
def health_check() -> dict[str, str]:
@@ -67,6 +69,10 @@ def app(db_session: Session) -> FastAPI:
@pytest.fixture()
def client(app: FastAPI) -> Iterator[TestClient]:
def client(app: FastAPI, db_session: Session) -> Iterator[TestClient]:
with TestClient(app) as test_client:
admin = auth.ensure_default_admin(db_session)
test_client.headers.update({
"Authorization": f"Bearer {auth.create_access_token(admin)}"
})
yield test_client

158
backend/tests/test_admin.py Normal file
View File

@@ -0,0 +1,158 @@
from models import Annotation, AuditLog, Frame, Mask, ProcessingTask, Project, Template, User
from routers.auth import create_access_token, hash_password
from statuses import PROJECT_STATUS_PENDING
def test_admin_user_management_and_audit_logs(client, db_session):
created = client.post("/api/admin/users", json={
"username": "doctor",
"password": "secret123",
"role": "annotator",
"is_active": True,
})
assert created.status_code == 201
user_id = created.json()["id"]
updated = client.patch(f"/api/admin/users/{user_id}", json={
"role": "viewer",
"password": "newsecret",
"is_active": False,
})
assert updated.status_code == 200
assert updated.json()["role"] == "viewer"
assert updated.json()["is_active"] == 0
users = client.get("/api/admin/users")
assert users.status_code == 200
assert any(user["username"] == "doctor" for user in users.json())
deleted = client.delete(f"/api/admin/users/{user_id}")
assert deleted.status_code == 204
logs = client.get("/api/admin/audit-logs")
assert logs.status_code == 200
actions = [log["action"] for log in logs.json()]
assert "admin.user_created" in actions
assert "admin.user_updated" in actions
assert "admin.user_deleted" in actions
def test_admin_routes_require_admin_role(client, db_session):
user = User(username="viewer", password_hash=hash_password("secret123"), role="viewer", is_active=1)
db_session.add(user)
db_session.commit()
db_session.refresh(user)
original_auth = client.headers["Authorization"]
client.headers.update({"Authorization": f"Bearer {create_access_token(user)}"})
try:
response = client.get("/api/admin/users")
assert response.status_code == 403
finally:
client.headers.update({"Authorization": original_auth})
def test_viewer_role_is_read_only_for_business_mutations(client, db_session):
project = client.post("/api/projects", json={"name": "Readonly Check"}).json()
user = User(username="readonly", password_hash=hash_password("secret123"), role="viewer", is_active=1)
db_session.add(user)
db_session.commit()
db_session.refresh(user)
original_auth = client.headers["Authorization"]
client.headers.update({"Authorization": f"Bearer {create_access_token(user)}"})
try:
assert client.get("/api/projects").status_code == 200
assert client.post("/api/projects", json={"name": "Nope"}).status_code == 403
assert client.patch(f"/api/projects/{project['id']}", json={"name": "Nope"}).status_code == 403
assert client.post("/api/ai/annotate", json={"project_id": project["id"]}).status_code == 403
finally:
client.headers.update({"Authorization": original_auth})
def test_admin_cannot_delete_self_or_user_with_projects(client, db_session):
me = client.get("/api/auth/me").json()
assert client.delete(f"/api/admin/users/{me['id']}").status_code == 400
user = User(username="owner", password_hash=hash_password("secret123"), role="annotator", is_active=1)
db_session.add(user)
db_session.commit()
db_session.refresh(user)
db_session.add(Project(name="Owned", owner_user_id=user.id))
db_session.commit()
response = client.delete(f"/api/admin/users/{user.id}")
assert response.status_code == 409
def test_demo_factory_reset_leaves_admin_and_unparsed_demo_video(client, db_session, monkeypatch, tmp_path):
video_path = tmp_path / "Data_MyVideo_1.mp4"
video_path.write_bytes(b"demo-video")
monkeypatch.setattr("routers.admin.settings.demo_video_path", str(video_path))
uploaded = []
monkeypatch.setattr("routers.admin.upload_file", lambda object_name, data, content_type, length: uploaded.append({
"object_name": object_name,
"data": data,
"content_type": content_type,
"length": length,
}))
extra_user = User(username="doctor", password_hash=hash_password("secret123"), role="annotator", is_active=1)
db_session.add(extra_user)
db_session.commit()
db_session.refresh(extra_user)
old_project = Project(name="Old", owner_user_id=extra_user.id, video_path="uploads/old.mp4")
db_session.add(old_project)
db_session.commit()
db_session.refresh(old_project)
frame = Frame(project_id=old_project.id, frame_index=0, image_url="frames/old.jpg")
db_session.add(frame)
task = ProcessingTask(task_type="parse_video", project_id=old_project.id)
private_template = Template(
name="Private",
description="private",
color="#fff",
z_index=1,
owner_user_id=extra_user.id,
)
db_session.add_all([task, private_template])
db_session.commit()
db_session.refresh(frame)
annotation = Annotation(project_id=old_project.id, frame_id=frame.id, mask_data={"label": "old"})
db_session.add(annotation)
db_session.commit()
db_session.refresh(annotation)
db_session.add(Mask(annotation_id=annotation.id, mask_url="masks/old.png"))
db_session.add(AuditLog(actor_user_id=extra_user.id, action="old.audit"))
db_session.commit()
response = client.post("/api/admin/demo-factory-reset", json={"confirmation": "RESET_DEMO_FACTORY"})
assert response.status_code == 200
data = response.json()
assert data["message"] == "演示环境已恢复出厂设置"
assert data["admin_user"]["username"] == "admin"
assert data["project"]["name"] == "Data_MyVideo_1"
assert data["project"]["status"] == PROJECT_STATUS_PENDING
assert data["project"]["frame_count"] == 0
assert data["project"]["video_path"] == f"uploads/{data['project']['id']}/Data_MyVideo_1.mp4"
assert uploaded == [{
"object_name": data["project"]["video_path"],
"data": b"demo-video",
"content_type": "video/mp4",
"length": len(b"demo-video"),
}]
assert [user.username for user in db_session.query(User).all()] == ["admin"]
assert db_session.query(Project).count() == 1
assert db_session.query(Frame).count() == 0
assert db_session.query(Annotation).count() == 0
assert db_session.query(Mask).count() == 0
assert db_session.query(ProcessingTask).count() == 0
assert db_session.query(Template).filter(Template.owner_user_id.is_not(None)).count() == 0
assert db_session.query(AuditLog).count() == 1
assert db_session.query(AuditLog).first().action == "admin.demo_factory_reset"
def test_demo_factory_reset_requires_exact_confirmation(client):
response = client.post("/api/admin/demo-factory-reset", json={"confirmation": "reset"})
assert response.status_code == 400

View File

@@ -223,6 +223,88 @@ def test_analyze_mask_returns_backend_geometry_properties(client):
assert body["message"] == "已从后端重新提取几何拓扑锚点"
def test_analyze_mask_reports_actual_polygon_anchor_count(client):
_, frame, _ = _create_project_and_frame(client)
polygon = [[0.1 + index * 0.005, 0.1 + (0.01 if index % 2 else 0)] for index in range(80)]
response = client.post("/api/ai/analyze-mask", json={
"frame_id": frame["id"],
"mask_data": {
"polygons": [polygon],
"label": "AI Mask",
"color": "#06b6d4",
},
"points": [[0.2, 0.2]],
})
assert response.status_code == 200
body = response.json()
assert body["topology_anchor_count"] == len(polygon)
assert len(body["topology_anchors"]) <= 64
def test_smooth_mask_simplifies_noisy_ai_polygon(client):
_, frame, _ = _create_project_and_frame(client)
polygon = []
for index in range(20):
polygon.append([0.1 + index * 0.02, 0.1 + (0.01 if index % 2 else 0)])
for index in range(20):
polygon.append([0.5 + (0.01 if index % 2 else 0), 0.1 + index * 0.02])
for index in range(20):
polygon.append([0.5 - index * 0.02, 0.5 + (0.01 if index % 2 else 0)])
for index in range(20):
polygon.append([0.1 + (0.01 if index % 2 else 0), 0.5 - index * 0.02])
response = client.post("/api/ai/smooth-mask", json={
"frame_id": frame["id"],
"mask_data": {
"polygons": [polygon],
"label": "AI Mask",
"color": "#06b6d4",
},
"strength": 80,
})
assert response.status_code == 200
body = response.json()
assert body["topology_anchor_count"] == len(body["polygons"][0])
assert len(body["polygons"][0]) < len(polygon)
def test_smooth_mask_uses_eased_strength_curve(client):
_, frame, _ = _create_project_and_frame(client)
polygon = []
for index in range(20):
polygon.append([0.1 + index * 0.02, 0.1 + (0.01 if index % 2 else 0)])
for index in range(20):
polygon.append([0.5 + (0.01 if index % 2 else 0), 0.1 + index * 0.02])
for index in range(20):
polygon.append([0.5 - index * 0.02, 0.5 + (0.01 if index % 2 else 0)])
for index in range(20):
polygon.append([0.1 + (0.01 if index % 2 else 0), 0.5 - index * 0.02])
def smoothed_count(strength: int) -> int:
response = client.post("/api/ai/smooth-mask", json={
"frame_id": frame["id"],
"mask_data": {
"polygons": [polygon],
"label": "AI Mask",
"color": "#06b6d4",
},
"strength": strength,
})
assert response.status_code == 200
return len(response.json()["polygons"][0])
low_count = smoothed_count(20)
mid_count = smoothed_count(70)
high_count = smoothed_count(95)
assert low_count <= len(polygon)
assert mid_count < low_count
assert high_count < mid_count
def test_smooth_mask_returns_backend_smoothed_geometry(client):
_, frame, _ = _create_project_and_frame(client)
@@ -311,6 +393,7 @@ def test_propagate_saves_tracked_annotations(client, monkeypatch):
"color": "#ff0000",
"class_metadata": {"id": "c1", "name": "胆囊", "color": "#ff0000", "zIndex": 20},
"template_id": None,
"smoothing": {"strength": 45, "method": "chaikin"},
},
})
@@ -327,6 +410,9 @@ def test_propagate_saves_tracked_annotations(client, monkeypatch):
assert saved["mask_data"]["source"] == "sam2.1_hiera_tiny_propagation"
assert saved["mask_data"]["class"]["name"] == "胆囊"
assert saved["mask_data"]["score"] == 0.8
assert saved["mask_data"]["geometry_smoothing"] == {"strength": 45.0, "method": "chaikin"}
assert saved["mask_data"]["polygons"][0] != [[0.15, 0.15], [0.25, 0.15], [0.25, 0.25]]
assert len(saved["mask_data"]["polygons"][0]) > 3
listing = client.get(f"/api/ai/annotations?project_id={project['id']}")
assert len(listing.json()) == 1
@@ -490,8 +576,10 @@ def test_propagation_task_runner_saves_annotations_and_progress(client, db_sessi
listing = client.get(f"/api/ai/annotations?project_id={project['id']}")
assert listing.json()[0]["frame_id"] == frames[1]["id"]
assert listing.json()[0]["mask_data"]["source"] == "sam2.1_hiera_tiny_propagation"
stored_polygon = listing.json()[0]["mask_data"]["polygons"][0]
assert listing.json()[0]["mask_data"]["geometry_smoothing"] == {"strength": 40.0, "method": "chaikin"}
assert len(listing.json()[0]["mask_data"]["polygons"][0]) > 3
assert stored_polygon != [[0.15, 0.15], [0.25, 0.15], [0.25, 0.25]]
assert len(stored_polygon) > 3
def test_propagation_task_runner_skips_unchanged_seed_and_replaces_changed_seed(client, db_session, monkeypatch):
@@ -1084,3 +1172,156 @@ def test_import_gt_mask_splits_label_values(client):
assert [item["mask_data"]["gt_label_value"] for item in body] == [1, 2]
assert [item["mask_data"]["label"] for item in body] == ["GT Class 1", "GT Class 2"]
assert all(len(item["points"]) == 1 for item in body)
def test_import_gt_mask_preserves_low_value_gtlabel_png(client):
project, frame, _ = _create_project_and_frame(client)
template = client.post("/api/templates", json={
"name": "GTLabel Template",
"color": "#06b6d4",
"z_index": 0,
"classes": [
{"id": "tumor", "name": "肿瘤", "color": "#ff0000", "zIndex": 10, "maskId": 1},
],
"rules": [],
}).json()
mask = np.zeros((360, 640), dtype=np.uint16)
cv2.rectangle(mask, (40, 40), (140, 140), 1, thickness=-1)
ok, encoded = cv2.imencode(".png", mask)
assert ok
response = client.post(
"/api/ai/import-gt-mask",
data={
"project_id": str(project["id"]),
"frame_id": str(frame["id"]),
"template_id": str(template["id"]),
"unknown_color_policy": "discard",
},
files={"file": ("GT_label.png", encoded.tobytes(), "image/png")},
)
assert response.status_code == 201
body = response.json()
assert len(body) == 1
assert body[0]["mask_data"]["gt_label_value"] == 1
assert body[0]["mask_data"]["class"]["name"] == "肿瘤"
assert body[0]["mask_data"]["class"]["maskId"] == 1
def test_import_gt_mask_rejects_rgb_color_masks(client):
project, frame, _ = _create_project_and_frame(client)
template = client.post("/api/templates", json={
"name": "Color Template",
"color": "#06b6d4",
"z_index": 0,
"classes": [
{"id": "known", "name": "已知类别", "color": "#ff0000", "zIndex": 10, "maskId": 1},
],
"rules": [],
}).json()
mask = np.zeros((80, 120, 3), dtype=np.uint8)
mask[10:40, 10:40] = [0, 0, 255] # BGR red -> #ff0000
mask[40:70, 70:110] = [0, 255, 0] # BGR green -> unknown #00ff00
ok, encoded = cv2.imencode(".png", mask)
assert ok
response = client.post(
"/api/ai/import-gt-mask",
data={
"project_id": str(project["id"]),
"frame_id": str(frame["id"]),
"template_id": str(template["id"]),
"unknown_color_policy": "discard",
},
files={"file": ("color-mask.png", encoded.tobytes(), "image/png")},
)
assert response.status_code == 400
assert "RGB 三通道完全相同" in response.json()["detail"]
def test_import_gt_mask_reads_uint16_gt_label_and_maps_maskid_class(client):
project, frame, _ = _create_project_and_frame(client)
template = client.post("/api/templates", json={
"name": "Label Template",
"color": "#06b6d4",
"z_index": 0,
"classes": [{"id": "tumor", "name": "肿瘤", "color": "#ff0000", "zIndex": 10, "maskId": 1}],
"rules": [],
}).json()
mask = np.zeros((360, 640), dtype=np.uint16)
cv2.rectangle(mask, (20, 20), (120, 120), 1, thickness=-1)
ok, encoded = cv2.imencode(".png", mask)
assert ok
response = client.post(
"/api/ai/import-gt-mask",
data={
"project_id": str(project["id"]),
"frame_id": str(frame["id"]),
"template_id": str(template["id"]),
"unknown_color_policy": "discard",
},
files={"file": ("gt_label.png", encoded.tobytes(), "image/png")},
)
assert response.status_code == 201
body = response.json()
assert len(body) == 1
assert body[0]["mask_data"]["gt_label_value"] == 1
assert body[0]["mask_data"]["label"] == "肿瘤"
assert body[0]["mask_data"]["class"]["maskId"] == 1
assert body[0]["mask_data"]["class"]["color"] == "#ff0000"
def test_import_gt_mask_handles_unknown_maskid_policy_and_resizes_to_frame(client):
project, frame, _ = _create_project_and_frame(client)
template = client.post("/api/templates", json={
"name": "Color Template",
"color": "#06b6d4",
"z_index": 0,
"classes": [{"id": "known", "name": "已定义", "color": "#ff0000", "zIndex": 10, "maskId": 1}],
"rules": [],
}).json()
mask = np.zeros((90, 160, 3), dtype=np.uint8)
cv2.rectangle(mask, (5, 5), (40, 40), (1, 1, 1), thickness=-1)
cv2.rectangle(mask, (80, 5), (120, 40), (2, 2, 2), thickness=-1)
ok, encoded = cv2.imencode(".png", mask)
assert ok
discard_response = client.post(
"/api/ai/import-gt-mask",
data={
"project_id": str(project["id"]),
"frame_id": str(frame["id"]),
"template_id": str(template["id"]),
"unknown_color_policy": "discard",
},
files={"file": ("colors.png", encoded.tobytes(), "image/png")},
)
assert discard_response.status_code == 201
assert [item["mask_data"]["label"] for item in discard_response.json()] == ["已定义"]
assert discard_response.json()[0]["mask_data"]["gt_original_size"] == {"width": 160, "height": 90}
assert discard_response.json()[0]["mask_data"]["gt_resized_to_frame"] is True
assert discard_response.json()[0]["mask_data"]["image_size"] == {"width": 640, "height": 360}
undefined_response = client.post(
"/api/ai/import-gt-mask",
data={
"project_id": str(project["id"]),
"frame_id": str(frame["id"]),
"template_id": str(template["id"]),
"unknown_color_policy": "undefined",
},
files={"file": ("colors.png", encoded.tobytes(), "image/png")},
)
assert undefined_response.status_code == 201
labels = {item["mask_data"]["label"] for item in undefined_response.json()}
assert labels == {"已定义", "未定义类别 2"}
unknown = next(item for item in undefined_response.json() if item["mask_data"]["label"].startswith("未定义"))
assert unknown["mask_data"]["gt_unknown_class"] is True
assert unknown["mask_data"]["gt_label_value"] == 2
assert unknown["mask_data"]["gt_resized_to_frame"] is True

View File

@@ -2,10 +2,11 @@ def test_login_success(client):
response = client.post("/api/auth/login", json={"username": "admin", "password": "123456"})
assert response.status_code == 200
assert response.json() == {
"token": "fake-jwt-token-for-admin",
"username": "admin",
}
body = response.json()
assert body["token"]
assert body["token_type"] == "bearer"
assert body["username"] == "admin"
assert body["user"]["username"] == "admin"
def test_login_rejects_invalid_credentials(client):
@@ -13,3 +14,19 @@ def test_login_rejects_invalid_credentials(client):
assert response.status_code == 401
assert response.json()["detail"] == "Invalid credentials"
def test_me_returns_current_user(client):
response = client.get("/api/auth/me")
assert response.status_code == 200
assert response.json()["username"] == "admin"
def test_business_routes_require_auth(app):
from fastapi.testclient import TestClient
with TestClient(app) as unauthenticated:
response = unauthenticated.get("/api/projects")
assert response.status_code == 401

View File

@@ -1,19 +1,31 @@
import zipfile
import json
from io import BytesIO
from urllib.parse import unquote
import cv2
import numpy as np
def _fake_image_bytes(width=100, height=50, color=(255, 255, 255)):
image = np.full((height, width, 3), color, dtype=np.uint8)
_, encoded = cv2.imencode(".jpg", image)
return encoded.tobytes()
def _seed_export_data(client):
project = client.post("/api/projects", json={"name": "Export Project"}).json()
project = client.post("/api/projects", json={
"name": "Export Project",
"video_path": "uploads/1/clip.mp4",
}).json()
frame = client.post(f"/api/projects/{project['id']}/frames", json={
"project_id": project["id"],
"frame_index": 0,
"image_url": "frames/0.jpg",
"width": 100,
"height": 50,
"timestamp_ms": 1250.0,
"source_frame_number": 37,
}).json()
template = client.post("/api/templates", json={
"name": "Category",
@@ -113,6 +125,328 @@ def test_export_masks_uses_z_index_for_semantic_fusion(client):
assert semantic[10, 10] == high_value
def test_export_results_zip_contains_coco_original_images_and_selected_mask_outputs(client, monkeypatch):
project, _, _, annotation = _seed_export_data(client)
monkeypatch.setattr("routers.export.download_file", lambda object_name: _fake_image_bytes())
response = client.get(f"/api/export/{project['id']}/results?scope=all&mask_type=both")
assert response.status_code == 200
assert response.headers["content-type"].startswith("application/zip")
with zipfile.ZipFile(BytesIO(response.content)) as archive:
names = archive.namelist()
frame_stem = "clip_0h00m01s250ms_frame000001"
assert "annotations_coco.json" in names
assert "maskid_GT像素值_类别映射.json" in names
assert f"原始图片/{frame_stem}.jpg" in names
assert f"分开Mask分割结果/{frame_stem}_分别导出/{frame_stem}_Category_maskid1.png" in names
assert f"GT_label图/{frame_stem}.png" in names
assert f"Pro_label彩色分割结果/{frame_stem}.png" in names
assert f"Mix_label重叠覆盖彩色分割结果/{frame_stem}.png" in names
coco = json.loads(archive.read("annotations_coco.json"))
mapping = json.loads(archive.read("maskid_GT像素值_类别映射.json"))
label_bytes = np.frombuffer(archive.read(f"GT_label图/{frame_stem}.png"), dtype=np.uint8)
gt_label = cv2.imdecode(label_bytes, cv2.IMREAD_UNCHANGED)
pro_label = cv2.imdecode(
np.frombuffer(archive.read(f"Pro_label彩色分割结果/{frame_stem}.png"), dtype=np.uint8),
cv2.IMREAD_COLOR,
)
mix_label = cv2.imdecode(
np.frombuffer(archive.read(f"Mix_label重叠覆盖彩色分割结果/{frame_stem}.png"), dtype=np.uint8),
cv2.IMREAD_COLOR,
)
assert coco["images"][0]["frame_index"] == 0
assert coco["annotations"][0]["image_id"] == annotation["frame_id"]
assert mapping["classes"] == [{
"gt_pixel_value": 1,
"maskid": 1,
"chineseName": "Category",
"className": "Category",
"categoryName": "Category",
"rgb": [6, 182, 212],
"color": "#06b6d4",
"key": f"template:{annotation['template_id']}",
"template_id": annotation["template_id"],
}]
assert gt_label[0, 0] == 0
assert gt_label[20, 50] == 1
assert pro_label[20, 50].tolist() == [212, 182, 6]
assert pro_label[0, 0].tolist() == [0, 0, 0]
assert mix_label[20, 50].tolist() != [255, 255, 255]
def test_export_results_uses_internal_layer_order_for_gt_pro_and_mix_outputs(client, monkeypatch):
monkeypatch.setattr("routers.export.download_file", lambda object_name: _fake_image_bytes(20, 20))
project = client.post("/api/projects", json={
"name": "Layered Export Project",
"video_path": "uploads/2/layered.mp4",
}).json()
frame = client.post(f"/api/projects/{project['id']}/frames", json={
"project_id": project["id"],
"frame_index": 0,
"image_url": "frames/layered.jpg",
"width": 20,
"height": 20,
"timestamp_ms": 0,
"source_frame_number": 0,
}).json()
client.post("/api/ai/annotate", json={
"project_id": project["id"],
"frame_id": frame["id"],
"mask_data": {
"polygons": [[[0.1, 0.1], [0.8, 0.1], [0.8, 0.8], [0.1, 0.8]]],
"label": "Low",
"color": "#00ff00",
"class": {"id": "low", "name": "Low", "color": "#00ff00", "zIndex": 10},
},
})
client.post("/api/ai/annotate", json={
"project_id": project["id"],
"frame_id": frame["id"],
"mask_data": {
"polygons": [[[0.4, 0.4], [0.9, 0.4], [0.9, 0.9], [0.4, 0.9]]],
"label": "High",
"color": "#ff0000",
"class": {"id": "high", "name": "High", "color": "#ff0000", "zIndex": 20},
},
})
response = client.get(
f"/api/export/{project['id']}/results?scope=all&outputs=gt_label,pro_label,mix_label&mix_opacity=0.5",
)
assert response.status_code == 200
with zipfile.ZipFile(BytesIO(response.content)) as archive:
mapping = json.loads(archive.read("maskid_GT像素值_类别映射.json"))
high_value = next(item["maskid"] for item in mapping["classes"] if item["key"] == "class:high")
stem = "layered_0h00m00s000ms_frame000001"
gt_label = cv2.imdecode(
np.frombuffer(archive.read(f"GT_label图/{stem}.png"), dtype=np.uint8),
cv2.IMREAD_UNCHANGED,
)
pro_label = cv2.imdecode(
np.frombuffer(archive.read(f"Pro_label彩色分割结果/{stem}.png"), dtype=np.uint8),
cv2.IMREAD_COLOR,
)
mix_label = cv2.imdecode(
np.frombuffer(archive.read(f"Mix_label重叠覆盖彩色分割结果/{stem}.png"), dtype=np.uint8),
cv2.IMREAD_COLOR,
)
assert gt_label[10, 10] == high_value
assert pro_label[10, 10].tolist() == [0, 0, 255]
assert mix_label[10, 10].tolist() == [127, 127, 255]
def test_export_results_supports_range_and_current_scope(client, monkeypatch):
monkeypatch.setattr("routers.export.download_file", lambda object_name: _fake_image_bytes(20, 20))
project = client.post("/api/projects", json={
"name": "Scoped Export Project",
"video_path": "uploads/9/scope.mp4",
"parse_fps": 2,
}).json()
template = client.post("/api/templates", json={
"name": "Scoped Category",
"color": "#06b6d4",
"z_index": 0,
"classes": [],
"rules": [],
}).json()
frames = []
annotations = []
for idx in range(3):
frame = client.post(f"/api/projects/{project['id']}/frames", json={
"project_id": project["id"],
"frame_index": idx,
"image_url": f"frames/{idx}.jpg",
"width": 20,
"height": 20,
"timestamp_ms": idx * 500.0,
"source_frame_number": idx * 10,
}).json()
frames.append(frame)
annotations.append(client.post("/api/ai/annotate", json={
"project_id": project["id"],
"frame_id": frame["id"],
"template_id": template["id"],
"mask_data": {"polygons": [[[0.1, 0.1], [0.8, 0.1], [0.8, 0.8], [0.1, 0.8]]]},
}).json())
range_response = client.get(
f"/api/export/{project['id']}/results?scope=range&start_frame=2&end_frame=3&mask_type=gt_label",
)
current_response = client.get(
f"/api/export/{project['id']}/results?scope=current&frame_id={frames[1]['id']}&mask_type=separate",
)
assert range_response.status_code == 200
assert "Scoped_Export_Project_seg_T_0h00m00s500ms-0h00m01s000ms_P_2-3.zip" in unquote(
range_response.headers["content-disposition"],
)
with zipfile.ZipFile(BytesIO(range_response.content)) as archive:
names = archive.namelist()
coco = json.loads(archive.read("annotations_coco.json"))
assert "原始图片/scope_0h00m00s500ms_frame000002.jpg" in names
assert "原始图片/scope_0h00m01s000ms_frame000003.jpg" in names
assert "原始图片/scope_0h00m00s000ms_frame000001.jpg" not in names
assert "GT_label图/scope_0h00m00s500ms_frame000002.png" in names
assert "GT_label图/scope_0h00m01s000ms_frame000003.png" in names
assert "GT_label图/scope_0h00m00s000ms_frame000001.png" not in names
assert not any(name.startswith("分开Mask分割结果/") for name in names)
assert not any(name.startswith("Pro_label彩色分割结果/") for name in names)
assert not any(name.startswith("Mix_label重叠覆盖彩色分割结果/") for name in names)
assert [image["frame_index"] for image in coco["images"]] == [1, 2]
assert current_response.status_code == 200
with zipfile.ZipFile(BytesIO(current_response.content)) as archive:
names = archive.namelist()
coco = json.loads(archive.read("annotations_coco.json"))
current_stem = "scope_0h00m00s500ms_frame000002"
assert f"原始图片/{current_stem}.jpg" in names
assert f"分开Mask分割结果/{current_stem}_分别导出/{current_stem}_Scoped_Category_maskid1.png" in names
assert f"分开Mask分割结果/scope_0h00m00s000ms_frame000001_分别导出/scope_0h00m00s000ms_frame000001_Scoped_Category_maskid1.png" not in names
assert not any(name.startswith("GT_label图/") for name in names)
assert not any(name.startswith("Pro_label彩色分割结果/") for name in names)
assert not any(name.startswith("Mix_label重叠覆盖彩色分割结果/") for name in names)
assert [image["id"] for image in coco["images"]] == [frames[1]["id"]]
def test_export_results_preserves_template_maskid_consistently_across_frames(client, monkeypatch):
monkeypatch.setattr("routers.export.download_file", lambda object_name: _fake_image_bytes(20, 20))
project = client.post("/api/projects", json={
"name": "MaskId Export Project",
"video_path": "uploads/8/maskid-demo.mp4",
"parse_fps": 1,
}).json()
frames = []
for idx in range(2):
frames.append(client.post(f"/api/projects/{project['id']}/frames", json={
"project_id": project["id"],
"frame_index": idx,
"image_url": f"frames/{idx}.jpg",
"width": 20,
"height": 20,
"timestamp_ms": idx * 1000.0,
"source_frame_number": idx,
}).json())
client.post("/api/ai/annotate", json={
"project_id": project["id"],
"frame_id": frames[-1]["id"],
"mask_data": {
"polygons": [[[0.1, 0.1], [0.8, 0.1], [0.8, 0.8], [0.1, 0.8]]],
"label": "Tumor",
"color": "#ff0000",
"class": {"id": "tumor", "name": "Tumor", "color": "#ff0000", "maskId": 7, "zIndex": 30},
},
})
response = client.get(f"/api/export/{project['id']}/results?scope=all&mask_type=both")
assert response.status_code == 200
with zipfile.ZipFile(BytesIO(response.content)) as archive:
names = archive.namelist()
mapping = json.loads(archive.read("maskid_GT像素值_类别映射.json"))
first_stem = "maskid-demo_0h00m00s000ms_frame000001"
second_stem = "maskid-demo_0h00m01s000ms_frame000002"
assert f"分开Mask分割结果/{first_stem}_分别导出/{first_stem}_Tumor_maskid7.png" in names
assert f"分开Mask分割结果/{second_stem}_分别导出/{second_stem}_Tumor_maskid7.png" in names
first_label = cv2.imdecode(np.frombuffer(archive.read(f"GT_label图/{first_stem}.png"), dtype=np.uint8), cv2.IMREAD_UNCHANGED)
second_label = cv2.imdecode(np.frombuffer(archive.read(f"GT_label图/{second_stem}.png"), dtype=np.uint8), cv2.IMREAD_UNCHANGED)
assert mapping["classes"] == [{
"gt_pixel_value": 7,
"maskid": 7,
"chineseName": "Tumor",
"className": "Tumor",
"categoryName": "",
"rgb": [255, 0, 0],
"color": "#ff0000",
"key": "class:tumor",
"template_id": None,
}]
assert first_label[5, 5] == 7
assert second_label[5, 5] == 7
def test_exported_gtlabel_round_trips_through_gt_mask_import_with_template_maskid(client, monkeypatch):
monkeypatch.setattr("routers.export.download_file", lambda object_name: _fake_image_bytes(20, 20))
project = client.post("/api/projects", json={
"name": "GT Roundtrip Project",
"video_path": "uploads/8/roundtrip.mp4",
}).json()
template = client.post("/api/templates", json={
"name": "Roundtrip Template",
"color": "#06b6d4",
"z_index": 0,
"classes": [
{"id": "tumor", "name": "Tumor", "color": "#ff0000", "zIndex": 30, "maskId": 7},
],
"rules": [],
}).json()
source_frame = client.post(f"/api/projects/{project['id']}/frames", json={
"project_id": project["id"],
"frame_index": 0,
"image_url": "frames/source.jpg",
"width": 20,
"height": 20,
"timestamp_ms": 0,
}).json()
target_frame = client.post(f"/api/projects/{project['id']}/frames", json={
"project_id": project["id"],
"frame_index": 1,
"image_url": "frames/target.jpg",
"width": 20,
"height": 20,
"timestamp_ms": 1000,
}).json()
client.post("/api/ai/annotate", json={
"project_id": project["id"],
"frame_id": source_frame["id"],
"template_id": template["id"],
"mask_data": {
"polygons": [[[0.1, 0.1], [0.8, 0.1], [0.8, 0.8], [0.1, 0.8]]],
"label": "Tumor",
"color": "#ff0000",
"class": {"id": "tumor", "name": "Tumor", "color": "#ff0000", "maskId": 7, "zIndex": 30},
},
})
export_response = client.get(
f"/api/export/{project['id']}/results?scope=current&frame_id={source_frame['id']}&outputs=gt_label",
)
assert export_response.status_code == 200
with zipfile.ZipFile(BytesIO(export_response.content)) as archive:
stem = "roundtrip_0h00m00s000ms_frame000001"
exported_gt_label = archive.read(f"GT_label图/{stem}.png")
gt_label = cv2.imdecode(np.frombuffer(exported_gt_label, dtype=np.uint8), cv2.IMREAD_UNCHANGED)
mapping = json.loads(archive.read("maskid_GT像素值_类别映射.json"))
assert gt_label[5, 5] == 7
assert mapping["classes"][0]["maskid"] == 7
import_response = client.post(
"/api/ai/import-gt-mask",
data={
"project_id": str(project["id"]),
"frame_id": str(target_frame["id"]),
"template_id": str(template["id"]),
"unknown_color_policy": "discard",
},
files={"file": ("exported_gt_label.png", exported_gt_label, "image/png")},
)
assert import_response.status_code == 201
imported = import_response.json()
assert len(imported) == 1
assert imported[0]["frame_id"] == target_frame["id"]
assert imported[0]["mask_data"]["gt_label_value"] == 7
assert imported[0]["mask_data"]["label"] == "Tumor"
assert imported[0]["mask_data"]["class"]["maskId"] == 7
def test_export_missing_project_returns_404(client):
assert client.get("/api/export/999/coco").status_code == 404
assert client.get("/api/export/999/masks").status_code == 404
assert client.get("/api/export/999/results").status_code == 404

View File

@@ -1,4 +1,5 @@
from models import Annotation, Frame, Mask, ProcessingTask, Project
from models import Annotation, Frame, Mask, ProcessingTask, Project, User
from routers.auth import create_access_token, hash_password
def test_project_crud_and_frames(client, monkeypatch):
@@ -93,3 +94,33 @@ def test_project_and_frame_404s(client):
}).status_code == 404
assert client.get("/api/projects/999/frames").status_code == 404
assert client.get("/api/projects/999/frames/1").status_code == 404
def test_projects_are_scoped_to_authenticated_owner(client, db_session):
owner_project = client.post("/api/projects", json={"name": "Owner Project"}).json()
other_user = User(
username="other",
password_hash=hash_password("pass"),
role="annotator",
is_active=1,
)
db_session.add(other_user)
db_session.commit()
db_session.refresh(other_user)
other_project = Project(name="Other Project", owner_user_id=other_user.id)
db_session.add(other_project)
db_session.commit()
db_session.refresh(other_project)
listing = client.get("/api/projects")
assert [project["id"] for project in listing.json()] == [owner_project["id"]]
assert client.get(f"/api/projects/{other_project.id}").status_code == 404
original_auth = client.headers["Authorization"]
client.headers.update({"Authorization": f"Bearer {create_access_token(other_user)}"})
try:
other_listing = client.get("/api/projects")
assert [project["id"] for project in other_listing.json()] == [other_project.id]
assert client.get(f"/api/projects/{owner_project['id']}").status_code == 404
finally:
client.headers.update({"Authorization": original_auth})