完善项目导入、模板与分割工作区交互

- 增强 DICOM/视频项目导入与演示数据:DICOM 按文件名自然顺序处理,导入后展示上传与解析任务进度,恢复演示出厂设置保留演示视频和演示 DICOM 项目,并补充 demo media seed 逻辑。

- 完善项目管理:项目支持重命名、删除、复制,删除使用站内确认弹窗,复制支持新项目重置和全内容复制,DICOM 项目不显示生成帧入口。

- 完善 GT Mask 与导出链路:只支持 8-bit maskid 图导入,非法/全背景图明确拒绝,尺寸自动适配,高精度 polygon 回显;统一导出默认当前帧,GT_label 使用 uint8 和真实 maskid,待分类 maskid 0 与背景一致。

- 完善分割工作区交互:新增画笔和橡皮擦并支持尺寸控制,移除创建点/线段入口,工具栏按类别分隔,AI 智能分割使用明确 AI 图标,取消黄色 seed point,清空/删除传播 mask 后同步清理空帧时间轴状态。

- 完善传播与时间轴:自动传播使用 SAM 2.1 权重任务,参考帧无遮罩时提示,传播历史按同一蓝色系递进变暗,删除/清空传播链时保留人工或独立 AI 标注来源。

- 完善模板库:新增头颈部 CT 分割默认模板,所有模板保留 maskid 0 待分类,支持鼠标复制模板、拖拽层级、JSON 批量导入预览、删除 label 和站内删除确认。

- 完善用户与高风险确认:用户改密码、删除用户、恢复演示出厂设置和清空人工/AI 标注帧均改为站内确认交互,避免浏览器原生 prompt/confirm。

- 补充前后端测试与文档:更新项目、模板、GT 导入、导出、传播、DICOM、用户管理等测试,并同步 README、AGENTS 和 doc 下实现/契约/测试计划文档。
This commit is contained in:
2026-05-03 17:11:59 +08:00
parent afcddfaeb9
commit 481ffa5b67
47 changed files with 3650 additions and 676 deletions

View File

@@ -39,6 +39,7 @@ class Settings(BaseSettings):
default_admin_username: str = "admin"
default_admin_password: str = "123456"
demo_video_path: str = "/home/wkmgc/Desktop/Seg_Server/Data_MyVideo_1.mp4"
demo_dicom_dir: str = "/home/wkmgc/Desktop/Seg_Server/2024_2_5_王芳/※2F458C45CFAA4C7CB76A39AA2BFE436B"
class Config:
env_file = ".env"

View File

