first commit

This commit is contained in:
admin
2026-05-20 15:05:35 +08:00
commit ac09b26253
2048 changed files with 189478 additions and 0 deletions

View File

@@ -0,0 +1,564 @@
import logging, argparse, sys, utils
import shutil, tempfile
from pathlib import Path
from typing import Dict, List, Tuple
import albumentations as A
import cv2
import matplotlib.pyplot as plt
import numpy as np
import segmentation_models_pytorch as smp
import torch
from albumentations.pytorch import ToTensorV2
from scipy.special import softmax
from sklearn.metrics import (auc, average_precision_score,
precision_recall_curve, roc_curve)
from tqdm import tqdm
# 本地应用/库的导入
import config
from utils import log_metrics_to_csv, get_preprocessing_transform, setup_directories
# --- 日志设置 ---
# 配置日志记录器
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(message)s",
handlers=[logging.StreamHandler()] # 输出日志到控制台
)
# 为 matplotlib 设置中文字体,请根据您的系统环境选择合适的字体
plt.rcParams['font.sans-serif'] = ['Noto Sans CJK JP', 'SimHei', 'Microsoft YaHei'] # 例如SimHei, Microsoft YaHei
plt.rcParams['axes.unicode_minus'] = False # 解决负号显示问题
# --- 核心工具函数 ---
# 将单通道的类别索引掩码转换为 RGB 彩色掩码
def colorize_mask(mask: np.ndarray, class_rgb_values: List[Tuple[int, int, int]]) -> np.ndarray:
"""将单通道的类别索引掩码转换为 RGB 彩色掩码。"""
color_mask = np.zeros((mask.shape[0], mask.shape[1], 3), dtype=np.uint8)
# 一个通用的调色板,可以轻松扩展
palette = [
(0, 0, 0), (220, 20, 60), (0, 255, 0), (0, 0, 255),
(255, 255, 0), (255, 0, 255), (0, 255, 255), (244, 164, 96),
]
for class_idx, color in enumerate(palette):
if class_idx < len(class_rgb_values):
# 假设掩码中的像素值 [0, 1, 2...] 对应类别的索引
pixels_to_color = (mask == class_idx)
color_mask[pixels_to_color] = color
return color_mask
# --- 分析与可视化函数 ---
def generate_and_save_curves(y_true: np.ndarray, y_probs: np.ndarray, class_names: List[str], output_dir: Path) -> None:
"""
根据整个测试集的聚合预测结果,为每个类别生成并保存 ROC 和 PR 曲线。
"""
logging.info("正在生成 ROC 和 PR 曲线...")
num_classes = len(class_names)
plt.style.use('seaborn-v0_8-whitegrid')
# --- 准备画布 ---
fig_roc, ax_roc = plt.subplots(figsize=(10, 8))
ax_roc.plot([0, 1], [0, 1], 'k--', label='随机猜测')
fig_pr, ax_pr = plt.subplots(figsize=(10, 8))
# --- 为每个类别(跳过背景)生成曲线 ---
for i in range(1, num_classes):
class_name = class_names[i]
y_true_binary = (y_true == i).astype(int)
y_class_probs = y_probs[:, i]
# ROC 曲线
fpr, tpr, _ = roc_curve(y_true_binary, y_class_probs)
roc_auc = auc(fpr, tpr)
ax_roc.plot(fpr, tpr, label=f'类别 {class_name} (AUC = {roc_auc:.4f})')
# PR 曲线
precision, recall, _ = precision_recall_curve(y_true_binary, y_class_probs)
avg_precision = average_precision_score(y_true_binary, y_class_probs)
ax_pr.plot(recall, precision, label=f'类别 {class_name} (AP = {avg_precision:.4f})')
# --- 美化并保存 ROC 曲线图 ---
ax_roc.set(xlabel='False positive rate (FPR)', ylabel='True positive rate (TPR)', title='Receiver operating characteristic (ROC) curve')
ax_roc.legend(loc='lower right')
ax_roc.set_aspect('equal', adjustable='box')
fig_roc.tight_layout()
roc_save_path = output_dir / "ROC_Curves.png"
fig_roc.savefig(roc_save_path)
plt.close(fig_roc)
logging.info(f"-> ROC 曲线图已保存至: {roc_save_path}")
# --- 美化并保存 PR 曲线图 ---
ax_pr.set(xlabel='Recall', ylabel='Precision', title='Precision-Recall (PR) Curve')
ax_pr.legend(loc='lower left')
ax_pr.set_aspect('equal', adjustable='box')
fig_pr.tight_layout()
pr_save_path = output_dir / "Precision_Recall_Curves.png"
fig_pr.savefig(pr_save_path)
plt.close(fig_pr)
logging.info(f"-> PR 曲线图已保存至: {pr_save_path}")
def save_visual_comparison(
image_name: str, original_image: np.ndarray, pred_mask: np.ndarray,
gt_mask: np.ndarray = None, stats: Dict = None, stats_text: str = "",
save_dir: str = ""
) -> None:
"""保存原始图像、预测掩码以及可选的真实掩码的对比图。"""
pred_color_mask = colorize_mask(pred_mask, config.CLASS_RGB_VALUES)
if gt_mask is not None and stats is not None:
num_plots = 3
figsize = (22, 8)
gt_color_mask = colorize_mask(gt_mask, config.CLASS_RGB_VALUES)
save_name = f"{Path(image_name).stem}-IOU_{stats['iou']:.4f}-ACC_{stats['acc']:.4f}.png"
else:
num_plots = 2
figsize = (14, 6)
save_name = f"{Path(image_name).stem}.png"
fig, axes = plt.subplots(1, num_plots, figsize=figsize)
fig.suptitle(f"Predict: {image_name}", fontsize=16)
axes[0].imshow(original_image)
axes[0].set_title("ORI_IMG")
axes[0].axis('off')
pred_title = "Predicted Mask"
if stats:
pred_title += f"\nIoU: {stats['iou']:.4f} | Acc: {stats['acc']:.4f}"
axes[1].imshow(pred_color_mask)
axes[1].set_title(pred_title)
axes[1].axis('off')
if num_plots == 3:
axes[2].imshow(gt_color_mask)
axes[2].set_title("Ground Truth")
axes[2].axis('off')
plt.figtext(0.5, 0.01, stats_text, ha="center", fontsize=10,
bbox={"facecolor":"white", "alpha":0.7, "pad":5}, family='monospace')
fig.subplots_adjust(bottom=0.3)
plt.tight_layout()
plt.savefig(save_dir / save_name)
plt.close(fig)
# --- 主要逻辑函数 ---
def load_model(model_architecture: str, model_save_path: Path) -> torch.nn.Module:
"""
加载已训练的分割模型。
此函数经过优化,可以智能处理由 torch.nn.DataParallel 保存的
(带有 'module.' 前缀) 和常规保存的模型权重。
"""
logging.info(f"正在从以下路径加载模型: {model_save_path}")
num_classes = len(config.CLASSES)
# 1. 【修改】根据 config 动态创建基础模型架构
logging.info(f"正在重建模型架构: '{model_architecture}'")
try:
model_params = config.ALL_MODEL_CONFIGS[model_architecture].copy() #
model_class = getattr(smp, model_architecture)
except AttributeError:
logging.error(f"模型 '{model_architecture}' 在 segmentation_models_pytorch 库中不存在!")
raise
# 2. 【修改】准备参数,与训练时保持一致
# 注意:加载权重时,我们不希望再次下载预训练权重,所以设为 None
params = model_params
params['in_channels'] = 3
params['classes'] = num_classes
params['encoder_weights'] = None # 非常重要避免在预测时重新下载ImageNet权重
# ======================== 【新增代码段开始】 ======================== #
# 自动检测 encoder_name 是否包含 'vit'
encoder_name = model_params.get('encoder_name', '')
if 'vit' in encoder_name.lower():
params['dynamic_img_size'] = True
logging.info(f"检测到 ViT 编码器 ('{encoder_name}')。自动设置 dynamic_img_size=True。")
# ======================== 【新增代码段结束】 ======================== #
# 3. 【修改】使用 **kwargs 创建模型实例
model = model_class(**params).to(config.DEVICE)
# 4. 从文件加载 state_dict
state_dict = torch.load(model_save_path, map_location=torch.device(config.DEVICE))
# 3. 检查是否存在 'module.' 前缀 (判断是否为 DataParallel 保存的权重)
# 我们通过检查字典中任何一个键是否以 'module.' 开头来判断
is_dataparallel = any(key.startswith('module.') for key in state_dict.keys())
if is_dataparallel:
logging.info("检测到模型由 DataParallel 保存,将移除 'module.' 前缀。")
# 创建一个新的、不带前缀的 state_dict
new_state_dict = {key.replace('module.', ''): value for key, value in state_dict.items()}
model.load_state_dict(new_state_dict)
else:
# 如果没有 'module.' 前缀,直接加载
logging.info("直接加载标准模型权重。")
model.load_state_dict(state_dict)
model.eval()
return model
def calculate_and_save_final_metrics(stats_list: List[Dict], num_classes: int, save_dir: str) -> None:
"""聚合每张图像的统计数据,计算总体指标,并保存到 CSV 文件。"""
if not stats_list:
logging.warning("未收集到统计数据,跳过最终指标计算。")
return
logging.info("\n正在聚合结果以计算最终指标...")
# 堆叠所有统计张量并按批次维度求和
tp = torch.cat([s['tp'] for s in stats_list], dim=0).sum(dim=0) # (N×B, C) -> (C,)
fp = torch.cat([s['fp'] for s in stats_list], dim=0).sum(dim=0) # (N×B, C) -> (C,)
fn = torch.cat([s['fn'] for s in stats_list], dim=0).sum(dim=0) # (N×B, C) -> (C,)
tn = torch.cat([s['tn'] for s in stats_list], dim=0).sum(dim=0) # (N×B, C) -> (C,)
# ======================== 【新增代码段开始】 ======================== #
# 根据 config.py 中的设置,筛选出用于计算宏观指标的类别
ignore_indices = [config.CLASSES.index(cls) for cls in config.EVALUATION_CLASSES_TO_IGNORE if cls in config.CLASSES]
keep_mask = torch.ones(num_classes, dtype=torch.bool, device=tp.device)
if ignore_indices:
keep_mask[ignore_indices] = False
logging.info(f"最终评估时将忽略类别: {config.EVALUATION_CLASSES_TO_IGNORE}")
tp_filtered = tp[keep_mask]
fp_filtered = fp[keep_mask]
fn_filtered = fn[keep_mask]
tn_filtered = tn[keep_mask]
# ======================== 【新增代码段结束】 ======================== #
# ======================== 【修改代码段开始】 ======================== #
# 使用过滤后的统计数据计算总体指标
metrics = {
"iou_score": smp.metrics.iou_score(tp_filtered, fp_filtered, fn_filtered, tn_filtered, reduction='micro').item(),
"f1_score": smp.metrics.f1_score(tp_filtered, fp_filtered, fn_filtered, tn_filtered, reduction='micro').item(),
"accuracy": smp.metrics.accuracy(tp_filtered, fp_filtered, fn_filtered, tn_filtered, reduction='micro').item(),
"recall": smp.metrics.recall(tp_filtered, fp_filtered, fn_filtered, tn_filtered, reduction='micro').item(),
"precision": smp.metrics.precision(tp_filtered, fp_filtered, fn_filtered, tn_filtered, reduction='micro').item(),
}
# ======================== 【修改代码段结束】 ======================== #
# 添加每个类别的原始统计数据(使用未经过滤的数据)
for i in range(num_classes):
metrics[f'tp_class_{i}'] = tp[i].item()
metrics[f'tn_class_{i}'] = tn[i].item()
metrics[f'fp_class_{i}'] = fp[i].item()
metrics[f'fn_class_{i}'] = fn[i].item()
csv_save_path = save_dir / "test_set_metrics.csv"
log_metrics_to_csv(metrics, str(csv_save_path))
logging.info(f"-> 整体测试集指标已保存至: {csv_save_path}")
# 2. 【新增】一个帮助函数,用于交互式选择模型目录
def select_trained_model_path(base_dir: Path, model_architecture: str) -> Path:
"""
查找指定架构的所有训练运行目录,并让用户选择一个。
"""
logging.info(f"正在 '{base_dir}' 中搜索模型 '{model_architecture}' 的训练记录...")
# 查找所有以架构名开头的文件夹
run_dirs = sorted([d for d in base_dir.iterdir() if d.is_dir() and d.name.startswith(model_architecture+'_')])
if not run_dirs:
logging.error(f"未找到任何 '{model_architecture}' 模型的训练记录。请先运行 train.py。")
return None
print("\n请选择要用于预测的模型:")
for i, dir_path in enumerate(run_dirs):
print(f" [{i+1}] {dir_path.name}")
while True:
try:
choice = input(f"请输入选项编号 (1-{len(run_dirs)}) 或按 Enter 退出: ")
if not choice:
return None
choice_idx = int(choice) - 1
if 0 <= choice_idx < len(run_dirs):
selected_dir = run_dirs[choice_idx]
logging.info(f"已选择模型: {selected_dir}")
return selected_dir
else:
print("无效的选项,请重新输入。")
except (ValueError, IndexError):
print("无效的输入,请输入一个数字。")
# 【修改】函数签名和内部逻辑,不再接收一个列表,而是直接接收聚合后的统计数据
def calculate_and_save_final_metrics(tp, fp, fn, tn, num_classes: int, save_dir: Path) -> None:
"""根据聚合后的统计数据,计算总体指标,并保存到 CSV 文件。"""
logging.info("\n正在计算最终的整体指标...")
# 根据 config.py 中的设置,筛选出用于计算宏观指标的类别
ignore_indices = [config.CLASSES.index(cls) for cls in config.EVALUATION_CLASSES_TO_IGNORE if cls in config.CLASSES]
keep_mask = torch.ones(num_classes, dtype=torch.bool, device=tp.device)
if ignore_indices:
keep_mask[ignore_indices] = False
logging.info(f"最终评估时将忽略类别: {config.EVALUATION_CLASSES_TO_IGNORE}")
tp_filtered = tp[keep_mask]
fp_filtered = fp[keep_mask]
fn_filtered = fn[keep_mask]
tn_filtered = tn[keep_mask]
# 使用过滤后的统计数据计算总体指标
metrics = {
"iou_score": smp.metrics.iou_score(tp_filtered, fp_filtered, fn_filtered, tn_filtered, reduction='micro').item(),
"f1_score": smp.metrics.f1_score(tp_filtered, fp_filtered, fn_filtered, tn_filtered, reduction='micro').item(),
"accuracy": smp.metrics.accuracy(tp_filtered, fp_filtered, fn_filtered, tn_filtered, reduction='micro').item(),
"recall": smp.metrics.recall(tp_filtered, fp_filtered, fn_filtered, tn_filtered, reduction='micro').item(),
"precision": smp.metrics.precision(tp_filtered, fp_filtered, fn_filtered, tn_filtered, reduction='micro').item(),
}
# 添加每个类别的原始统计数据(使用未经过滤的数据)
for i in range(num_classes):
metrics[f'tp_class_{i}'] = tp[i].item()
metrics[f'tn_class_{i}'] = tn[i].item()
metrics[f'fp_class_{i}'] = fp[i].item()
metrics[f'fn_class_{i}'] = fn[i].item()
csv_save_path = save_dir / "test_set_metrics.csv"
log_metrics_to_csv(metrics, str(csv_save_path))
logging.info(f"-> 整体测试集指标已保存至: {csv_save_path}")
def main(model_architecture: str) -> None:
"""
主函数,执行模型加载、图像预测、评估,并生成分析结果。
"""
# --- 1. 【修改】选择模型并动态设置路径 ---
selected_run_dir = select_trained_model_path(Path(config.PREDICT_BEST_MODEL_DIR), model_architecture)
if not selected_run_dir:
logging.info("未选择任何模型,程序退出。")
sys.exit()
# 【重要】动态覆写 config 中的路径,使其指向所选模型的目录
# 这样,后续所有代码都会自动在正确的文件夹中加载模型和保存结果
best_model_save_path = selected_run_dir / "best_model.pth"
raw_mask_dir = selected_run_dir / "predicted_raw_masks"
analysis_results_dir = selected_run_dir / "prediction_analysis"
if not best_model_save_path.exists():
logging.error(f"错误: 在 '{selected_run_dir}' 中未找到 'best_model.pth' 文件!")
sys.exit()
# setup_directories 函数现在用于创建预测结果的子目录
utils.setup_directories(raw_mask_dir)
utils.setup_directories(analysis_results_dir)
model = load_model(model_architecture, best_model_save_path) # 1. 加载模型
# ======================== 【新增/修改的代码段开始】 ======================== #
# --- 2. 探测模型以确定正确的图像尺寸 ---
logging.info("正在探测模型以确定预测时所需的输入尺寸...") #
target_height, target_width = config.IMAGE_HEIGHT, config.IMAGE_WIDTH #
try:
# 检查编码器是否需要固定输入尺寸
if model.encoder.is_fixed_input_size:
required_size = model.encoder.input_size
target_height = required_size[1]
target_width = required_size[2]
logging.warning(
f"模型 '{model_architecture}' 的编码器需要固定的输入尺寸 "
f"({target_height}x{target_width})。"
f"将忽略 config.py 中的尺寸设置进行预测。"
)
else:
logging.info(f"模型 '{model_architecture}' 支持动态输入尺寸。将使用 config.py 中定义的尺寸 ({target_height}x{target_width})。")
except AttributeError:
logging.info("无法自动检测模型尺寸要求。将使用 config.py 中定义的尺寸。")
# --- 3. 获取用于预测的图像预处理流程 ---
# 使用探测到的或配置中指定的目标尺寸
transform = get_preprocessing_transform(target_height, target_width) #
# ======================== 【新增/修改的代码段结束】 ======================== #
# --- 4. 获取图像列表并判断模式 ---
image_filenames = sorted(list(config.TEST_IMAGE_DIR.iterdir()))
evaluate_mode = config.TEST_MASK_DIR.is_dir()
if evaluate_mode:
logging.info("正在以 [评估模式] 运行。将使用真实掩码进行评估。")
else:
logging.info("正在以 [仅预测模式] 运行。未找到真实掩码。")
# --- 3. 初始化用于增量计算的变量 ---
# 【修改】不再使用 list 来累积,而是直接累加 TP, FP, FN, TN
num_classes = len(config.CLASSES)
total_tp = torch.zeros(num_classes, dtype=torch.long, device=config.DEVICE)
total_fp = torch.zeros(num_classes, dtype=torch.long, device=config.DEVICE)
total_fn = torch.zeros(num_classes, dtype=torch.long, device=config.DEVICE)
total_tn = torch.zeros(num_classes, dtype=torch.long, device=config.DEVICE)
# 【新增】为 ROC/PR 曲线创建临时文件,用磁盘空间换内存
# 使用 tempfile 模块来安全地创建临时文件
temp_gt_file = tempfile.NamedTemporaryFile(delete=False, suffix='.npy')
temp_probs_file = tempfile.NamedTemporaryFile(delete=False, suffix='.npy')
logging.info(f"为ROC/PR曲线创建临时文件: {temp_gt_file.name}, {temp_probs_file.name}")
# --- 4. 主推理循环 - 预测图片 ---
for image_path in tqdm(image_filenames, desc="正在处理测试图像"):
# --- 4.1. 读取图片并预处理 ---
original_image = cv2.cvtColor(cv2.imread(str(image_path)), cv2.COLOR_BGR2RGB)
original_h, original_w = original_image.shape[:2]
image_tensor = transform(image=original_image)['image'].unsqueeze(0).to(config.DEVICE)
# --- 4.2. 图片推理并保存 ---
with torch.no_grad():
pred_logits_tensor = model(image_tensor)
pred_mask_tensor = torch.argmax(pred_logits_tensor, dim=1)
pred_mask_numpy = pred_mask_tensor.squeeze().cpu().numpy().astype(np.uint8)
pred_mask_resized_raw = cv2.resize(pred_mask_numpy, (original_w, original_h), interpolation=cv2.INTER_NEAREST)
# 保存原始(单通道)的预测掩码
cv2.imwrite(str(raw_mask_dir / f"{image_path.stem}.png"), pred_mask_resized_raw)
# --- 4.3. 评估与可视化 ---
gt_mask_numpy, per_image_stats, stats_text_display, gt_mask_raw = None, None, "", None
if evaluate_mode:
mask_path = config.TEST_MASK_DIR / image_path.name
if mask_path.exists():
gt_mask_raw = cv2.imread(str(mask_path), cv2.IMREAD_GRAYSCALE)
gt_mask_numpy = cv2.resize(
gt_mask_raw, (target_width, target_height),
interpolation=cv2.INTER_NEAREST
)
gt_mask_tensor = torch.from_numpy(gt_mask_numpy).long().to(config.DEVICE).unsqueeze(0)
# ======================== 【修复代码段开始】 ======================== #
# --- 1. [修复] 计算单张图像的统计数据 ---
# 这个代码块是必需的,它为下面的 per_image_stats 和 stats_text_display 定义了 tp, fp, fn, tn
tp, fp, fn, tn = smp.metrics.get_stats(
pred_mask_tensor, gt_mask_tensor,
mode=config.SEG_MODE, num_classes=num_classes
)
# --- 2. 累加到总数,用于最终的整体评估 ---
total_tp += tp.squeeze(0).to(config.DEVICE)
total_fp += fp.squeeze(0).to(config.DEVICE)
total_fn += fn.squeeze(0).to(config.DEVICE)
total_tn += tn.squeeze(0).to(config.DEVICE)
# --- 3. 将曲线数据写入临时文件 (这部分逻辑正确,无需修改) ---
probs_tensor = torch.softmax(pred_logits_tensor, dim=1)
gt_flat = gt_mask_tensor.cpu().numpy().flatten()
probs_flat = probs_tensor.permute(0, 2, 3, 1).reshape(-1, num_classes).cpu().numpy()
with open(temp_gt_file.name, 'ab') as f_gt, open(temp_probs_file.name, 'ab') as f_probs:
np.save(f_gt, gt_flat)
np.save(f_probs, probs_flat)
# --- 4. [修复] 现在可以安全地计算单图指标用于可视化 ---
# 因为上面的代码块已经定义了 tp, fp, fn, tn所以这里不再报错
per_image_stats = {
'iou': smp.metrics.iou_score(tp, fp, fn, tn, reduction='micro').item(),
'acc': smp.metrics.accuracy(tp, fp, fn, tn, reduction='micro').item()
}
# --- 5. [修复] 修正生成统计文本时的 Bug ---
# `get_stats` 返回的 tp 张量形状是 (1, num_classes),
# 因此我们应该遍历第1维度 (类别),并使用 [0, i] 进行索引
stats_text_display = "Each Class TP / FP / FN / TN:\n" + "-"*35 + "\n"
for i in range(1, tp.shape[1]): # 修正: 遍历类别维度 tp.shape[1]
if i < len(config.CLASSES):
class_name = config.CLASSES[i]
stats_text_display += (
f" {class_name:<10}: "
# 修正: 使用正确的索引 tp[0, i]
f"{tp[0, i]:>5} / {fp[0, i]:>5} / {fn[0, i]:>5} / {tn[0, i]:>7}\n"
)
save_visual_comparison(
image_name=image_path.name,
original_image=original_image,
pred_mask=pred_mask_resized_raw,
gt_mask=gt_mask_raw,
stats=per_image_stats,
stats_text=stats_text_display,
save_dir = analysis_results_dir
)
logging.info("\n预测流程已完成!")
logging.info(f"原始预测掩码已保存至: {raw_mask_dir}")
logging.info(f"分析图像已保存至: {analysis_results_dir}")
# --- 5. 最终分析(循环结束后) ---
if evaluate_mode:
# calculate_and_save_final_metrics(all_stats_for_metrics, len(config.CLASSES), save_dir = analysis_results_dir)
calculate_and_save_final_metrics(total_tp.cpu(), total_fp.cpu(), total_fn.cpu(), total_tn.cpu(), num_classes, save_dir=analysis_results_dir)
# if all_gt_for_curves:
# logging.info("\n正在聚合结果以生成 ROC 和 PR 曲线...")
# full_gt = torch.cat(all_gt_for_curves).numpy().flatten()
# full_logits = torch.cat(all_logits_for_curves)
# # 调整维度并计算概率
# num_classes = len(config.CLASSES)
# full_probs = softmax(
# full_logits.permute(0, 2, 3, 1).reshape(-1, num_classes).numpy(),
# axis=1
# )
# generate_and_save_curves(
# y_true=full_gt, y_probs=full_probs,
# class_names=config.CLASSES, output_dir=analysis_results_dir
# )
# 【新增】从临时文件中分块读取数据来生成曲线
try:
logging.info("\n正在从临时文件加载数据以生成 ROC 和 PR 曲线...")
temp_gt_file.close()
temp_probs_file.close()
all_gt_parts, all_probs_parts = [], []
with open(temp_gt_file.name, 'rb') as f_gt, open(temp_probs_file.name, 'rb') as f_probs:
while True:
try:
all_gt_parts.append(np.load(f_gt))
all_probs_parts.append(np.load(f_probs))
# 【修复】将 EOFError 添加到要捕获的异常列表中
except (IOError, ValueError, EOFError):
break # 文件读取完毕
if all_gt_parts:
full_gt = np.concatenate(all_gt_parts)
full_probs = np.concatenate(all_probs_parts)
generate_and_save_curves(
y_true=full_gt, y_probs=full_probs,
class_names=config.CLASSES, output_dir=analysis_results_dir
)
else:
logging.warning("临时文件中没有数据,跳过 ROC/PR 曲线生成。")
finally:
# 【新增】确保在最后删除临时文件
import os
os.remove(temp_gt_file.name)
os.remove(temp_probs_file.name)
logging.info("已清理临时文件。")
if __name__ == "__main__":
# 创建一个参数解析器
parser = argparse.ArgumentParser(description="训练图像分割模型。")
# 添加 --architecture 参数
# - `choices` 会自动从你的 config 文件中获取所有可用的模型名称
# - `default` 设置了在不提供参数时的默认模型
parser.add_argument(
"-a", "--architecture",
type=str,
# default='Unet',
choices=list(config.ALL_MODEL_CONFIGS.keys()),
required=True,
help="选择要训练的模型架构。"
)
# 解析命令行传入的参数
args = parser.parse_args()
# 将解析出的架构名称传入 main 函数并执行
main(model_architecture=args.architecture)