Files
Seg_Data_Server/Seg_All_In_One_SegModel/1_predict.py
2026-05-20 15:05:35 +08:00

564 lines
27 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import logging, argparse, sys, utils
import shutil, tempfile
from pathlib import Path
from typing import Dict, List, Tuple
import albumentations as A
import cv2
import matplotlib.pyplot as plt
import numpy as np
import segmentation_models_pytorch as smp
import torch
from albumentations.pytorch import ToTensorV2
from scipy.special import softmax
from sklearn.metrics import (auc, average_precision_score,
precision_recall_curve, roc_curve)
from tqdm import tqdm
# 本地应用/库的导入
import config
from utils import log_metrics_to_csv, get_preprocessing_transform, setup_directories
# --- 日志设置 ---
# 配置日志记录器
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(message)s",
handlers=[logging.StreamHandler()] # 输出日志到控制台
)
# 为 matplotlib 设置中文字体,请根据您的系统环境选择合适的字体
plt.rcParams['font.sans-serif'] = ['Noto Sans CJK JP', 'SimHei', 'Microsoft YaHei'] # 例如SimHei, Microsoft YaHei
plt.rcParams['axes.unicode_minus'] = False # 解决负号显示问题
# --- 核心工具函数 ---
# 将单通道的类别索引掩码转换为 RGB 彩色掩码
def colorize_mask(mask: np.ndarray, class_rgb_values: List[Tuple[int, int, int]]) -> np.ndarray:
"""将单通道的类别索引掩码转换为 RGB 彩色掩码。"""
color_mask = np.zeros((mask.shape[0], mask.shape[1], 3), dtype=np.uint8)
# 一个通用的调色板,可以轻松扩展
palette = [
(0, 0, 0), (220, 20, 60), (0, 255, 0), (0, 0, 255),
(255, 255, 0), (255, 0, 255), (0, 255, 255), (244, 164, 96),
]
for class_idx, color in enumerate(palette):
if class_idx < len(class_rgb_values):
# 假设掩码中的像素值 [0, 1, 2...] 对应类别的索引
pixels_to_color = (mask == class_idx)
color_mask[pixels_to_color] = color
return color_mask
# --- 分析与可视化函数 ---
def generate_and_save_curves(y_true: np.ndarray, y_probs: np.ndarray, class_names: List[str], output_dir: Path) -> None:
"""
根据整个测试集的聚合预测结果,为每个类别生成并保存 ROC 和 PR 曲线。
"""
logging.info("正在生成 ROC 和 PR 曲线...")
num_classes = len(class_names)
plt.style.use('seaborn-v0_8-whitegrid')
# --- 准备画布 ---
fig_roc, ax_roc = plt.subplots(figsize=(10, 8))
ax_roc.plot([0, 1], [0, 1], 'k--', label='随机猜测')
fig_pr, ax_pr = plt.subplots(figsize=(10, 8))
# --- 为每个类别(跳过背景)生成曲线 ---
for i in range(1, num_classes):
class_name = class_names[i]
y_true_binary = (y_true == i).astype(int)
y_class_probs = y_probs[:, i]
# ROC 曲线
fpr, tpr, _ = roc_curve(y_true_binary, y_class_probs)
roc_auc = auc(fpr, tpr)
ax_roc.plot(fpr, tpr, label=f'类别 {class_name} (AUC = {roc_auc:.4f})')
# PR 曲线
precision, recall, _ = precision_recall_curve(y_true_binary, y_class_probs)
avg_precision = average_precision_score(y_true_binary, y_class_probs)
ax_pr.plot(recall, precision, label=f'类别 {class_name} (AP = {avg_precision:.4f})')
# --- 美化并保存 ROC 曲线图 ---
ax_roc.set(xlabel='False positive rate (FPR)', ylabel='True positive rate (TPR)', title='Receiver operating characteristic (ROC) curve')
ax_roc.legend(loc='lower right')
ax_roc.set_aspect('equal', adjustable='box')
fig_roc.tight_layout()
roc_save_path = output_dir / "ROC_Curves.png"
fig_roc.savefig(roc_save_path)
plt.close(fig_roc)
logging.info(f"-> ROC 曲线图已保存至: {roc_save_path}")
# --- 美化并保存 PR 曲线图 ---
ax_pr.set(xlabel='Recall', ylabel='Precision', title='Precision-Recall (PR) Curve')
ax_pr.legend(loc='lower left')
ax_pr.set_aspect('equal', adjustable='box')
fig_pr.tight_layout()
pr_save_path = output_dir / "Precision_Recall_Curves.png"
fig_pr.savefig(pr_save_path)
plt.close(fig_pr)
logging.info(f"-> PR 曲线图已保存至: {pr_save_path}")
def save_visual_comparison(
image_name: str, original_image: np.ndarray, pred_mask: np.ndarray,
gt_mask: np.ndarray = None, stats: Dict = None, stats_text: str = "",
save_dir: str = ""
) -> None:
"""保存原始图像、预测掩码以及可选的真实掩码的对比图。"""
pred_color_mask = colorize_mask(pred_mask, config.CLASS_RGB_VALUES)
if gt_mask is not None and stats is not None:
num_plots = 3
figsize = (22, 8)
gt_color_mask = colorize_mask(gt_mask, config.CLASS_RGB_VALUES)
save_name = f"{Path(image_name).stem}-IOU_{stats['iou']:.4f}-ACC_{stats['acc']:.4f}.png"
else:
num_plots = 2
figsize = (14, 6)
save_name = f"{Path(image_name).stem}.png"
fig, axes = plt.subplots(1, num_plots, figsize=figsize)
fig.suptitle(f"Predict: {image_name}", fontsize=16)
axes[0].imshow(original_image)
axes[0].set_title("ORI_IMG")
axes[0].axis('off')
pred_title = "Predicted Mask"
if stats:
pred_title += f"\nIoU: {stats['iou']:.4f} | Acc: {stats['acc']:.4f}"
axes[1].imshow(pred_color_mask)
axes[1].set_title(pred_title)
axes[1].axis('off')
if num_plots == 3:
axes[2].imshow(gt_color_mask)
axes[2].set_title("Ground Truth")
axes[2].axis('off')
plt.figtext(0.5, 0.01, stats_text, ha="center", fontsize=10,
bbox={"facecolor":"white", "alpha":0.7, "pad":5}, family='monospace')
fig.subplots_adjust(bottom=0.3)
plt.tight_layout()
plt.savefig(save_dir / save_name)
plt.close(fig)
# --- 主要逻辑函数 ---
def load_model(model_architecture: str, model_save_path: Path) -> torch.nn.Module:
"""
加载已训练的分割模型。
此函数经过优化,可以智能处理由 torch.nn.DataParallel 保存的
(带有 'module.' 前缀) 和常规保存的模型权重。
"""
logging.info(f"正在从以下路径加载模型: {model_save_path}")
num_classes = len(config.CLASSES)
# 1. 【修改】根据 config 动态创建基础模型架构
logging.info(f"正在重建模型架构: '{model_architecture}'")
try:
model_params = config.ALL_MODEL_CONFIGS[model_architecture].copy() #
model_class = getattr(smp, model_architecture)
except AttributeError:
logging.error(f"模型 '{model_architecture}' 在 segmentation_models_pytorch 库中不存在!")
raise
# 2. 【修改】准备参数,与训练时保持一致
# 注意:加载权重时,我们不希望再次下载预训练权重,所以设为 None
params = model_params
params['in_channels'] = 3
params['classes'] = num_classes
params['encoder_weights'] = None # 非常重要避免在预测时重新下载ImageNet权重
# ======================== 【新增代码段开始】 ======================== #
# 自动检测 encoder_name 是否包含 'vit'
encoder_name = model_params.get('encoder_name', '')
if 'vit' in encoder_name.lower():
params['dynamic_img_size'] = True
logging.info(f"检测到 ViT 编码器 ('{encoder_name}')。自动设置 dynamic_img_size=True。")
# ======================== 【新增代码段结束】 ======================== #
# 3. 【修改】使用 **kwargs 创建模型实例
model = model_class(**params).to(config.DEVICE)
# 4. 从文件加载 state_dict
state_dict = torch.load(model_save_path, map_location=torch.device(config.DEVICE))
# 3. 检查是否存在 'module.' 前缀 (判断是否为 DataParallel 保存的权重)
# 我们通过检查字典中任何一个键是否以 'module.' 开头来判断
is_dataparallel = any(key.startswith('module.') for key in state_dict.keys())
if is_dataparallel:
logging.info("检测到模型由 DataParallel 保存,将移除 'module.' 前缀。")
# 创建一个新的、不带前缀的 state_dict
new_state_dict = {key.replace('module.', ''): value for key, value in state_dict.items()}
model.load_state_dict(new_state_dict)
else:
# 如果没有 'module.' 前缀,直接加载
logging.info("直接加载标准模型权重。")
model.load_state_dict(state_dict)
model.eval()
return model
def calculate_and_save_final_metrics(stats_list: List[Dict], num_classes: int, save_dir: str) -> None:
"""聚合每张图像的统计数据,计算总体指标,并保存到 CSV 文件。"""
if not stats_list:
logging.warning("未收集到统计数据,跳过最终指标计算。")
return
logging.info("\n正在聚合结果以计算最终指标...")
# 堆叠所有统计张量并按批次维度求和
tp = torch.cat([s['tp'] for s in stats_list], dim=0).sum(dim=0) # (N×B, C) -> (C,)
fp = torch.cat([s['fp'] for s in stats_list], dim=0).sum(dim=0) # (N×B, C) -> (C,)
fn = torch.cat([s['fn'] for s in stats_list], dim=0).sum(dim=0) # (N×B, C) -> (C,)
tn = torch.cat([s['tn'] for s in stats_list], dim=0).sum(dim=0) # (N×B, C) -> (C,)
# ======================== 【新增代码段开始】 ======================== #
# 根据 config.py 中的设置,筛选出用于计算宏观指标的类别
ignore_indices = [config.CLASSES.index(cls) for cls in config.EVALUATION_CLASSES_TO_IGNORE if cls in config.CLASSES]
keep_mask = torch.ones(num_classes, dtype=torch.bool, device=tp.device)
if ignore_indices:
keep_mask[ignore_indices] = False
logging.info(f"最终评估时将忽略类别: {config.EVALUATION_CLASSES_TO_IGNORE}")
tp_filtered = tp[keep_mask]
fp_filtered = fp[keep_mask]
fn_filtered = fn[keep_mask]
tn_filtered = tn[keep_mask]
# ======================== 【新增代码段结束】 ======================== #
# ======================== 【修改代码段开始】 ======================== #
# 使用过滤后的统计数据计算总体指标
metrics = {
"iou_score": smp.metrics.iou_score(tp_filtered, fp_filtered, fn_filtered, tn_filtered, reduction='micro').item(),
"f1_score": smp.metrics.f1_score(tp_filtered, fp_filtered, fn_filtered, tn_filtered, reduction='micro').item(),
"accuracy": smp.metrics.accuracy(tp_filtered, fp_filtered, fn_filtered, tn_filtered, reduction='micro').item(),
"recall": smp.metrics.recall(tp_filtered, fp_filtered, fn_filtered, tn_filtered, reduction='micro').item(),
"precision": smp.metrics.precision(tp_filtered, fp_filtered, fn_filtered, tn_filtered, reduction='micro').item(),
}
# ======================== 【修改代码段结束】 ======================== #
# 添加每个类别的原始统计数据(使用未经过滤的数据)
for i in range(num_classes):
metrics[f'tp_class_{i}'] = tp[i].item()
metrics[f'tn_class_{i}'] = tn[i].item()
metrics[f'fp_class_{i}'] = fp[i].item()
metrics[f'fn_class_{i}'] = fn[i].item()
csv_save_path = save_dir / "test_set_metrics.csv"
log_metrics_to_csv(metrics, str(csv_save_path))
logging.info(f"-> 整体测试集指标已保存至: {csv_save_path}")
# 2. 【新增】一个帮助函数,用于交互式选择模型目录
def select_trained_model_path(base_dir: Path, model_architecture: str) -> Path:
"""
查找指定架构的所有训练运行目录,并让用户选择一个。
"""
logging.info(f"正在 '{base_dir}' 中搜索模型 '{model_architecture}' 的训练记录...")
# 查找所有以架构名开头的文件夹
run_dirs = sorted([d for d in base_dir.iterdir() if d.is_dir() and d.name.startswith(model_architecture+'_')])
if not run_dirs:
logging.error(f"未找到任何 '{model_architecture}' 模型的训练记录。请先运行 train.py。")
return None
print("\n请选择要用于预测的模型:")
for i, dir_path in enumerate(run_dirs):
print(f" [{i+1}] {dir_path.name}")
while True:
try:
choice = input(f"请输入选项编号 (1-{len(run_dirs)}) 或按 Enter 退出: ")
if not choice:
return None
choice_idx = int(choice) - 1
if 0 <= choice_idx < len(run_dirs):
selected_dir = run_dirs[choice_idx]
logging.info(f"已选择模型: {selected_dir}")
return selected_dir
else:
print("无效的选项,请重新输入。")
except (ValueError, IndexError):
print("无效的输入,请输入一个数字。")
# 【修改】函数签名和内部逻辑,不再接收一个列表,而是直接接收聚合后的统计数据
def calculate_and_save_final_metrics(tp, fp, fn, tn, num_classes: int, save_dir: Path) -> None:
"""根据聚合后的统计数据,计算总体指标,并保存到 CSV 文件。"""
logging.info("\n正在计算最终的整体指标...")
# 根据 config.py 中的设置,筛选出用于计算宏观指标的类别
ignore_indices = [config.CLASSES.index(cls) for cls in config.EVALUATION_CLASSES_TO_IGNORE if cls in config.CLASSES]
keep_mask = torch.ones(num_classes, dtype=torch.bool, device=tp.device)
if ignore_indices:
keep_mask[ignore_indices] = False
logging.info(f"最终评估时将忽略类别: {config.EVALUATION_CLASSES_TO_IGNORE}")
tp_filtered = tp[keep_mask]
fp_filtered = fp[keep_mask]
fn_filtered = fn[keep_mask]
tn_filtered = tn[keep_mask]
# 使用过滤后的统计数据计算总体指标
metrics = {
"iou_score": smp.metrics.iou_score(tp_filtered, fp_filtered, fn_filtered, tn_filtered, reduction='micro').item(),
"f1_score": smp.metrics.f1_score(tp_filtered, fp_filtered, fn_filtered, tn_filtered, reduction='micro').item(),
"accuracy": smp.metrics.accuracy(tp_filtered, fp_filtered, fn_filtered, tn_filtered, reduction='micro').item(),
"recall": smp.metrics.recall(tp_filtered, fp_filtered, fn_filtered, tn_filtered, reduction='micro').item(),
"precision": smp.metrics.precision(tp_filtered, fp_filtered, fn_filtered, tn_filtered, reduction='micro').item(),
}
# 添加每个类别的原始统计数据(使用未经过滤的数据)
for i in range(num_classes):
metrics[f'tp_class_{i}'] = tp[i].item()
metrics[f'tn_class_{i}'] = tn[i].item()
metrics[f'fp_class_{i}'] = fp[i].item()
metrics[f'fn_class_{i}'] = fn[i].item()
csv_save_path = save_dir / "test_set_metrics.csv"
log_metrics_to_csv(metrics, str(csv_save_path))
logging.info(f"-> 整体测试集指标已保存至: {csv_save_path}")
def main(model_architecture: str) -> None:
"""
主函数,执行模型加载、图像预测、评估,并生成分析结果。
"""
# --- 1. 【修改】选择模型并动态设置路径 ---
selected_run_dir = select_trained_model_path(Path(config.PREDICT_BEST_MODEL_DIR), model_architecture)
if not selected_run_dir:
logging.info("未选择任何模型,程序退出。")
sys.exit()
# 【重要】动态覆写 config 中的路径,使其指向所选模型的目录
# 这样,后续所有代码都会自动在正确的文件夹中加载模型和保存结果
best_model_save_path = selected_run_dir / "best_model.pth"
raw_mask_dir = selected_run_dir / "predicted_raw_masks"
analysis_results_dir = selected_run_dir / "prediction_analysis"
if not best_model_save_path.exists():
logging.error(f"错误: 在 '{selected_run_dir}' 中未找到 'best_model.pth' 文件!")
sys.exit()
# setup_directories 函数现在用于创建预测结果的子目录
utils.setup_directories(raw_mask_dir)
utils.setup_directories(analysis_results_dir)
model = load_model(model_architecture, best_model_save_path) # 1. 加载模型
# ======================== 【新增/修改的代码段开始】 ======================== #
# --- 2. 探测模型以确定正确的图像尺寸 ---
logging.info("正在探测模型以确定预测时所需的输入尺寸...") #
target_height, target_width = config.IMAGE_HEIGHT, config.IMAGE_WIDTH #
try:
# 检查编码器是否需要固定输入尺寸
if model.encoder.is_fixed_input_size:
required_size = model.encoder.input_size
target_height = required_size[1]
target_width = required_size[2]
logging.warning(
f"模型 '{model_architecture}' 的编码器需要固定的输入尺寸 "
f"({target_height}x{target_width})。"
f"将忽略 config.py 中的尺寸设置进行预测。"
)
else:
logging.info(f"模型 '{model_architecture}' 支持动态输入尺寸。将使用 config.py 中定义的尺寸 ({target_height}x{target_width})。")
except AttributeError:
logging.info("无法自动检测模型尺寸要求。将使用 config.py 中定义的尺寸。")
# --- 3. 获取用于预测的图像预处理流程 ---
# 使用探测到的或配置中指定的目标尺寸
transform = get_preprocessing_transform(target_height, target_width) #
# ======================== 【新增/修改的代码段结束】 ======================== #
# --- 4. 获取图像列表并判断模式 ---
image_filenames = sorted(list(config.TEST_IMAGE_DIR.iterdir()))
evaluate_mode = config.TEST_MASK_DIR.is_dir()
if evaluate_mode:
logging.info("正在以 [评估模式] 运行。将使用真实掩码进行评估。")
else:
logging.info("正在以 [仅预测模式] 运行。未找到真实掩码。")
# --- 3. 初始化用于增量计算的变量 ---
# 【修改】不再使用 list 来累积,而是直接累加 TP, FP, FN, TN
num_classes = len(config.CLASSES)
total_tp = torch.zeros(num_classes, dtype=torch.long, device=config.DEVICE)
total_fp = torch.zeros(num_classes, dtype=torch.long, device=config.DEVICE)
total_fn = torch.zeros(num_classes, dtype=torch.long, device=config.DEVICE)
total_tn = torch.zeros(num_classes, dtype=torch.long, device=config.DEVICE)
# 【新增】为 ROC/PR 曲线创建临时文件,用磁盘空间换内存
# 使用 tempfile 模块来安全地创建临时文件
temp_gt_file = tempfile.NamedTemporaryFile(delete=False, suffix='.npy')
temp_probs_file = tempfile.NamedTemporaryFile(delete=False, suffix='.npy')
logging.info(f"为ROC/PR曲线创建临时文件: {temp_gt_file.name}, {temp_probs_file.name}")
# --- 4. 主推理循环 - 预测图片 ---
for image_path in tqdm(image_filenames, desc="正在处理测试图像"):
# --- 4.1. 读取图片并预处理 ---
original_image = cv2.cvtColor(cv2.imread(str(image_path)), cv2.COLOR_BGR2RGB)
original_h, original_w = original_image.shape[:2]
image_tensor = transform(image=original_image)['image'].unsqueeze(0).to(config.DEVICE)
# --- 4.2. 图片推理并保存 ---
with torch.no_grad():
pred_logits_tensor = model(image_tensor)
pred_mask_tensor = torch.argmax(pred_logits_tensor, dim=1)
pred_mask_numpy = pred_mask_tensor.squeeze().cpu().numpy().astype(np.uint8)
pred_mask_resized_raw = cv2.resize(pred_mask_numpy, (original_w, original_h), interpolation=cv2.INTER_NEAREST)
# 保存原始(单通道)的预测掩码
cv2.imwrite(str(raw_mask_dir / f"{image_path.stem}.png"), pred_mask_resized_raw)
# --- 4.3. 评估与可视化 ---
gt_mask_numpy, per_image_stats, stats_text_display, gt_mask_raw = None, None, "", None
if evaluate_mode:
mask_path = config.TEST_MASK_DIR / image_path.name
if mask_path.exists():
gt_mask_raw = cv2.imread(str(mask_path), cv2.IMREAD_GRAYSCALE)
gt_mask_numpy = cv2.resize(
gt_mask_raw, (target_width, target_height),
interpolation=cv2.INTER_NEAREST
)
gt_mask_tensor = torch.from_numpy(gt_mask_numpy).long().to(config.DEVICE).unsqueeze(0)
# ======================== 【修复代码段开始】 ======================== #
# --- 1. [修复] 计算单张图像的统计数据 ---
# 这个代码块是必需的,它为下面的 per_image_stats 和 stats_text_display 定义了 tp, fp, fn, tn
tp, fp, fn, tn = smp.metrics.get_stats(
pred_mask_tensor, gt_mask_tensor,
mode=config.SEG_MODE, num_classes=num_classes
)
# --- 2. 累加到总数,用于最终的整体评估 ---
total_tp += tp.squeeze(0).to(config.DEVICE)
total_fp += fp.squeeze(0).to(config.DEVICE)
total_fn += fn.squeeze(0).to(config.DEVICE)
total_tn += tn.squeeze(0).to(config.DEVICE)
# --- 3. 将曲线数据写入临时文件 (这部分逻辑正确,无需修改) ---
probs_tensor = torch.softmax(pred_logits_tensor, dim=1)
gt_flat = gt_mask_tensor.cpu().numpy().flatten()
probs_flat = probs_tensor.permute(0, 2, 3, 1).reshape(-1, num_classes).cpu().numpy()
with open(temp_gt_file.name, 'ab') as f_gt, open(temp_probs_file.name, 'ab') as f_probs:
np.save(f_gt, gt_flat)
np.save(f_probs, probs_flat)
# --- 4. [修复] 现在可以安全地计算单图指标用于可视化 ---
# 因为上面的代码块已经定义了 tp, fp, fn, tn所以这里不再报错
per_image_stats = {
'iou': smp.metrics.iou_score(tp, fp, fn, tn, reduction='micro').item(),
'acc': smp.metrics.accuracy(tp, fp, fn, tn, reduction='micro').item()
}
# --- 5. [修复] 修正生成统计文本时的 Bug ---
# `get_stats` 返回的 tp 张量形状是 (1, num_classes),
# 因此我们应该遍历第1维度 (类别),并使用 [0, i] 进行索引
stats_text_display = "Each Class TP / FP / FN / TN:\n" + "-"*35 + "\n"
for i in range(1, tp.shape[1]): # 修正: 遍历类别维度 tp.shape[1]
if i < len(config.CLASSES):
class_name = config.CLASSES[i]
stats_text_display += (
f" {class_name:<10}: "
# 修正: 使用正确的索引 tp[0, i]
f"{tp[0, i]:>5} / {fp[0, i]:>5} / {fn[0, i]:>5} / {tn[0, i]:>7}\n"
)
save_visual_comparison(
image_name=image_path.name,
original_image=original_image,
pred_mask=pred_mask_resized_raw,
gt_mask=gt_mask_raw,
stats=per_image_stats,
stats_text=stats_text_display,
save_dir = analysis_results_dir
)
logging.info("\n预测流程已完成!")
logging.info(f"原始预测掩码已保存至: {raw_mask_dir}")
logging.info(f"分析图像已保存至: {analysis_results_dir}")
# --- 5. 最终分析(循环结束后) ---
if evaluate_mode:
# calculate_and_save_final_metrics(all_stats_for_metrics, len(config.CLASSES), save_dir = analysis_results_dir)
calculate_and_save_final_metrics(total_tp.cpu(), total_fp.cpu(), total_fn.cpu(), total_tn.cpu(), num_classes, save_dir=analysis_results_dir)
# if all_gt_for_curves:
# logging.info("\n正在聚合结果以生成 ROC 和 PR 曲线...")
# full_gt = torch.cat(all_gt_for_curves).numpy().flatten()
# full_logits = torch.cat(all_logits_for_curves)
# # 调整维度并计算概率
# num_classes = len(config.CLASSES)
# full_probs = softmax(
# full_logits.permute(0, 2, 3, 1).reshape(-1, num_classes).numpy(),
# axis=1
# )
# generate_and_save_curves(
# y_true=full_gt, y_probs=full_probs,
# class_names=config.CLASSES, output_dir=analysis_results_dir
# )
# 【新增】从临时文件中分块读取数据来生成曲线
try:
logging.info("\n正在从临时文件加载数据以生成 ROC 和 PR 曲线...")
temp_gt_file.close()
temp_probs_file.close()
all_gt_parts, all_probs_parts = [], []
with open(temp_gt_file.name, 'rb') as f_gt, open(temp_probs_file.name, 'rb') as f_probs:
while True:
try:
all_gt_parts.append(np.load(f_gt))
all_probs_parts.append(np.load(f_probs))
# 【修复】将 EOFError 添加到要捕获的异常列表中
except (IOError, ValueError, EOFError):
break # 文件读取完毕
if all_gt_parts:
full_gt = np.concatenate(all_gt_parts)
full_probs = np.concatenate(all_probs_parts)
generate_and_save_curves(
y_true=full_gt, y_probs=full_probs,
class_names=config.CLASSES, output_dir=analysis_results_dir
)
else:
logging.warning("临时文件中没有数据,跳过 ROC/PR 曲线生成。")
finally:
# 【新增】确保在最后删除临时文件
import os
os.remove(temp_gt_file.name)
os.remove(temp_probs_file.name)
logging.info("已清理临时文件。")
if __name__ == "__main__":
# 创建一个参数解析器
parser = argparse.ArgumentParser(description="训练图像分割模型。")
# 添加 --architecture 参数
# - `choices` 会自动从你的 config 文件中获取所有可用的模型名称
# - `default` 设置了在不提供参数时的默认模型
parser.add_argument(
"-a", "--architecture",
type=str,
# default='Unet',
choices=list(config.ALL_MODEL_CONFIGS.keys()),
required=True,
help="选择要训练的模型架构。"
)
# 解析命令行传入的参数
args = parser.parse_args()
# 将解析出的架构名称传入 main 函数并执行
main(model_architecture=args.architecture)