@@ -4,8 +4,6 @@ import asyncio
import json
import logging
import os
import shutil
import tempfile
from contextlib import asynccontextmanager, suppress
from datetime import datetime, timezone
@@ -15,10 +13,9 @@ from sqlalchemy import inspect, text
from config import settings
from database import Base, engine, SessionLocal
from minio_client import ensure_bucket_exists, upload_file
from minio_client import ensure_bucket_exists
from progress_events import PROGRESS_CHANNEL
from redis_client import get_redis_client, ping as redis_ping
from statuses import PROJECT_STATUS_PENDING, PROJECT_STATUS_READY
from routers import projects, templates, media, ai, export, auth, dashboard, tasks, admin
@@ -27,9 +24,24 @@ logging.basicConfig(
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
)
logger = logging.getLogger(__name__)
RESERVED_UNCLASSIFIED_CLASS = {
"id": "reserved-unclassified",
"name": "待分类",
"color": "#000000",
"zIndex": 0,
"maskId": 0,
"category": "系统保留",
}
DEFAULT_VIDEO_PATH = settings.demo_video_path
def _with_reserved_unclassified_class(classes: list[dict]) -> list[dict]:
filtered = [
item for item in classes
if item.get("id") != RESERVED_UNCLASSIFIED_CLASS["id"]
and item.get("name") != RESERVED_UNCLASSIFIED_CLASS["name"]
and item.get("maskId") != 0
]
return [*filtered, dict(RESERVED_UNCLASSIFIED_CLASS)]
def _ensure_runtime_schema_columns() -> None:
"""Add nullable columns introduced after initial create_all deployments."""
@@ -72,92 +84,51 @@ def _seed_default_admin_and_ownership_sync() -> None:
def _seed_default_project_sync() -> None:
"""Synchronously seed the default video project on first startup."""
import cv2
from models import Project, Frame
"""Synchronously seed the bundled demo video and DICOM projects on first startup."""
from models import Project
from routers.auth import ensure_default_admin
from services.frame_parser import parse_video, upload_frames_to_minio, extract_thumbnail
from services.demo_media import (
DEMO_DICOM_PROJECT_NAME,
DEMO_VIDEO_PROJECT_NAME,
create_parsed_dicom_demo_project,
create_unparsed_video_demo_project,
demo_dicom_files,
)
db = SessionLocal()
try:
admin = ensure_default_admin(db)
existing = db.query(Project).filter(Project.name == "Data_MyVideo_1").first()
if existing is not None:
if existing.owner_user_id is None:
existing.owner_user_id = admin.id
existing_video = db.query(Project).filter(Project.name == DEMO_VIDEO_PROJECT_NAME).first()
if existing_video is not None and existing_video.owner_user_id is None:
existing_video.owner_user_id = admin.id
db.commit()
elif existing_video is None and os.path.exists(settings.demo_video_path):
video_project = create_unparsed_video_demo_project(
db,
owner=admin,
video_path=settings.demo_video_path,
project_name=DEMO_VIDEO_PROJECT_NAME,
)
logger.info("Seeded default video project id=%s", video_project.id)
existing_dicom = db.query(Project).filter(Project.name == DEMO_DICOM_PROJECT_NAME).first()
if existing_dicom is not None:
if existing_dicom.owner_user_id is None:
existing_dicom.owner_user_id = admin.id
db.commit()
return
if not os.path.exists(DEFAULT_VIDEO_PATH):
logger.warning("Default video not found at %s", DEFAULT_VIDEO_PATH)
if not demo_dicom_files(settings.demo_dicom_dir):
logger.warning("Default DICOM series not found at %s", settings.demo_dicom_dir)
return
project = Project(
name="Data_MyVideo_1",
description="默认演示视频",
status=PROJECT_STATUS_PENDING,
source_type="video",
parse_fps=30.0,
owner_user_id=admin.id,
project = create_parsed_dicom_demo_project(
db,
owner=admin,
dicom_dir=settings.demo_dicom_dir,
project_name=DEMO_DICOM_PROJECT_NAME,
)
db.add(project)
db.commit()
db.refresh(project)
with open(DEFAULT_VIDEO_PATH, "rb") as f:
data = f.read()
object_name = f"uploads/{project.id}/Data_MyVideo_1.mp4"
upload_file(object_name, data, content_type="video/mp4", length=len(data))
project.video_path = object_name
db.commit()
# Parse frames
tmp_dir = tempfile.mkdtemp(prefix=f"seg_seed_{project.id}_")
try:
local_path = os.path.join(tmp_dir, "video.mp4")
with open(local_path, "wb") as f:
f.write(data)
output_dir = os.path.join(tmp_dir, "frames")
os.makedirs(output_dir, exist_ok=True)
frame_files, original_fps = parse_video(local_path, output_dir, fps=30, max_frames=100)
project.original_fps = original_fps
# Extract thumbnail
thumbnail_path = os.path.join(tmp_dir, "thumbnail.jpg")
try:
extract_thumbnail(local_path, thumbnail_path)
with open(thumbnail_path, "rb") as f:
thumb_data = f.read()
thumb_object = f"projects/{project.id}/thumbnail.jpg"
upload_file(thumb_object, thumb_data, content_type="image/jpeg", length=len(thumb_data))
project.thumbnail_url = thumb_object
except Exception as exc: # noqa: BLE001
logger.warning("Thumbnail extraction failed: %s", exc)
object_names = upload_frames_to_minio(frame_files, project.id)
for idx, obj_name in enumerate(object_names):
img = cv2.imread(frame_files[idx])
h, w = img.shape[:2] if img is not None else (None, None)
timestamp_ms = idx * 1000.0 / 30.0
source_frame_number = int(round(idx * original_fps / 30.0)) if original_fps else None
frame = Frame(
project_id=project.id,
frame_index=idx,
image_url=obj_name,
width=w,
height=h,
timestamp_ms=timestamp_ms,
source_frame_number=source_frame_number,
)
db.add(frame)
project.status = PROJECT_STATUS_READY
db.commit()
logger.info("Seeded default project id=%s with %d frames", project.id, len(object_names))
finally:
shutil.rmtree(tmp_dir, ignore_errors=True)
logger.info("Seeded default DICOM project id=%s with %d frames", project.id, len(project.frames))
except Exception as exc:
logger.error("Failed to seed default project: %s", exc)
finally:
@@ -170,53 +141,120 @@ def _seed_default_templates_sync() -> None:
db = SessionLocal()
try:
if db.query(Template).first() is not None:
return
# Laparoscopic cholecystectomy template (35 classes)
colors = [
(134, 124, 118), (0, 157, 142), (245, 161, 0), (255, 172, 159), (146, 175, 236), (155, 62, 0),
(255, 91, 0), (255, 234, 0), (85, 111, 181), (155, 132, 0), (181, 227, 14), (72, 0, 255),
(255, 0, 255), (29, 32, 136), (240, 16, 116), (160, 15, 95), (0, 155, 33), (0, 160, 233),
(52, 184, 178), (66, 115, 82), (90, 120, 41), (255, 0, 0), (117, 0, 0), (167, 24, 233),
(42, 8, 66), (112, 113, 150), (0, 255, 0), (255, 255, 255), (0, 255, 255), (181, 85, 105),
(113, 102, 140), (202, 202, 200), (197, 83, 181), (136, 162, 196), (138, 251, 213),
]
names = [
'', '线', '肿瘤', '血管阻断夹', '棉球', '双极电凝',
'肝脏', '胆囊', '分离钳', '脂肪', '止血海绵', '肝总管',
'吸引器', '剪刀', '超声刀', '止血纱布', '胆总管', '生物夹',
'无损伤钳', '钳夹', '喷洒', '胆囊管', '动脉', '电凝',
'静脉', '标本袋', '引流管', '纱布', '金属钛夹', '韧带',
'肝蒂', '推结器', '乳胶管-血管阻断', '吻合器', '术中超声',
]
classes = []
for idx, (rgb, name) in enumerate(zip(colors, names)):
color_hex = f"#{rgb[0]:02x}{rgb[1]:02x}{rgb[2]:02x}"
classes.append({
"id": f"cls-lap-{idx}",
"name": name,
"color": color_hex,
"zIndex": (len(names) - idx) * 10,
"category": "腹腔镜胆囊切除术",
})
template = Template(
name="腹腔镜胆囊切除术",
description="腹腔镜胆囊切除术LC手术器械与解剖结构语义分割模板共35个分类",
color="#06b6d4",
z_index=0,
mapping_rules={"classes": classes, "rules": []},
)
db.add(template)
db.commit()
logger.info("Seeded default template '腹腔镜胆囊切除术' with %d classes", len(classes))
ensure_default_templates(db)
except Exception as exc:
logger.error("Failed to seed default templates: %s", exc)
finally:
db.close()
def _template_classes(
template_name: str,
names: list[str],
colors: list[tuple[int, int, int]],
*,
id_prefix: str,
) -> list[dict]:
classes = []
for idx, (rgb, name) in enumerate(zip(colors, names)):
color_hex = f"#{rgb[0]:02x}{rgb[1]:02x}{rgb[2]:02x}"
classes.append({
"id": f"{id_prefix}-{idx}",
"name": name,
"color": color_hex,
"zIndex": (len(names) - idx) * 10,
"maskId": idx + 1,
"category": template_name,
})
return classes
def ensure_default_templates(db) -> None:
"""Ensure all bundled system templates exist."""
from models import Template
default_templates = [
{
"name": "腹腔镜胆囊切除术",
"description": "腹腔镜胆囊切除术LC手术器械与解剖结构语义分割模板共35个分类",
"color": "#06b6d4",
"z_index": 0,
"classes": _with_reserved_unclassified_class(_template_classes(
"腹腔镜胆囊切除术",
[
'', '线', '肿瘤', '血管阻断夹', '棉球', '双极电凝',
'肝脏', '胆囊', '分离钳', '脂肪', '止血海绵', '肝总管',
'吸引器', '剪刀', '超声刀', '止血纱布', '胆总管', '生物夹',
'无损伤钳', '钳夹', '喷洒', '胆囊管', '动脉', '电凝',
'静脉', '标本袋', '引流管', '纱布', '金属钛夹', '韧带',
'肝蒂', '推结器', '乳胶管-血管阻断', '吻合器', '术中超声',
],
[
(134, 124, 118), (0, 157, 142), (245, 161, 0), (255, 172, 159), (146, 175, 236), (155, 62, 0),
(255, 91, 0), (255, 234, 0), (85, 111, 181), (155, 132, 0), (181, 227, 14), (72, 0, 255),
(255, 0, 255), (29, 32, 136), (240, 16, 116), (160, 15, 95), (0, 155, 33), (0, 160, 233),
(52, 184, 178), (66, 115, 82), (90, 120, 41), (255, 0, 0), (117, 0, 0), (167, 24, 233),
(42, 8, 66), (112, 113, 150), (0, 255, 0), (255, 255, 255), (0, 255, 255), (181, 85, 105),
(113, 102, 140), (202, 202, 200), (197, 83, 181), (136, 162, 196), (138, 251, 213),
],
id_prefix="cls-lap",
)),
},
{
"name": "头颈部CT分割",
"description": "头颈部CT分割",
"color": "#ef4444",
"z_index": 10,
"classes": _with_reserved_unclassified_class(_template_classes(
"头颈部CT分割",
[
"肿瘤/结节 (Tumor/Nodule)",
"下颌骨 (Mandible)",
"甲状腺 (Thyroid)",
"气管 (Trachea)",
"颈椎 (Cervical Spine)",
"颈动脉 (Carotid Artery)",
"颈静脉 (Jugular Vein)",
"腮腺 (Parotid Gland)",
"下颌下腺 (Submandibular Gland)",
"舌骨 (Hyoid Bone)",
],
[
(255, 0, 0),
(0, 255, 0),
(0, 0, 255),
(255, 255, 0),
(255, 0, 255),
(0, 255, 255),
(255, 128, 0),
(128, 0, 128),
(0, 128, 128),
(128, 128, 0),
],
id_prefix="cls-head-neck-ct",
)),
},
]
for definition in default_templates:
existing = db.query(Template).filter(
Template.name == definition["name"],
Template.owner_user_id.is_(None),
).first()
if existing is not None:
continue
template = Template(
name=definition["name"],
description=definition["description"],
color=definition["color"],
z_index=definition["z_index"],
mapping_rules={"classes": definition["classes"], "rules": []},
owner_user_id=None,
)
db.add(template)
logger.info("Seeded default template '%s' with %d classes", definition["name"], len(definition["classes"]))
db.commit()
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Application lifespan: startup and shutdown hooks."""

View File

@@ -9,7 +9,6 @@ from sqlalchemy.orm import Session
from config import settings
from database import get_db
from minio_client import upload_file
from models import Annotation, AuditLog, Frame, Mask, ProcessingTask, Project, Template, User
from routers.auth import ensure_default_admin, hash_password, require_admin, write_audit_log
from schemas import (
@@ -20,13 +19,19 @@ from schemas import (
DemoFactoryResetRequest,
UserOut,
)
from statuses import PROJECT_STATUS_PENDING
from services.demo_media import (
DEMO_DICOM_PROJECT_NAME,
DEMO_VIDEO_PROJECT_NAME,
create_parsed_dicom_demo_project,
create_unparsed_video_demo_project,
demo_dicom_files,
)
router = APIRouter(prefix="/api/admin", tags=["Admin"])
VALID_ROLES = {"admin", "annotator", "viewer"}
DEMO_RESET_CONFIRMATION = "RESET_DEMO_FACTORY"
DEMO_PROJECT_NAME = "Data_MyVideo_1"
DEMO_PROJECT_NAME = DEMO_DICOM_PROJECT_NAME
def _normalize_role(role: str | None) -> str:
@@ -191,7 +196,7 @@ def reset_demo_factory(
db: Session = Depends(get_db),
admin_user: User = Depends(require_admin),
) -> dict:
"""Reset a demo deployment to one admin account and one unparsed demo video project."""
"""Reset a demo deployment to one admin account, the demo video, and the demo DICOM project."""
if payload.confirmation != DEMO_RESET_CONFIRMATION:
raise HTTPException(status_code=400, detail="Invalid reset confirmation")
@@ -200,6 +205,11 @@ def reset_demo_factory(
status_code=409,
detail=f"Demo video not found: {settings.demo_video_path}",
)
if not demo_dicom_files(settings.demo_dicom_dir):
raise HTTPException(
status_code=409,
detail=f"Demo DICOM series not found: {settings.demo_dicom_dir}",
)
requested_by = admin_user.username
preserved_admin = ensure_default_admin(db)
@@ -226,37 +236,39 @@ def reset_demo_factory(
if not preserved_admin:
raise HTTPException(status_code=500, detail="Default admin was not preserved")
project = Project(
name=DEMO_PROJECT_NAME,
description="默认演示视频,尚未生成帧",
status=PROJECT_STATUS_PENDING,
source_type="video",
parse_fps=30.0,
owner_user_id=preserved_admin.id,
video_project = create_unparsed_video_demo_project(
db,
owner=preserved_admin,
video_path=settings.demo_video_path,
project_name=DEMO_VIDEO_PROJECT_NAME,
)
db.add(project)
db.flush()
video_project.frame_count = 0
with open(settings.demo_video_path, "rb") as file_obj:
data = file_obj.read()
object_name = f"uploads/{project.id}/{os.path.basename(settings.demo_video_path)}"
upload_file(object_name, data, content_type="video/mp4", length=len(data))
project.video_path = object_name
project.thumbnail_url = None
project.original_fps = None
db.commit()
dicom_project = create_parsed_dicom_demo_project(
db,
owner=preserved_admin,
dicom_dir=settings.demo_dicom_dir,
project_name=DEMO_PROJECT_NAME,
)
db.refresh(preserved_admin)
db.refresh(project)
db.refresh(video_project)
db.refresh(dicom_project)
video_project.frame_count = len(video_project.frames)
dicom_project.frame_count = len(dicom_project.frames)
projects = [video_project, dicom_project]
write_audit_log(
db,
actor=preserved_admin,
action="admin.demo_factory_reset",
target_type="project",
target_id=project.id,
target_id=dicom_project.id,
detail={
"project_name": project.name,
"video_path": project.video_path,
"project_names": [project.name for project in projects],
"video_path": video_project.video_path,
"dicom_path": dicom_project.video_path,
"source_types": [project.source_type for project in projects],
"frame_counts": {project.name: len(project.frames) for project in projects},
"deleted_counts": deleted_counts,
"requested_by": requested_by,
},
@@ -264,7 +276,8 @@ def reset_demo_factory(
return {
"admin_user": preserved_admin,
"project": project,
"project": dicom_project,
"projects": projects,
"deleted_counts": deleted_counts,
"message": "演示环境已恢复出厂设置",
}

View File

@@ -39,6 +39,10 @@ from services.sam_registry import ModelUnavailableError, sam_registry
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/ai", tags=["AI"])
GT_MASK_EMPTY_DETAIL = "GT Mask 图片中没有非背景 maskid 区域。"
GT_IMPORT_MAX_CONTOUR_POINTS = 2048
GT_IMPORT_CONTOUR_EPSILON_RATIO = 0.00075
GT_IMPORT_MIN_CONTOUR_EPSILON = 0.35
def _owned_project_or_404(project_id: int, db: Session, current_user: User) -> Project:
@@ -152,13 +156,19 @@ def _load_frame_image(frame: Frame) -> np.ndarray:
def _normalized_contour(contour: np.ndarray, width: int, height: int) -> list[list[float]]:
"""Approximate a contour and convert it to normalized polygon coordinates."""
"""Convert a contour to a detailed normalized polygon with a point-count cap."""
arc_length = cv2.arcLength(contour, True)
epsilon = max(1.0, arc_length * 0.01)
epsilon = max(GT_IMPORT_MIN_CONTOUR_EPSILON, arc_length * GT_IMPORT_CONTOUR_EPSILON_RATIO)
approx = cv2.approxPolyDP(contour, epsilon, True)
while len(approx) > GT_IMPORT_MAX_CONTOUR_POINTS and epsilon < arc_length * 0.02:
epsilon *= 1.5
approx = cv2.approxPolyDP(contour, epsilon, True)
points = approx.reshape(-1, 2)
if len(points) < 3:
points = contour.reshape(-1, 2)
if len(points) > GT_IMPORT_MAX_CONTOUR_POINTS:
step = int(math.ceil(len(points) / GT_IMPORT_MAX_CONTOUR_POINTS))
points = points[::step]
return [
[
min(max(float(x) / max(width, 1), 0.0), 1.0),
@@ -977,6 +987,13 @@ async def import_gt_mask(
if image is None:
raise HTTPException(status_code=400, detail="Invalid mask image")
invalid_format_detail = (
"GT Mask 图片不符合要求:仅支持 8-bit 灰度图,或 8-bit RGB 三通道完全相同的 maskid 图"
"(背景 0像素值为 1-255 的 maskid"
)
if image.dtype != np.uint8:
raise HTTPException(status_code=400, detail=invalid_format_detail)
if image.ndim == 2:
label_image = image
elif image.ndim == 3 and image.shape[2] >= 3:
@@ -984,16 +1001,10 @@ async def import_gt_mask(
# GT label images are maskid maps: either grayscale or RGB/BGR where
# all three color channels contain the same maskid value [X, X, X].
if not (np.array_equal(channels[:, :, 0], channels[:, :, 1]) and np.array_equal(channels[:, :, 1], channels[:, :, 2])):
raise HTTPException(
status_code=400,
detail="GT Mask 图片不符合要求:请上传灰度图,或 RGB 三通道完全相同的 maskid 图(背景 0像素值为 maskid",
)
raise HTTPException(status_code=400, detail=invalid_format_detail)
label_image = channels[:, :, 0]
else:
raise HTTPException(
status_code=400,
detail="GT Mask 图片不符合要求:请上传灰度图,或 RGB 三通道完全相同的 maskid 图(背景 0像素值为 maskid",
)
raise HTTPException(status_code=400, detail=invalid_format_detail)
width = int(frame.width or image.shape[1])
height = int(frame.height or image.shape[0])
@@ -1041,12 +1052,12 @@ async def import_gt_mask(
if not import_items:
if skipped_unknown > 0:
raise HTTPException(status_code=400, detail="No matching GT mask classes found")
raise HTTPException(status_code=400, detail="No foreground mask regions found")
raise HTTPException(status_code=400, detail=GT_MASK_EMPTY_DETAIL)
annotations: list[Annotation] = []
for item in import_items:
binary = item["binary"]
contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
for contour in contours:
if cv2.contourArea(contour) < 1:
@@ -1085,7 +1096,7 @@ async def import_gt_mask(
annotations.append(annotation)
if not annotations:
raise HTTPException(status_code=400, detail="No foreground mask regions found")
raise HTTPException(status_code=400, detail=GT_MASK_EMPTY_DETAIL)
db.commit()
for annotation in annotations:

View File

@@ -64,7 +64,7 @@ def _annotation_mask_id(annotation: Annotation) -> int | None:
value = int(class_meta[key])
except (TypeError, ValueError):
continue
if value > 0:
if value >= 0:
return value
return None
@@ -361,7 +361,7 @@ def _build_gt_class_mapping(annotations: list[Annotation]) -> tuple[dict[str, in
ordered = sorted(
entries_by_key.values(),
key=lambda item: (
item["maskidHint"] if isinstance(item.get("maskidHint"), int) and item["maskidHint"] > 0 else 10_000_000,
item["maskidHint"] if isinstance(item.get("maskidHint"), int) and item["maskidHint"] >= 0 else 10_000_000,
str(item["className"]),
str(item["key"]),
),
@@ -375,6 +375,8 @@ def _build_gt_class_mapping(annotations: list[Annotation]) -> tuple[dict[str, in
nonlocal next_maskid
while next_maskid in used_maskids:
next_maskid += 1
if next_maskid > 255:
raise HTTPException(status_code=400, detail="GT_label 仅支持 8-bit maskid类别值必须在 1-255 之间")
value = next_maskid
used_maskids.add(value)
next_maskid += 1
@@ -382,7 +384,12 @@ def _build_gt_class_mapping(annotations: list[Annotation]) -> tuple[dict[str, in
for entry in ordered:
hinted_maskid = entry.get("maskidHint")
if isinstance(hinted_maskid, int) and hinted_maskid > 0 and hinted_maskid not in used_maskids:
if isinstance(hinted_maskid, int) and hinted_maskid > 255:
raise HTTPException(status_code=400, detail="GT_label 仅支持 8-bit maskid类别值必须在 1-255 之间")
if isinstance(hinted_maskid, int) and hinted_maskid == 0:
maskid = 0
used_maskids.add(maskid)
elif isinstance(hinted_maskid, int) and 0 < hinted_maskid <= 255 and hinted_maskid not in used_maskids:
maskid = hinted_maskid
used_maskids.add(maskid)
else:
@@ -513,7 +520,7 @@ def _write_result_mask_outputs(
)
needs_fused_output = include_semantic or include_pro_label or include_mix_label
semantic = np.zeros((height, width), dtype=np.uint16) if needs_fused_output else None
semantic = np.zeros((height, width), dtype=np.uint8) if needs_fused_output else None
pro_label = np.zeros((height, width, 3), dtype=np.uint8) if (include_pro_label or include_mix_label) else None
if needs_fused_output:

View File

@@ -1,6 +1,7 @@
"""Media upload and parsing endpoints."""
import logging
import re
from pathlib import Path
from typing import List, Optional
@@ -22,6 +23,13 @@ router = APIRouter(prefix="/api/media", tags=["Media"])
ALLOWED_EXTENSIONS = {".mp4", ".avi", ".mov", ".mkv", ".webm", ".png", ".jpg", ".jpeg", ".dcm"}
def natural_filename_key(filename: str) -> tuple[object, ...]:
return tuple(
int(part) if part.isdigit() else part.casefold()
for part in re.split(r"(\d+)", Path(filename).name)
)
def _get_ext(filename: str) -> str:
return Path(filename).suffix.lower()
@@ -124,6 +132,12 @@ async def upload_dicom_batch(
if not files:
raise HTTPException(status_code=400, detail="No files uploaded")
sorted_files = sorted(
[file for file in files if file.filename and file.filename.lower().endswith(".dcm")],
key=lambda file: natural_filename_key(file.filename or ""),
)
if not sorted_files:
raise HTTPException(status_code=400, detail="No valid DICOM files uploaded")
uploaded = []
if project_id:
@@ -135,10 +149,10 @@ async def upload_dicom_batch(
raise HTTPException(status_code=404, detail="Project not found")
else:
# Create new DICOM project
first_name = files[0].filename or "DICOM_Series"
first_name = sorted_files[0].filename or "DICOM_Series"
project = Project(
name=first_name,
description=f"DICOM series with {len(files)} files",
description=f"DICOM series with {len(sorted_files)} files",
status=PROJECT_STATUS_PENDING,
source_type="dicom",
owner_user_id=current_user.id,
@@ -149,9 +163,7 @@ async def upload_dicom_batch(
project_id = project.id
logger.info("Auto-created DICOM project id=%s", project_id)
for file in files:
if not file.filename or not file.filename.lower().endswith(".dcm"):
continue
for file in sorted_files:
data = await file.read()
object_name = f"uploads/{project_id}/dicom/{file.filename}"
try:

View File

@@ -7,15 +7,38 @@ from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.orm import Session
from database import get_db
from models import Project, Frame, User
from models import Annotation, Mask, Project, Frame, User
from routers.auth import get_current_user, require_editor
from schemas import ProjectCreate, ProjectOut, ProjectUpdate, FrameCreate, FrameOut
from schemas import ProjectCopyRequest, ProjectCreate, ProjectOut, ProjectUpdate, FrameCreate, FrameOut
from minio_client import get_presigned_url
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/projects", tags=["Projects"])
def _next_project_copy_name(db: Session, owner_user_id: int, source_name: str) -> str:
base_name = f"{source_name} 副本"
existing_names = {
row[0]
for row in db.query(Project.name)
.filter(Project.owner_user_id == owner_user_id, Project.name.like(f"{base_name}%"))
.all()
}
if base_name not in existing_names:
return base_name
suffix = 2
while f"{base_name} {suffix}" in existing_names:
suffix += 1
return f"{base_name} {suffix}"
def _prepare_project_response(project: Project) -> Project:
project.frame_count = len(project.frames)
if project.thumbnail_url:
project.thumbnail_url = get_presigned_url(project.thumbnail_url, expires=3600)
return project
# ---------------------------------------------------------------------------
# Projects
# ---------------------------------------------------------------------------
@@ -59,9 +82,7 @@ def list_projects(
.all()
)
for p in projects:
p.frame_count = len(p.frames)
if p.thumbnail_url:
p.thumbnail_url = get_presigned_url(p.thumbnail_url, expires=3600)
_prepare_project_response(p)
return projects
@@ -82,10 +103,85 @@ def get_project(
).first()
if not project:
raise HTTPException(status_code=404, detail="Project not found")
project.frame_count = len(project.frames)
if project.thumbnail_url:
project.thumbnail_url = get_presigned_url(project.thumbnail_url, expires=3600)
return project
return _prepare_project_response(project)
@router.post(
"/{project_id}/copy",
response_model=ProjectOut,
status_code=status.HTTP_201_CREATED,
summary="Copy a project",
)
def copy_project(
project_id: int,
payload: ProjectCopyRequest,
db: Session = Depends(get_db),
current_user: User = Depends(require_editor),
) -> Project:
"""Copy a project. Reset copies media/frame sequence; full also copies annotations and mask metadata."""
source = db.query(Project).filter(
Project.id == project_id,
Project.owner_user_id == current_user.id,
).first()
if not source:
raise HTTPException(status_code=404, detail="Project not found")
next_name = (payload.name or "").strip() if payload.name is not None else ""
if not next_name:
next_name = _next_project_copy_name(db, current_user.id, source.name)
copied = Project(
name=next_name,
description=source.description,
video_path=source.video_path,
thumbnail_url=source.thumbnail_url,
status=source.status,
source_type=source.source_type,
original_fps=source.original_fps,
parse_fps=source.parse_fps,
owner_user_id=current_user.id,
)
db.add(copied)
db.flush()
frame_id_map: dict[int, int] = {}
for frame in sorted(source.frames, key=lambda item: item.frame_index):
copied_frame = Frame(
project_id=copied.id,
frame_index=frame.frame_index,
image_url=frame.image_url,
width=frame.width,
height=frame.height,
timestamp_ms=frame.timestamp_ms,
source_frame_number=frame.source_frame_number,
)
db.add(copied_frame)
db.flush()
frame_id_map[frame.id] = copied_frame.id
if payload.mode == "full":
for annotation in sorted(source.annotations, key=lambda item: item.id):
copied_annotation = Annotation(
project_id=copied.id,
frame_id=frame_id_map.get(annotation.frame_id) if annotation.frame_id is not None else None,
template_id=annotation.template_id,
mask_data=annotation.mask_data,
points=annotation.points,
bbox=annotation.bbox,
)
db.add(copied_annotation)
db.flush()
for mask in annotation.masks:
db.add(Mask(
annotation_id=copied_annotation.id,
mask_url=mask.mask_url,
format=mask.format,
))
db.commit()
db.refresh(copied)
logger.info("Copied project id=%s to id=%s mode=%s", project_id, copied.id, payload.mode)
return _prepare_project_response(copied)
@router.patch(
@@ -108,6 +204,10 @@ def update_project(
raise HTTPException(status_code=404, detail="Project not found")
for key, value in payload.model_dump(exclude_unset=True).items():
if key == "name":
value = (value or "").strip()
if not value:
raise HTTPException(status_code=400, detail="Project name is required")
setattr(project, key, value)
db.commit()

View File

@@ -14,15 +14,38 @@ from schemas import TemplateCreate, TemplateOut, TemplateUpdate
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/templates", tags=["Templates"])
RESERVED_UNCLASSIFIED_CLASS = {
"id": "reserved-unclassified",
"name": "待分类",
"color": "#000000",
"zIndex": 0,
"maskId": 0,
"category": "系统保留",
}
def _is_reserved_class(item: dict) -> bool:
return (
item.get("id") == RESERVED_UNCLASSIFIED_CLASS["id"]
or item.get("name") == RESERVED_UNCLASSIFIED_CLASS["name"]
or item.get("maskId") == 0
)
def _normalize_template_classes(classes: list[dict] | None) -> list[dict]:
normalized = [item for item in (classes or []) if not _is_reserved_class(item)]
return [*normalized, dict(RESERVED_UNCLASSIFIED_CLASS)]
def _pack_mapping_rules(data: dict) -> dict:
"""Pack classes/rules into mapping_rules for DB storage."""
mapping = data.get("mapping_rules") or {}
if "classes" in data and data["classes"] is not None:
mapping["classes"] = data.pop("classes")
mapping["classes"] = _normalize_template_classes(data.pop("classes"))
if "rules" in data and data["rules"] is not None:
mapping["rules"] = data.pop("rules")
if "classes" in mapping:
mapping["classes"] = _normalize_template_classes(mapping.get("classes"))
data["mapping_rules"] = mapping
return data
@@ -31,7 +54,7 @@ def _unpack_template(template: Template) -> Template:
"""Unpack mapping_rules into classes/rules for response."""
mapping = template.mapping_rules or {}
# Set as attributes so Pydantic from_attributes can pick them up
template.classes = mapping.get("classes", [])
template.classes = _normalize_template_classes(mapping.get("classes", []))
template.rules = mapping.get("rules", [])
return template

View File

@@ -1,7 +1,7 @@
"""Pydantic schemas for request/response validation."""
from datetime import datetime
from typing import Optional, Any
from typing import Literal, Optional, Any
from pydantic import BaseModel, ConfigDict
@@ -83,6 +83,11 @@ class ProjectUpdate(BaseModel):
parse_fps: Optional[float] = None
class ProjectCopyRequest(BaseModel):
mode: Literal["reset", "full"] = "reset"
name: Optional[str] = None
class ProjectOut(ProjectBase):
model_config = ConfigDict(from_attributes=True)
@@ -96,6 +101,7 @@ class ProjectOut(ProjectBase):
class DemoFactoryResetOut(BaseModel):
admin_user: UserOut
project: ProjectOut
projects: list[ProjectOut]
deleted_counts: dict[str, int]
message: str

View File

@@ -0,0 +1,128 @@
"""Helpers for seeding the bundled demo media project."""
from __future__ import annotations
import os
import shutil
import tempfile
from pathlib import Path
import cv2
from sqlalchemy.orm import Session
from minio_client import upload_file
from models import Frame, Project, User
from services.frame_parser import natural_filename_key, parse_dicom, upload_frames_to_minio
from statuses import PROJECT_STATUS_PENDING, PROJECT_STATUS_READY
DEMO_DICOM_PROJECT_NAME = "演示DICOM序列"
DEMO_DICOM_PARSE_FPS = 30.0
DEMO_VIDEO_PROJECT_NAME = "Data_MyVideo_1"
def demo_dicom_files(dicom_dir: str) -> list[Path]:
"""Return .dcm files in natural file-name order."""
root = Path(dicom_dir)
if not root.exists() or not root.is_dir():
return []
return sorted(
[path for path in root.iterdir() if path.is_file() and path.name.lower().endswith(".dcm")],
key=lambda path: natural_filename_key(path.name),
)
def create_unparsed_video_demo_project(
db: Session,
*,
owner: User,
video_path: str,
project_name: str = DEMO_VIDEO_PROJECT_NAME,
) -> Project:
"""Create the bundled demo video project without extracting frames."""
source = Path(video_path)
if not source.exists() or not source.is_file():
raise FileNotFoundError(f"Demo video not found: {video_path}")
project = Project(
name=project_name,
description="默认演示视频,尚未生成帧",
status=PROJECT_STATUS_PENDING,
source_type="video",
parse_fps=30.0,
original_fps=None,
owner_user_id=owner.id,
)
db.add(project)
db.flush()
data = source.read_bytes()
object_name = f"uploads/{project.id}/{source.name}"
upload_file(object_name, data, content_type="video/mp4", length=len(data))
project.video_path = object_name
project.thumbnail_url = None
db.commit()
db.refresh(project)
return project
def create_parsed_dicom_demo_project(
db: Session,
*,
owner: User,
dicom_dir: str,
project_name: str = DEMO_DICOM_PROJECT_NAME,
) -> Project:
"""Create the demo DICOM project, upload the series, and register parsed frames."""
dcm_files = demo_dicom_files(dicom_dir)
if not dcm_files:
raise FileNotFoundError(f"Demo DICOM series not found: {dicom_dir}")
project = Project(
name=project_name,
description=f"默认演示 DICOM 序列,已按文件名自然顺序生成 {len(dcm_files)}",
status=PROJECT_STATUS_PENDING,
source_type="dicom",
parse_fps=DEMO_DICOM_PARSE_FPS,
original_fps=None,
owner_user_id=owner.id,
)
db.add(project)
db.flush()
dicom_prefix = f"uploads/{project.id}/dicom"
for dcm_file in dcm_files:
data = dcm_file.read_bytes()
upload_file(
f"{dicom_prefix}/{dcm_file.name}",
data,
content_type="application/dicom",
length=len(data),
)
project.video_path = dicom_prefix
tmp_dir = tempfile.mkdtemp(prefix=f"seg_demo_dicom_{project.id}_")
try:
output_dir = os.path.join(tmp_dir, "frames")
frame_files = parse_dicom(dicom_dir, output_dir)
object_names = upload_frames_to_minio(frame_files, project.id)
for idx, obj_name in enumerate(object_names):
image = cv2.imread(frame_files[idx])
height, width = image.shape[:2] if image is not None else (None, None)
db.add(Frame(
project_id=project.id,
frame_index=idx,
image_url=obj_name,
width=width,
height=height,
timestamp_ms=idx * 1000.0 / DEMO_DICOM_PARSE_FPS,
source_frame_number=idx,
))
if object_names:
project.thumbnail_url = object_names[0]
project.status = PROJECT_STATUS_READY
db.commit()
db.refresh(project)
return project
finally:
shutil.rmtree(tmp_dir, ignore_errors=True)

View File

@@ -2,6 +2,7 @@
import logging
import os
import re
import shutil
import subprocess
from pathlib import Path
@@ -16,6 +17,14 @@ from minio_client import upload_file, BUCKET_NAME
logger = logging.getLogger(__name__)
def natural_filename_key(filename: str) -> Tuple[object, ...]:
"""Sort file names by their visible numeric order instead of pure lexicographic order."""
return tuple(
int(part) if part.isdigit() else part.casefold()
for part in re.split(r"(\d+)", Path(filename).name)
)
def get_video_fps(video_path: str) -> float:
"""Read the original frame rate of a video file."""
cap = cv2.VideoCapture(video_path)
@@ -150,7 +159,8 @@ def parse_dicom(
"""
os.makedirs(output_dir, exist_ok=True)
dcm_files = sorted(
[f for f in os.listdir(dicom_dir) if f.lower().endswith(".dcm")]
[f for f in os.listdir(dicom_dir) if f.lower().endswith(".dcm")],
key=natural_filename_key,
)
frame_paths: List[str] = []

View File

@@ -15,6 +15,7 @@ from models import Frame, ProcessingTask, Project
from progress_events import publish_task_progress_event
from services.frame_parser import (
extract_thumbnail,
natural_filename_key,
parse_dicom,
parse_video,
upload_frames_to_minio,
@@ -188,7 +189,10 @@ def run_parse_media_task(db: Session, task_id: int) -> dict[str, Any]:
os.makedirs(dcm_dir, exist_ok=True)
client = get_minio_client()
objects = list(client.list_objects(BUCKET_NAME, prefix=project.video_path, recursive=True))
objects = sorted(
list(client.list_objects(BUCKET_NAME, prefix=project.video_path, recursive=True)),
key=lambda obj: natural_filename_key(obj.object_name),
)
for obj in objects:
_ensure_not_cancelled(db, task)
if obj.object_name.lower().endswith(".dcm"):

View File

@@ -1,6 +1,6 @@
from models import Annotation, AuditLog, Frame, Mask, ProcessingTask, Project, Template, User
from routers.auth import create_access_token, hash_password
from statuses import PROJECT_STATUS_PENDING
from statuses import PROJECT_STATUS_READY
def test_admin_user_management_and_audit_logs(client, db_session):
@@ -83,17 +83,34 @@ def test_admin_cannot_delete_self_or_user_with_projects(client, db_session):
assert response.status_code == 409
def test_demo_factory_reset_leaves_admin_and_unparsed_demo_video(client, db_session, monkeypatch, tmp_path):
def test_demo_factory_reset_leaves_admin_and_parsed_demo_dicom(client, db_session, monkeypatch, tmp_path):
video_path = tmp_path / "Data_MyVideo_1.mp4"
video_path.write_bytes(b"demo-video")
monkeypatch.setattr("routers.admin.settings.demo_video_path", str(video_path))
dicom_dir = tmp_path / "dicom"
dicom_dir.mkdir()
for name in ["10.dcm", "2.dcm", "1.dcm"]:
(dicom_dir / name).write_bytes(name.encode())
monkeypatch.setattr("routers.admin.settings.demo_dicom_dir", str(dicom_dir))
parsed_frame_paths = []
for idx in range(3):
frame_path = tmp_path / f"frame_{idx:06d}.jpg"
frame_path.write_bytes(b"frame")
parsed_frame_paths.append(str(frame_path))
uploaded = []
monkeypatch.setattr("routers.admin.upload_file", lambda object_name, data, content_type, length: uploaded.append({
monkeypatch.setattr("services.demo_media.upload_file", lambda object_name, data, content_type, length: uploaded.append({
"object_name": object_name,
"data": data,
"content_type": content_type,
"length": length,
}))
monkeypatch.setattr("services.demo_media.parse_dicom", lambda dicom_dir_arg, output_dir: parsed_frame_paths)
monkeypatch.setattr(
"services.demo_media.upload_frames_to_minio",
lambda frame_files, project_id: [f"projects/{project_id}/frames/frame_{idx:06d}.jpg" for idx, _ in enumerate(frame_files)],
)
extra_user = User(username="doctor", password_hash=hash_password("secret123"), role="annotator", is_active=1)
db_session.add(extra_user)
@@ -113,7 +130,15 @@ def test_demo_factory_reset_leaves_admin_and_unparsed_demo_video(client, db_sess
z_index=1,
owner_user_id=extra_user.id,
)
db_session.add_all([task, private_template])
system_template = Template(
name="头颈部CT分割",
description="头颈部CT分割",
color="#ef4444",
z_index=10,
owner_user_id=None,
mapping_rules={"classes": [{"name": "肿瘤/结节 (Tumor/Nodule)", "color": "#ff0000", "maskId": 1}], "rules": []},
)
db_session.add_all([task, private_template, system_template])
db_session.commit()
db_session.refresh(frame)
annotation = Annotation(project_id=old_project.id, frame_id=frame.id, mask_data={"label": "old"})
@@ -130,24 +155,36 @@ def test_demo_factory_reset_leaves_admin_and_unparsed_demo_video(client, db_sess
data = response.json()
assert data["message"] == "演示环境已恢复出厂设置"
assert data["admin_user"]["username"] == "admin"
assert data["project"]["name"] == "Data_MyVideo_1"
assert data["project"]["status"] == PROJECT_STATUS_PENDING
assert data["project"]["frame_count"] == 0
assert data["project"]["video_path"] == f"uploads/{data['project']['id']}/Data_MyVideo_1.mp4"
assert uploaded == [{
"object_name": data["project"]["video_path"],
"data": b"demo-video",
"content_type": "video/mp4",
"length": len(b"demo-video"),
}]
assert data["project"]["name"] == "演示DICOM序列"
assert data["project"]["status"] == PROJECT_STATUS_READY
assert data["project"]["source_type"] == "dicom"
assert data["project"]["frame_count"] == 3
assert data["project"]["video_path"] == f"uploads/{data['project']['id']}/dicom"
assert [project["name"] for project in data["projects"]] == ["Data_MyVideo_1", "演示DICOM序列"]
assert data["projects"][0]["status"] == "pending"
assert data["projects"][0]["source_type"] == "video"
assert data["projects"][0]["frame_count"] == 0
assert data["projects"][1]["status"] == PROJECT_STATUS_READY
assert data["projects"][1]["source_type"] == "dicom"
assert data["projects"][1]["frame_count"] == 3
assert [item["object_name"] for item in uploaded] == [
f"uploads/{data['projects'][0]['id']}/Data_MyVideo_1.mp4",
f"uploads/{data['project']['id']}/dicom/1.dcm",
f"uploads/{data['project']['id']}/dicom/2.dcm",
f"uploads/{data['project']['id']}/dicom/10.dcm",
]
assert [item["content_type"] for item in uploaded] == ["video/mp4", "application/dicom", "application/dicom", "application/dicom"]
assert [user.username for user in db_session.query(User).all()] == ["admin"]
assert db_session.query(Project).count() == 1
assert db_session.query(Frame).count() == 0
assert db_session.query(Project).count() == 2
assert db_session.query(Frame).count() == 3
assert [frame.source_frame_number for frame in db_session.query(Frame).order_by(Frame.frame_index).all()] == [0, 1, 2]
assert db_session.query(Annotation).count() == 0
assert db_session.query(Mask).count() == 0
assert db_session.query(ProcessingTask).count() == 0
assert db_session.query(Template).filter(Template.owner_user_id.is_not(None)).count() == 0
preserved_templates = db_session.query(Template).filter(Template.owner_user_id.is_(None)).all()
assert [template.name for template in preserved_templates] == ["头颈部CT分割"]
assert db_session.query(AuditLog).count() == 1
assert db_session.query(AuditLog).first().action == "admin.demo_factory_reset"

View File

@@ -1149,6 +1149,81 @@ def test_import_gt_mask_creates_annotations_with_seed_points(client):
assert 0.0 <= body[0]["points"][0][1] <= 1.0
def test_import_gt_mask_polygons_work_with_analysis_and_smoothing(client):
project, frame, _ = _create_project_and_frame(client)
mask = np.zeros((360, 640), dtype=np.uint8)
cv2.ellipse(mask, (260, 160), (130, 70), 20, 0, 360, 1, thickness=-1)
ok, encoded = cv2.imencode(".png", mask)
assert ok
response = client.post(
"/api/ai/import-gt-mask",
data={
"project_id": str(project["id"]),
"frame_id": str(frame["id"]),
"label": "Imported GT",
"color": "#22c55e",
},
files={"file": ("mask.png", encoded.tobytes(), "image/png")},
)
assert response.status_code == 201
annotation = response.json()[0]
assert annotation["mask_data"]["source"] == "gt_mask"
analysis = client.post("/api/ai/analyze-mask", json={
"frame_id": frame["id"],
"mask_data": annotation["mask_data"],
"points": annotation["points"],
"bbox": annotation["bbox"],
})
assert analysis.status_code == 200
assert analysis.json()["topology_anchor_count"] == len(annotation["mask_data"]["polygons"][0])
smoothing = client.post("/api/ai/smooth-mask", json={
"frame_id": frame["id"],
"mask_data": annotation["mask_data"],
"points": annotation["points"],
"bbox": annotation["bbox"],
"strength": 35,
})
assert smoothing.status_code == 200
assert smoothing.json()["topology_anchor_count"] == len(smoothing.json()["polygons"][0])
def test_import_gt_mask_preserves_detailed_contours(client):
project, frame, _ = _create_project_and_frame(client)
mask = np.zeros((360, 640), dtype=np.uint8)
center = np.array([320, 180])
vertices = []
for index in range(96):
angle = 2 * np.pi * index / 96
radius = 120 if index % 2 == 0 else 88
vertices.append([
int(center[0] + np.cos(angle) * radius),
int(center[1] + np.sin(angle) * radius),
])
cv2.fillPoly(mask, [np.array(vertices, dtype=np.int32)], 1)
ok, encoded = cv2.imencode(".png", mask)
assert ok
response = client.post(
"/api/ai/import-gt-mask",
data={
"project_id": str(project["id"]),
"frame_id": str(frame["id"]),
"label": "Detailed GT",
"color": "#22c55e",
},
files={"file": ("mask.png", encoded.tobytes(), "image/png")},
)
assert response.status_code == 201
polygon = response.json()[0]["mask_data"]["polygons"][0]
assert len(polygon) > 80
assert len(polygon) <= 2048
def test_import_gt_mask_splits_label_values(client):
project, frame, _ = _create_project_and_frame(client)
mask = np.zeros((360, 640), dtype=np.uint8)
@@ -1174,7 +1249,27 @@ def test_import_gt_mask_splits_label_values(client):
assert all(len(item["points"]) == 1 for item in body)
def test_import_gt_mask_preserves_low_value_gtlabel_png(client):
def test_import_gt_mask_rejects_background_only_label_image(client):
project, frame, _ = _create_project_and_frame(client)
mask = np.zeros((360, 640), dtype=np.uint8)
ok, encoded = cv2.imencode(".png", mask)
assert ok
response = client.post(
"/api/ai/import-gt-mask",
data={
"project_id": str(project["id"]),
"frame_id": str(frame["id"]),
"label": "GT Class",
},
files={"file": ("empty-gt-label.png", encoded.tobytes(), "image/png")},
)
assert response.status_code == 400
assert response.json()["detail"] == "GT Mask 图片中没有非背景 maskid 区域。"
def test_import_gt_mask_accepts_uint8_low_value_gtlabel_png(client):
project, frame, _ = _create_project_and_frame(client)
template = client.post("/api/templates", json={
"name": "GTLabel Template",
@@ -1185,7 +1280,7 @@ def test_import_gt_mask_preserves_low_value_gtlabel_png(client):
],
"rules": [],
}).json()
mask = np.zeros((360, 640), dtype=np.uint16)
mask = np.zeros((360, 640), dtype=np.uint8)
cv2.rectangle(mask, (40, 40), (140, 140), 1, thickness=-1)
ok, encoded = cv2.imencode(".png", mask)
assert ok
@@ -1241,7 +1336,7 @@ def test_import_gt_mask_rejects_rgb_color_masks(client):
assert "RGB 三通道完全相同" in response.json()["detail"]
def test_import_gt_mask_reads_uint16_gt_label_and_maps_maskid_class(client):
def test_import_gt_mask_rejects_uint16_gt_label(client):
project, frame, _ = _create_project_and_frame(client)
template = client.post("/api/templates", json={
"name": "Label Template",
@@ -1266,13 +1361,8 @@ def test_import_gt_mask_reads_uint16_gt_label_and_maps_maskid_class(client):
files={"file": ("gt_label.png", encoded.tobytes(), "image/png")},
)
assert response.status_code == 201
body = response.json()
assert len(body) == 1
assert body[0]["mask_data"]["gt_label_value"] == 1
assert body[0]["mask_data"]["label"] == "肿瘤"
assert body[0]["mask_data"]["class"]["maskId"] == 1
assert body[0]["mask_data"]["class"]["color"] == "#ff0000"
assert response.status_code == 400
assert "仅支持 8-bit" in response.json()["detail"]
def test_import_gt_mask_handles_unknown_maskid_policy_and_resizes_to_frame(client):

View File

@@ -169,6 +169,7 @@ def test_export_results_zip_contains_coco_original_images_and_selected_mask_outp
"key": f"template:{annotation['template_id']}",
"template_id": annotation["template_id"],
}]
assert gt_label.dtype == np.uint8
assert gt_label[0, 0] == 0
assert gt_label[20, 50] == 1
assert pro_label[20, 50].tolist() == [212, 182, 6]
@@ -234,6 +235,7 @@ def test_export_results_uses_internal_layer_order_for_gt_pro_and_mix_outputs(cli
cv2.IMREAD_COLOR,
)
assert gt_label.dtype == np.uint8
assert gt_label[10, 10] == high_value
assert pro_label[10, 10].tolist() == [0, 0, 255]
assert mix_label[10, 10].tolist() == [127, 127, 255]
@@ -365,10 +367,74 @@ def test_export_results_preserves_template_maskid_consistently_across_frames(cli
"key": "class:tumor",
"template_id": None,
}]
assert first_label.dtype == np.uint8
assert second_label.dtype == np.uint8
assert first_label[5, 5] == 7
assert second_label[5, 5] == 7
def test_export_results_keeps_unclassified_maskid_zero_black_in_gt_and_pro(client, monkeypatch):
monkeypatch.setattr("routers.export.download_file", lambda object_name: _fake_image_bytes(20, 20))
project = client.post("/api/projects", json={
"name": "Unclassified Export Project",
"video_path": "uploads/8/unclassified.mp4",
}).json()
frame = client.post(f"/api/projects/{project['id']}/frames", json={
"project_id": project["id"],
"frame_index": 0,
"image_url": "frames/source.jpg",
"width": 20,
"height": 20,
"timestamp_ms": 0,
}).json()
client.post("/api/ai/annotate", json={
"project_id": project["id"],
"frame_id": frame["id"],
"mask_data": {
"polygons": [[[0.1, 0.1], [0.8, 0.1], [0.8, 0.8], [0.1, 0.8]]],
"label": "待分类",
"color": "#000000",
"class": {
"id": "reserved-unclassified",
"name": "待分类",
"color": "#000000",
"maskId": 0,
"zIndex": 0,
},
},
})
response = client.get(f"/api/export/{project['id']}/results?scope=all&outputs=gt_label,pro_label")
assert response.status_code == 200
with zipfile.ZipFile(BytesIO(response.content)) as archive:
mapping = json.loads(archive.read("maskid_GT像素值_类别映射.json"))
stem = "unclassified_0h00m00s000ms_frame000001"
gt_label = cv2.imdecode(
np.frombuffer(archive.read(f"GT_label图/{stem}.png"), dtype=np.uint8),
cv2.IMREAD_UNCHANGED,
)
pro_label = cv2.imdecode(
np.frombuffer(archive.read(f"Pro_label彩色分割结果/{stem}.png"), dtype=np.uint8),
cv2.IMREAD_COLOR,
)
assert mapping["classes"] == [{
"gt_pixel_value": 0,
"maskid": 0,
"chineseName": "待分类",
"className": "待分类",
"categoryName": "",
"rgb": [0, 0, 0],
"color": "#000000",
"key": "class:reserved-unclassified",
"template_id": None,
}]
assert gt_label.dtype == np.uint8
assert gt_label[5, 5] == 0
assert pro_label[5, 5].tolist() == [0, 0, 0]
def test_exported_gtlabel_round_trips_through_gt_mask_import_with_template_maskid(client, monkeypatch):
monkeypatch.setattr("routers.export.download_file", lambda object_name: _fake_image_bytes(20, 20))
project = client.post("/api/projects", json={
@@ -423,6 +489,7 @@ def test_exported_gtlabel_round_trips_through_gt_mask_import_with_template_maski
gt_label = cv2.imdecode(np.frombuffer(exported_gt_label, dtype=np.uint8), cv2.IMREAD_UNCHANGED)
mapping = json.loads(archive.read("maskid_GT像素值_类别映射.json"))
assert gt_label.dtype == np.uint8
assert gt_label[5, 5] == 7
assert mapping["classes"][0]["maskid"] == 7
@@ -446,6 +513,36 @@ def test_exported_gtlabel_round_trips_through_gt_mask_import_with_template_maski
assert imported[0]["mask_data"]["class"]["maskId"] == 7
def test_export_results_rejects_gtlabel_maskid_outside_uint8_range(client, monkeypatch):
monkeypatch.setattr("routers.export.download_file", lambda object_name: _fake_image_bytes(20, 20))
project = client.post("/api/projects", json={
"name": "Large MaskId Project",
"video_path": "uploads/8/large-maskid.mp4",
}).json()
frame = client.post(f"/api/projects/{project['id']}/frames", json={
"project_id": project["id"],
"frame_index": 0,
"image_url": "frames/source.jpg",
"width": 20,
"height": 20,
}).json()
client.post("/api/ai/annotate", json={
"project_id": project["id"],
"frame_id": frame["id"],
"mask_data": {
"polygons": [[[0.1, 0.1], [0.8, 0.1], [0.8, 0.8], [0.1, 0.8]]],
"label": "TooLarge",
"color": "#ff0000",
"class": {"id": "too-large", "name": "TooLarge", "color": "#ff0000", "maskId": 300, "zIndex": 30},
},
})
response = client.get(f"/api/export/{project['id']}/results?scope=all&outputs=gt_label")
assert response.status_code == 400
assert "8-bit maskid" in response.json()["detail"]
def test_export_missing_project_returns_404(client):
assert client.get("/api/export/999/coco").status_code == 404
assert client.get("/api/export/999/masks").status_code == 404

View File

@@ -48,15 +48,37 @@ def test_upload_dicom_batch_filters_files_and_creates_project(client, monkeypatc
response = client.post(
"/api/media/upload/dicom",
files=[
("files", ("a.dcm", b"dcm", "application/dicom")),
("files", ("10.dcm", b"dcm10", "application/dicom")),
("files", ("skip.txt", b"text", "text/plain")),
("files", ("2.dcm", b"dcm2", "application/dicom")),
("files", ("1.dcm", b"dcm1", "application/dicom")),
],
)
assert response.status_code == 201
data = response.json()
assert data["uploaded_count"] == 1
assert uploaded == [f"uploads/{data['project_id']}/dicom/a.dcm"]
assert data["uploaded_count"] == 3
assert uploaded == [
f"uploads/{data['project_id']}/dicom/1.dcm",
f"uploads/{data['project_id']}/dicom/2.dcm",
f"uploads/{data['project_id']}/dicom/10.dcm",
]
project_detail = client.get(f"/api/projects/{data['project_id']}").json()
assert project_detail["name"] == "1.dcm"
def test_upload_dicom_batch_rejects_when_no_valid_dicom(client, monkeypatch):
monkeypatch.setattr("routers.media.upload_file", lambda *args, **kwargs: None)
response = client.post(
"/api/media/upload/dicom",
files=[
("files", ("notes.txt", b"text", "text/plain")),
],
)
assert response.status_code == 400
assert response.json()["detail"] == "No valid DICOM files uploaded"
def test_parse_media_queues_background_task(client, monkeypatch):
@@ -194,6 +216,101 @@ def test_parse_task_runner_registers_frames(client, db_session, monkeypatch, tmp
assert frames[0]["source_frame_number"] == 0
def test_parse_dicom_reads_files_in_natural_filename_order(monkeypatch, tmp_path):
from pathlib import Path
import numpy as np
from services.frame_parser import parse_dicom
dcm_dir = tmp_path / "dcm"
output_dir = tmp_path / "frames"
dcm_dir.mkdir()
for name in ["10.dcm", "2.dcm", "1.dcm"]:
(dcm_dir / name).write_bytes(b"dcm")
read_order = []
class FakeDicom:
pixel_array = np.ones((2, 2), dtype=np.uint8)
def fake_dcmread(path):
read_order.append(Path(path).name)
return FakeDicom()
def fake_imwrite(path, image, params=None):
Path(path).write_bytes(image.tobytes())
return True
monkeypatch.setattr("services.frame_parser.dcmread", fake_dcmread)
monkeypatch.setattr("services.frame_parser.cv2.imwrite", fake_imwrite)
frame_files = parse_dicom(str(dcm_dir), str(output_dir))
assert read_order == ["1.dcm", "2.dcm", "10.dcm"]
assert [Path(path).name for path in frame_files] == ["frame_000000.jpg", "frame_000001.jpg", "frame_000002.jpg"]
def test_parse_task_runner_downloads_dicom_objects_in_natural_filename_order(client, db_session, monkeypatch, tmp_path):
from types import SimpleNamespace
from models import ProcessingTask
from services.media_task_runner import run_parse_media_task
project = client.post("/api/projects", json={
"name": "DICOM",
"video_path": "uploads/1/dicom",
"source_type": "dicom",
"parse_fps": 30,
}).json()
task = ProcessingTask(
task_type="parse_dicom",
status="queued",
progress=0,
project_id=project["id"],
payload={"source_type": "dicom"},
)
db_session.add(task)
db_session.commit()
db_session.refresh(task)
class FakeClient:
def list_objects(self, bucket, prefix, recursive=True):
return [
SimpleNamespace(object_name=f"{prefix}/10.dcm"),
SimpleNamespace(object_name=f"{prefix}/2.dcm"),
SimpleNamespace(object_name=f"{prefix}/1.dcm"),
]
downloaded = []
frame_files = []
for idx in range(3):
frame_file = tmp_path / f"frame_{idx:06d}.jpg"
frame_file.write_bytes(b"fake image")
frame_files.append(str(frame_file))
monkeypatch.setattr("services.media_task_runner.get_minio_client", lambda: FakeClient())
monkeypatch.setattr(
"services.media_task_runner.download_file",
lambda object_name: downloaded.append(object_name) or b"dcm",
)
monkeypatch.setattr("services.media_task_runner.parse_dicom", lambda *args, **kwargs: frame_files)
monkeypatch.setattr(
"services.media_task_runner.upload_frames_to_minio",
lambda frames, project_id: [f"projects/{project_id}/frames/{idx}.jpg" for idx, _ in enumerate(frames)],
)
monkeypatch.setattr("services.media_task_runner.publish_task_progress_event", lambda task: None)
result = run_parse_media_task(db_session, task.id)
assert result["frames_extracted"] == 3
assert downloaded == [
"uploads/1/dicom/1.dcm",
"uploads/1/dicom/2.dcm",
"uploads/1/dicom/10.dcm",
]
def test_parse_task_runner_skips_already_cancelled_task(db_session):
from models import ProcessingTask
from services.media_task_runner import run_parse_media_task

View File

@@ -42,6 +42,9 @@ def test_project_crud_and_frames(client, monkeypatch):
assert updated.json()["name"] == "Renamed"
assert updated.json()["status"] == "ready"
empty_name = client.patch(f"/api/projects/{project_id}", json={"name": " "})
assert empty_name.status_code == 400
deleted = client.delete(f"/api/projects/{project_id}")
assert deleted.status_code == 204
assert client.get(f"/api/projects/{project_id}").status_code == 404
@@ -83,10 +86,97 @@ def test_delete_project_cascades_related_records(client, db_session):
assert db_session.query(ProcessingTask).filter(ProcessingTask.project_id == project_id).count() == 0
def test_copy_project_reset_copies_frame_sequence_without_annotations(client, db_session):
created = client.post("/api/projects", json={
"name": "Reset Source",
"description": "desc",
"video_path": "uploads/source.mp4",
"thumbnail_url": "thumb.jpg",
"status": "ready",
"parse_fps": 12,
})
assert created.status_code == 201
project_id = created.json()["id"]
frame = client.post(f"/api/projects/{project_id}/frames", json={
"project_id": project_id,
"frame_index": 0,
"image_url": "frames/source/frame_000000.jpg",
"width": 640,
"height": 360,
"timestamp_ms": 0,
"source_frame_number": 0,
})
assert frame.status_code == 201
annotation = client.post("/api/ai/annotate", json={
"project_id": project_id,
"frame_id": frame.json()["id"],
"mask_data": {"label": "Tumor", "polygons": [[[0.1, 0.1], [0.2, 0.1], [0.2, 0.2]]]},
})
assert annotation.status_code == 201
copied = client.post(f"/api/projects/{project_id}/copy", json={"mode": "reset"})
assert copied.status_code == 201
copied_body = copied.json()
assert copied_body["name"] == "Reset Source 副本"
assert copied_body["frame_count"] == 1
assert copied_body["video_path"] == "uploads/source.mp4"
assert copied_body["parse_fps"] == 12
copied_frames = db_session.query(Frame).filter(Frame.project_id == copied_body["id"]).all()
assert len(copied_frames) == 1
assert copied_frames[0].image_url == "frames/source/frame_000000.jpg"
assert db_session.query(Annotation).filter(Annotation.project_id == copied_body["id"]).count() == 0
def test_copy_project_full_copies_annotations_and_mask_metadata(client, db_session):
created = client.post("/api/projects", json={
"name": "Full Source",
"status": "ready",
})
assert created.status_code == 201
project_id = created.json()["id"]
frame = client.post(f"/api/projects/{project_id}/frames", json={
"project_id": project_id,
"frame_index": 0,
"image_url": "frames/source/frame_000000.jpg",
"width": 640,
"height": 360,
})
assert frame.status_code == 201
frame_id = frame.json()["id"]
annotation = client.post("/api/ai/annotate", json={
"project_id": project_id,
"frame_id": frame_id,
"mask_data": {"label": "Tumor", "polygons": [[[0.1, 0.1], [0.2, 0.1], [0.2, 0.2]]]},
"points": [[0.1, 0.1]],
"bbox": [0.1, 0.1, 0.1, 0.1],
})
assert annotation.status_code == 201
annotation_id = annotation.json()["id"]
db_session.add(Mask(annotation_id=annotation_id, mask_url="masks/source.png", format="png"))
db_session.commit()
copied = client.post(f"/api/projects/{project_id}/copy", json={"mode": "full"})
assert copied.status_code == 201
copied_body = copied.json()
copied_frames = db_session.query(Frame).filter(Frame.project_id == copied_body["id"]).all()
copied_annotations = db_session.query(Annotation).filter(Annotation.project_id == copied_body["id"]).all()
assert copied_body["name"] == "Full Source 副本"
assert len(copied_frames) == 1
assert len(copied_annotations) == 1
assert copied_annotations[0].id != annotation_id
assert copied_annotations[0].frame_id == copied_frames[0].id
assert copied_annotations[0].mask_data["label"] == "Tumor"
assert copied_annotations[0].bbox == [0.1, 0.1, 0.1, 0.1]
assert copied_annotations[0].masks[0].mask_url == "masks/source.png"
def test_project_and_frame_404s(client):
assert client.get("/api/projects/999").status_code == 404
assert client.patch("/api/projects/999", json={"name": "x"}).status_code == 404
assert client.delete("/api/projects/999").status_code == 404
assert client.post("/api/projects/999/copy", json={"mode": "reset"}).status_code == 404
assert client.post("/api/projects/999/frames", json={
"project_id": 999,
"frame_index": 0,

View File

@@ -37,3 +37,55 @@ def test_template_404s(client):
assert client.get("/api/templates/999").status_code == 404
assert client.patch("/api/templates/999", json={"name": "x"}).status_code == 404
assert client.delete("/api/templates/999").status_code == 404
def test_default_head_neck_ct_template_is_seeded_and_visible(client, db_session):
from main import ensure_default_templates
from models import Template
ensure_default_templates(db_session)
ensure_default_templates(db_session)
templates = db_session.query(Template).filter(Template.owner_user_id.is_(None)).all()
names = [template.name for template in templates]
assert names.count("头颈部CT分割") == 1
listing = client.get("/api/templates")
assert listing.status_code == 200
head_neck = next(template for template in listing.json() if template["name"] == "头颈部CT分割")
assert head_neck["description"] == "头颈部CT分割"
expected_names = [
"肿瘤/结节 (Tumor/Nodule)",
"下颌骨 (Mandible)",
"甲状腺 (Thyroid)",
"气管 (Trachea)",
"颈椎 (Cervical Spine)",
"颈动脉 (Carotid Artery)",
"颈静脉 (Jugular Vein)",
"腮腺 (Parotid Gland)",
"下颌下腺 (Submandibular Gland)",
"舌骨 (Hyoid Bone)",
"待分类",
]
expected_colors = [
"#ff0000",
"#00ff00",
"#0000ff",
"#ffff00",
"#ff00ff",
"#00ffff",
"#ff8000",
"#800080",
"#008080",
"#808000",
"#000000",
]
actual_names = [item["name"] for item in head_neck["classes"]]
actual_colors = [item["color"] for item in head_neck["classes"]]
actual_mask_ids = [item["maskId"] for item in head_neck["classes"]]
if actual_names != expected_names:
raise AssertionError(f"Unexpected head-neck classes: {actual_names}")
if actual_colors != expected_colors:
raise AssertionError(f"Unexpected head-neck colors: {actual_colors}")
if actual_mask_ids != [*list(range(1, 11)), 0]:
raise AssertionError(f"Unexpected head-neck mask IDs: {actual_mask_ids}")