feat: 打通全栈标注闭环、异步拆帧与模型状态
后端能力: - 新增 Celery app、worker task、ProcessingTask 模型、/api/tasks 查询接口和 media_task_runner,将 /api/media/parse 改为创建后台任务并由 worker 执行 FFmpeg/OpenCV/pydicom 拆帧。 - 新增 Redis 进度事件模块和 FastAPI Redis pub/sub 订阅,将 worker 任务进度广播到 /ws/progress;Dashboard 后端概览接口改为聚合 projects/frames/annotations/templates/processing_tasks。 - 统一项目状态为 pending/parsing/ready/error,新增共享 status 常量,并让前端兼容归一化旧状态值。 - 扩展 AI 后端:新增 SAM registry、SAM2 真实运行状态、SAM3 状态检测与文本语义推理适配入口,以及 /api/ai/models/status GPU/模型状态接口。 - 补齐标注保存/更新/删除、COCO/PNG mask 导出相关后端契约和模板 mapping_rules 打包/解包行为。 前端能力: - 新增运行时 API/WS 地址推导配置,前端 API 封装对齐 FastAPI 路由、字段映射、任务轮询、标注归档、导出下载和 AI 预测响应转换。 - Dashboard 改为读取 /api/dashboard/overview,并订阅 WebSocket progress/complete/error/status 更新解析队列和实时流转记录。 - 项目库导入视频/DICOM 后创建项目、上传媒体、触发异步解析并刷新真实项目列表。 - 工作区加载真实帧、无帧时触发解析任务、回显已保存标注、保存未归档 mask、更新 dirty mask、清空当前帧后端标注、导出 COCO JSON。 - Canvas 支持当前帧点/框提示调用后端 AI、渲染推理/已保存 mask、应用模板分类并维护保存状态计数;时间轴按项目 fps 播放。 - AI 页面新增 SAM2/SAM3 模型选择,预测请求携带 model;侧边栏和工作区新增真实 GPU/SAM 状态徽标。 - 模板库和本体面板接入真实模板 CRUD、分类编辑、拖拽排序、JSON 导入、默认腹腔镜分类和本地自定义分类选择。 测试与文档: - 新增 Vitest 配置、前端测试 setup、API/config/websocket/store/组件测试,覆盖登录、项目库、Dashboard、Canvas、工作区、模型状态、时间轴、本体和模板库。 - 新增 pytest 后端测试夹具和 auth/projects/templates/media/AI/export/dashboard/tasks/progress 测试,使用 SQLite、fake MinIO、fake SAM registry 和 Redis monkeypatch 隔离外部服务。 - 新增 doc/ 文档结构,冻结当前需求、设计、接口契约、测试计划、前端逐元素审计、实现地图和后续实施计划,并同步更新 README 与 AGENTS。 验证: - conda run -n seg_server pytest backend/tests:27 passed。 - npm run test:run:54 passed。 - npm run lint、npm run build、compileall、git diff --check 均通过;Vite 仅提示大 chunk 警告。
This commit is contained in:
72
backend/tests/conftest.py
Normal file
72
backend/tests/conftest.py
Normal file
@@ -0,0 +1,72 @@
|
||||
"""Shared pytest fixtures for backend API tests."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Iterator
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
from sqlalchemy.pool import StaticPool
|
||||
|
||||
BACKEND_DIR = Path(__file__).resolve().parents[1]
|
||||
if str(BACKEND_DIR) not in sys.path:
|
||||
sys.path.insert(0, str(BACKEND_DIR))
|
||||
|
||||
from database import Base, get_db # noqa: E402
|
||||
from main import websocket_progress # noqa: E402
|
||||
from routers import ai, auth, dashboard, export, media, projects, tasks, templates # noqa: E402
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def db_session() -> Iterator[Session]:
|
||||
engine = create_engine(
|
||||
"sqlite://",
|
||||
connect_args={"check_same_thread": False},
|
||||
poolclass=StaticPool,
|
||||
)
|
||||
TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
Base.metadata.create_all(bind=engine)
|
||||
session = TestingSessionLocal()
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
session.close()
|
||||
Base.metadata.drop_all(bind=engine)
|
||||
engine.dispose()
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def app(db_session: Session) -> FastAPI:
|
||||
test_app = FastAPI()
|
||||
|
||||
def override_get_db() -> Iterator[Session]:
|
||||
yield db_session
|
||||
|
||||
test_app.dependency_overrides[get_db] = override_get_db
|
||||
test_app.include_router(auth.router)
|
||||
test_app.include_router(projects.router)
|
||||
test_app.include_router(templates.router)
|
||||
test_app.include_router(media.router)
|
||||
test_app.include_router(ai.router)
|
||||
test_app.include_router(export.router)
|
||||
test_app.include_router(dashboard.router)
|
||||
test_app.include_router(tasks.router)
|
||||
|
||||
@test_app.get("/health")
|
||||
def health_check() -> dict[str, str]:
|
||||
return {"status": "ok", "service": "SegServer"}
|
||||
|
||||
test_app.add_api_websocket_route("/ws/progress", websocket_progress)
|
||||
|
||||
return test_app
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def client(app: FastAPI) -> Iterator[TestClient]:
|
||||
with TestClient(app) as test_client:
|
||||
yield test_client
|
||||
248
backend/tests/test_ai.py
Normal file
248
backend/tests/test_ai.py
Normal file
@@ -0,0 +1,248 @@
|
||||
import numpy as np
|
||||
|
||||
|
||||
def _create_project_and_frame(client):
|
||||
project = client.post("/api/projects", json={"name": "AI Project"}).json()
|
||||
frame = client.post(f"/api/projects/{project['id']}/frames", json={
|
||||
"project_id": project["id"],
|
||||
"frame_index": 0,
|
||||
"image_url": "frames/0.jpg",
|
||||
"width": 640,
|
||||
"height": 360,
|
||||
}).json()
|
||||
template = client.post("/api/templates", json={
|
||||
"name": "Template",
|
||||
"color": "#06b6d4",
|
||||
"z_index": 0,
|
||||
"classes": [],
|
||||
"rules": [],
|
||||
}).json()
|
||||
return project, frame, template
|
||||
|
||||
|
||||
def test_predict_accepts_point_object_with_labels(client, monkeypatch):
|
||||
_, frame, _ = _create_project_and_frame(client)
|
||||
calls = {}
|
||||
|
||||
monkeypatch.setattr("routers.ai._load_frame_image", lambda frame: np.zeros((10, 10, 3), dtype=np.uint8))
|
||||
|
||||
def fake_predict_points(image, points, labels):
|
||||
calls["args"] = (points, labels)
|
||||
return (
|
||||
[[[0.1, 0.1], [0.9, 0.1], [0.9, 0.9]]],
|
||||
[0.95],
|
||||
)
|
||||
|
||||
monkeypatch.setattr("routers.ai.sam_registry.predict_points", lambda model, image, points, labels: fake_predict_points(image, points, labels))
|
||||
|
||||
response = client.post("/api/ai/predict", json={
|
||||
"image_id": frame["id"],
|
||||
"prompt_type": "point",
|
||||
"prompt_data": {"points": [[0.5, 0.5], [0.1, 0.1]], "labels": [1, 0]},
|
||||
})
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["scores"] == [0.95]
|
||||
assert calls["args"] == ([[0.5, 0.5], [0.1, 0.1]], [1, 0])
|
||||
|
||||
|
||||
def test_predict_box_and_semantic_fallback(client, monkeypatch):
|
||||
_, frame, _ = _create_project_and_frame(client)
|
||||
monkeypatch.setattr("routers.ai._load_frame_image", lambda frame: np.zeros((10, 10, 3), dtype=np.uint8))
|
||||
monkeypatch.setattr("routers.ai.sam_registry.predict_box", lambda model, image, box: (
|
||||
[[[0.2, 0.2], [0.8, 0.2], [0.8, 0.8]]],
|
||||
[0.8],
|
||||
))
|
||||
monkeypatch.setattr("routers.ai.sam_registry.predict_semantic", lambda model, image, text: (
|
||||
[[[0.0, 0.0], [1.0, 0.0], [1.0, 1.0]]],
|
||||
[0.5],
|
||||
))
|
||||
|
||||
box_response = client.post("/api/ai/predict", json={
|
||||
"image_id": frame["id"],
|
||||
"prompt_type": "box",
|
||||
"prompt_data": [0.2, 0.2, 0.8, 0.8],
|
||||
})
|
||||
semantic_response = client.post("/api/ai/predict", json={
|
||||
"image_id": frame["id"],
|
||||
"prompt_type": "semantic",
|
||||
"prompt_data": "胆囊",
|
||||
})
|
||||
|
||||
assert box_response.status_code == 200
|
||||
assert box_response.json()["scores"] == [0.8]
|
||||
assert semantic_response.status_code == 200
|
||||
assert semantic_response.json()["scores"] == [0.5]
|
||||
|
||||
|
||||
def test_model_status_reports_runtime(client, monkeypatch):
|
||||
monkeypatch.setattr("routers.ai.sam_registry.runtime_status", lambda selected_model=None: {
|
||||
"selected_model": selected_model or "sam2",
|
||||
"gpu": {
|
||||
"available": False,
|
||||
"device": "cpu",
|
||||
"name": None,
|
||||
"torch_available": True,
|
||||
"torch_version": "2.x",
|
||||
"cuda_version": None,
|
||||
},
|
||||
"models": [
|
||||
{
|
||||
"id": "sam2",
|
||||
"label": "SAM 2",
|
||||
"available": True,
|
||||
"loaded": False,
|
||||
"device": "cpu",
|
||||
"supports": ["point", "box", "auto"],
|
||||
"message": "ready",
|
||||
"package_available": True,
|
||||
"checkpoint_exists": True,
|
||||
"checkpoint_path": "model.pt",
|
||||
"python_ok": True,
|
||||
"torch_ok": True,
|
||||
"cuda_required": False,
|
||||
},
|
||||
{
|
||||
"id": "sam3",
|
||||
"label": "SAM 3",
|
||||
"available": False,
|
||||
"loaded": False,
|
||||
"device": "unavailable",
|
||||
"supports": ["semantic"],
|
||||
"message": "missing Python 3.12+ runtime",
|
||||
"package_available": False,
|
||||
"checkpoint_exists": False,
|
||||
"checkpoint_path": None,
|
||||
"python_ok": False,
|
||||
"torch_ok": True,
|
||||
"cuda_required": True,
|
||||
},
|
||||
],
|
||||
})
|
||||
|
||||
response = client.get("/api/ai/models/status?selected_model=sam3")
|
||||
|
||||
assert response.status_code == 200
|
||||
body = response.json()
|
||||
assert body["selected_model"] == "sam3"
|
||||
assert body["models"][1]["id"] == "sam3"
|
||||
assert body["models"][1]["available"] is False
|
||||
|
||||
|
||||
def test_predict_validation_errors(client, monkeypatch):
|
||||
project, _, _ = _create_project_and_frame(client)
|
||||
|
||||
assert client.post("/api/ai/predict", json={
|
||||
"image_id": 999,
|
||||
"prompt_type": "point",
|
||||
"prompt_data": [[0.5, 0.5]],
|
||||
}).status_code == 404
|
||||
|
||||
frame = client.post(f"/api/projects/{project['id']}/frames", json={
|
||||
"project_id": project["id"],
|
||||
"frame_index": 1,
|
||||
"image_url": "frames/1.jpg",
|
||||
}).json()
|
||||
monkeypatch.setattr("routers.ai._load_frame_image", lambda frame: np.zeros((10, 10, 3), dtype=np.uint8))
|
||||
assert client.post("/api/ai/predict", json={
|
||||
"image_id": frame["id"],
|
||||
"prompt_type": "box",
|
||||
"prompt_data": [0.1, 0.2],
|
||||
}).status_code == 400
|
||||
|
||||
|
||||
def test_save_annotation_validates_project_and_frame(client):
|
||||
project, frame, template = _create_project_and_frame(client)
|
||||
|
||||
saved = client.post("/api/ai/annotate", json={
|
||||
"project_id": project["id"],
|
||||
"frame_id": frame["id"],
|
||||
"template_id": template["id"],
|
||||
"mask_data": {"polygons": [[[0.1, 0.1], [0.9, 0.1], [0.9, 0.9]]]},
|
||||
"points": [[0.5, 0.5]],
|
||||
"bbox": [0.1, 0.1, 0.8, 0.8],
|
||||
})
|
||||
assert saved.status_code == 201
|
||||
assert saved.json()["project_id"] == project["id"]
|
||||
|
||||
listing = client.get(f"/api/ai/annotations?project_id={project['id']}")
|
||||
assert listing.status_code == 200
|
||||
assert listing.json()[0]["id"] == saved.json()["id"]
|
||||
|
||||
frame_listing = client.get(f"/api/ai/annotations?project_id={project['id']}&frame_id={frame['id']}")
|
||||
assert frame_listing.status_code == 200
|
||||
assert len(frame_listing.json()) == 1
|
||||
|
||||
missing_project = client.post("/api/ai/annotate", json={"project_id": 999})
|
||||
assert missing_project.status_code == 404
|
||||
|
||||
missing_frame = client.post("/api/ai/annotate", json={
|
||||
"project_id": project["id"],
|
||||
"frame_id": 999,
|
||||
})
|
||||
assert missing_frame.status_code == 404
|
||||
|
||||
missing_project_list = client.get("/api/ai/annotations?project_id=999")
|
||||
assert missing_project_list.status_code == 404
|
||||
|
||||
|
||||
def test_update_and_delete_annotation(client):
|
||||
project, frame, template = _create_project_and_frame(client)
|
||||
saved = client.post("/api/ai/annotate", json={
|
||||
"project_id": project["id"],
|
||||
"frame_id": frame["id"],
|
||||
"template_id": template["id"],
|
||||
"mask_data": {
|
||||
"polygons": [[[0.1, 0.1], [0.9, 0.1], [0.9, 0.9]]],
|
||||
"label": "AI Mask",
|
||||
"color": "#06b6d4",
|
||||
},
|
||||
"points": [[0.5, 0.5]],
|
||||
"bbox": [0.1, 0.1, 0.8, 0.8],
|
||||
}).json()
|
||||
|
||||
updated = client.patch(f"/api/ai/annotations/{saved['id']}", json={
|
||||
"template_id": template["id"],
|
||||
"mask_data": {
|
||||
"polygons": [[[0.2, 0.2], [0.8, 0.2], [0.8, 0.8]]],
|
||||
"label": "胆囊",
|
||||
"color": "#ff0000",
|
||||
"class": {"id": "c1", "name": "胆囊", "color": "#ff0000", "zIndex": 20},
|
||||
},
|
||||
"points": [[0.4, 0.4]],
|
||||
"bbox": [0.2, 0.2, 0.6, 0.6],
|
||||
})
|
||||
|
||||
assert updated.status_code == 200
|
||||
body = updated.json()
|
||||
assert body["mask_data"]["label"] == "胆囊"
|
||||
assert body["mask_data"]["class"]["id"] == "c1"
|
||||
assert body["points"] == [[0.4, 0.4]]
|
||||
assert body["bbox"] == [0.2, 0.2, 0.6, 0.6]
|
||||
|
||||
listing = client.get(f"/api/ai/annotations?project_id={project['id']}")
|
||||
assert listing.status_code == 200
|
||||
assert listing.json()[0]["mask_data"]["class"]["name"] == "胆囊"
|
||||
|
||||
deleted = client.delete(f"/api/ai/annotations/{saved['id']}")
|
||||
assert deleted.status_code == 204
|
||||
|
||||
empty_listing = client.get(f"/api/ai/annotations?project_id={project['id']}")
|
||||
assert empty_listing.status_code == 200
|
||||
assert empty_listing.json() == []
|
||||
|
||||
|
||||
def test_update_and_delete_annotation_validation(client):
|
||||
project, frame, template = _create_project_and_frame(client)
|
||||
saved = client.post("/api/ai/annotate", json={
|
||||
"project_id": project["id"],
|
||||
"frame_id": frame["id"],
|
||||
"template_id": template["id"],
|
||||
}).json()
|
||||
|
||||
assert client.patch("/api/ai/annotations/999", json={"bbox": [0, 0, 1, 1]}).status_code == 404
|
||||
assert client.delete("/api/ai/annotations/999").status_code == 404
|
||||
assert client.patch(
|
||||
f"/api/ai/annotations/{saved['id']}",
|
||||
json={"template_id": 999},
|
||||
).status_code == 404
|
||||
15
backend/tests/test_auth.py
Normal file
15
backend/tests/test_auth.py
Normal file
@@ -0,0 +1,15 @@
|
||||
def test_login_success(client):
|
||||
response = client.post("/api/auth/login", json={"username": "admin", "password": "123456"})
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {
|
||||
"token": "fake-jwt-token-for-admin",
|
||||
"username": "admin",
|
||||
}
|
||||
|
||||
|
||||
def test_login_rejects_invalid_credentials(client):
|
||||
response = client.post("/api/auth/login", json={"username": "admin", "password": "wrong"})
|
||||
|
||||
assert response.status_code == 401
|
||||
assert response.json()["detail"] == "Invalid credentials"
|
||||
69
backend/tests/test_dashboard.py
Normal file
69
backend/tests/test_dashboard.py
Normal file
@@ -0,0 +1,69 @@
|
||||
def test_dashboard_overview_uses_persisted_records(client, db_session):
|
||||
from models import ProcessingTask
|
||||
|
||||
project_pending = client.post("/api/projects", json={
|
||||
"name": "Pending Project",
|
||||
"status": "pending",
|
||||
}).json()
|
||||
project_ready = client.post("/api/projects", json={
|
||||
"name": "Ready Project",
|
||||
"status": "ready",
|
||||
}).json()
|
||||
frame = client.post(f"/api/projects/{project_pending['id']}/frames", json={
|
||||
"project_id": project_pending["id"],
|
||||
"frame_index": 0,
|
||||
"image_url": "frames/0.jpg",
|
||||
"width": 640,
|
||||
"height": 360,
|
||||
}).json()
|
||||
template = client.post("/api/templates", json={
|
||||
"name": "Dashboard Template",
|
||||
"color": "#06b6d4",
|
||||
"z_index": 0,
|
||||
"classes": [],
|
||||
"rules": [],
|
||||
}).json()
|
||||
annotation = client.post("/api/ai/annotate", json={
|
||||
"project_id": project_pending["id"],
|
||||
"frame_id": frame["id"],
|
||||
"template_id": template["id"],
|
||||
"mask_data": {"polygons": [[[0.1, 0.1], [0.9, 0.1], [0.9, 0.9]]]},
|
||||
})
|
||||
assert annotation.status_code == 201
|
||||
task = ProcessingTask(
|
||||
task_type="parse_video",
|
||||
status="running",
|
||||
progress=35,
|
||||
message="正在使用 FFmpeg/OpenCV 拆帧",
|
||||
project_id=project_pending["id"],
|
||||
payload={"source_type": "video"},
|
||||
)
|
||||
db_session.add(task)
|
||||
db_session.commit()
|
||||
db_session.refresh(task)
|
||||
|
||||
response = client.get("/api/dashboard/overview")
|
||||
|
||||
assert response.status_code == 200
|
||||
body = response.json()
|
||||
assert body["summary"]["project_count"] == 2
|
||||
assert body["summary"]["frame_count"] == 1
|
||||
assert body["summary"]["annotation_count"] == 1
|
||||
assert body["summary"]["template_count"] == 1
|
||||
assert body["summary"]["parsing_task_count"] == 1
|
||||
assert body["tasks"] == [
|
||||
{
|
||||
"id": f"task-{task.id}",
|
||||
"task_id": task.id,
|
||||
"project_id": project_pending["id"],
|
||||
"name": "Pending Project",
|
||||
"progress": 35,
|
||||
"status": "正在使用 FFmpeg/OpenCV 拆帧",
|
||||
"frame_count": 0,
|
||||
"updated_at": body["tasks"][0]["updated_at"],
|
||||
},
|
||||
]
|
||||
assert any(item["kind"] == "task" for item in body["activity"])
|
||||
assert any(item["kind"] == "annotation" for item in body["activity"])
|
||||
assert any(item["kind"] == "template" for item in body["activity"])
|
||||
assert all(item["name"] != "Ready Project" for item in body["tasks"])
|
||||
66
backend/tests/test_export.py
Normal file
66
backend/tests/test_export.py
Normal file
@@ -0,0 +1,66 @@
|
||||
import zipfile
|
||||
from io import BytesIO
|
||||
|
||||
|
||||
def _seed_export_data(client):
|
||||
project = client.post("/api/projects", json={"name": "Export Project"}).json()
|
||||
frame = client.post(f"/api/projects/{project['id']}/frames", json={
|
||||
"project_id": project["id"],
|
||||
"frame_index": 0,
|
||||
"image_url": "frames/0.jpg",
|
||||
"width": 100,
|
||||
"height": 50,
|
||||
}).json()
|
||||
template = client.post("/api/templates", json={
|
||||
"name": "Category",
|
||||
"color": "#06b6d4",
|
||||
"z_index": 0,
|
||||
"classes": [],
|
||||
"rules": [],
|
||||
}).json()
|
||||
annotation = client.post("/api/ai/annotate", json={
|
||||
"project_id": project["id"],
|
||||
"frame_id": frame["id"],
|
||||
"template_id": template["id"],
|
||||
"mask_data": {"polygons": [[[0.1, 0.2], [0.9, 0.2], [0.9, 0.8], [0.1, 0.8]]]},
|
||||
"points": [[0.5, 0.5]],
|
||||
"bbox": [0.1, 0.2, 0.8, 0.6],
|
||||
}).json()
|
||||
return project, frame, template, annotation
|
||||
|
||||
|
||||
def test_export_coco_json_structure(client):
|
||||
project, frame, _, _ = _seed_export_data(client)
|
||||
|
||||
response = client.get(f"/api/export/{project['id']}/coco")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.headers["content-type"].startswith("application/json")
|
||||
data = response.json()
|
||||
assert data["info"]["description"] == "Annotations for Export Project"
|
||||
assert data["images"][0] == {
|
||||
"id": frame["id"],
|
||||
"file_name": "frames/0.jpg",
|
||||
"width": 100,
|
||||
"height": 50,
|
||||
"frame_index": 0,
|
||||
}
|
||||
assert data["annotations"][0]["segmentation"] == [[10.0, 10.0, 90.0, 10.0, 90.0, 40.0, 10.0, 40.0]]
|
||||
assert data["annotations"][0]["bbox"] == [10.0, 10.0, 80.0, 30.000000000000004]
|
||||
assert data["categories"][0]["name"] == "Category"
|
||||
|
||||
|
||||
def test_export_masks_zip(client):
|
||||
project, _, _, annotation = _seed_export_data(client)
|
||||
|
||||
response = client.get(f"/api/export/{project['id']}/masks")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.headers["content-type"].startswith("application/zip")
|
||||
with zipfile.ZipFile(BytesIO(response.content)) as archive:
|
||||
assert archive.namelist() == [f"mask_{annotation['id']:06d}.png"]
|
||||
|
||||
|
||||
def test_export_missing_project_returns_404(client):
|
||||
assert client.get("/api/export/999/coco").status_code == 404
|
||||
assert client.get("/api/export/999/masks").status_code == 404
|
||||
15
backend/tests/test_main.py
Normal file
15
backend/tests/test_main.py
Normal file
@@ -0,0 +1,15 @@
|
||||
def test_health_endpoint(client):
|
||||
response = client.get("/health")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"status": "ok", "service": "SegServer"}
|
||||
|
||||
|
||||
def test_websocket_progress_heartbeat(client):
|
||||
with client.websocket_connect("/ws/progress") as websocket:
|
||||
websocket.send_text("ping")
|
||||
data = websocket.receive_json()
|
||||
|
||||
assert data["type"] == "status"
|
||||
assert data["status"] == "connected"
|
||||
assert data["message"] == "Progress stream active"
|
||||
142
backend/tests/test_media.py
Normal file
142
backend/tests/test_media.py
Normal file
@@ -0,0 +1,142 @@
|
||||
def test_upload_rejects_unsupported_file_type(client):
|
||||
response = client.post(
|
||||
"/api/media/upload",
|
||||
files={"file": ("notes.txt", b"text", "text/plain")},
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert "Unsupported file type" in response.json()["detail"]
|
||||
|
||||
|
||||
def test_upload_auto_creates_project(client, monkeypatch):
|
||||
uploaded = []
|
||||
monkeypatch.setattr("routers.media.upload_file", lambda object_name, data, content_type, length: uploaded.append(object_name))
|
||||
monkeypatch.setattr("routers.media.get_presigned_url", lambda object_name, expires=3600: f"http://storage/{object_name}")
|
||||
|
||||
response = client.post(
|
||||
"/api/media/upload",
|
||||
files={"file": ("clip.mp4", b"video", "video/mp4")},
|
||||
)
|
||||
|
||||
assert response.status_code == 201
|
||||
data = response.json()
|
||||
assert data["project_id"] is not None
|
||||
assert data["object_name"] == f"uploads/{data['project_id']}/clip.mp4"
|
||||
assert uploaded == ["uploads/general/clip.mp4", f"uploads/{data['project_id']}/clip.mp4"]
|
||||
|
||||
|
||||
def test_upload_links_existing_project(client, monkeypatch):
|
||||
project = client.post("/api/projects", json={"name": "Existing"}).json()
|
||||
monkeypatch.setattr("routers.media.upload_file", lambda *args, **kwargs: None)
|
||||
monkeypatch.setattr("routers.media.get_presigned_url", lambda object_name, expires=3600: f"http://storage/{object_name}")
|
||||
|
||||
response = client.post(
|
||||
"/api/media/upload",
|
||||
data={"project_id": str(project["id"])},
|
||||
files={"file": ("clip.mp4", b"video", "video/mp4")},
|
||||
)
|
||||
|
||||
assert response.status_code == 201
|
||||
detail = client.get(f"/api/projects/{project['id']}").json()
|
||||
assert detail["video_path"] == f"uploads/{project['id']}/clip.mp4"
|
||||
|
||||
|
||||
def test_upload_dicom_batch_filters_files_and_creates_project(client, monkeypatch):
|
||||
uploaded = []
|
||||
monkeypatch.setattr("routers.media.upload_file", lambda object_name, data, content_type, length: uploaded.append(object_name))
|
||||
|
||||
response = client.post(
|
||||
"/api/media/upload/dicom",
|
||||
files=[
|
||||
("files", ("a.dcm", b"dcm", "application/dicom")),
|
||||
("files", ("skip.txt", b"text", "text/plain")),
|
||||
],
|
||||
)
|
||||
|
||||
assert response.status_code == 201
|
||||
data = response.json()
|
||||
assert data["uploaded_count"] == 1
|
||||
assert uploaded == [f"uploads/{data['project_id']}/dicom/a.dcm"]
|
||||
|
||||
|
||||
def test_parse_media_queues_background_task(client, monkeypatch):
|
||||
project = client.post("/api/projects", json={
|
||||
"name": "Parse Me",
|
||||
"video_path": "uploads/1/clip.mp4",
|
||||
"source_type": "video",
|
||||
"parse_fps": 5,
|
||||
}).json()
|
||||
|
||||
class FakeAsyncResult:
|
||||
id = "celery-1"
|
||||
|
||||
queued = []
|
||||
monkeypatch.setattr("routers.media.parse_project_media.delay", lambda task_id: queued.append(task_id) or FakeAsyncResult())
|
||||
published = []
|
||||
monkeypatch.setattr("routers.media.publish_task_progress_event", lambda task: published.append(task.id))
|
||||
|
||||
response = client.post(f"/api/media/parse?project_id={project['id']}")
|
||||
|
||||
assert response.status_code == 202
|
||||
data = response.json()
|
||||
assert data["task_type"] == "parse_video"
|
||||
assert data["status"] == "queued"
|
||||
assert data["progress"] == 0
|
||||
assert data["project_id"] == project["id"]
|
||||
assert data["celery_task_id"] == "celery-1"
|
||||
assert queued == [data["id"]]
|
||||
assert published == [data["id"]]
|
||||
|
||||
detail = client.get(f"/api/tasks/{data['id']}")
|
||||
assert detail.status_code == 200
|
||||
assert detail.json()["status"] == "queued"
|
||||
project_detail = client.get(f"/api/projects/{project['id']}").json()
|
||||
assert project_detail["status"] == "parsing"
|
||||
|
||||
|
||||
def test_parse_task_runner_registers_frames(client, db_session, monkeypatch, tmp_path):
|
||||
from models import ProcessingTask
|
||||
from services.media_task_runner import run_parse_media_task
|
||||
|
||||
project = client.post("/api/projects", json={
|
||||
"name": "Parse Me",
|
||||
"video_path": "uploads/1/clip.mp4",
|
||||
"source_type": "video",
|
||||
"parse_fps": 5,
|
||||
}).json()
|
||||
task = ProcessingTask(
|
||||
task_type="parse_video",
|
||||
status="queued",
|
||||
progress=0,
|
||||
project_id=project["id"],
|
||||
payload={"source_type": "video"},
|
||||
)
|
||||
db_session.add(task)
|
||||
db_session.commit()
|
||||
db_session.refresh(task)
|
||||
frame_file = tmp_path / "frame_000001.jpg"
|
||||
frame_file.write_bytes(b"fake image")
|
||||
|
||||
monkeypatch.setattr("services.media_task_runner.download_file", lambda object_name: b"video")
|
||||
monkeypatch.setattr("services.media_task_runner.parse_video", lambda local_path, output_dir, fps: ([str(frame_file)], 25.0))
|
||||
monkeypatch.setattr("services.media_task_runner.extract_thumbnail", lambda local_path, thumbnail_path: open(thumbnail_path, "wb").write(b"thumb"))
|
||||
monkeypatch.setattr("services.media_task_runner.upload_file", lambda *args, **kwargs: None)
|
||||
monkeypatch.setattr("services.media_task_runner.upload_frames_to_minio", lambda frame_files, project_id: [f"projects/{project_id}/frames/frame_000001.jpg"])
|
||||
published = []
|
||||
monkeypatch.setattr(
|
||||
"services.media_task_runner.publish_task_progress_event",
|
||||
lambda event_task: published.append((event_task.status, event_task.progress, event_task.message)),
|
||||
)
|
||||
|
||||
result = run_parse_media_task(db_session, task.id)
|
||||
|
||||
assert result["frames_extracted"] == 1
|
||||
db_session.refresh(task)
|
||||
assert task.status == "success"
|
||||
assert task.progress == 100
|
||||
assert ("running", 5, "后台解析已启动") in published
|
||||
assert ("success", 100, "解析完成") in published
|
||||
project_detail = client.get(f"/api/projects/{project['id']}").json()
|
||||
assert project_detail["status"] == "ready"
|
||||
frames = client.get(f"/api/projects/{project['id']}/frames").json()
|
||||
assert "frame_000001.jpg" in frames[0]["image_url"]
|
||||
42
backend/tests/test_progress_events.py
Normal file
42
backend/tests/test_progress_events.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from types import SimpleNamespace
|
||||
|
||||
from progress_events import PROGRESS_CHANNEL, publish_progress_event, task_progress_payload
|
||||
|
||||
|
||||
def test_task_progress_payload_uses_dashboard_task_id_and_project_name():
|
||||
task = SimpleNamespace(
|
||||
id=12,
|
||||
project_id=7,
|
||||
project=SimpleNamespace(name="demo.mp4"),
|
||||
status="success",
|
||||
progress=100,
|
||||
message="解析完成",
|
||||
error=None,
|
||||
updated_at=None,
|
||||
)
|
||||
|
||||
payload = task_progress_payload(task)
|
||||
|
||||
assert payload["type"] == "complete"
|
||||
assert payload["taskId"] == "task-12"
|
||||
assert payload["task_id"] == 12
|
||||
assert payload["project_id"] == 7
|
||||
assert payload["filename"] == "demo.mp4"
|
||||
assert payload["projectName"] == "demo.mp4"
|
||||
assert payload["status"] == "解析完成"
|
||||
|
||||
|
||||
def test_publish_progress_event_writes_json_to_redis(monkeypatch):
|
||||
calls = []
|
||||
|
||||
class FakeRedis:
|
||||
def publish(self, channel, payload):
|
||||
calls.append((channel, payload))
|
||||
|
||||
monkeypatch.setattr("progress_events.get_redis_client", lambda: FakeRedis())
|
||||
|
||||
publish_progress_event({"type": "progress", "message": "正在下载媒体文件"})
|
||||
|
||||
assert calls
|
||||
assert calls[0][0] == PROGRESS_CHANNEL
|
||||
assert "正在下载媒体文件" in calls[0][1]
|
||||
56
backend/tests/test_projects.py
Normal file
56
backend/tests/test_projects.py
Normal file
@@ -0,0 +1,56 @@
|
||||
def test_project_crud_and_frames(client, monkeypatch):
|
||||
monkeypatch.setattr("routers.projects.get_presigned_url", lambda key, expires=3600: f"http://storage/{key}")
|
||||
|
||||
created = client.post("/api/projects", json={
|
||||
"name": "Demo",
|
||||
"description": "desc",
|
||||
"thumbnail_url": "thumb.jpg",
|
||||
"parse_fps": 12,
|
||||
})
|
||||
assert created.status_code == 201
|
||||
project_id = created.json()["id"]
|
||||
|
||||
frame = client.post(f"/api/projects/{project_id}/frames", json={
|
||||
"project_id": project_id,
|
||||
"frame_index": 0,
|
||||
"image_url": "frames/0.jpg",
|
||||
"width": 640,
|
||||
"height": 360,
|
||||
})
|
||||
assert frame.status_code == 201
|
||||
frame_id = frame.json()["id"]
|
||||
|
||||
listing = client.get("/api/projects")
|
||||
assert listing.status_code == 200
|
||||
assert listing.json()[0]["frame_count"] == 1
|
||||
assert listing.json()[0]["thumbnail_url"] == "http://storage/thumb.jpg"
|
||||
|
||||
frames = client.get(f"/api/projects/{project_id}/frames")
|
||||
assert frames.status_code == 200
|
||||
assert frames.json()[0]["image_url"] == "http://storage/frames/0.jpg"
|
||||
|
||||
single_frame = client.get(f"/api/projects/{project_id}/frames/{frame_id}")
|
||||
assert single_frame.status_code == 200
|
||||
assert single_frame.json()["frame_index"] == 0
|
||||
|
||||
updated = client.patch(f"/api/projects/{project_id}", json={"name": "Renamed", "status": "ready"})
|
||||
assert updated.status_code == 200
|
||||
assert updated.json()["name"] == "Renamed"
|
||||
assert updated.json()["status"] == "ready"
|
||||
|
||||
deleted = client.delete(f"/api/projects/{project_id}")
|
||||
assert deleted.status_code == 204
|
||||
assert client.get(f"/api/projects/{project_id}").status_code == 404
|
||||
|
||||
|
||||
def test_project_and_frame_404s(client):
|
||||
assert client.get("/api/projects/999").status_code == 404
|
||||
assert client.patch("/api/projects/999", json={"name": "x"}).status_code == 404
|
||||
assert client.delete("/api/projects/999").status_code == 404
|
||||
assert client.post("/api/projects/999/frames", json={
|
||||
"project_id": 999,
|
||||
"frame_index": 0,
|
||||
"image_url": "missing.jpg",
|
||||
}).status_code == 404
|
||||
assert client.get("/api/projects/999/frames").status_code == 404
|
||||
assert client.get("/api/projects/999/frames/1").status_code == 404
|
||||
39
backend/tests/test_templates.py
Normal file
39
backend/tests/test_templates.py
Normal file
@@ -0,0 +1,39 @@
|
||||
def test_template_crud_packs_and_unpacks_mapping_rules(client):
|
||||
payload = {
|
||||
"name": "Template",
|
||||
"color": "#06b6d4",
|
||||
"z_index": 0,
|
||||
"classes": [{"id": "c1", "name": "胆囊", "color": "#ff0000", "zIndex": 10}],
|
||||
"rules": [{"id": "r1", "name": "rule"}],
|
||||
}
|
||||
|
||||
created = client.post("/api/templates", json=payload)
|
||||
assert created.status_code == 201
|
||||
template_id = created.json()["id"]
|
||||
assert created.json()["classes"][0]["name"] == "胆囊"
|
||||
assert created.json()["rules"][0]["id"] == "r1"
|
||||
|
||||
listing = client.get("/api/templates")
|
||||
assert listing.status_code == 200
|
||||
assert listing.json()[0]["classes"][0]["name"] == "胆囊"
|
||||
|
||||
detail = client.get(f"/api/templates/{template_id}")
|
||||
assert detail.status_code == 200
|
||||
assert detail.json()["name"] == "Template"
|
||||
|
||||
updated = client.patch(f"/api/templates/{template_id}", json={
|
||||
"classes": [{"id": "c2", "name": "肝脏", "color": "#00ff00", "zIndex": 20}],
|
||||
"rules": [],
|
||||
})
|
||||
assert updated.status_code == 200
|
||||
assert updated.json()["classes"][0]["name"] == "肝脏"
|
||||
|
||||
deleted = client.delete(f"/api/templates/{template_id}")
|
||||
assert deleted.status_code == 204
|
||||
assert client.get(f"/api/templates/{template_id}").status_code == 404
|
||||
|
||||
|
||||
def test_template_404s(client):
|
||||
assert client.get("/api/templates/999").status_code == 404
|
||||
assert client.patch("/api/templates/999", json={"name": "x"}).status_code == 404
|
||||
assert client.delete("/api/templates/999").status_code == 404
|
||||
Reference in New Issue
Block a user