完善项目导入、模板与分割工作区交互
- 增强 DICOM/视频项目导入与演示数据:DICOM 按文件名自然顺序处理,导入后展示上传与解析任务进度,恢复演示出厂设置保留演示视频和演示 DICOM 项目,并补充 demo media seed 逻辑。 - 完善项目管理:项目支持重命名、删除、复制,删除使用站内确认弹窗,复制支持新项目重置和全内容复制,DICOM 项目不显示生成帧入口。 - 完善 GT Mask 与导出链路:只支持 8-bit maskid 图导入,非法/全背景图明确拒绝,尺寸自动适配,高精度 polygon 回显;统一导出默认当前帧,GT_label 使用 uint8 和真实 maskid,待分类 maskid 0 与背景一致。 - 完善分割工作区交互:新增画笔和橡皮擦并支持尺寸控制,移除创建点/线段入口,工具栏按类别分隔,AI 智能分割使用明确 AI 图标,取消黄色 seed point,清空/删除传播 mask 后同步清理空帧时间轴状态。 - 完善传播与时间轴:自动传播使用 SAM 2.1 权重任务,参考帧无遮罩时提示,传播历史按同一蓝色系递进变暗,删除/清空传播链时保留人工或独立 AI 标注来源。 - 完善模板库:新增头颈部 CT 分割默认模板,所有模板保留 maskid 0 待分类,支持鼠标复制模板、拖拽层级、JSON 批量导入预览、删除 label 和站内删除确认。 - 完善用户与高风险确认:用户改密码、删除用户、恢复演示出厂设置和清空人工/AI 标注帧均改为站内确认交互,避免浏览器原生 prompt/confirm。 - 补充前后端测试与文档:更新项目、模板、GT 导入、导出、传播、DICOM、用户管理等测试,并同步 README、AGENTS 和 doc 下实现/契约/测试计划文档。
This commit is contained in:
@@ -9,7 +9,6 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from config import settings
|
||||
from database import get_db
|
||||
from minio_client import upload_file
|
||||
from models import Annotation, AuditLog, Frame, Mask, ProcessingTask, Project, Template, User
|
||||
from routers.auth import ensure_default_admin, hash_password, require_admin, write_audit_log
|
||||
from schemas import (
|
||||
@@ -20,13 +19,19 @@ from schemas import (
|
||||
DemoFactoryResetRequest,
|
||||
UserOut,
|
||||
)
|
||||
from statuses import PROJECT_STATUS_PENDING
|
||||
from services.demo_media import (
|
||||
DEMO_DICOM_PROJECT_NAME,
|
||||
DEMO_VIDEO_PROJECT_NAME,
|
||||
create_parsed_dicom_demo_project,
|
||||
create_unparsed_video_demo_project,
|
||||
demo_dicom_files,
|
||||
)
|
||||
|
||||
router = APIRouter(prefix="/api/admin", tags=["Admin"])
|
||||
|
||||
VALID_ROLES = {"admin", "annotator", "viewer"}
|
||||
DEMO_RESET_CONFIRMATION = "RESET_DEMO_FACTORY"
|
||||
DEMO_PROJECT_NAME = "Data_MyVideo_1"
|
||||
DEMO_PROJECT_NAME = DEMO_DICOM_PROJECT_NAME
|
||||
|
||||
|
||||
def _normalize_role(role: str | None) -> str:
|
||||
@@ -191,7 +196,7 @@ def reset_demo_factory(
|
||||
db: Session = Depends(get_db),
|
||||
admin_user: User = Depends(require_admin),
|
||||
) -> dict:
|
||||
"""Reset a demo deployment to one admin account and one unparsed demo video project."""
|
||||
"""Reset a demo deployment to one admin account, the demo video, and the demo DICOM project."""
|
||||
if payload.confirmation != DEMO_RESET_CONFIRMATION:
|
||||
raise HTTPException(status_code=400, detail="Invalid reset confirmation")
|
||||
|
||||
@@ -200,6 +205,11 @@ def reset_demo_factory(
|
||||
status_code=409,
|
||||
detail=f"Demo video not found: {settings.demo_video_path}",
|
||||
)
|
||||
if not demo_dicom_files(settings.demo_dicom_dir):
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail=f"Demo DICOM series not found: {settings.demo_dicom_dir}",
|
||||
)
|
||||
|
||||
requested_by = admin_user.username
|
||||
preserved_admin = ensure_default_admin(db)
|
||||
@@ -226,37 +236,39 @@ def reset_demo_factory(
|
||||
if not preserved_admin:
|
||||
raise HTTPException(status_code=500, detail="Default admin was not preserved")
|
||||
|
||||
project = Project(
|
||||
name=DEMO_PROJECT_NAME,
|
||||
description="默认演示视频,尚未生成帧",
|
||||
status=PROJECT_STATUS_PENDING,
|
||||
source_type="video",
|
||||
parse_fps=30.0,
|
||||
owner_user_id=preserved_admin.id,
|
||||
video_project = create_unparsed_video_demo_project(
|
||||
db,
|
||||
owner=preserved_admin,
|
||||
video_path=settings.demo_video_path,
|
||||
project_name=DEMO_VIDEO_PROJECT_NAME,
|
||||
)
|
||||
db.add(project)
|
||||
db.flush()
|
||||
video_project.frame_count = 0
|
||||
|
||||
with open(settings.demo_video_path, "rb") as file_obj:
|
||||
data = file_obj.read()
|
||||
object_name = f"uploads/{project.id}/{os.path.basename(settings.demo_video_path)}"
|
||||
upload_file(object_name, data, content_type="video/mp4", length=len(data))
|
||||
project.video_path = object_name
|
||||
project.thumbnail_url = None
|
||||
project.original_fps = None
|
||||
db.commit()
|
||||
dicom_project = create_parsed_dicom_demo_project(
|
||||
db,
|
||||
owner=preserved_admin,
|
||||
dicom_dir=settings.demo_dicom_dir,
|
||||
project_name=DEMO_PROJECT_NAME,
|
||||
)
|
||||
db.refresh(preserved_admin)
|
||||
db.refresh(project)
|
||||
db.refresh(video_project)
|
||||
db.refresh(dicom_project)
|
||||
video_project.frame_count = len(video_project.frames)
|
||||
dicom_project.frame_count = len(dicom_project.frames)
|
||||
projects = [video_project, dicom_project]
|
||||
|
||||
write_audit_log(
|
||||
db,
|
||||
actor=preserved_admin,
|
||||
action="admin.demo_factory_reset",
|
||||
target_type="project",
|
||||
target_id=project.id,
|
||||
target_id=dicom_project.id,
|
||||
detail={
|
||||
"project_name": project.name,
|
||||
"video_path": project.video_path,
|
||||
"project_names": [project.name for project in projects],
|
||||
"video_path": video_project.video_path,
|
||||
"dicom_path": dicom_project.video_path,
|
||||
"source_types": [project.source_type for project in projects],
|
||||
"frame_counts": {project.name: len(project.frames) for project in projects},
|
||||
"deleted_counts": deleted_counts,
|
||||
"requested_by": requested_by,
|
||||
},
|
||||
@@ -264,7 +276,8 @@ def reset_demo_factory(
|
||||
|
||||
return {
|
||||
"admin_user": preserved_admin,
|
||||
"project": project,
|
||||
"project": dicom_project,
|
||||
"projects": projects,
|
||||
"deleted_counts": deleted_counts,
|
||||
"message": "演示环境已恢复出厂设置",
|
||||
}
|
||||
|
||||
@@ -39,6 +39,10 @@ from services.sam_registry import ModelUnavailableError, sam_registry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/api/ai", tags=["AI"])
|
||||
GT_MASK_EMPTY_DETAIL = "GT Mask 图片中没有非背景 maskid 区域。"
|
||||
GT_IMPORT_MAX_CONTOUR_POINTS = 2048
|
||||
GT_IMPORT_CONTOUR_EPSILON_RATIO = 0.00075
|
||||
GT_IMPORT_MIN_CONTOUR_EPSILON = 0.35
|
||||
|
||||
|
||||
def _owned_project_or_404(project_id: int, db: Session, current_user: User) -> Project:
|
||||
@@ -152,13 +156,19 @@ def _load_frame_image(frame: Frame) -> np.ndarray:
|
||||
|
||||
|
||||
def _normalized_contour(contour: np.ndarray, width: int, height: int) -> list[list[float]]:
|
||||
"""Approximate a contour and convert it to normalized polygon coordinates."""
|
||||
"""Convert a contour to a detailed normalized polygon with a point-count cap."""
|
||||
arc_length = cv2.arcLength(contour, True)
|
||||
epsilon = max(1.0, arc_length * 0.01)
|
||||
epsilon = max(GT_IMPORT_MIN_CONTOUR_EPSILON, arc_length * GT_IMPORT_CONTOUR_EPSILON_RATIO)
|
||||
approx = cv2.approxPolyDP(contour, epsilon, True)
|
||||
while len(approx) > GT_IMPORT_MAX_CONTOUR_POINTS and epsilon < arc_length * 0.02:
|
||||
epsilon *= 1.5
|
||||
approx = cv2.approxPolyDP(contour, epsilon, True)
|
||||
points = approx.reshape(-1, 2)
|
||||
if len(points) < 3:
|
||||
points = contour.reshape(-1, 2)
|
||||
if len(points) > GT_IMPORT_MAX_CONTOUR_POINTS:
|
||||
step = int(math.ceil(len(points) / GT_IMPORT_MAX_CONTOUR_POINTS))
|
||||
points = points[::step]
|
||||
return [
|
||||
[
|
||||
min(max(float(x) / max(width, 1), 0.0), 1.0),
|
||||
@@ -977,6 +987,13 @@ async def import_gt_mask(
|
||||
if image is None:
|
||||
raise HTTPException(status_code=400, detail="Invalid mask image")
|
||||
|
||||
invalid_format_detail = (
|
||||
"GT Mask 图片不符合要求:仅支持 8-bit 灰度图,或 8-bit RGB 三通道完全相同的 maskid 图"
|
||||
"(背景 0,像素值为 1-255 的 maskid)。"
|
||||
)
|
||||
if image.dtype != np.uint8:
|
||||
raise HTTPException(status_code=400, detail=invalid_format_detail)
|
||||
|
||||
if image.ndim == 2:
|
||||
label_image = image
|
||||
elif image.ndim == 3 and image.shape[2] >= 3:
|
||||
@@ -984,16 +1001,10 @@ async def import_gt_mask(
|
||||
# GT label images are maskid maps: either grayscale or RGB/BGR where
|
||||
# all three color channels contain the same maskid value [X, X, X].
|
||||
if not (np.array_equal(channels[:, :, 0], channels[:, :, 1]) and np.array_equal(channels[:, :, 1], channels[:, :, 2])):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="GT Mask 图片不符合要求:请上传灰度图,或 RGB 三通道完全相同的 maskid 图(背景 0,像素值为 maskid)。",
|
||||
)
|
||||
raise HTTPException(status_code=400, detail=invalid_format_detail)
|
||||
label_image = channels[:, :, 0]
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="GT Mask 图片不符合要求:请上传灰度图,或 RGB 三通道完全相同的 maskid 图(背景 0,像素值为 maskid)。",
|
||||
)
|
||||
raise HTTPException(status_code=400, detail=invalid_format_detail)
|
||||
|
||||
width = int(frame.width or image.shape[1])
|
||||
height = int(frame.height or image.shape[0])
|
||||
@@ -1041,12 +1052,12 @@ async def import_gt_mask(
|
||||
if not import_items:
|
||||
if skipped_unknown > 0:
|
||||
raise HTTPException(status_code=400, detail="No matching GT mask classes found")
|
||||
raise HTTPException(status_code=400, detail="No foreground mask regions found")
|
||||
raise HTTPException(status_code=400, detail=GT_MASK_EMPTY_DETAIL)
|
||||
|
||||
annotations: list[Annotation] = []
|
||||
for item in import_items:
|
||||
binary = item["binary"]
|
||||
contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||
contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
|
||||
|
||||
for contour in contours:
|
||||
if cv2.contourArea(contour) < 1:
|
||||
@@ -1085,7 +1096,7 @@ async def import_gt_mask(
|
||||
annotations.append(annotation)
|
||||
|
||||
if not annotations:
|
||||
raise HTTPException(status_code=400, detail="No foreground mask regions found")
|
||||
raise HTTPException(status_code=400, detail=GT_MASK_EMPTY_DETAIL)
|
||||
|
||||
db.commit()
|
||||
for annotation in annotations:
|
||||
|
||||
@@ -64,7 +64,7 @@ def _annotation_mask_id(annotation: Annotation) -> int | None:
|
||||
value = int(class_meta[key])
|
||||
except (TypeError, ValueError):
|
||||
continue
|
||||
if value > 0:
|
||||
if value >= 0:
|
||||
return value
|
||||
return None
|
||||
|
||||
@@ -361,7 +361,7 @@ def _build_gt_class_mapping(annotations: list[Annotation]) -> tuple[dict[str, in
|
||||
ordered = sorted(
|
||||
entries_by_key.values(),
|
||||
key=lambda item: (
|
||||
item["maskidHint"] if isinstance(item.get("maskidHint"), int) and item["maskidHint"] > 0 else 10_000_000,
|
||||
item["maskidHint"] if isinstance(item.get("maskidHint"), int) and item["maskidHint"] >= 0 else 10_000_000,
|
||||
str(item["className"]),
|
||||
str(item["key"]),
|
||||
),
|
||||
@@ -375,6 +375,8 @@ def _build_gt_class_mapping(annotations: list[Annotation]) -> tuple[dict[str, in
|
||||
nonlocal next_maskid
|
||||
while next_maskid in used_maskids:
|
||||
next_maskid += 1
|
||||
if next_maskid > 255:
|
||||
raise HTTPException(status_code=400, detail="GT_label 仅支持 8-bit maskid,类别值必须在 1-255 之间")
|
||||
value = next_maskid
|
||||
used_maskids.add(value)
|
||||
next_maskid += 1
|
||||
@@ -382,7 +384,12 @@ def _build_gt_class_mapping(annotations: list[Annotation]) -> tuple[dict[str, in
|
||||
|
||||
for entry in ordered:
|
||||
hinted_maskid = entry.get("maskidHint")
|
||||
if isinstance(hinted_maskid, int) and hinted_maskid > 0 and hinted_maskid not in used_maskids:
|
||||
if isinstance(hinted_maskid, int) and hinted_maskid > 255:
|
||||
raise HTTPException(status_code=400, detail="GT_label 仅支持 8-bit maskid,类别值必须在 1-255 之间")
|
||||
if isinstance(hinted_maskid, int) and hinted_maskid == 0:
|
||||
maskid = 0
|
||||
used_maskids.add(maskid)
|
||||
elif isinstance(hinted_maskid, int) and 0 < hinted_maskid <= 255 and hinted_maskid not in used_maskids:
|
||||
maskid = hinted_maskid
|
||||
used_maskids.add(maskid)
|
||||
else:
|
||||
@@ -513,7 +520,7 @@ def _write_result_mask_outputs(
|
||||
)
|
||||
|
||||
needs_fused_output = include_semantic or include_pro_label or include_mix_label
|
||||
semantic = np.zeros((height, width), dtype=np.uint16) if needs_fused_output else None
|
||||
semantic = np.zeros((height, width), dtype=np.uint8) if needs_fused_output else None
|
||||
pro_label = np.zeros((height, width, 3), dtype=np.uint8) if (include_pro_label or include_mix_label) else None
|
||||
|
||||
if needs_fused_output:
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""Media upload and parsing endpoints."""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
@@ -22,6 +23,13 @@ router = APIRouter(prefix="/api/media", tags=["Media"])
|
||||
ALLOWED_EXTENSIONS = {".mp4", ".avi", ".mov", ".mkv", ".webm", ".png", ".jpg", ".jpeg", ".dcm"}
|
||||
|
||||
|
||||
def natural_filename_key(filename: str) -> tuple[object, ...]:
|
||||
return tuple(
|
||||
int(part) if part.isdigit() else part.casefold()
|
||||
for part in re.split(r"(\d+)", Path(filename).name)
|
||||
)
|
||||
|
||||
|
||||
def _get_ext(filename: str) -> str:
|
||||
return Path(filename).suffix.lower()
|
||||
|
||||
@@ -124,6 +132,12 @@ async def upload_dicom_batch(
|
||||
if not files:
|
||||
raise HTTPException(status_code=400, detail="No files uploaded")
|
||||
|
||||
sorted_files = sorted(
|
||||
[file for file in files if file.filename and file.filename.lower().endswith(".dcm")],
|
||||
key=lambda file: natural_filename_key(file.filename or ""),
|
||||
)
|
||||
if not sorted_files:
|
||||
raise HTTPException(status_code=400, detail="No valid DICOM files uploaded")
|
||||
uploaded = []
|
||||
|
||||
if project_id:
|
||||
@@ -135,10 +149,10 @@ async def upload_dicom_batch(
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
else:
|
||||
# Create new DICOM project
|
||||
first_name = files[0].filename or "DICOM_Series"
|
||||
first_name = sorted_files[0].filename or "DICOM_Series"
|
||||
project = Project(
|
||||
name=first_name,
|
||||
description=f"DICOM series with {len(files)} files",
|
||||
description=f"DICOM series with {len(sorted_files)} files",
|
||||
status=PROJECT_STATUS_PENDING,
|
||||
source_type="dicom",
|
||||
owner_user_id=current_user.id,
|
||||
@@ -149,9 +163,7 @@ async def upload_dicom_batch(
|
||||
project_id = project.id
|
||||
logger.info("Auto-created DICOM project id=%s", project_id)
|
||||
|
||||
for file in files:
|
||||
if not file.filename or not file.filename.lower().endswith(".dcm"):
|
||||
continue
|
||||
for file in sorted_files:
|
||||
data = await file.read()
|
||||
object_name = f"uploads/{project_id}/dicom/{file.filename}"
|
||||
try:
|
||||
|
||||
@@ -7,15 +7,38 @@ from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from database import get_db
|
||||
from models import Project, Frame, User
|
||||
from models import Annotation, Mask, Project, Frame, User
|
||||
from routers.auth import get_current_user, require_editor
|
||||
from schemas import ProjectCreate, ProjectOut, ProjectUpdate, FrameCreate, FrameOut
|
||||
from schemas import ProjectCopyRequest, ProjectCreate, ProjectOut, ProjectUpdate, FrameCreate, FrameOut
|
||||
from minio_client import get_presigned_url
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/api/projects", tags=["Projects"])
|
||||
|
||||
|
||||
def _next_project_copy_name(db: Session, owner_user_id: int, source_name: str) -> str:
|
||||
base_name = f"{source_name} 副本"
|
||||
existing_names = {
|
||||
row[0]
|
||||
for row in db.query(Project.name)
|
||||
.filter(Project.owner_user_id == owner_user_id, Project.name.like(f"{base_name}%"))
|
||||
.all()
|
||||
}
|
||||
if base_name not in existing_names:
|
||||
return base_name
|
||||
suffix = 2
|
||||
while f"{base_name} {suffix}" in existing_names:
|
||||
suffix += 1
|
||||
return f"{base_name} {suffix}"
|
||||
|
||||
|
||||
def _prepare_project_response(project: Project) -> Project:
|
||||
project.frame_count = len(project.frames)
|
||||
if project.thumbnail_url:
|
||||
project.thumbnail_url = get_presigned_url(project.thumbnail_url, expires=3600)
|
||||
return project
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Projects
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -59,9 +82,7 @@ def list_projects(
|
||||
.all()
|
||||
)
|
||||
for p in projects:
|
||||
p.frame_count = len(p.frames)
|
||||
if p.thumbnail_url:
|
||||
p.thumbnail_url = get_presigned_url(p.thumbnail_url, expires=3600)
|
||||
_prepare_project_response(p)
|
||||
return projects
|
||||
|
||||
|
||||
@@ -82,10 +103,85 @@ def get_project(
|
||||
).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
project.frame_count = len(project.frames)
|
||||
if project.thumbnail_url:
|
||||
project.thumbnail_url = get_presigned_url(project.thumbnail_url, expires=3600)
|
||||
return project
|
||||
return _prepare_project_response(project)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/{project_id}/copy",
|
||||
response_model=ProjectOut,
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
summary="Copy a project",
|
||||
)
|
||||
def copy_project(
|
||||
project_id: int,
|
||||
payload: ProjectCopyRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_editor),
|
||||
) -> Project:
|
||||
"""Copy a project. Reset copies media/frame sequence; full also copies annotations and mask metadata."""
|
||||
source = db.query(Project).filter(
|
||||
Project.id == project_id,
|
||||
Project.owner_user_id == current_user.id,
|
||||
).first()
|
||||
if not source:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
|
||||
next_name = (payload.name or "").strip() if payload.name is not None else ""
|
||||
if not next_name:
|
||||
next_name = _next_project_copy_name(db, current_user.id, source.name)
|
||||
|
||||
copied = Project(
|
||||
name=next_name,
|
||||
description=source.description,
|
||||
video_path=source.video_path,
|
||||
thumbnail_url=source.thumbnail_url,
|
||||
status=source.status,
|
||||
source_type=source.source_type,
|
||||
original_fps=source.original_fps,
|
||||
parse_fps=source.parse_fps,
|
||||
owner_user_id=current_user.id,
|
||||
)
|
||||
db.add(copied)
|
||||
db.flush()
|
||||
|
||||
frame_id_map: dict[int, int] = {}
|
||||
for frame in sorted(source.frames, key=lambda item: item.frame_index):
|
||||
copied_frame = Frame(
|
||||
project_id=copied.id,
|
||||
frame_index=frame.frame_index,
|
||||
image_url=frame.image_url,
|
||||
width=frame.width,
|
||||
height=frame.height,
|
||||
timestamp_ms=frame.timestamp_ms,
|
||||
source_frame_number=frame.source_frame_number,
|
||||
)
|
||||
db.add(copied_frame)
|
||||
db.flush()
|
||||
frame_id_map[frame.id] = copied_frame.id
|
||||
|
||||
if payload.mode == "full":
|
||||
for annotation in sorted(source.annotations, key=lambda item: item.id):
|
||||
copied_annotation = Annotation(
|
||||
project_id=copied.id,
|
||||
frame_id=frame_id_map.get(annotation.frame_id) if annotation.frame_id is not None else None,
|
||||
template_id=annotation.template_id,
|
||||
mask_data=annotation.mask_data,
|
||||
points=annotation.points,
|
||||
bbox=annotation.bbox,
|
||||
)
|
||||
db.add(copied_annotation)
|
||||
db.flush()
|
||||
for mask in annotation.masks:
|
||||
db.add(Mask(
|
||||
annotation_id=copied_annotation.id,
|
||||
mask_url=mask.mask_url,
|
||||
format=mask.format,
|
||||
))
|
||||
|
||||
db.commit()
|
||||
db.refresh(copied)
|
||||
logger.info("Copied project id=%s to id=%s mode=%s", project_id, copied.id, payload.mode)
|
||||
return _prepare_project_response(copied)
|
||||
|
||||
|
||||
@router.patch(
|
||||
@@ -108,6 +204,10 @@ def update_project(
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
|
||||
for key, value in payload.model_dump(exclude_unset=True).items():
|
||||
if key == "name":
|
||||
value = (value or "").strip()
|
||||
if not value:
|
||||
raise HTTPException(status_code=400, detail="Project name is required")
|
||||
setattr(project, key, value)
|
||||
|
||||
db.commit()
|
||||
|
||||
@@ -14,15 +14,38 @@ from schemas import TemplateCreate, TemplateOut, TemplateUpdate
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/api/templates", tags=["Templates"])
|
||||
RESERVED_UNCLASSIFIED_CLASS = {
|
||||
"id": "reserved-unclassified",
|
||||
"name": "待分类",
|
||||
"color": "#000000",
|
||||
"zIndex": 0,
|
||||
"maskId": 0,
|
||||
"category": "系统保留",
|
||||
}
|
||||
|
||||
|
||||
def _is_reserved_class(item: dict) -> bool:
|
||||
return (
|
||||
item.get("id") == RESERVED_UNCLASSIFIED_CLASS["id"]
|
||||
or item.get("name") == RESERVED_UNCLASSIFIED_CLASS["name"]
|
||||
or item.get("maskId") == 0
|
||||
)
|
||||
|
||||
|
||||
def _normalize_template_classes(classes: list[dict] | None) -> list[dict]:
|
||||
normalized = [item for item in (classes or []) if not _is_reserved_class(item)]
|
||||
return [*normalized, dict(RESERVED_UNCLASSIFIED_CLASS)]
|
||||
|
||||
|
||||
def _pack_mapping_rules(data: dict) -> dict:
|
||||
"""Pack classes/rules into mapping_rules for DB storage."""
|
||||
mapping = data.get("mapping_rules") or {}
|
||||
if "classes" in data and data["classes"] is not None:
|
||||
mapping["classes"] = data.pop("classes")
|
||||
mapping["classes"] = _normalize_template_classes(data.pop("classes"))
|
||||
if "rules" in data and data["rules"] is not None:
|
||||
mapping["rules"] = data.pop("rules")
|
||||
if "classes" in mapping:
|
||||
mapping["classes"] = _normalize_template_classes(mapping.get("classes"))
|
||||
data["mapping_rules"] = mapping
|
||||
return data
|
||||
|
||||
@@ -31,7 +54,7 @@ def _unpack_template(template: Template) -> Template:
|
||||
"""Unpack mapping_rules into classes/rules for response."""
|
||||
mapping = template.mapping_rules or {}
|
||||
# Set as attributes so Pydantic from_attributes can pick them up
|
||||
template.classes = mapping.get("classes", [])
|
||||
template.classes = _normalize_template_classes(mapping.get("classes", []))
|
||||
template.rules = mapping.get("rules", [])
|
||||
return template
|
||||
|
||||
|
||||
Reference in New Issue
Block a user