完善项目导入、模板与分割工作区交互
- 增强 DICOM/视频项目导入与演示数据:DICOM 按文件名自然顺序处理,导入后展示上传与解析任务进度,恢复演示出厂设置保留演示视频和演示 DICOM 项目,并补充 demo media seed 逻辑。 - 完善项目管理:项目支持重命名、删除、复制,删除使用站内确认弹窗,复制支持新项目重置和全内容复制,DICOM 项目不显示生成帧入口。 - 完善 GT Mask 与导出链路:只支持 8-bit maskid 图导入,非法/全背景图明确拒绝,尺寸自动适配,高精度 polygon 回显;统一导出默认当前帧,GT_label 使用 uint8 和真实 maskid,待分类 maskid 0 与背景一致。 - 完善分割工作区交互:新增画笔和橡皮擦并支持尺寸控制,移除创建点/线段入口,工具栏按类别分隔,AI 智能分割使用明确 AI 图标,取消黄色 seed point,清空/删除传播 mask 后同步清理空帧时间轴状态。 - 完善传播与时间轴:自动传播使用 SAM 2.1 权重任务,参考帧无遮罩时提示,传播历史按同一蓝色系递进变暗,删除/清空传播链时保留人工或独立 AI 标注来源。 - 完善模板库:新增头颈部 CT 分割默认模板,所有模板保留 maskid 0 待分类,支持鼠标复制模板、拖拽层级、JSON 批量导入预览、删除 label 和站内删除确认。 - 完善用户与高风险确认:用户改密码、删除用户、恢复演示出厂设置和清空人工/AI 标注帧均改为站内确认交互,避免浏览器原生 prompt/confirm。 - 补充前后端测试与文档:更新项目、模板、GT 导入、导出、传播、DICOM、用户管理等测试,并同步 README、AGENTS 和 doc 下实现/契约/测试计划文档。
This commit is contained in:
@@ -39,6 +39,10 @@ from services.sam_registry import ModelUnavailableError, sam_registry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/api/ai", tags=["AI"])
|
||||
GT_MASK_EMPTY_DETAIL = "GT Mask 图片中没有非背景 maskid 区域。"
|
||||
GT_IMPORT_MAX_CONTOUR_POINTS = 2048
|
||||
GT_IMPORT_CONTOUR_EPSILON_RATIO = 0.00075
|
||||
GT_IMPORT_MIN_CONTOUR_EPSILON = 0.35
|
||||
|
||||
|
||||
def _owned_project_or_404(project_id: int, db: Session, current_user: User) -> Project:
|
||||
@@ -152,13 +156,19 @@ def _load_frame_image(frame: Frame) -> np.ndarray:
|
||||
|
||||
|
||||
def _normalized_contour(contour: np.ndarray, width: int, height: int) -> list[list[float]]:
|
||||
"""Approximate a contour and convert it to normalized polygon coordinates."""
|
||||
"""Convert a contour to a detailed normalized polygon with a point-count cap."""
|
||||
arc_length = cv2.arcLength(contour, True)
|
||||
epsilon = max(1.0, arc_length * 0.01)
|
||||
epsilon = max(GT_IMPORT_MIN_CONTOUR_EPSILON, arc_length * GT_IMPORT_CONTOUR_EPSILON_RATIO)
|
||||
approx = cv2.approxPolyDP(contour, epsilon, True)
|
||||
while len(approx) > GT_IMPORT_MAX_CONTOUR_POINTS and epsilon < arc_length * 0.02:
|
||||
epsilon *= 1.5
|
||||
approx = cv2.approxPolyDP(contour, epsilon, True)
|
||||
points = approx.reshape(-1, 2)
|
||||
if len(points) < 3:
|
||||
points = contour.reshape(-1, 2)
|
||||
if len(points) > GT_IMPORT_MAX_CONTOUR_POINTS:
|
||||
step = int(math.ceil(len(points) / GT_IMPORT_MAX_CONTOUR_POINTS))
|
||||
points = points[::step]
|
||||
return [
|
||||
[
|
||||
min(max(float(x) / max(width, 1), 0.0), 1.0),
|
||||
@@ -977,6 +987,13 @@ async def import_gt_mask(
|
||||
if image is None:
|
||||
raise HTTPException(status_code=400, detail="Invalid mask image")
|
||||
|
||||
invalid_format_detail = (
|
||||
"GT Mask 图片不符合要求:仅支持 8-bit 灰度图,或 8-bit RGB 三通道完全相同的 maskid 图"
|
||||
"(背景 0,像素值为 1-255 的 maskid)。"
|
||||
)
|
||||
if image.dtype != np.uint8:
|
||||
raise HTTPException(status_code=400, detail=invalid_format_detail)
|
||||
|
||||
if image.ndim == 2:
|
||||
label_image = image
|
||||
elif image.ndim == 3 and image.shape[2] >= 3:
|
||||
@@ -984,16 +1001,10 @@ async def import_gt_mask(
|
||||
# GT label images are maskid maps: either grayscale or RGB/BGR where
|
||||
# all three color channels contain the same maskid value [X, X, X].
|
||||
if not (np.array_equal(channels[:, :, 0], channels[:, :, 1]) and np.array_equal(channels[:, :, 1], channels[:, :, 2])):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="GT Mask 图片不符合要求:请上传灰度图,或 RGB 三通道完全相同的 maskid 图(背景 0,像素值为 maskid)。",
|
||||
)
|
||||
raise HTTPException(status_code=400, detail=invalid_format_detail)
|
||||
label_image = channels[:, :, 0]
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="GT Mask 图片不符合要求:请上传灰度图,或 RGB 三通道完全相同的 maskid 图(背景 0,像素值为 maskid)。",
|
||||
)
|
||||
raise HTTPException(status_code=400, detail=invalid_format_detail)
|
||||
|
||||
width = int(frame.width or image.shape[1])
|
||||
height = int(frame.height or image.shape[0])
|
||||
@@ -1041,12 +1052,12 @@ async def import_gt_mask(
|
||||
if not import_items:
|
||||
if skipped_unknown > 0:
|
||||
raise HTTPException(status_code=400, detail="No matching GT mask classes found")
|
||||
raise HTTPException(status_code=400, detail="No foreground mask regions found")
|
||||
raise HTTPException(status_code=400, detail=GT_MASK_EMPTY_DETAIL)
|
||||
|
||||
annotations: list[Annotation] = []
|
||||
for item in import_items:
|
||||
binary = item["binary"]
|
||||
contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||
contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
|
||||
|
||||
for contour in contours:
|
||||
if cv2.contourArea(contour) < 1:
|
||||
@@ -1085,7 +1096,7 @@ async def import_gt_mask(
|
||||
annotations.append(annotation)
|
||||
|
||||
if not annotations:
|
||||
raise HTTPException(status_code=400, detail="No foreground mask regions found")
|
||||
raise HTTPException(status_code=400, detail=GT_MASK_EMPTY_DETAIL)
|
||||
|
||||
db.commit()
|
||||
for annotation in annotations:
|
||||
|
||||
Reference in New Issue
Block a user