完善项目导入、模板与分割工作区交互
- 增强 DICOM/视频项目导入与演示数据:DICOM 按文件名自然顺序处理,导入后展示上传与解析任务进度,恢复演示出厂设置保留演示视频和演示 DICOM 项目,并补充 demo media seed 逻辑。 - 完善项目管理:项目支持重命名、删除、复制,删除使用站内确认弹窗,复制支持新项目重置和全内容复制,DICOM 项目不显示生成帧入口。 - 完善 GT Mask 与导出链路:只支持 8-bit maskid 图导入,非法/全背景图明确拒绝,尺寸自动适配,高精度 polygon 回显;统一导出默认当前帧,GT_label 使用 uint8 和真实 maskid,待分类 maskid 0 与背景一致。 - 完善分割工作区交互:新增画笔和橡皮擦并支持尺寸控制,移除创建点/线段入口,工具栏按类别分隔,AI 智能分割使用明确 AI 图标,取消黄色 seed point,清空/删除传播 mask 后同步清理空帧时间轴状态。 - 完善传播与时间轴:自动传播使用 SAM 2.1 权重任务,参考帧无遮罩时提示,传播历史按同一蓝色系递进变暗,删除/清空传播链时保留人工或独立 AI 标注来源。 - 完善模板库:新增头颈部 CT 分割默认模板,所有模板保留 maskid 0 待分类,支持鼠标复制模板、拖拽层级、JSON 批量导入预览、删除 label 和站内删除确认。 - 完善用户与高风险确认:用户改密码、删除用户、恢复演示出厂设置和清空人工/AI 标注帧均改为站内确认交互,避免浏览器原生 prompt/confirm。 - 补充前后端测试与文档:更新项目、模板、GT 导入、导出、传播、DICOM、用户管理等测试,并同步 README、AGENTS 和 doc 下实现/契约/测试计划文档。
This commit is contained in:
@@ -39,6 +39,7 @@ class Settings(BaseSettings):
|
||||
default_admin_username: str = "admin"
|
||||
default_admin_password: str = "123456"
|
||||
demo_video_path: str = "/home/wkmgc/Desktop/Seg_Server/Data_MyVideo_1.mp4"
|
||||
demo_dicom_dir: str = "/home/wkmgc/Desktop/Seg_Server/2024_2_5_王芳/※2F458C45CFAA4C7CB76A39AA2BFE436B"
|
||||
|
||||
class Config:
|
||||
env_file = ".env"
|
||||
|
||||
280
backend/main.py
280
backend/main.py
@@ -4,8 +4,6 @@ import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
from contextlib import asynccontextmanager, suppress
|
||||
from datetime import datetime, timezone
|
||||
|
||||
@@ -15,10 +13,9 @@ from sqlalchemy import inspect, text
|
||||
|
||||
from config import settings
|
||||
from database import Base, engine, SessionLocal
|
||||
from minio_client import ensure_bucket_exists, upload_file
|
||||
from minio_client import ensure_bucket_exists
|
||||
from progress_events import PROGRESS_CHANNEL
|
||||
from redis_client import get_redis_client, ping as redis_ping
|
||||
from statuses import PROJECT_STATUS_PENDING, PROJECT_STATUS_READY
|
||||
|
||||
from routers import projects, templates, media, ai, export, auth, dashboard, tasks, admin
|
||||
|
||||
@@ -27,9 +24,24 @@ logging.basicConfig(
|
||||
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
RESERVED_UNCLASSIFIED_CLASS = {
|
||||
"id": "reserved-unclassified",
|
||||
"name": "待分类",
|
||||
"color": "#000000",
|
||||
"zIndex": 0,
|
||||
"maskId": 0,
|
||||
"category": "系统保留",
|
||||
}
|
||||
|
||||
DEFAULT_VIDEO_PATH = settings.demo_video_path
|
||||
|
||||
def _with_reserved_unclassified_class(classes: list[dict]) -> list[dict]:
|
||||
filtered = [
|
||||
item for item in classes
|
||||
if item.get("id") != RESERVED_UNCLASSIFIED_CLASS["id"]
|
||||
and item.get("name") != RESERVED_UNCLASSIFIED_CLASS["name"]
|
||||
and item.get("maskId") != 0
|
||||
]
|
||||
return [*filtered, dict(RESERVED_UNCLASSIFIED_CLASS)]
|
||||
|
||||
def _ensure_runtime_schema_columns() -> None:
|
||||
"""Add nullable columns introduced after initial create_all deployments."""
|
||||
@@ -72,92 +84,51 @@ def _seed_default_admin_and_ownership_sync() -> None:
|
||||
|
||||
|
||||
def _seed_default_project_sync() -> None:
|
||||
"""Synchronously seed the default video project on first startup."""
|
||||
import cv2
|
||||
from models import Project, Frame
|
||||
"""Synchronously seed the bundled demo video and DICOM projects on first startup."""
|
||||
from models import Project
|
||||
from routers.auth import ensure_default_admin
|
||||
from services.frame_parser import parse_video, upload_frames_to_minio, extract_thumbnail
|
||||
from services.demo_media import (
|
||||
DEMO_DICOM_PROJECT_NAME,
|
||||
DEMO_VIDEO_PROJECT_NAME,
|
||||
create_parsed_dicom_demo_project,
|
||||
create_unparsed_video_demo_project,
|
||||
demo_dicom_files,
|
||||
)
|
||||
|
||||
db = SessionLocal()
|
||||
try:
|
||||
admin = ensure_default_admin(db)
|
||||
existing = db.query(Project).filter(Project.name == "Data_MyVideo_1").first()
|
||||
if existing is not None:
|
||||
if existing.owner_user_id is None:
|
||||
existing.owner_user_id = admin.id
|
||||
existing_video = db.query(Project).filter(Project.name == DEMO_VIDEO_PROJECT_NAME).first()
|
||||
if existing_video is not None and existing_video.owner_user_id is None:
|
||||
existing_video.owner_user_id = admin.id
|
||||
db.commit()
|
||||
elif existing_video is None and os.path.exists(settings.demo_video_path):
|
||||
video_project = create_unparsed_video_demo_project(
|
||||
db,
|
||||
owner=admin,
|
||||
video_path=settings.demo_video_path,
|
||||
project_name=DEMO_VIDEO_PROJECT_NAME,
|
||||
)
|
||||
logger.info("Seeded default video project id=%s", video_project.id)
|
||||
|
||||
existing_dicom = db.query(Project).filter(Project.name == DEMO_DICOM_PROJECT_NAME).first()
|
||||
if existing_dicom is not None:
|
||||
if existing_dicom.owner_user_id is None:
|
||||
existing_dicom.owner_user_id = admin.id
|
||||
db.commit()
|
||||
return
|
||||
|
||||
if not os.path.exists(DEFAULT_VIDEO_PATH):
|
||||
logger.warning("Default video not found at %s", DEFAULT_VIDEO_PATH)
|
||||
if not demo_dicom_files(settings.demo_dicom_dir):
|
||||
logger.warning("Default DICOM series not found at %s", settings.demo_dicom_dir)
|
||||
return
|
||||
|
||||
project = Project(
|
||||
name="Data_MyVideo_1",
|
||||
description="默认演示视频",
|
||||
status=PROJECT_STATUS_PENDING,
|
||||
source_type="video",
|
||||
parse_fps=30.0,
|
||||
owner_user_id=admin.id,
|
||||
project = create_parsed_dicom_demo_project(
|
||||
db,
|
||||
owner=admin,
|
||||
dicom_dir=settings.demo_dicom_dir,
|
||||
project_name=DEMO_DICOM_PROJECT_NAME,
|
||||
)
|
||||
db.add(project)
|
||||
db.commit()
|
||||
db.refresh(project)
|
||||
|
||||
with open(DEFAULT_VIDEO_PATH, "rb") as f:
|
||||
data = f.read()
|
||||
object_name = f"uploads/{project.id}/Data_MyVideo_1.mp4"
|
||||
upload_file(object_name, data, content_type="video/mp4", length=len(data))
|
||||
|
||||
project.video_path = object_name
|
||||
db.commit()
|
||||
|
||||
# Parse frames
|
||||
tmp_dir = tempfile.mkdtemp(prefix=f"seg_seed_{project.id}_")
|
||||
try:
|
||||
local_path = os.path.join(tmp_dir, "video.mp4")
|
||||
with open(local_path, "wb") as f:
|
||||
f.write(data)
|
||||
output_dir = os.path.join(tmp_dir, "frames")
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
frame_files, original_fps = parse_video(local_path, output_dir, fps=30, max_frames=100)
|
||||
project.original_fps = original_fps
|
||||
|
||||
# Extract thumbnail
|
||||
thumbnail_path = os.path.join(tmp_dir, "thumbnail.jpg")
|
||||
try:
|
||||
extract_thumbnail(local_path, thumbnail_path)
|
||||
with open(thumbnail_path, "rb") as f:
|
||||
thumb_data = f.read()
|
||||
thumb_object = f"projects/{project.id}/thumbnail.jpg"
|
||||
upload_file(thumb_object, thumb_data, content_type="image/jpeg", length=len(thumb_data))
|
||||
project.thumbnail_url = thumb_object
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.warning("Thumbnail extraction failed: %s", exc)
|
||||
|
||||
object_names = upload_frames_to_minio(frame_files, project.id)
|
||||
|
||||
for idx, obj_name in enumerate(object_names):
|
||||
img = cv2.imread(frame_files[idx])
|
||||
h, w = img.shape[:2] if img is not None else (None, None)
|
||||
timestamp_ms = idx * 1000.0 / 30.0
|
||||
source_frame_number = int(round(idx * original_fps / 30.0)) if original_fps else None
|
||||
frame = Frame(
|
||||
project_id=project.id,
|
||||
frame_index=idx,
|
||||
image_url=obj_name,
|
||||
width=w,
|
||||
height=h,
|
||||
timestamp_ms=timestamp_ms,
|
||||
source_frame_number=source_frame_number,
|
||||
)
|
||||
db.add(frame)
|
||||
|
||||
project.status = PROJECT_STATUS_READY
|
||||
db.commit()
|
||||
logger.info("Seeded default project id=%s with %d frames", project.id, len(object_names))
|
||||
finally:
|
||||
shutil.rmtree(tmp_dir, ignore_errors=True)
|
||||
logger.info("Seeded default DICOM project id=%s with %d frames", project.id, len(project.frames))
|
||||
except Exception as exc:
|
||||
logger.error("Failed to seed default project: %s", exc)
|
||||
finally:
|
||||
@@ -170,53 +141,120 @@ def _seed_default_templates_sync() -> None:
|
||||
|
||||
db = SessionLocal()
|
||||
try:
|
||||
if db.query(Template).first() is not None:
|
||||
return
|
||||
|
||||
# Laparoscopic cholecystectomy template (35 classes)
|
||||
colors = [
|
||||
(134, 124, 118), (0, 157, 142), (245, 161, 0), (255, 172, 159), (146, 175, 236), (155, 62, 0),
|
||||
(255, 91, 0), (255, 234, 0), (85, 111, 181), (155, 132, 0), (181, 227, 14), (72, 0, 255),
|
||||
(255, 0, 255), (29, 32, 136), (240, 16, 116), (160, 15, 95), (0, 155, 33), (0, 160, 233),
|
||||
(52, 184, 178), (66, 115, 82), (90, 120, 41), (255, 0, 0), (117, 0, 0), (167, 24, 233),
|
||||
(42, 8, 66), (112, 113, 150), (0, 255, 0), (255, 255, 255), (0, 255, 255), (181, 85, 105),
|
||||
(113, 102, 140), (202, 202, 200), (197, 83, 181), (136, 162, 196), (138, 251, 213),
|
||||
]
|
||||
names = [
|
||||
'针', '线', '肿瘤', '血管阻断夹', '棉球', '双极电凝',
|
||||
'肝脏', '胆囊', '分离钳', '脂肪', '止血海绵', '肝总管',
|
||||
'吸引器', '剪刀', '超声刀', '止血纱布', '胆总管', '生物夹',
|
||||
'无损伤钳', '钳夹', '喷洒', '胆囊管', '动脉', '电凝',
|
||||
'静脉', '标本袋', '引流管', '纱布', '金属钛夹', '韧带',
|
||||
'肝蒂', '推结器', '乳胶管-血管阻断', '吻合器', '术中超声',
|
||||
]
|
||||
classes = []
|
||||
for idx, (rgb, name) in enumerate(zip(colors, names)):
|
||||
color_hex = f"#{rgb[0]:02x}{rgb[1]:02x}{rgb[2]:02x}"
|
||||
classes.append({
|
||||
"id": f"cls-lap-{idx}",
|
||||
"name": name,
|
||||
"color": color_hex,
|
||||
"zIndex": (len(names) - idx) * 10,
|
||||
"category": "腹腔镜胆囊切除术",
|
||||
})
|
||||
|
||||
template = Template(
|
||||
name="腹腔镜胆囊切除术",
|
||||
description="腹腔镜胆囊切除术(LC)手术器械与解剖结构语义分割模板,共35个分类",
|
||||
color="#06b6d4",
|
||||
z_index=0,
|
||||
mapping_rules={"classes": classes, "rules": []},
|
||||
)
|
||||
db.add(template)
|
||||
db.commit()
|
||||
logger.info("Seeded default template '腹腔镜胆囊切除术' with %d classes", len(classes))
|
||||
ensure_default_templates(db)
|
||||
except Exception as exc:
|
||||
logger.error("Failed to seed default templates: %s", exc)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
def _template_classes(
|
||||
template_name: str,
|
||||
names: list[str],
|
||||
colors: list[tuple[int, int, int]],
|
||||
*,
|
||||
id_prefix: str,
|
||||
) -> list[dict]:
|
||||
classes = []
|
||||
for idx, (rgb, name) in enumerate(zip(colors, names)):
|
||||
color_hex = f"#{rgb[0]:02x}{rgb[1]:02x}{rgb[2]:02x}"
|
||||
classes.append({
|
||||
"id": f"{id_prefix}-{idx}",
|
||||
"name": name,
|
||||
"color": color_hex,
|
||||
"zIndex": (len(names) - idx) * 10,
|
||||
"maskId": idx + 1,
|
||||
"category": template_name,
|
||||
})
|
||||
return classes
|
||||
|
||||
|
||||
def ensure_default_templates(db) -> None:
|
||||
"""Ensure all bundled system templates exist."""
|
||||
from models import Template
|
||||
|
||||
default_templates = [
|
||||
{
|
||||
"name": "腹腔镜胆囊切除术",
|
||||
"description": "腹腔镜胆囊切除术(LC)手术器械与解剖结构语义分割模板,共35个分类",
|
||||
"color": "#06b6d4",
|
||||
"z_index": 0,
|
||||
"classes": _with_reserved_unclassified_class(_template_classes(
|
||||
"腹腔镜胆囊切除术",
|
||||
[
|
||||
'针', '线', '肿瘤', '血管阻断夹', '棉球', '双极电凝',
|
||||
'肝脏', '胆囊', '分离钳', '脂肪', '止血海绵', '肝总管',
|
||||
'吸引器', '剪刀', '超声刀', '止血纱布', '胆总管', '生物夹',
|
||||
'无损伤钳', '钳夹', '喷洒', '胆囊管', '动脉', '电凝',
|
||||
'静脉', '标本袋', '引流管', '纱布', '金属钛夹', '韧带',
|
||||
'肝蒂', '推结器', '乳胶管-血管阻断', '吻合器', '术中超声',
|
||||
],
|
||||
[
|
||||
(134, 124, 118), (0, 157, 142), (245, 161, 0), (255, 172, 159), (146, 175, 236), (155, 62, 0),
|
||||
(255, 91, 0), (255, 234, 0), (85, 111, 181), (155, 132, 0), (181, 227, 14), (72, 0, 255),
|
||||
(255, 0, 255), (29, 32, 136), (240, 16, 116), (160, 15, 95), (0, 155, 33), (0, 160, 233),
|
||||
(52, 184, 178), (66, 115, 82), (90, 120, 41), (255, 0, 0), (117, 0, 0), (167, 24, 233),
|
||||
(42, 8, 66), (112, 113, 150), (0, 255, 0), (255, 255, 255), (0, 255, 255), (181, 85, 105),
|
||||
(113, 102, 140), (202, 202, 200), (197, 83, 181), (136, 162, 196), (138, 251, 213),
|
||||
],
|
||||
id_prefix="cls-lap",
|
||||
)),
|
||||
},
|
||||
{
|
||||
"name": "头颈部CT分割",
|
||||
"description": "头颈部CT分割",
|
||||
"color": "#ef4444",
|
||||
"z_index": 10,
|
||||
"classes": _with_reserved_unclassified_class(_template_classes(
|
||||
"头颈部CT分割",
|
||||
[
|
||||
"肿瘤/结节 (Tumor/Nodule)",
|
||||
"下颌骨 (Mandible)",
|
||||
"甲状腺 (Thyroid)",
|
||||
"气管 (Trachea)",
|
||||
"颈椎 (Cervical Spine)",
|
||||
"颈动脉 (Carotid Artery)",
|
||||
"颈静脉 (Jugular Vein)",
|
||||
"腮腺 (Parotid Gland)",
|
||||
"下颌下腺 (Submandibular Gland)",
|
||||
"舌骨 (Hyoid Bone)",
|
||||
],
|
||||
[
|
||||
(255, 0, 0),
|
||||
(0, 255, 0),
|
||||
(0, 0, 255),
|
||||
(255, 255, 0),
|
||||
(255, 0, 255),
|
||||
(0, 255, 255),
|
||||
(255, 128, 0),
|
||||
(128, 0, 128),
|
||||
(0, 128, 128),
|
||||
(128, 128, 0),
|
||||
],
|
||||
id_prefix="cls-head-neck-ct",
|
||||
)),
|
||||
},
|
||||
]
|
||||
|
||||
for definition in default_templates:
|
||||
existing = db.query(Template).filter(
|
||||
Template.name == definition["name"],
|
||||
Template.owner_user_id.is_(None),
|
||||
).first()
|
||||
if existing is not None:
|
||||
continue
|
||||
template = Template(
|
||||
name=definition["name"],
|
||||
description=definition["description"],
|
||||
color=definition["color"],
|
||||
z_index=definition["z_index"],
|
||||
mapping_rules={"classes": definition["classes"], "rules": []},
|
||||
owner_user_id=None,
|
||||
)
|
||||
db.add(template)
|
||||
logger.info("Seeded default template '%s' with %d classes", definition["name"], len(definition["classes"]))
|
||||
db.commit()
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""Application lifespan: startup and shutdown hooks."""
|
||||
|
||||
@@ -9,7 +9,6 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from config import settings
|
||||
from database import get_db
|
||||
from minio_client import upload_file
|
||||
from models import Annotation, AuditLog, Frame, Mask, ProcessingTask, Project, Template, User
|
||||
from routers.auth import ensure_default_admin, hash_password, require_admin, write_audit_log
|
||||
from schemas import (
|
||||
@@ -20,13 +19,19 @@ from schemas import (
|
||||
DemoFactoryResetRequest,
|
||||
UserOut,
|
||||
)
|
||||
from statuses import PROJECT_STATUS_PENDING
|
||||
from services.demo_media import (
|
||||
DEMO_DICOM_PROJECT_NAME,
|
||||
DEMO_VIDEO_PROJECT_NAME,
|
||||
create_parsed_dicom_demo_project,
|
||||
create_unparsed_video_demo_project,
|
||||
demo_dicom_files,
|
||||
)
|
||||
|
||||
router = APIRouter(prefix="/api/admin", tags=["Admin"])
|
||||
|
||||
VALID_ROLES = {"admin", "annotator", "viewer"}
|
||||
DEMO_RESET_CONFIRMATION = "RESET_DEMO_FACTORY"
|
||||
DEMO_PROJECT_NAME = "Data_MyVideo_1"
|
||||
DEMO_PROJECT_NAME = DEMO_DICOM_PROJECT_NAME
|
||||
|
||||
|
||||
def _normalize_role(role: str | None) -> str:
|
||||
@@ -191,7 +196,7 @@ def reset_demo_factory(
|
||||
db: Session = Depends(get_db),
|
||||
admin_user: User = Depends(require_admin),
|
||||
) -> dict:
|
||||
"""Reset a demo deployment to one admin account and one unparsed demo video project."""
|
||||
"""Reset a demo deployment to one admin account, the demo video, and the demo DICOM project."""
|
||||
if payload.confirmation != DEMO_RESET_CONFIRMATION:
|
||||
raise HTTPException(status_code=400, detail="Invalid reset confirmation")
|
||||
|
||||
@@ -200,6 +205,11 @@ def reset_demo_factory(
|
||||
status_code=409,
|
||||
detail=f"Demo video not found: {settings.demo_video_path}",
|
||||
)
|
||||
if not demo_dicom_files(settings.demo_dicom_dir):
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail=f"Demo DICOM series not found: {settings.demo_dicom_dir}",
|
||||
)
|
||||
|
||||
requested_by = admin_user.username
|
||||
preserved_admin = ensure_default_admin(db)
|
||||
@@ -226,37 +236,39 @@ def reset_demo_factory(
|
||||
if not preserved_admin:
|
||||
raise HTTPException(status_code=500, detail="Default admin was not preserved")
|
||||
|
||||
project = Project(
|
||||
name=DEMO_PROJECT_NAME,
|
||||
description="默认演示视频,尚未生成帧",
|
||||
status=PROJECT_STATUS_PENDING,
|
||||
source_type="video",
|
||||
parse_fps=30.0,
|
||||
owner_user_id=preserved_admin.id,
|
||||
video_project = create_unparsed_video_demo_project(
|
||||
db,
|
||||
owner=preserved_admin,
|
||||
video_path=settings.demo_video_path,
|
||||
project_name=DEMO_VIDEO_PROJECT_NAME,
|
||||
)
|
||||
db.add(project)
|
||||
db.flush()
|
||||
video_project.frame_count = 0
|
||||
|
||||
with open(settings.demo_video_path, "rb") as file_obj:
|
||||
data = file_obj.read()
|
||||
object_name = f"uploads/{project.id}/{os.path.basename(settings.demo_video_path)}"
|
||||
upload_file(object_name, data, content_type="video/mp4", length=len(data))
|
||||
project.video_path = object_name
|
||||
project.thumbnail_url = None
|
||||
project.original_fps = None
|
||||
db.commit()
|
||||
dicom_project = create_parsed_dicom_demo_project(
|
||||
db,
|
||||
owner=preserved_admin,
|
||||
dicom_dir=settings.demo_dicom_dir,
|
||||
project_name=DEMO_PROJECT_NAME,
|
||||
)
|
||||
db.refresh(preserved_admin)
|
||||
db.refresh(project)
|
||||
db.refresh(video_project)
|
||||
db.refresh(dicom_project)
|
||||
video_project.frame_count = len(video_project.frames)
|
||||
dicom_project.frame_count = len(dicom_project.frames)
|
||||
projects = [video_project, dicom_project]
|
||||
|
||||
write_audit_log(
|
||||
db,
|
||||
actor=preserved_admin,
|
||||
action="admin.demo_factory_reset",
|
||||
target_type="project",
|
||||
target_id=project.id,
|
||||
target_id=dicom_project.id,
|
||||
detail={
|
||||
"project_name": project.name,
|
||||
"video_path": project.video_path,
|
||||
"project_names": [project.name for project in projects],
|
||||
"video_path": video_project.video_path,
|
||||
"dicom_path": dicom_project.video_path,
|
||||
"source_types": [project.source_type for project in projects],
|
||||
"frame_counts": {project.name: len(project.frames) for project in projects},
|
||||
"deleted_counts": deleted_counts,
|
||||
"requested_by": requested_by,
|
||||
},
|
||||
@@ -264,7 +276,8 @@ def reset_demo_factory(
|
||||
|
||||
return {
|
||||
"admin_user": preserved_admin,
|
||||
"project": project,
|
||||
"project": dicom_project,
|
||||
"projects": projects,
|
||||
"deleted_counts": deleted_counts,
|
||||
"message": "演示环境已恢复出厂设置",
|
||||
}
|
||||
|
||||
@@ -39,6 +39,10 @@ from services.sam_registry import ModelUnavailableError, sam_registry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/api/ai", tags=["AI"])
|
||||
GT_MASK_EMPTY_DETAIL = "GT Mask 图片中没有非背景 maskid 区域。"
|
||||
GT_IMPORT_MAX_CONTOUR_POINTS = 2048
|
||||
GT_IMPORT_CONTOUR_EPSILON_RATIO = 0.00075
|
||||
GT_IMPORT_MIN_CONTOUR_EPSILON = 0.35
|
||||
|
||||
|
||||
def _owned_project_or_404(project_id: int, db: Session, current_user: User) -> Project:
|
||||
@@ -152,13 +156,19 @@ def _load_frame_image(frame: Frame) -> np.ndarray:
|
||||
|
||||
|
||||
def _normalized_contour(contour: np.ndarray, width: int, height: int) -> list[list[float]]:
|
||||
"""Approximate a contour and convert it to normalized polygon coordinates."""
|
||||
"""Convert a contour to a detailed normalized polygon with a point-count cap."""
|
||||
arc_length = cv2.arcLength(contour, True)
|
||||
epsilon = max(1.0, arc_length * 0.01)
|
||||
epsilon = max(GT_IMPORT_MIN_CONTOUR_EPSILON, arc_length * GT_IMPORT_CONTOUR_EPSILON_RATIO)
|
||||
approx = cv2.approxPolyDP(contour, epsilon, True)
|
||||
while len(approx) > GT_IMPORT_MAX_CONTOUR_POINTS and epsilon < arc_length * 0.02:
|
||||
epsilon *= 1.5
|
||||
approx = cv2.approxPolyDP(contour, epsilon, True)
|
||||
points = approx.reshape(-1, 2)
|
||||
if len(points) < 3:
|
||||
points = contour.reshape(-1, 2)
|
||||
if len(points) > GT_IMPORT_MAX_CONTOUR_POINTS:
|
||||
step = int(math.ceil(len(points) / GT_IMPORT_MAX_CONTOUR_POINTS))
|
||||
points = points[::step]
|
||||
return [
|
||||
[
|
||||
min(max(float(x) / max(width, 1), 0.0), 1.0),
|
||||
@@ -977,6 +987,13 @@ async def import_gt_mask(
|
||||
if image is None:
|
||||
raise HTTPException(status_code=400, detail="Invalid mask image")
|
||||
|
||||
invalid_format_detail = (
|
||||
"GT Mask 图片不符合要求:仅支持 8-bit 灰度图,或 8-bit RGB 三通道完全相同的 maskid 图"
|
||||
"(背景 0,像素值为 1-255 的 maskid)。"
|
||||
)
|
||||
if image.dtype != np.uint8:
|
||||
raise HTTPException(status_code=400, detail=invalid_format_detail)
|
||||
|
||||
if image.ndim == 2:
|
||||
label_image = image
|
||||
elif image.ndim == 3 and image.shape[2] >= 3:
|
||||
@@ -984,16 +1001,10 @@ async def import_gt_mask(
|
||||
# GT label images are maskid maps: either grayscale or RGB/BGR where
|
||||
# all three color channels contain the same maskid value [X, X, X].
|
||||
if not (np.array_equal(channels[:, :, 0], channels[:, :, 1]) and np.array_equal(channels[:, :, 1], channels[:, :, 2])):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="GT Mask 图片不符合要求:请上传灰度图,或 RGB 三通道完全相同的 maskid 图(背景 0,像素值为 maskid)。",
|
||||
)
|
||||
raise HTTPException(status_code=400, detail=invalid_format_detail)
|
||||
label_image = channels[:, :, 0]
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="GT Mask 图片不符合要求:请上传灰度图,或 RGB 三通道完全相同的 maskid 图(背景 0,像素值为 maskid)。",
|
||||
)
|
||||
raise HTTPException(status_code=400, detail=invalid_format_detail)
|
||||
|
||||
width = int(frame.width or image.shape[1])
|
||||
height = int(frame.height or image.shape[0])
|
||||
@@ -1041,12 +1052,12 @@ async def import_gt_mask(
|
||||
if not import_items:
|
||||
if skipped_unknown > 0:
|
||||
raise HTTPException(status_code=400, detail="No matching GT mask classes found")
|
||||
raise HTTPException(status_code=400, detail="No foreground mask regions found")
|
||||
raise HTTPException(status_code=400, detail=GT_MASK_EMPTY_DETAIL)
|
||||
|
||||
annotations: list[Annotation] = []
|
||||
for item in import_items:
|
||||
binary = item["binary"]
|
||||
contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||
contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
|
||||
|
||||
for contour in contours:
|
||||
if cv2.contourArea(contour) < 1:
|
||||
@@ -1085,7 +1096,7 @@ async def import_gt_mask(
|
||||
annotations.append(annotation)
|
||||
|
||||
if not annotations:
|
||||
raise HTTPException(status_code=400, detail="No foreground mask regions found")
|
||||
raise HTTPException(status_code=400, detail=GT_MASK_EMPTY_DETAIL)
|
||||
|
||||
db.commit()
|
||||
for annotation in annotations:
|
||||
|
||||
@@ -64,7 +64,7 @@ def _annotation_mask_id(annotation: Annotation) -> int | None:
|
||||
value = int(class_meta[key])
|
||||
except (TypeError, ValueError):
|
||||
continue
|
||||
if value > 0:
|
||||
if value >= 0:
|
||||
return value
|
||||
return None
|
||||
|
||||
@@ -361,7 +361,7 @@ def _build_gt_class_mapping(annotations: list[Annotation]) -> tuple[dict[str, in
|
||||
ordered = sorted(
|
||||
entries_by_key.values(),
|
||||
key=lambda item: (
|
||||
item["maskidHint"] if isinstance(item.get("maskidHint"), int) and item["maskidHint"] > 0 else 10_000_000,
|
||||
item["maskidHint"] if isinstance(item.get("maskidHint"), int) and item["maskidHint"] >= 0 else 10_000_000,
|
||||
str(item["className"]),
|
||||
str(item["key"]),
|
||||
),
|
||||
@@ -375,6 +375,8 @@ def _build_gt_class_mapping(annotations: list[Annotation]) -> tuple[dict[str, in
|
||||
nonlocal next_maskid
|
||||
while next_maskid in used_maskids:
|
||||
next_maskid += 1
|
||||
if next_maskid > 255:
|
||||
raise HTTPException(status_code=400, detail="GT_label 仅支持 8-bit maskid,类别值必须在 1-255 之间")
|
||||
value = next_maskid
|
||||
used_maskids.add(value)
|
||||
next_maskid += 1
|
||||
@@ -382,7 +384,12 @@ def _build_gt_class_mapping(annotations: list[Annotation]) -> tuple[dict[str, in
|
||||
|
||||
for entry in ordered:
|
||||
hinted_maskid = entry.get("maskidHint")
|
||||
if isinstance(hinted_maskid, int) and hinted_maskid > 0 and hinted_maskid not in used_maskids:
|
||||
if isinstance(hinted_maskid, int) and hinted_maskid > 255:
|
||||
raise HTTPException(status_code=400, detail="GT_label 仅支持 8-bit maskid,类别值必须在 1-255 之间")
|
||||
if isinstance(hinted_maskid, int) and hinted_maskid == 0:
|
||||
maskid = 0
|
||||
used_maskids.add(maskid)
|
||||
elif isinstance(hinted_maskid, int) and 0 < hinted_maskid <= 255 and hinted_maskid not in used_maskids:
|
||||
maskid = hinted_maskid
|
||||
used_maskids.add(maskid)
|
||||
else:
|
||||
@@ -513,7 +520,7 @@ def _write_result_mask_outputs(
|
||||
)
|
||||
|
||||
needs_fused_output = include_semantic or include_pro_label or include_mix_label
|
||||
semantic = np.zeros((height, width), dtype=np.uint16) if needs_fused_output else None
|
||||
semantic = np.zeros((height, width), dtype=np.uint8) if needs_fused_output else None
|
||||
pro_label = np.zeros((height, width, 3), dtype=np.uint8) if (include_pro_label or include_mix_label) else None
|
||||
|
||||
if needs_fused_output:
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""Media upload and parsing endpoints."""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
@@ -22,6 +23,13 @@ router = APIRouter(prefix="/api/media", tags=["Media"])
|
||||
ALLOWED_EXTENSIONS = {".mp4", ".avi", ".mov", ".mkv", ".webm", ".png", ".jpg", ".jpeg", ".dcm"}
|
||||
|
||||
|
||||
def natural_filename_key(filename: str) -> tuple[object, ...]:
|
||||
return tuple(
|
||||
int(part) if part.isdigit() else part.casefold()
|
||||
for part in re.split(r"(\d+)", Path(filename).name)
|
||||
)
|
||||
|
||||
|
||||
def _get_ext(filename: str) -> str:
|
||||
return Path(filename).suffix.lower()
|
||||
|
||||
@@ -124,6 +132,12 @@ async def upload_dicom_batch(
|
||||
if not files:
|
||||
raise HTTPException(status_code=400, detail="No files uploaded")
|
||||
|
||||
sorted_files = sorted(
|
||||
[file for file in files if file.filename and file.filename.lower().endswith(".dcm")],
|
||||
key=lambda file: natural_filename_key(file.filename or ""),
|
||||
)
|
||||
if not sorted_files:
|
||||
raise HTTPException(status_code=400, detail="No valid DICOM files uploaded")
|
||||
uploaded = []
|
||||
|
||||
if project_id:
|
||||
@@ -135,10 +149,10 @@ async def upload_dicom_batch(
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
else:
|
||||
# Create new DICOM project
|
||||
first_name = files[0].filename or "DICOM_Series"
|
||||
first_name = sorted_files[0].filename or "DICOM_Series"
|
||||
project = Project(
|
||||
name=first_name,
|
||||
description=f"DICOM series with {len(files)} files",
|
||||
description=f"DICOM series with {len(sorted_files)} files",
|
||||
status=PROJECT_STATUS_PENDING,
|
||||
source_type="dicom",
|
||||
owner_user_id=current_user.id,
|
||||
@@ -149,9 +163,7 @@ async def upload_dicom_batch(
|
||||
project_id = project.id
|
||||
logger.info("Auto-created DICOM project id=%s", project_id)
|
||||
|
||||
for file in files:
|
||||
if not file.filename or not file.filename.lower().endswith(".dcm"):
|
||||
continue
|
||||
for file in sorted_files:
|
||||
data = await file.read()
|
||||
object_name = f"uploads/{project_id}/dicom/{file.filename}"
|
||||
try:
|
||||
|
||||
@@ -7,15 +7,38 @@ from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from database import get_db
|
||||
from models import Project, Frame, User
|
||||
from models import Annotation, Mask, Project, Frame, User
|
||||
from routers.auth import get_current_user, require_editor
|
||||
from schemas import ProjectCreate, ProjectOut, ProjectUpdate, FrameCreate, FrameOut
|
||||
from schemas import ProjectCopyRequest, ProjectCreate, ProjectOut, ProjectUpdate, FrameCreate, FrameOut
|
||||
from minio_client import get_presigned_url
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/api/projects", tags=["Projects"])
|
||||
|
||||
|
||||
def _next_project_copy_name(db: Session, owner_user_id: int, source_name: str) -> str:
|
||||
base_name = f"{source_name} 副本"
|
||||
existing_names = {
|
||||
row[0]
|
||||
for row in db.query(Project.name)
|
||||
.filter(Project.owner_user_id == owner_user_id, Project.name.like(f"{base_name}%"))
|
||||
.all()
|
||||
}
|
||||
if base_name not in existing_names:
|
||||
return base_name
|
||||
suffix = 2
|
||||
while f"{base_name} {suffix}" in existing_names:
|
||||
suffix += 1
|
||||
return f"{base_name} {suffix}"
|
||||
|
||||
|
||||
def _prepare_project_response(project: Project) -> Project:
|
||||
project.frame_count = len(project.frames)
|
||||
if project.thumbnail_url:
|
||||
project.thumbnail_url = get_presigned_url(project.thumbnail_url, expires=3600)
|
||||
return project
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Projects
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -59,9 +82,7 @@ def list_projects(
|
||||
.all()
|
||||
)
|
||||
for p in projects:
|
||||
p.frame_count = len(p.frames)
|
||||
if p.thumbnail_url:
|
||||
p.thumbnail_url = get_presigned_url(p.thumbnail_url, expires=3600)
|
||||
_prepare_project_response(p)
|
||||
return projects
|
||||
|
||||
|
||||
@@ -82,10 +103,85 @@ def get_project(
|
||||
).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
project.frame_count = len(project.frames)
|
||||
if project.thumbnail_url:
|
||||
project.thumbnail_url = get_presigned_url(project.thumbnail_url, expires=3600)
|
||||
return project
|
||||
return _prepare_project_response(project)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/{project_id}/copy",
|
||||
response_model=ProjectOut,
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
summary="Copy a project",
|
||||
)
|
||||
def copy_project(
|
||||
project_id: int,
|
||||
payload: ProjectCopyRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_editor),
|
||||
) -> Project:
|
||||
"""Copy a project. Reset copies media/frame sequence; full also copies annotations and mask metadata."""
|
||||
source = db.query(Project).filter(
|
||||
Project.id == project_id,
|
||||
Project.owner_user_id == current_user.id,
|
||||
).first()
|
||||
if not source:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
|
||||
next_name = (payload.name or "").strip() if payload.name is not None else ""
|
||||
if not next_name:
|
||||
next_name = _next_project_copy_name(db, current_user.id, source.name)
|
||||
|
||||
copied = Project(
|
||||
name=next_name,
|
||||
description=source.description,
|
||||
video_path=source.video_path,
|
||||
thumbnail_url=source.thumbnail_url,
|
||||
status=source.status,
|
||||
source_type=source.source_type,
|
||||
original_fps=source.original_fps,
|
||||
parse_fps=source.parse_fps,
|
||||
owner_user_id=current_user.id,
|
||||
)
|
||||
db.add(copied)
|
||||
db.flush()
|
||||
|
||||
frame_id_map: dict[int, int] = {}
|
||||
for frame in sorted(source.frames, key=lambda item: item.frame_index):
|
||||
copied_frame = Frame(
|
||||
project_id=copied.id,
|
||||
frame_index=frame.frame_index,
|
||||
image_url=frame.image_url,
|
||||
width=frame.width,
|
||||
height=frame.height,
|
||||
timestamp_ms=frame.timestamp_ms,
|
||||
source_frame_number=frame.source_frame_number,
|
||||
)
|
||||
db.add(copied_frame)
|
||||
db.flush()
|
||||
frame_id_map[frame.id] = copied_frame.id
|
||||
|
||||
if payload.mode == "full":
|
||||
for annotation in sorted(source.annotations, key=lambda item: item.id):
|
||||
copied_annotation = Annotation(
|
||||
project_id=copied.id,
|
||||
frame_id=frame_id_map.get(annotation.frame_id) if annotation.frame_id is not None else None,
|
||||
template_id=annotation.template_id,
|
||||
mask_data=annotation.mask_data,
|
||||
points=annotation.points,
|
||||
bbox=annotation.bbox,
|
||||
)
|
||||
db.add(copied_annotation)
|
||||
db.flush()
|
||||
for mask in annotation.masks:
|
||||
db.add(Mask(
|
||||
annotation_id=copied_annotation.id,
|
||||
mask_url=mask.mask_url,
|
||||
format=mask.format,
|
||||
))
|
||||
|
||||
db.commit()
|
||||
db.refresh(copied)
|
||||
logger.info("Copied project id=%s to id=%s mode=%s", project_id, copied.id, payload.mode)
|
||||
return _prepare_project_response(copied)
|
||||
|
||||
|
||||
@router.patch(
|
||||
@@ -108,6 +204,10 @@ def update_project(
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
|
||||
for key, value in payload.model_dump(exclude_unset=True).items():
|
||||
if key == "name":
|
||||
value = (value or "").strip()
|
||||
if not value:
|
||||
raise HTTPException(status_code=400, detail="Project name is required")
|
||||
setattr(project, key, value)
|
||||
|
||||
db.commit()
|
||||
|
||||
@@ -14,15 +14,38 @@ from schemas import TemplateCreate, TemplateOut, TemplateUpdate
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/api/templates", tags=["Templates"])
|
||||
RESERVED_UNCLASSIFIED_CLASS = {
|
||||
"id": "reserved-unclassified",
|
||||
"name": "待分类",
|
||||
"color": "#000000",
|
||||
"zIndex": 0,
|
||||
"maskId": 0,
|
||||
"category": "系统保留",
|
||||
}
|
||||
|
||||
|
||||
def _is_reserved_class(item: dict) -> bool:
|
||||
return (
|
||||
item.get("id") == RESERVED_UNCLASSIFIED_CLASS["id"]
|
||||
or item.get("name") == RESERVED_UNCLASSIFIED_CLASS["name"]
|
||||
or item.get("maskId") == 0
|
||||
)
|
||||
|
||||
|
||||
def _normalize_template_classes(classes: list[dict] | None) -> list[dict]:
|
||||
normalized = [item for item in (classes or []) if not _is_reserved_class(item)]
|
||||
return [*normalized, dict(RESERVED_UNCLASSIFIED_CLASS)]
|
||||
|
||||
|
||||
def _pack_mapping_rules(data: dict) -> dict:
|
||||
"""Pack classes/rules into mapping_rules for DB storage."""
|
||||
mapping = data.get("mapping_rules") or {}
|
||||
if "classes" in data and data["classes"] is not None:
|
||||
mapping["classes"] = data.pop("classes")
|
||||
mapping["classes"] = _normalize_template_classes(data.pop("classes"))
|
||||
if "rules" in data and data["rules"] is not None:
|
||||
mapping["rules"] = data.pop("rules")
|
||||
if "classes" in mapping:
|
||||
mapping["classes"] = _normalize_template_classes(mapping.get("classes"))
|
||||
data["mapping_rules"] = mapping
|
||||
return data
|
||||
|
||||
@@ -31,7 +54,7 @@ def _unpack_template(template: Template) -> Template:
|
||||
"""Unpack mapping_rules into classes/rules for response."""
|
||||
mapping = template.mapping_rules or {}
|
||||
# Set as attributes so Pydantic from_attributes can pick them up
|
||||
template.classes = mapping.get("classes", [])
|
||||
template.classes = _normalize_template_classes(mapping.get("classes", []))
|
||||
template.rules = mapping.get("rules", [])
|
||||
return template
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Pydantic schemas for request/response validation."""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Optional, Any
|
||||
from typing import Literal, Optional, Any
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
|
||||
@@ -83,6 +83,11 @@ class ProjectUpdate(BaseModel):
|
||||
parse_fps: Optional[float] = None
|
||||
|
||||
|
||||
class ProjectCopyRequest(BaseModel):
|
||||
mode: Literal["reset", "full"] = "reset"
|
||||
name: Optional[str] = None
|
||||
|
||||
|
||||
class ProjectOut(ProjectBase):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
@@ -96,6 +101,7 @@ class ProjectOut(ProjectBase):
|
||||
class DemoFactoryResetOut(BaseModel):
|
||||
admin_user: UserOut
|
||||
project: ProjectOut
|
||||
projects: list[ProjectOut]
|
||||
deleted_counts: dict[str, int]
|
||||
message: str
|
||||
|
||||
|
||||
128
backend/services/demo_media.py
Normal file
128
backend/services/demo_media.py
Normal file
@@ -0,0 +1,128 @@
|
||||
"""Helpers for seeding the bundled demo media project."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from minio_client import upload_file
|
||||
from models import Frame, Project, User
|
||||
from services.frame_parser import natural_filename_key, parse_dicom, upload_frames_to_minio
|
||||
from statuses import PROJECT_STATUS_PENDING, PROJECT_STATUS_READY
|
||||
|
||||
DEMO_DICOM_PROJECT_NAME = "演示DICOM序列"
|
||||
DEMO_DICOM_PARSE_FPS = 30.0
|
||||
DEMO_VIDEO_PROJECT_NAME = "Data_MyVideo_1"
|
||||
|
||||
|
||||
def demo_dicom_files(dicom_dir: str) -> list[Path]:
|
||||
"""Return .dcm files in natural file-name order."""
|
||||
root = Path(dicom_dir)
|
||||
if not root.exists() or not root.is_dir():
|
||||
return []
|
||||
return sorted(
|
||||
[path for path in root.iterdir() if path.is_file() and path.name.lower().endswith(".dcm")],
|
||||
key=lambda path: natural_filename_key(path.name),
|
||||
)
|
||||
|
||||
|
||||
def create_unparsed_video_demo_project(
|
||||
db: Session,
|
||||
*,
|
||||
owner: User,
|
||||
video_path: str,
|
||||
project_name: str = DEMO_VIDEO_PROJECT_NAME,
|
||||
) -> Project:
|
||||
"""Create the bundled demo video project without extracting frames."""
|
||||
source = Path(video_path)
|
||||
if not source.exists() or not source.is_file():
|
||||
raise FileNotFoundError(f"Demo video not found: {video_path}")
|
||||
|
||||
project = Project(
|
||||
name=project_name,
|
||||
description="默认演示视频,尚未生成帧",
|
||||
status=PROJECT_STATUS_PENDING,
|
||||
source_type="video",
|
||||
parse_fps=30.0,
|
||||
original_fps=None,
|
||||
owner_user_id=owner.id,
|
||||
)
|
||||
db.add(project)
|
||||
db.flush()
|
||||
|
||||
data = source.read_bytes()
|
||||
object_name = f"uploads/{project.id}/{source.name}"
|
||||
upload_file(object_name, data, content_type="video/mp4", length=len(data))
|
||||
project.video_path = object_name
|
||||
project.thumbnail_url = None
|
||||
db.commit()
|
||||
db.refresh(project)
|
||||
return project
|
||||
|
||||
|
||||
def create_parsed_dicom_demo_project(
|
||||
db: Session,
|
||||
*,
|
||||
owner: User,
|
||||
dicom_dir: str,
|
||||
project_name: str = DEMO_DICOM_PROJECT_NAME,
|
||||
) -> Project:
|
||||
"""Create the demo DICOM project, upload the series, and register parsed frames."""
|
||||
dcm_files = demo_dicom_files(dicom_dir)
|
||||
if not dcm_files:
|
||||
raise FileNotFoundError(f"Demo DICOM series not found: {dicom_dir}")
|
||||
|
||||
project = Project(
|
||||
name=project_name,
|
||||
description=f"默认演示 DICOM 序列,已按文件名自然顺序生成 {len(dcm_files)} 帧",
|
||||
status=PROJECT_STATUS_PENDING,
|
||||
source_type="dicom",
|
||||
parse_fps=DEMO_DICOM_PARSE_FPS,
|
||||
original_fps=None,
|
||||
owner_user_id=owner.id,
|
||||
)
|
||||
db.add(project)
|
||||
db.flush()
|
||||
|
||||
dicom_prefix = f"uploads/{project.id}/dicom"
|
||||
for dcm_file in dcm_files:
|
||||
data = dcm_file.read_bytes()
|
||||
upload_file(
|
||||
f"{dicom_prefix}/{dcm_file.name}",
|
||||
data,
|
||||
content_type="application/dicom",
|
||||
length=len(data),
|
||||
)
|
||||
project.video_path = dicom_prefix
|
||||
|
||||
tmp_dir = tempfile.mkdtemp(prefix=f"seg_demo_dicom_{project.id}_")
|
||||
try:
|
||||
output_dir = os.path.join(tmp_dir, "frames")
|
||||
frame_files = parse_dicom(dicom_dir, output_dir)
|
||||
object_names = upload_frames_to_minio(frame_files, project.id)
|
||||
|
||||
for idx, obj_name in enumerate(object_names):
|
||||
image = cv2.imread(frame_files[idx])
|
||||
height, width = image.shape[:2] if image is not None else (None, None)
|
||||
db.add(Frame(
|
||||
project_id=project.id,
|
||||
frame_index=idx,
|
||||
image_url=obj_name,
|
||||
width=width,
|
||||
height=height,
|
||||
timestamp_ms=idx * 1000.0 / DEMO_DICOM_PARSE_FPS,
|
||||
source_frame_number=idx,
|
||||
))
|
||||
if object_names:
|
||||
project.thumbnail_url = object_names[0]
|
||||
project.status = PROJECT_STATUS_READY
|
||||
db.commit()
|
||||
db.refresh(project)
|
||||
return project
|
||||
finally:
|
||||
shutil.rmtree(tmp_dir, ignore_errors=True)
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
@@ -16,6 +17,14 @@ from minio_client import upload_file, BUCKET_NAME
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def natural_filename_key(filename: str) -> Tuple[object, ...]:
|
||||
"""Sort file names by their visible numeric order instead of pure lexicographic order."""
|
||||
return tuple(
|
||||
int(part) if part.isdigit() else part.casefold()
|
||||
for part in re.split(r"(\d+)", Path(filename).name)
|
||||
)
|
||||
|
||||
|
||||
def get_video_fps(video_path: str) -> float:
|
||||
"""Read the original frame rate of a video file."""
|
||||
cap = cv2.VideoCapture(video_path)
|
||||
@@ -150,7 +159,8 @@ def parse_dicom(
|
||||
"""
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
dcm_files = sorted(
|
||||
[f for f in os.listdir(dicom_dir) if f.lower().endswith(".dcm")]
|
||||
[f for f in os.listdir(dicom_dir) if f.lower().endswith(".dcm")],
|
||||
key=natural_filename_key,
|
||||
)
|
||||
|
||||
frame_paths: List[str] = []
|
||||
|
||||
@@ -15,6 +15,7 @@ from models import Frame, ProcessingTask, Project
|
||||
from progress_events import publish_task_progress_event
|
||||
from services.frame_parser import (
|
||||
extract_thumbnail,
|
||||
natural_filename_key,
|
||||
parse_dicom,
|
||||
parse_video,
|
||||
upload_frames_to_minio,
|
||||
@@ -188,7 +189,10 @@ def run_parse_media_task(db: Session, task_id: int) -> dict[str, Any]:
|
||||
os.makedirs(dcm_dir, exist_ok=True)
|
||||
|
||||
client = get_minio_client()
|
||||
objects = list(client.list_objects(BUCKET_NAME, prefix=project.video_path, recursive=True))
|
||||
objects = sorted(
|
||||
list(client.list_objects(BUCKET_NAME, prefix=project.video_path, recursive=True)),
|
||||
key=lambda obj: natural_filename_key(obj.object_name),
|
||||
)
|
||||
for obj in objects:
|
||||
_ensure_not_cancelled(db, task)
|
||||
if obj.object_name.lower().endswith(".dcm"):
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from models import Annotation, AuditLog, Frame, Mask, ProcessingTask, Project, Template, User
|
||||
from routers.auth import create_access_token, hash_password
|
||||
from statuses import PROJECT_STATUS_PENDING
|
||||
from statuses import PROJECT_STATUS_READY
|
||||
|
||||
|
||||
def test_admin_user_management_and_audit_logs(client, db_session):
|
||||
@@ -83,17 +83,34 @@ def test_admin_cannot_delete_self_or_user_with_projects(client, db_session):
|
||||
assert response.status_code == 409
|
||||
|
||||
|
||||
def test_demo_factory_reset_leaves_admin_and_unparsed_demo_video(client, db_session, monkeypatch, tmp_path):
|
||||
def test_demo_factory_reset_leaves_admin_and_parsed_demo_dicom(client, db_session, monkeypatch, tmp_path):
|
||||
video_path = tmp_path / "Data_MyVideo_1.mp4"
|
||||
video_path.write_bytes(b"demo-video")
|
||||
monkeypatch.setattr("routers.admin.settings.demo_video_path", str(video_path))
|
||||
dicom_dir = tmp_path / "dicom"
|
||||
dicom_dir.mkdir()
|
||||
for name in ["10.dcm", "2.dcm", "1.dcm"]:
|
||||
(dicom_dir / name).write_bytes(name.encode())
|
||||
monkeypatch.setattr("routers.admin.settings.demo_dicom_dir", str(dicom_dir))
|
||||
|
||||
parsed_frame_paths = []
|
||||
for idx in range(3):
|
||||
frame_path = tmp_path / f"frame_{idx:06d}.jpg"
|
||||
frame_path.write_bytes(b"frame")
|
||||
parsed_frame_paths.append(str(frame_path))
|
||||
|
||||
uploaded = []
|
||||
monkeypatch.setattr("routers.admin.upload_file", lambda object_name, data, content_type, length: uploaded.append({
|
||||
monkeypatch.setattr("services.demo_media.upload_file", lambda object_name, data, content_type, length: uploaded.append({
|
||||
"object_name": object_name,
|
||||
"data": data,
|
||||
"content_type": content_type,
|
||||
"length": length,
|
||||
}))
|
||||
monkeypatch.setattr("services.demo_media.parse_dicom", lambda dicom_dir_arg, output_dir: parsed_frame_paths)
|
||||
monkeypatch.setattr(
|
||||
"services.demo_media.upload_frames_to_minio",
|
||||
lambda frame_files, project_id: [f"projects/{project_id}/frames/frame_{idx:06d}.jpg" for idx, _ in enumerate(frame_files)],
|
||||
)
|
||||
|
||||
extra_user = User(username="doctor", password_hash=hash_password("secret123"), role="annotator", is_active=1)
|
||||
db_session.add(extra_user)
|
||||
@@ -113,7 +130,15 @@ def test_demo_factory_reset_leaves_admin_and_unparsed_demo_video(client, db_sess
|
||||
z_index=1,
|
||||
owner_user_id=extra_user.id,
|
||||
)
|
||||
db_session.add_all([task, private_template])
|
||||
system_template = Template(
|
||||
name="头颈部CT分割",
|
||||
description="头颈部CT分割",
|
||||
color="#ef4444",
|
||||
z_index=10,
|
||||
owner_user_id=None,
|
||||
mapping_rules={"classes": [{"name": "肿瘤/结节 (Tumor/Nodule)", "color": "#ff0000", "maskId": 1}], "rules": []},
|
||||
)
|
||||
db_session.add_all([task, private_template, system_template])
|
||||
db_session.commit()
|
||||
db_session.refresh(frame)
|
||||
annotation = Annotation(project_id=old_project.id, frame_id=frame.id, mask_data={"label": "old"})
|
||||
@@ -130,24 +155,36 @@ def test_demo_factory_reset_leaves_admin_and_unparsed_demo_video(client, db_sess
|
||||
data = response.json()
|
||||
assert data["message"] == "演示环境已恢复出厂设置"
|
||||
assert data["admin_user"]["username"] == "admin"
|
||||
assert data["project"]["name"] == "Data_MyVideo_1"
|
||||
assert data["project"]["status"] == PROJECT_STATUS_PENDING
|
||||
assert data["project"]["frame_count"] == 0
|
||||
assert data["project"]["video_path"] == f"uploads/{data['project']['id']}/Data_MyVideo_1.mp4"
|
||||
assert uploaded == [{
|
||||
"object_name": data["project"]["video_path"],
|
||||
"data": b"demo-video",
|
||||
"content_type": "video/mp4",
|
||||
"length": len(b"demo-video"),
|
||||
}]
|
||||
assert data["project"]["name"] == "演示DICOM序列"
|
||||
assert data["project"]["status"] == PROJECT_STATUS_READY
|
||||
assert data["project"]["source_type"] == "dicom"
|
||||
assert data["project"]["frame_count"] == 3
|
||||
assert data["project"]["video_path"] == f"uploads/{data['project']['id']}/dicom"
|
||||
assert [project["name"] for project in data["projects"]] == ["Data_MyVideo_1", "演示DICOM序列"]
|
||||
assert data["projects"][0]["status"] == "pending"
|
||||
assert data["projects"][0]["source_type"] == "video"
|
||||
assert data["projects"][0]["frame_count"] == 0
|
||||
assert data["projects"][1]["status"] == PROJECT_STATUS_READY
|
||||
assert data["projects"][1]["source_type"] == "dicom"
|
||||
assert data["projects"][1]["frame_count"] == 3
|
||||
assert [item["object_name"] for item in uploaded] == [
|
||||
f"uploads/{data['projects'][0]['id']}/Data_MyVideo_1.mp4",
|
||||
f"uploads/{data['project']['id']}/dicom/1.dcm",
|
||||
f"uploads/{data['project']['id']}/dicom/2.dcm",
|
||||
f"uploads/{data['project']['id']}/dicom/10.dcm",
|
||||
]
|
||||
assert [item["content_type"] for item in uploaded] == ["video/mp4", "application/dicom", "application/dicom", "application/dicom"]
|
||||
|
||||
assert [user.username for user in db_session.query(User).all()] == ["admin"]
|
||||
assert db_session.query(Project).count() == 1
|
||||
assert db_session.query(Frame).count() == 0
|
||||
assert db_session.query(Project).count() == 2
|
||||
assert db_session.query(Frame).count() == 3
|
||||
assert [frame.source_frame_number for frame in db_session.query(Frame).order_by(Frame.frame_index).all()] == [0, 1, 2]
|
||||
assert db_session.query(Annotation).count() == 0
|
||||
assert db_session.query(Mask).count() == 0
|
||||
assert db_session.query(ProcessingTask).count() == 0
|
||||
assert db_session.query(Template).filter(Template.owner_user_id.is_not(None)).count() == 0
|
||||
preserved_templates = db_session.query(Template).filter(Template.owner_user_id.is_(None)).all()
|
||||
assert [template.name for template in preserved_templates] == ["头颈部CT分割"]
|
||||
assert db_session.query(AuditLog).count() == 1
|
||||
assert db_session.query(AuditLog).first().action == "admin.demo_factory_reset"
|
||||
|
||||
|
||||
@@ -1149,6 +1149,81 @@ def test_import_gt_mask_creates_annotations_with_seed_points(client):
|
||||
assert 0.0 <= body[0]["points"][0][1] <= 1.0
|
||||
|
||||
|
||||
def test_import_gt_mask_polygons_work_with_analysis_and_smoothing(client):
|
||||
project, frame, _ = _create_project_and_frame(client)
|
||||
mask = np.zeros((360, 640), dtype=np.uint8)
|
||||
cv2.ellipse(mask, (260, 160), (130, 70), 20, 0, 360, 1, thickness=-1)
|
||||
ok, encoded = cv2.imencode(".png", mask)
|
||||
assert ok
|
||||
|
||||
response = client.post(
|
||||
"/api/ai/import-gt-mask",
|
||||
data={
|
||||
"project_id": str(project["id"]),
|
||||
"frame_id": str(frame["id"]),
|
||||
"label": "Imported GT",
|
||||
"color": "#22c55e",
|
||||
},
|
||||
files={"file": ("mask.png", encoded.tobytes(), "image/png")},
|
||||
)
|
||||
|
||||
assert response.status_code == 201
|
||||
annotation = response.json()[0]
|
||||
assert annotation["mask_data"]["source"] == "gt_mask"
|
||||
|
||||
analysis = client.post("/api/ai/analyze-mask", json={
|
||||
"frame_id": frame["id"],
|
||||
"mask_data": annotation["mask_data"],
|
||||
"points": annotation["points"],
|
||||
"bbox": annotation["bbox"],
|
||||
})
|
||||
assert analysis.status_code == 200
|
||||
assert analysis.json()["topology_anchor_count"] == len(annotation["mask_data"]["polygons"][0])
|
||||
|
||||
smoothing = client.post("/api/ai/smooth-mask", json={
|
||||
"frame_id": frame["id"],
|
||||
"mask_data": annotation["mask_data"],
|
||||
"points": annotation["points"],
|
||||
"bbox": annotation["bbox"],
|
||||
"strength": 35,
|
||||
})
|
||||
assert smoothing.status_code == 200
|
||||
assert smoothing.json()["topology_anchor_count"] == len(smoothing.json()["polygons"][0])
|
||||
|
||||
|
||||
def test_import_gt_mask_preserves_detailed_contours(client):
|
||||
project, frame, _ = _create_project_and_frame(client)
|
||||
mask = np.zeros((360, 640), dtype=np.uint8)
|
||||
center = np.array([320, 180])
|
||||
vertices = []
|
||||
for index in range(96):
|
||||
angle = 2 * np.pi * index / 96
|
||||
radius = 120 if index % 2 == 0 else 88
|
||||
vertices.append([
|
||||
int(center[0] + np.cos(angle) * radius),
|
||||
int(center[1] + np.sin(angle) * radius),
|
||||
])
|
||||
cv2.fillPoly(mask, [np.array(vertices, dtype=np.int32)], 1)
|
||||
ok, encoded = cv2.imencode(".png", mask)
|
||||
assert ok
|
||||
|
||||
response = client.post(
|
||||
"/api/ai/import-gt-mask",
|
||||
data={
|
||||
"project_id": str(project["id"]),
|
||||
"frame_id": str(frame["id"]),
|
||||
"label": "Detailed GT",
|
||||
"color": "#22c55e",
|
||||
},
|
||||
files={"file": ("mask.png", encoded.tobytes(), "image/png")},
|
||||
)
|
||||
|
||||
assert response.status_code == 201
|
||||
polygon = response.json()[0]["mask_data"]["polygons"][0]
|
||||
assert len(polygon) > 80
|
||||
assert len(polygon) <= 2048
|
||||
|
||||
|
||||
def test_import_gt_mask_splits_label_values(client):
|
||||
project, frame, _ = _create_project_and_frame(client)
|
||||
mask = np.zeros((360, 640), dtype=np.uint8)
|
||||
@@ -1174,7 +1249,27 @@ def test_import_gt_mask_splits_label_values(client):
|
||||
assert all(len(item["points"]) == 1 for item in body)
|
||||
|
||||
|
||||
def test_import_gt_mask_preserves_low_value_gtlabel_png(client):
|
||||
def test_import_gt_mask_rejects_background_only_label_image(client):
|
||||
project, frame, _ = _create_project_and_frame(client)
|
||||
mask = np.zeros((360, 640), dtype=np.uint8)
|
||||
ok, encoded = cv2.imencode(".png", mask)
|
||||
assert ok
|
||||
|
||||
response = client.post(
|
||||
"/api/ai/import-gt-mask",
|
||||
data={
|
||||
"project_id": str(project["id"]),
|
||||
"frame_id": str(frame["id"]),
|
||||
"label": "GT Class",
|
||||
},
|
||||
files={"file": ("empty-gt-label.png", encoded.tobytes(), "image/png")},
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert response.json()["detail"] == "GT Mask 图片中没有非背景 maskid 区域。"
|
||||
|
||||
|
||||
def test_import_gt_mask_accepts_uint8_low_value_gtlabel_png(client):
|
||||
project, frame, _ = _create_project_and_frame(client)
|
||||
template = client.post("/api/templates", json={
|
||||
"name": "GTLabel Template",
|
||||
@@ -1185,7 +1280,7 @@ def test_import_gt_mask_preserves_low_value_gtlabel_png(client):
|
||||
],
|
||||
"rules": [],
|
||||
}).json()
|
||||
mask = np.zeros((360, 640), dtype=np.uint16)
|
||||
mask = np.zeros((360, 640), dtype=np.uint8)
|
||||
cv2.rectangle(mask, (40, 40), (140, 140), 1, thickness=-1)
|
||||
ok, encoded = cv2.imencode(".png", mask)
|
||||
assert ok
|
||||
@@ -1241,7 +1336,7 @@ def test_import_gt_mask_rejects_rgb_color_masks(client):
|
||||
assert "RGB 三通道完全相同" in response.json()["detail"]
|
||||
|
||||
|
||||
def test_import_gt_mask_reads_uint16_gt_label_and_maps_maskid_class(client):
|
||||
def test_import_gt_mask_rejects_uint16_gt_label(client):
|
||||
project, frame, _ = _create_project_and_frame(client)
|
||||
template = client.post("/api/templates", json={
|
||||
"name": "Label Template",
|
||||
@@ -1266,13 +1361,8 @@ def test_import_gt_mask_reads_uint16_gt_label_and_maps_maskid_class(client):
|
||||
files={"file": ("gt_label.png", encoded.tobytes(), "image/png")},
|
||||
)
|
||||
|
||||
assert response.status_code == 201
|
||||
body = response.json()
|
||||
assert len(body) == 1
|
||||
assert body[0]["mask_data"]["gt_label_value"] == 1
|
||||
assert body[0]["mask_data"]["label"] == "肿瘤"
|
||||
assert body[0]["mask_data"]["class"]["maskId"] == 1
|
||||
assert body[0]["mask_data"]["class"]["color"] == "#ff0000"
|
||||
assert response.status_code == 400
|
||||
assert "仅支持 8-bit" in response.json()["detail"]
|
||||
|
||||
|
||||
def test_import_gt_mask_handles_unknown_maskid_policy_and_resizes_to_frame(client):
|
||||
|
||||
@@ -169,6 +169,7 @@ def test_export_results_zip_contains_coco_original_images_and_selected_mask_outp
|
||||
"key": f"template:{annotation['template_id']}",
|
||||
"template_id": annotation["template_id"],
|
||||
}]
|
||||
assert gt_label.dtype == np.uint8
|
||||
assert gt_label[0, 0] == 0
|
||||
assert gt_label[20, 50] == 1
|
||||
assert pro_label[20, 50].tolist() == [212, 182, 6]
|
||||
@@ -234,6 +235,7 @@ def test_export_results_uses_internal_layer_order_for_gt_pro_and_mix_outputs(cli
|
||||
cv2.IMREAD_COLOR,
|
||||
)
|
||||
|
||||
assert gt_label.dtype == np.uint8
|
||||
assert gt_label[10, 10] == high_value
|
||||
assert pro_label[10, 10].tolist() == [0, 0, 255]
|
||||
assert mix_label[10, 10].tolist() == [127, 127, 255]
|
||||
@@ -365,10 +367,74 @@ def test_export_results_preserves_template_maskid_consistently_across_frames(cli
|
||||
"key": "class:tumor",
|
||||
"template_id": None,
|
||||
}]
|
||||
assert first_label.dtype == np.uint8
|
||||
assert second_label.dtype == np.uint8
|
||||
assert first_label[5, 5] == 7
|
||||
assert second_label[5, 5] == 7
|
||||
|
||||
|
||||
def test_export_results_keeps_unclassified_maskid_zero_black_in_gt_and_pro(client, monkeypatch):
|
||||
monkeypatch.setattr("routers.export.download_file", lambda object_name: _fake_image_bytes(20, 20))
|
||||
project = client.post("/api/projects", json={
|
||||
"name": "Unclassified Export Project",
|
||||
"video_path": "uploads/8/unclassified.mp4",
|
||||
}).json()
|
||||
frame = client.post(f"/api/projects/{project['id']}/frames", json={
|
||||
"project_id": project["id"],
|
||||
"frame_index": 0,
|
||||
"image_url": "frames/source.jpg",
|
||||
"width": 20,
|
||||
"height": 20,
|
||||
"timestamp_ms": 0,
|
||||
}).json()
|
||||
client.post("/api/ai/annotate", json={
|
||||
"project_id": project["id"],
|
||||
"frame_id": frame["id"],
|
||||
"mask_data": {
|
||||
"polygons": [[[0.1, 0.1], [0.8, 0.1], [0.8, 0.8], [0.1, 0.8]]],
|
||||
"label": "待分类",
|
||||
"color": "#000000",
|
||||
"class": {
|
||||
"id": "reserved-unclassified",
|
||||
"name": "待分类",
|
||||
"color": "#000000",
|
||||
"maskId": 0,
|
||||
"zIndex": 0,
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
response = client.get(f"/api/export/{project['id']}/results?scope=all&outputs=gt_label,pro_label")
|
||||
|
||||
assert response.status_code == 200
|
||||
with zipfile.ZipFile(BytesIO(response.content)) as archive:
|
||||
mapping = json.loads(archive.read("maskid_GT像素值_类别映射.json"))
|
||||
stem = "unclassified_0h00m00s000ms_frame000001"
|
||||
gt_label = cv2.imdecode(
|
||||
np.frombuffer(archive.read(f"GT_label图/{stem}.png"), dtype=np.uint8),
|
||||
cv2.IMREAD_UNCHANGED,
|
||||
)
|
||||
pro_label = cv2.imdecode(
|
||||
np.frombuffer(archive.read(f"Pro_label彩色分割结果/{stem}.png"), dtype=np.uint8),
|
||||
cv2.IMREAD_COLOR,
|
||||
)
|
||||
|
||||
assert mapping["classes"] == [{
|
||||
"gt_pixel_value": 0,
|
||||
"maskid": 0,
|
||||
"chineseName": "待分类",
|
||||
"className": "待分类",
|
||||
"categoryName": "",
|
||||
"rgb": [0, 0, 0],
|
||||
"color": "#000000",
|
||||
"key": "class:reserved-unclassified",
|
||||
"template_id": None,
|
||||
}]
|
||||
assert gt_label.dtype == np.uint8
|
||||
assert gt_label[5, 5] == 0
|
||||
assert pro_label[5, 5].tolist() == [0, 0, 0]
|
||||
|
||||
|
||||
def test_exported_gtlabel_round_trips_through_gt_mask_import_with_template_maskid(client, monkeypatch):
|
||||
monkeypatch.setattr("routers.export.download_file", lambda object_name: _fake_image_bytes(20, 20))
|
||||
project = client.post("/api/projects", json={
|
||||
@@ -423,6 +489,7 @@ def test_exported_gtlabel_round_trips_through_gt_mask_import_with_template_maski
|
||||
gt_label = cv2.imdecode(np.frombuffer(exported_gt_label, dtype=np.uint8), cv2.IMREAD_UNCHANGED)
|
||||
mapping = json.loads(archive.read("maskid_GT像素值_类别映射.json"))
|
||||
|
||||
assert gt_label.dtype == np.uint8
|
||||
assert gt_label[5, 5] == 7
|
||||
assert mapping["classes"][0]["maskid"] == 7
|
||||
|
||||
@@ -446,6 +513,36 @@ def test_exported_gtlabel_round_trips_through_gt_mask_import_with_template_maski
|
||||
assert imported[0]["mask_data"]["class"]["maskId"] == 7
|
||||
|
||||
|
||||
def test_export_results_rejects_gtlabel_maskid_outside_uint8_range(client, monkeypatch):
|
||||
monkeypatch.setattr("routers.export.download_file", lambda object_name: _fake_image_bytes(20, 20))
|
||||
project = client.post("/api/projects", json={
|
||||
"name": "Large MaskId Project",
|
||||
"video_path": "uploads/8/large-maskid.mp4",
|
||||
}).json()
|
||||
frame = client.post(f"/api/projects/{project['id']}/frames", json={
|
||||
"project_id": project["id"],
|
||||
"frame_index": 0,
|
||||
"image_url": "frames/source.jpg",
|
||||
"width": 20,
|
||||
"height": 20,
|
||||
}).json()
|
||||
client.post("/api/ai/annotate", json={
|
||||
"project_id": project["id"],
|
||||
"frame_id": frame["id"],
|
||||
"mask_data": {
|
||||
"polygons": [[[0.1, 0.1], [0.8, 0.1], [0.8, 0.8], [0.1, 0.8]]],
|
||||
"label": "TooLarge",
|
||||
"color": "#ff0000",
|
||||
"class": {"id": "too-large", "name": "TooLarge", "color": "#ff0000", "maskId": 300, "zIndex": 30},
|
||||
},
|
||||
})
|
||||
|
||||
response = client.get(f"/api/export/{project['id']}/results?scope=all&outputs=gt_label")
|
||||
|
||||
assert response.status_code == 400
|
||||
assert "8-bit maskid" in response.json()["detail"]
|
||||
|
||||
|
||||
def test_export_missing_project_returns_404(client):
|
||||
assert client.get("/api/export/999/coco").status_code == 404
|
||||
assert client.get("/api/export/999/masks").status_code == 404
|
||||
|
||||
@@ -48,15 +48,37 @@ def test_upload_dicom_batch_filters_files_and_creates_project(client, monkeypatc
|
||||
response = client.post(
|
||||
"/api/media/upload/dicom",
|
||||
files=[
|
||||
("files", ("a.dcm", b"dcm", "application/dicom")),
|
||||
("files", ("10.dcm", b"dcm10", "application/dicom")),
|
||||
("files", ("skip.txt", b"text", "text/plain")),
|
||||
("files", ("2.dcm", b"dcm2", "application/dicom")),
|
||||
("files", ("1.dcm", b"dcm1", "application/dicom")),
|
||||
],
|
||||
)
|
||||
|
||||
assert response.status_code == 201
|
||||
data = response.json()
|
||||
assert data["uploaded_count"] == 1
|
||||
assert uploaded == [f"uploads/{data['project_id']}/dicom/a.dcm"]
|
||||
assert data["uploaded_count"] == 3
|
||||
assert uploaded == [
|
||||
f"uploads/{data['project_id']}/dicom/1.dcm",
|
||||
f"uploads/{data['project_id']}/dicom/2.dcm",
|
||||
f"uploads/{data['project_id']}/dicom/10.dcm",
|
||||
]
|
||||
project_detail = client.get(f"/api/projects/{data['project_id']}").json()
|
||||
assert project_detail["name"] == "1.dcm"
|
||||
|
||||
|
||||
def test_upload_dicom_batch_rejects_when_no_valid_dicom(client, monkeypatch):
|
||||
monkeypatch.setattr("routers.media.upload_file", lambda *args, **kwargs: None)
|
||||
|
||||
response = client.post(
|
||||
"/api/media/upload/dicom",
|
||||
files=[
|
||||
("files", ("notes.txt", b"text", "text/plain")),
|
||||
],
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert response.json()["detail"] == "No valid DICOM files uploaded"
|
||||
|
||||
|
||||
def test_parse_media_queues_background_task(client, monkeypatch):
|
||||
@@ -194,6 +216,101 @@ def test_parse_task_runner_registers_frames(client, db_session, monkeypatch, tmp
|
||||
assert frames[0]["source_frame_number"] == 0
|
||||
|
||||
|
||||
def test_parse_dicom_reads_files_in_natural_filename_order(monkeypatch, tmp_path):
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
|
||||
from services.frame_parser import parse_dicom
|
||||
|
||||
dcm_dir = tmp_path / "dcm"
|
||||
output_dir = tmp_path / "frames"
|
||||
dcm_dir.mkdir()
|
||||
for name in ["10.dcm", "2.dcm", "1.dcm"]:
|
||||
(dcm_dir / name).write_bytes(b"dcm")
|
||||
|
||||
read_order = []
|
||||
|
||||
class FakeDicom:
|
||||
pixel_array = np.ones((2, 2), dtype=np.uint8)
|
||||
|
||||
def fake_dcmread(path):
|
||||
read_order.append(Path(path).name)
|
||||
return FakeDicom()
|
||||
|
||||
def fake_imwrite(path, image, params=None):
|
||||
Path(path).write_bytes(image.tobytes())
|
||||
return True
|
||||
|
||||
monkeypatch.setattr("services.frame_parser.dcmread", fake_dcmread)
|
||||
monkeypatch.setattr("services.frame_parser.cv2.imwrite", fake_imwrite)
|
||||
|
||||
frame_files = parse_dicom(str(dcm_dir), str(output_dir))
|
||||
|
||||
assert read_order == ["1.dcm", "2.dcm", "10.dcm"]
|
||||
assert [Path(path).name for path in frame_files] == ["frame_000000.jpg", "frame_000001.jpg", "frame_000002.jpg"]
|
||||
|
||||
|
||||
def test_parse_task_runner_downloads_dicom_objects_in_natural_filename_order(client, db_session, monkeypatch, tmp_path):
|
||||
from types import SimpleNamespace
|
||||
|
||||
from models import ProcessingTask
|
||||
from services.media_task_runner import run_parse_media_task
|
||||
|
||||
project = client.post("/api/projects", json={
|
||||
"name": "DICOM",
|
||||
"video_path": "uploads/1/dicom",
|
||||
"source_type": "dicom",
|
||||
"parse_fps": 30,
|
||||
}).json()
|
||||
task = ProcessingTask(
|
||||
task_type="parse_dicom",
|
||||
status="queued",
|
||||
progress=0,
|
||||
project_id=project["id"],
|
||||
payload={"source_type": "dicom"},
|
||||
)
|
||||
db_session.add(task)
|
||||
db_session.commit()
|
||||
db_session.refresh(task)
|
||||
|
||||
class FakeClient:
|
||||
def list_objects(self, bucket, prefix, recursive=True):
|
||||
return [
|
||||
SimpleNamespace(object_name=f"{prefix}/10.dcm"),
|
||||
SimpleNamespace(object_name=f"{prefix}/2.dcm"),
|
||||
SimpleNamespace(object_name=f"{prefix}/1.dcm"),
|
||||
]
|
||||
|
||||
downloaded = []
|
||||
frame_files = []
|
||||
for idx in range(3):
|
||||
frame_file = tmp_path / f"frame_{idx:06d}.jpg"
|
||||
frame_file.write_bytes(b"fake image")
|
||||
frame_files.append(str(frame_file))
|
||||
|
||||
monkeypatch.setattr("services.media_task_runner.get_minio_client", lambda: FakeClient())
|
||||
monkeypatch.setattr(
|
||||
"services.media_task_runner.download_file",
|
||||
lambda object_name: downloaded.append(object_name) or b"dcm",
|
||||
)
|
||||
monkeypatch.setattr("services.media_task_runner.parse_dicom", lambda *args, **kwargs: frame_files)
|
||||
monkeypatch.setattr(
|
||||
"services.media_task_runner.upload_frames_to_minio",
|
||||
lambda frames, project_id: [f"projects/{project_id}/frames/{idx}.jpg" for idx, _ in enumerate(frames)],
|
||||
)
|
||||
monkeypatch.setattr("services.media_task_runner.publish_task_progress_event", lambda task: None)
|
||||
|
||||
result = run_parse_media_task(db_session, task.id)
|
||||
|
||||
assert result["frames_extracted"] == 3
|
||||
assert downloaded == [
|
||||
"uploads/1/dicom/1.dcm",
|
||||
"uploads/1/dicom/2.dcm",
|
||||
"uploads/1/dicom/10.dcm",
|
||||
]
|
||||
|
||||
|
||||
def test_parse_task_runner_skips_already_cancelled_task(db_session):
|
||||
from models import ProcessingTask
|
||||
from services.media_task_runner import run_parse_media_task
|
||||
|
||||
@@ -42,6 +42,9 @@ def test_project_crud_and_frames(client, monkeypatch):
|
||||
assert updated.json()["name"] == "Renamed"
|
||||
assert updated.json()["status"] == "ready"
|
||||
|
||||
empty_name = client.patch(f"/api/projects/{project_id}", json={"name": " "})
|
||||
assert empty_name.status_code == 400
|
||||
|
||||
deleted = client.delete(f"/api/projects/{project_id}")
|
||||
assert deleted.status_code == 204
|
||||
assert client.get(f"/api/projects/{project_id}").status_code == 404
|
||||
@@ -83,10 +86,97 @@ def test_delete_project_cascades_related_records(client, db_session):
|
||||
assert db_session.query(ProcessingTask).filter(ProcessingTask.project_id == project_id).count() == 0
|
||||
|
||||
|
||||
def test_copy_project_reset_copies_frame_sequence_without_annotations(client, db_session):
|
||||
created = client.post("/api/projects", json={
|
||||
"name": "Reset Source",
|
||||
"description": "desc",
|
||||
"video_path": "uploads/source.mp4",
|
||||
"thumbnail_url": "thumb.jpg",
|
||||
"status": "ready",
|
||||
"parse_fps": 12,
|
||||
})
|
||||
assert created.status_code == 201
|
||||
project_id = created.json()["id"]
|
||||
frame = client.post(f"/api/projects/{project_id}/frames", json={
|
||||
"project_id": project_id,
|
||||
"frame_index": 0,
|
||||
"image_url": "frames/source/frame_000000.jpg",
|
||||
"width": 640,
|
||||
"height": 360,
|
||||
"timestamp_ms": 0,
|
||||
"source_frame_number": 0,
|
||||
})
|
||||
assert frame.status_code == 201
|
||||
annotation = client.post("/api/ai/annotate", json={
|
||||
"project_id": project_id,
|
||||
"frame_id": frame.json()["id"],
|
||||
"mask_data": {"label": "Tumor", "polygons": [[[0.1, 0.1], [0.2, 0.1], [0.2, 0.2]]]},
|
||||
})
|
||||
assert annotation.status_code == 201
|
||||
|
||||
copied = client.post(f"/api/projects/{project_id}/copy", json={"mode": "reset"})
|
||||
assert copied.status_code == 201
|
||||
copied_body = copied.json()
|
||||
assert copied_body["name"] == "Reset Source 副本"
|
||||
assert copied_body["frame_count"] == 1
|
||||
assert copied_body["video_path"] == "uploads/source.mp4"
|
||||
assert copied_body["parse_fps"] == 12
|
||||
|
||||
copied_frames = db_session.query(Frame).filter(Frame.project_id == copied_body["id"]).all()
|
||||
assert len(copied_frames) == 1
|
||||
assert copied_frames[0].image_url == "frames/source/frame_000000.jpg"
|
||||
assert db_session.query(Annotation).filter(Annotation.project_id == copied_body["id"]).count() == 0
|
||||
|
||||
|
||||
def test_copy_project_full_copies_annotations_and_mask_metadata(client, db_session):
|
||||
created = client.post("/api/projects", json={
|
||||
"name": "Full Source",
|
||||
"status": "ready",
|
||||
})
|
||||
assert created.status_code == 201
|
||||
project_id = created.json()["id"]
|
||||
frame = client.post(f"/api/projects/{project_id}/frames", json={
|
||||
"project_id": project_id,
|
||||
"frame_index": 0,
|
||||
"image_url": "frames/source/frame_000000.jpg",
|
||||
"width": 640,
|
||||
"height": 360,
|
||||
})
|
||||
assert frame.status_code == 201
|
||||
frame_id = frame.json()["id"]
|
||||
annotation = client.post("/api/ai/annotate", json={
|
||||
"project_id": project_id,
|
||||
"frame_id": frame_id,
|
||||
"mask_data": {"label": "Tumor", "polygons": [[[0.1, 0.1], [0.2, 0.1], [0.2, 0.2]]]},
|
||||
"points": [[0.1, 0.1]],
|
||||
"bbox": [0.1, 0.1, 0.1, 0.1],
|
||||
})
|
||||
assert annotation.status_code == 201
|
||||
annotation_id = annotation.json()["id"]
|
||||
db_session.add(Mask(annotation_id=annotation_id, mask_url="masks/source.png", format="png"))
|
||||
db_session.commit()
|
||||
|
||||
copied = client.post(f"/api/projects/{project_id}/copy", json={"mode": "full"})
|
||||
assert copied.status_code == 201
|
||||
copied_body = copied.json()
|
||||
copied_frames = db_session.query(Frame).filter(Frame.project_id == copied_body["id"]).all()
|
||||
copied_annotations = db_session.query(Annotation).filter(Annotation.project_id == copied_body["id"]).all()
|
||||
|
||||
assert copied_body["name"] == "Full Source 副本"
|
||||
assert len(copied_frames) == 1
|
||||
assert len(copied_annotations) == 1
|
||||
assert copied_annotations[0].id != annotation_id
|
||||
assert copied_annotations[0].frame_id == copied_frames[0].id
|
||||
assert copied_annotations[0].mask_data["label"] == "Tumor"
|
||||
assert copied_annotations[0].bbox == [0.1, 0.1, 0.1, 0.1]
|
||||
assert copied_annotations[0].masks[0].mask_url == "masks/source.png"
|
||||
|
||||
|
||||
def test_project_and_frame_404s(client):
|
||||
assert client.get("/api/projects/999").status_code == 404
|
||||
assert client.patch("/api/projects/999", json={"name": "x"}).status_code == 404
|
||||
assert client.delete("/api/projects/999").status_code == 404
|
||||
assert client.post("/api/projects/999/copy", json={"mode": "reset"}).status_code == 404
|
||||
assert client.post("/api/projects/999/frames", json={
|
||||
"project_id": 999,
|
||||
"frame_index": 0,
|
||||
|
||||
@@ -37,3 +37,55 @@ def test_template_404s(client):
|
||||
assert client.get("/api/templates/999").status_code == 404
|
||||
assert client.patch("/api/templates/999", json={"name": "x"}).status_code == 404
|
||||
assert client.delete("/api/templates/999").status_code == 404
|
||||
|
||||
|
||||
def test_default_head_neck_ct_template_is_seeded_and_visible(client, db_session):
|
||||
from main import ensure_default_templates
|
||||
from models import Template
|
||||
|
||||
ensure_default_templates(db_session)
|
||||
ensure_default_templates(db_session)
|
||||
|
||||
templates = db_session.query(Template).filter(Template.owner_user_id.is_(None)).all()
|
||||
names = [template.name for template in templates]
|
||||
assert names.count("头颈部CT分割") == 1
|
||||
|
||||
listing = client.get("/api/templates")
|
||||
assert listing.status_code == 200
|
||||
head_neck = next(template for template in listing.json() if template["name"] == "头颈部CT分割")
|
||||
assert head_neck["description"] == "头颈部CT分割"
|
||||
expected_names = [
|
||||
"肿瘤/结节 (Tumor/Nodule)",
|
||||
"下颌骨 (Mandible)",
|
||||
"甲状腺 (Thyroid)",
|
||||
"气管 (Trachea)",
|
||||
"颈椎 (Cervical Spine)",
|
||||
"颈动脉 (Carotid Artery)",
|
||||
"颈静脉 (Jugular Vein)",
|
||||
"腮腺 (Parotid Gland)",
|
||||
"下颌下腺 (Submandibular Gland)",
|
||||
"舌骨 (Hyoid Bone)",
|
||||
"待分类",
|
||||
]
|
||||
expected_colors = [
|
||||
"#ff0000",
|
||||
"#00ff00",
|
||||
"#0000ff",
|
||||
"#ffff00",
|
||||
"#ff00ff",
|
||||
"#00ffff",
|
||||
"#ff8000",
|
||||
"#800080",
|
||||
"#008080",
|
||||
"#808000",
|
||||
"#000000",
|
||||
]
|
||||
actual_names = [item["name"] for item in head_neck["classes"]]
|
||||
actual_colors = [item["color"] for item in head_neck["classes"]]
|
||||
actual_mask_ids = [item["maskId"] for item in head_neck["classes"]]
|
||||
if actual_names != expected_names:
|
||||
raise AssertionError(f"Unexpected head-neck classes: {actual_names}")
|
||||
if actual_colors != expected_colors:
|
||||
raise AssertionError(f"Unexpected head-neck colors: {actual_colors}")
|
||||
if actual_mask_ids != [*list(range(1, 11)), 0]:
|
||||
raise AssertionError(f"Unexpected head-neck mask IDs: {actual_mask_ids}")
|
||||
|
||||
Reference in New Issue
Block a user