first commit
This commit is contained in:
564
Seg_All_In_One_SegModel/1_predict.py
Normal file
564
Seg_All_In_One_SegModel/1_predict.py
Normal file
@@ -0,0 +1,564 @@
|
||||
import logging, argparse, sys, utils
|
||||
import shutil, tempfile
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import albumentations as A
|
||||
import cv2
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import segmentation_models_pytorch as smp
|
||||
import torch
|
||||
from albumentations.pytorch import ToTensorV2
|
||||
from scipy.special import softmax
|
||||
from sklearn.metrics import (auc, average_precision_score,
|
||||
precision_recall_curve, roc_curve)
|
||||
from tqdm import tqdm
|
||||
|
||||
# 本地应用/库的导入
|
||||
import config
|
||||
from utils import log_metrics_to_csv, get_preprocessing_transform, setup_directories
|
||||
|
||||
# --- 日志设置 ---
|
||||
# 配置日志记录器
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s [%(levelname)s] %(message)s",
|
||||
handlers=[logging.StreamHandler()] # 输出日志到控制台
|
||||
)
|
||||
# 为 matplotlib 设置中文字体,请根据您的系统环境选择合适的字体
|
||||
plt.rcParams['font.sans-serif'] = ['Noto Sans CJK JP', 'SimHei', 'Microsoft YaHei'] # 例如:SimHei, Microsoft YaHei
|
||||
plt.rcParams['axes.unicode_minus'] = False # 解决负号显示问题
|
||||
|
||||
# --- 核心工具函数 ---
|
||||
|
||||
# 将单通道的类别索引掩码转换为 RGB 彩色掩码
|
||||
def colorize_mask(mask: np.ndarray, class_rgb_values: List[Tuple[int, int, int]]) -> np.ndarray:
|
||||
"""将单通道的类别索引掩码转换为 RGB 彩色掩码。"""
|
||||
color_mask = np.zeros((mask.shape[0], mask.shape[1], 3), dtype=np.uint8)
|
||||
# 一个通用的调色板,可以轻松扩展
|
||||
palette = [
|
||||
(0, 0, 0), (220, 20, 60), (0, 255, 0), (0, 0, 255),
|
||||
(255, 255, 0), (255, 0, 255), (0, 255, 255), (244, 164, 96),
|
||||
]
|
||||
for class_idx, color in enumerate(palette):
|
||||
if class_idx < len(class_rgb_values):
|
||||
# 假设掩码中的像素值 [0, 1, 2...] 对应类别的索引
|
||||
pixels_to_color = (mask == class_idx)
|
||||
color_mask[pixels_to_color] = color
|
||||
return color_mask
|
||||
|
||||
# --- 分析与可视化函数 ---
|
||||
|
||||
def generate_and_save_curves(y_true: np.ndarray, y_probs: np.ndarray, class_names: List[str], output_dir: Path) -> None:
|
||||
"""
|
||||
根据整个测试集的聚合预测结果,为每个类别生成并保存 ROC 和 PR 曲线。
|
||||
"""
|
||||
logging.info("正在生成 ROC 和 PR 曲线...")
|
||||
num_classes = len(class_names)
|
||||
plt.style.use('seaborn-v0_8-whitegrid')
|
||||
|
||||
# --- 准备画布 ---
|
||||
fig_roc, ax_roc = plt.subplots(figsize=(10, 8))
|
||||
ax_roc.plot([0, 1], [0, 1], 'k--', label='随机猜测')
|
||||
fig_pr, ax_pr = plt.subplots(figsize=(10, 8))
|
||||
|
||||
# --- 为每个类别(跳过背景)生成曲线 ---
|
||||
for i in range(1, num_classes):
|
||||
class_name = class_names[i]
|
||||
y_true_binary = (y_true == i).astype(int)
|
||||
y_class_probs = y_probs[:, i]
|
||||
|
||||
# ROC 曲线
|
||||
fpr, tpr, _ = roc_curve(y_true_binary, y_class_probs)
|
||||
roc_auc = auc(fpr, tpr)
|
||||
ax_roc.plot(fpr, tpr, label=f'类别 {class_name} (AUC = {roc_auc:.4f})')
|
||||
|
||||
# PR 曲线
|
||||
precision, recall, _ = precision_recall_curve(y_true_binary, y_class_probs)
|
||||
avg_precision = average_precision_score(y_true_binary, y_class_probs)
|
||||
ax_pr.plot(recall, precision, label=f'类别 {class_name} (AP = {avg_precision:.4f})')
|
||||
|
||||
# --- 美化并保存 ROC 曲线图 ---
|
||||
ax_roc.set(xlabel='False positive rate (FPR)', ylabel='True positive rate (TPR)', title='Receiver operating characteristic (ROC) curve')
|
||||
ax_roc.legend(loc='lower right')
|
||||
ax_roc.set_aspect('equal', adjustable='box')
|
||||
fig_roc.tight_layout()
|
||||
roc_save_path = output_dir / "ROC_Curves.png"
|
||||
fig_roc.savefig(roc_save_path)
|
||||
plt.close(fig_roc)
|
||||
logging.info(f"-> ROC 曲线图已保存至: {roc_save_path}")
|
||||
|
||||
# --- 美化并保存 PR 曲线图 ---
|
||||
ax_pr.set(xlabel='Recall', ylabel='Precision', title='Precision-Recall (PR) Curve')
|
||||
ax_pr.legend(loc='lower left')
|
||||
ax_pr.set_aspect('equal', adjustable='box')
|
||||
fig_pr.tight_layout()
|
||||
pr_save_path = output_dir / "Precision_Recall_Curves.png"
|
||||
fig_pr.savefig(pr_save_path)
|
||||
plt.close(fig_pr)
|
||||
logging.info(f"-> PR 曲线图已保存至: {pr_save_path}")
|
||||
|
||||
def save_visual_comparison(
|
||||
image_name: str, original_image: np.ndarray, pred_mask: np.ndarray,
|
||||
gt_mask: np.ndarray = None, stats: Dict = None, stats_text: str = "",
|
||||
save_dir: str = ""
|
||||
) -> None:
|
||||
"""保存原始图像、预测掩码以及可选的真实掩码的对比图。"""
|
||||
pred_color_mask = colorize_mask(pred_mask, config.CLASS_RGB_VALUES)
|
||||
|
||||
if gt_mask is not None and stats is not None:
|
||||
num_plots = 3
|
||||
figsize = (22, 8)
|
||||
gt_color_mask = colorize_mask(gt_mask, config.CLASS_RGB_VALUES)
|
||||
save_name = f"{Path(image_name).stem}-IOU_{stats['iou']:.4f}-ACC_{stats['acc']:.4f}.png"
|
||||
else:
|
||||
num_plots = 2
|
||||
figsize = (14, 6)
|
||||
save_name = f"{Path(image_name).stem}.png"
|
||||
|
||||
fig, axes = plt.subplots(1, num_plots, figsize=figsize)
|
||||
fig.suptitle(f"Predict: {image_name}", fontsize=16)
|
||||
|
||||
axes[0].imshow(original_image)
|
||||
axes[0].set_title("ORI_IMG")
|
||||
axes[0].axis('off')
|
||||
|
||||
pred_title = "Predicted Mask"
|
||||
if stats:
|
||||
pred_title += f"\nIoU: {stats['iou']:.4f} | Acc: {stats['acc']:.4f}"
|
||||
axes[1].imshow(pred_color_mask)
|
||||
axes[1].set_title(pred_title)
|
||||
axes[1].axis('off')
|
||||
|
||||
if num_plots == 3:
|
||||
axes[2].imshow(gt_color_mask)
|
||||
axes[2].set_title("Ground Truth")
|
||||
axes[2].axis('off')
|
||||
plt.figtext(0.5, 0.01, stats_text, ha="center", fontsize=10,
|
||||
bbox={"facecolor":"white", "alpha":0.7, "pad":5}, family='monospace')
|
||||
fig.subplots_adjust(bottom=0.3)
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig(save_dir / save_name)
|
||||
plt.close(fig)
|
||||
|
||||
# --- 主要逻辑函数 ---
|
||||
|
||||
def load_model(model_architecture: str, model_save_path: Path) -> torch.nn.Module:
|
||||
"""
|
||||
加载已训练的分割模型。
|
||||
此函数经过优化,可以智能处理由 torch.nn.DataParallel 保存的
|
||||
(带有 'module.' 前缀) 和常规保存的模型权重。
|
||||
"""
|
||||
logging.info(f"正在从以下路径加载模型: {model_save_path}")
|
||||
num_classes = len(config.CLASSES)
|
||||
|
||||
# 1. 【修改】根据 config 动态创建基础模型架构
|
||||
logging.info(f"正在重建模型架构: '{model_architecture}'")
|
||||
try:
|
||||
model_params = config.ALL_MODEL_CONFIGS[model_architecture].copy() #
|
||||
model_class = getattr(smp, model_architecture)
|
||||
except AttributeError:
|
||||
logging.error(f"模型 '{model_architecture}' 在 segmentation_models_pytorch 库中不存在!")
|
||||
raise
|
||||
|
||||
# 2. 【修改】准备参数,与训练时保持一致
|
||||
# 注意:加载权重时,我们不希望再次下载预训练权重,所以设为 None
|
||||
params = model_params
|
||||
params['in_channels'] = 3
|
||||
params['classes'] = num_classes
|
||||
params['encoder_weights'] = None # 非常重要!避免在预测时重新下载ImageNet权重
|
||||
|
||||
# ======================== 【新增代码段开始】 ======================== #
|
||||
# 自动检测 encoder_name 是否包含 'vit'
|
||||
encoder_name = model_params.get('encoder_name', '')
|
||||
if 'vit' in encoder_name.lower():
|
||||
params['dynamic_img_size'] = True
|
||||
logging.info(f"检测到 ViT 编码器 ('{encoder_name}')。自动设置 dynamic_img_size=True。")
|
||||
# ======================== 【新增代码段结束】 ======================== #
|
||||
|
||||
# 3. 【修改】使用 **kwargs 创建模型实例
|
||||
model = model_class(**params).to(config.DEVICE)
|
||||
|
||||
# 4. 从文件加载 state_dict
|
||||
state_dict = torch.load(model_save_path, map_location=torch.device(config.DEVICE))
|
||||
|
||||
# 3. 检查是否存在 'module.' 前缀 (判断是否为 DataParallel 保存的权重)
|
||||
# 我们通过检查字典中任何一个键是否以 'module.' 开头来判断
|
||||
is_dataparallel = any(key.startswith('module.') for key in state_dict.keys())
|
||||
|
||||
if is_dataparallel:
|
||||
logging.info("检测到模型由 DataParallel 保存,将移除 'module.' 前缀。")
|
||||
# 创建一个新的、不带前缀的 state_dict
|
||||
new_state_dict = {key.replace('module.', ''): value for key, value in state_dict.items()}
|
||||
model.load_state_dict(new_state_dict)
|
||||
else:
|
||||
# 如果没有 'module.' 前缀,直接加载
|
||||
logging.info("直接加载标准模型权重。")
|
||||
model.load_state_dict(state_dict)
|
||||
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
def calculate_and_save_final_metrics(stats_list: List[Dict], num_classes: int, save_dir: str) -> None:
|
||||
"""聚合每张图像的统计数据,计算总体指标,并保存到 CSV 文件。"""
|
||||
if not stats_list:
|
||||
logging.warning("未收集到统计数据,跳过最终指标计算。")
|
||||
return
|
||||
|
||||
logging.info("\n正在聚合结果以计算最终指标...")
|
||||
# 堆叠所有统计张量并按批次维度求和
|
||||
tp = torch.cat([s['tp'] for s in stats_list], dim=0).sum(dim=0) # (N×B, C) -> (C,)
|
||||
fp = torch.cat([s['fp'] for s in stats_list], dim=0).sum(dim=0) # (N×B, C) -> (C,)
|
||||
fn = torch.cat([s['fn'] for s in stats_list], dim=0).sum(dim=0) # (N×B, C) -> (C,)
|
||||
tn = torch.cat([s['tn'] for s in stats_list], dim=0).sum(dim=0) # (N×B, C) -> (C,)
|
||||
|
||||
# ======================== 【新增代码段开始】 ======================== #
|
||||
# 根据 config.py 中的设置,筛选出用于计算宏观指标的类别
|
||||
ignore_indices = [config.CLASSES.index(cls) for cls in config.EVALUATION_CLASSES_TO_IGNORE if cls in config.CLASSES]
|
||||
|
||||
keep_mask = torch.ones(num_classes, dtype=torch.bool, device=tp.device)
|
||||
if ignore_indices:
|
||||
keep_mask[ignore_indices] = False
|
||||
logging.info(f"最终评估时将忽略类别: {config.EVALUATION_CLASSES_TO_IGNORE}")
|
||||
|
||||
tp_filtered = tp[keep_mask]
|
||||
fp_filtered = fp[keep_mask]
|
||||
fn_filtered = fn[keep_mask]
|
||||
tn_filtered = tn[keep_mask]
|
||||
# ======================== 【新增代码段结束】 ======================== #
|
||||
|
||||
|
||||
# ======================== 【修改代码段开始】 ======================== #
|
||||
# 使用过滤后的统计数据计算总体指标
|
||||
metrics = {
|
||||
"iou_score": smp.metrics.iou_score(tp_filtered, fp_filtered, fn_filtered, tn_filtered, reduction='micro').item(),
|
||||
"f1_score": smp.metrics.f1_score(tp_filtered, fp_filtered, fn_filtered, tn_filtered, reduction='micro').item(),
|
||||
"accuracy": smp.metrics.accuracy(tp_filtered, fp_filtered, fn_filtered, tn_filtered, reduction='micro').item(),
|
||||
"recall": smp.metrics.recall(tp_filtered, fp_filtered, fn_filtered, tn_filtered, reduction='micro').item(),
|
||||
"precision": smp.metrics.precision(tp_filtered, fp_filtered, fn_filtered, tn_filtered, reduction='micro').item(),
|
||||
}
|
||||
# ======================== 【修改代码段结束】 ======================== #
|
||||
|
||||
# 添加每个类别的原始统计数据(使用未经过滤的数据)
|
||||
for i in range(num_classes):
|
||||
metrics[f'tp_class_{i}'] = tp[i].item()
|
||||
metrics[f'tn_class_{i}'] = tn[i].item()
|
||||
metrics[f'fp_class_{i}'] = fp[i].item()
|
||||
metrics[f'fn_class_{i}'] = fn[i].item()
|
||||
|
||||
csv_save_path = save_dir / "test_set_metrics.csv"
|
||||
log_metrics_to_csv(metrics, str(csv_save_path))
|
||||
logging.info(f"-> 整体测试集指标已保存至: {csv_save_path}")
|
||||
|
||||
# 2. 【新增】一个帮助函数,用于交互式选择模型目录
|
||||
def select_trained_model_path(base_dir: Path, model_architecture: str) -> Path:
|
||||
"""
|
||||
查找指定架构的所有训练运行目录,并让用户选择一个。
|
||||
"""
|
||||
logging.info(f"正在 '{base_dir}' 中搜索模型 '{model_architecture}' 的训练记录...")
|
||||
|
||||
# 查找所有以架构名开头的文件夹
|
||||
run_dirs = sorted([d for d in base_dir.iterdir() if d.is_dir() and d.name.startswith(model_architecture+'_')])
|
||||
|
||||
if not run_dirs:
|
||||
logging.error(f"未找到任何 '{model_architecture}' 模型的训练记录。请先运行 train.py。")
|
||||
return None
|
||||
|
||||
print("\n请选择要用于预测的模型:")
|
||||
for i, dir_path in enumerate(run_dirs):
|
||||
print(f" [{i+1}] {dir_path.name}")
|
||||
|
||||
while True:
|
||||
try:
|
||||
choice = input(f"请输入选项编号 (1-{len(run_dirs)}) 或按 Enter 退出: ")
|
||||
if not choice:
|
||||
return None
|
||||
choice_idx = int(choice) - 1
|
||||
if 0 <= choice_idx < len(run_dirs):
|
||||
selected_dir = run_dirs[choice_idx]
|
||||
logging.info(f"已选择模型: {selected_dir}")
|
||||
return selected_dir
|
||||
else:
|
||||
print("无效的选项,请重新输入。")
|
||||
except (ValueError, IndexError):
|
||||
print("无效的输入,请输入一个数字。")
|
||||
|
||||
# 【修改】函数签名和内部逻辑,不再接收一个列表,而是直接接收聚合后的统计数据
|
||||
def calculate_and_save_final_metrics(tp, fp, fn, tn, num_classes: int, save_dir: Path) -> None:
|
||||
"""根据聚合后的统计数据,计算总体指标,并保存到 CSV 文件。"""
|
||||
logging.info("\n正在计算最终的整体指标...")
|
||||
|
||||
# 根据 config.py 中的设置,筛选出用于计算宏观指标的类别
|
||||
ignore_indices = [config.CLASSES.index(cls) for cls in config.EVALUATION_CLASSES_TO_IGNORE if cls in config.CLASSES]
|
||||
|
||||
keep_mask = torch.ones(num_classes, dtype=torch.bool, device=tp.device)
|
||||
if ignore_indices:
|
||||
keep_mask[ignore_indices] = False
|
||||
logging.info(f"最终评估时将忽略类别: {config.EVALUATION_CLASSES_TO_IGNORE}")
|
||||
|
||||
tp_filtered = tp[keep_mask]
|
||||
fp_filtered = fp[keep_mask]
|
||||
fn_filtered = fn[keep_mask]
|
||||
tn_filtered = tn[keep_mask]
|
||||
|
||||
# 使用过滤后的统计数据计算总体指标
|
||||
metrics = {
|
||||
"iou_score": smp.metrics.iou_score(tp_filtered, fp_filtered, fn_filtered, tn_filtered, reduction='micro').item(),
|
||||
"f1_score": smp.metrics.f1_score(tp_filtered, fp_filtered, fn_filtered, tn_filtered, reduction='micro').item(),
|
||||
"accuracy": smp.metrics.accuracy(tp_filtered, fp_filtered, fn_filtered, tn_filtered, reduction='micro').item(),
|
||||
"recall": smp.metrics.recall(tp_filtered, fp_filtered, fn_filtered, tn_filtered, reduction='micro').item(),
|
||||
"precision": smp.metrics.precision(tp_filtered, fp_filtered, fn_filtered, tn_filtered, reduction='micro').item(),
|
||||
}
|
||||
|
||||
# 添加每个类别的原始统计数据(使用未经过滤的数据)
|
||||
for i in range(num_classes):
|
||||
metrics[f'tp_class_{i}'] = tp[i].item()
|
||||
metrics[f'tn_class_{i}'] = tn[i].item()
|
||||
metrics[f'fp_class_{i}'] = fp[i].item()
|
||||
metrics[f'fn_class_{i}'] = fn[i].item()
|
||||
|
||||
csv_save_path = save_dir / "test_set_metrics.csv"
|
||||
log_metrics_to_csv(metrics, str(csv_save_path))
|
||||
logging.info(f"-> 整体测试集指标已保存至: {csv_save_path}")
|
||||
|
||||
def main(model_architecture: str) -> None:
|
||||
"""
|
||||
主函数,执行模型加载、图像预测、评估,并生成分析结果。
|
||||
"""
|
||||
# --- 1. 【修改】选择模型并动态设置路径 ---
|
||||
selected_run_dir = select_trained_model_path(Path(config.PREDICT_BEST_MODEL_DIR), model_architecture)
|
||||
|
||||
if not selected_run_dir:
|
||||
logging.info("未选择任何模型,程序退出。")
|
||||
sys.exit()
|
||||
|
||||
# 【重要】动态覆写 config 中的路径,使其指向所选模型的目录
|
||||
# 这样,后续所有代码都会自动在正确的文件夹中加载模型和保存结果
|
||||
best_model_save_path = selected_run_dir / "best_model.pth"
|
||||
raw_mask_dir = selected_run_dir / "predicted_raw_masks"
|
||||
analysis_results_dir = selected_run_dir / "prediction_analysis"
|
||||
|
||||
if not best_model_save_path.exists():
|
||||
logging.error(f"错误: 在 '{selected_run_dir}' 中未找到 'best_model.pth' 文件!")
|
||||
sys.exit()
|
||||
|
||||
# setup_directories 函数现在用于创建预测结果的子目录
|
||||
utils.setup_directories(raw_mask_dir)
|
||||
utils.setup_directories(analysis_results_dir)
|
||||
|
||||
model = load_model(model_architecture, best_model_save_path) # 1. 加载模型
|
||||
|
||||
# ======================== 【新增/修改的代码段开始】 ======================== #
|
||||
# --- 2. 探测模型以确定正确的图像尺寸 ---
|
||||
logging.info("正在探测模型以确定预测时所需的输入尺寸...") #
|
||||
target_height, target_width = config.IMAGE_HEIGHT, config.IMAGE_WIDTH #
|
||||
|
||||
try:
|
||||
# 检查编码器是否需要固定输入尺寸
|
||||
if model.encoder.is_fixed_input_size:
|
||||
required_size = model.encoder.input_size
|
||||
target_height = required_size[1]
|
||||
target_width = required_size[2]
|
||||
logging.warning(
|
||||
f"模型 '{model_architecture}' 的编码器需要固定的输入尺寸 "
|
||||
f"({target_height}x{target_width})。"
|
||||
f"将忽略 config.py 中的尺寸设置进行预测。"
|
||||
)
|
||||
else:
|
||||
logging.info(f"模型 '{model_architecture}' 支持动态输入尺寸。将使用 config.py 中定义的尺寸 ({target_height}x{target_width})。")
|
||||
except AttributeError:
|
||||
logging.info("无法自动检测模型尺寸要求。将使用 config.py 中定义的尺寸。")
|
||||
|
||||
# --- 3. 获取用于预测的图像预处理流程 ---
|
||||
# 使用探测到的或配置中指定的目标尺寸
|
||||
transform = get_preprocessing_transform(target_height, target_width) #
|
||||
# ======================== 【新增/修改的代码段结束】 ======================== #
|
||||
|
||||
# --- 4. 获取图像列表并判断模式 ---
|
||||
image_filenames = sorted(list(config.TEST_IMAGE_DIR.iterdir()))
|
||||
evaluate_mode = config.TEST_MASK_DIR.is_dir()
|
||||
|
||||
if evaluate_mode:
|
||||
logging.info("正在以 [评估模式] 运行。将使用真实掩码进行评估。")
|
||||
else:
|
||||
logging.info("正在以 [仅预测模式] 运行。未找到真实掩码。")
|
||||
|
||||
# --- 3. 初始化用于增量计算的变量 ---
|
||||
# 【修改】不再使用 list 来累积,而是直接累加 TP, FP, FN, TN
|
||||
num_classes = len(config.CLASSES)
|
||||
total_tp = torch.zeros(num_classes, dtype=torch.long, device=config.DEVICE)
|
||||
total_fp = torch.zeros(num_classes, dtype=torch.long, device=config.DEVICE)
|
||||
total_fn = torch.zeros(num_classes, dtype=torch.long, device=config.DEVICE)
|
||||
total_tn = torch.zeros(num_classes, dtype=torch.long, device=config.DEVICE)
|
||||
|
||||
# 【新增】为 ROC/PR 曲线创建临时文件,用磁盘空间换内存
|
||||
# 使用 tempfile 模块来安全地创建临时文件
|
||||
temp_gt_file = tempfile.NamedTemporaryFile(delete=False, suffix='.npy')
|
||||
temp_probs_file = tempfile.NamedTemporaryFile(delete=False, suffix='.npy')
|
||||
logging.info(f"为ROC/PR曲线创建临时文件: {temp_gt_file.name}, {temp_probs_file.name}")
|
||||
|
||||
# --- 4. 主推理循环 - 预测图片 ---
|
||||
for image_path in tqdm(image_filenames, desc="正在处理测试图像"):
|
||||
# --- 4.1. 读取图片并预处理 ---
|
||||
original_image = cv2.cvtColor(cv2.imread(str(image_path)), cv2.COLOR_BGR2RGB)
|
||||
original_h, original_w = original_image.shape[:2]
|
||||
image_tensor = transform(image=original_image)['image'].unsqueeze(0).to(config.DEVICE)
|
||||
|
||||
# --- 4.2. 图片推理并保存 ---
|
||||
with torch.no_grad():
|
||||
pred_logits_tensor = model(image_tensor)
|
||||
pred_mask_tensor = torch.argmax(pred_logits_tensor, dim=1)
|
||||
pred_mask_numpy = pred_mask_tensor.squeeze().cpu().numpy().astype(np.uint8)
|
||||
pred_mask_resized_raw = cv2.resize(pred_mask_numpy, (original_w, original_h), interpolation=cv2.INTER_NEAREST)
|
||||
|
||||
# 保存原始(单通道)的预测掩码
|
||||
cv2.imwrite(str(raw_mask_dir / f"{image_path.stem}.png"), pred_mask_resized_raw)
|
||||
|
||||
# --- 4.3. 评估与可视化 ---
|
||||
gt_mask_numpy, per_image_stats, stats_text_display, gt_mask_raw = None, None, "", None
|
||||
if evaluate_mode:
|
||||
mask_path = config.TEST_MASK_DIR / image_path.name
|
||||
if mask_path.exists():
|
||||
gt_mask_raw = cv2.imread(str(mask_path), cv2.IMREAD_GRAYSCALE)
|
||||
gt_mask_numpy = cv2.resize(
|
||||
gt_mask_raw, (target_width, target_height),
|
||||
interpolation=cv2.INTER_NEAREST
|
||||
)
|
||||
gt_mask_tensor = torch.from_numpy(gt_mask_numpy).long().to(config.DEVICE).unsqueeze(0)
|
||||
|
||||
# ======================== 【修复代码段开始】 ======================== #
|
||||
|
||||
# --- 1. [修复] 计算单张图像的统计数据 ---
|
||||
# 这个代码块是必需的,它为下面的 per_image_stats 和 stats_text_display 定义了 tp, fp, fn, tn
|
||||
tp, fp, fn, tn = smp.metrics.get_stats(
|
||||
pred_mask_tensor, gt_mask_tensor,
|
||||
mode=config.SEG_MODE, num_classes=num_classes
|
||||
)
|
||||
|
||||
# --- 2. 累加到总数,用于最终的整体评估 ---
|
||||
total_tp += tp.squeeze(0).to(config.DEVICE)
|
||||
total_fp += fp.squeeze(0).to(config.DEVICE)
|
||||
total_fn += fn.squeeze(0).to(config.DEVICE)
|
||||
total_tn += tn.squeeze(0).to(config.DEVICE)
|
||||
|
||||
# --- 3. 将曲线数据写入临时文件 (这部分逻辑正确,无需修改) ---
|
||||
probs_tensor = torch.softmax(pred_logits_tensor, dim=1)
|
||||
gt_flat = gt_mask_tensor.cpu().numpy().flatten()
|
||||
probs_flat = probs_tensor.permute(0, 2, 3, 1).reshape(-1, num_classes).cpu().numpy()
|
||||
with open(temp_gt_file.name, 'ab') as f_gt, open(temp_probs_file.name, 'ab') as f_probs:
|
||||
np.save(f_gt, gt_flat)
|
||||
np.save(f_probs, probs_flat)
|
||||
|
||||
# --- 4. [修复] 现在可以安全地计算单图指标用于可视化 ---
|
||||
# 因为上面的代码块已经定义了 tp, fp, fn, tn,所以这里不再报错
|
||||
per_image_stats = {
|
||||
'iou': smp.metrics.iou_score(tp, fp, fn, tn, reduction='micro').item(),
|
||||
'acc': smp.metrics.accuracy(tp, fp, fn, tn, reduction='micro').item()
|
||||
}
|
||||
|
||||
# --- 5. [修复] 修正生成统计文本时的 Bug ---
|
||||
# `get_stats` 返回的 tp 张量形状是 (1, num_classes),
|
||||
# 因此我们应该遍历第1维度 (类别),并使用 [0, i] 进行索引
|
||||
stats_text_display = "Each Class TP / FP / FN / TN:\n" + "-"*35 + "\n"
|
||||
for i in range(1, tp.shape[1]): # 修正: 遍历类别维度 tp.shape[1]
|
||||
if i < len(config.CLASSES):
|
||||
class_name = config.CLASSES[i]
|
||||
stats_text_display += (
|
||||
f" {class_name:<10}: "
|
||||
# 修正: 使用正确的索引 tp[0, i]
|
||||
f"{tp[0, i]:>5} / {fp[0, i]:>5} / {fn[0, i]:>5} / {tn[0, i]:>7}\n"
|
||||
)
|
||||
|
||||
save_visual_comparison(
|
||||
image_name=image_path.name,
|
||||
original_image=original_image,
|
||||
pred_mask=pred_mask_resized_raw,
|
||||
gt_mask=gt_mask_raw,
|
||||
stats=per_image_stats,
|
||||
stats_text=stats_text_display,
|
||||
save_dir = analysis_results_dir
|
||||
)
|
||||
|
||||
logging.info("\n预测流程已完成!")
|
||||
logging.info(f"原始预测掩码已保存至: {raw_mask_dir}")
|
||||
logging.info(f"分析图像已保存至: {analysis_results_dir}")
|
||||
|
||||
# --- 5. 最终分析(循环结束后) ---
|
||||
if evaluate_mode:
|
||||
# calculate_and_save_final_metrics(all_stats_for_metrics, len(config.CLASSES), save_dir = analysis_results_dir)
|
||||
calculate_and_save_final_metrics(total_tp.cpu(), total_fp.cpu(), total_fn.cpu(), total_tn.cpu(), num_classes, save_dir=analysis_results_dir)
|
||||
|
||||
# if all_gt_for_curves:
|
||||
# logging.info("\n正在聚合结果以生成 ROC 和 PR 曲线...")
|
||||
# full_gt = torch.cat(all_gt_for_curves).numpy().flatten()
|
||||
# full_logits = torch.cat(all_logits_for_curves)
|
||||
|
||||
# # 调整维度并计算概率
|
||||
# num_classes = len(config.CLASSES)
|
||||
# full_probs = softmax(
|
||||
# full_logits.permute(0, 2, 3, 1).reshape(-1, num_classes).numpy(),
|
||||
# axis=1
|
||||
# )
|
||||
|
||||
# generate_and_save_curves(
|
||||
# y_true=full_gt, y_probs=full_probs,
|
||||
# class_names=config.CLASSES, output_dir=analysis_results_dir
|
||||
# )
|
||||
|
||||
# 【新增】从临时文件中分块读取数据来生成曲线
|
||||
try:
|
||||
logging.info("\n正在从临时文件加载数据以生成 ROC 和 PR 曲线...")
|
||||
temp_gt_file.close()
|
||||
temp_probs_file.close()
|
||||
|
||||
all_gt_parts, all_probs_parts = [], []
|
||||
with open(temp_gt_file.name, 'rb') as f_gt, open(temp_probs_file.name, 'rb') as f_probs:
|
||||
while True:
|
||||
try:
|
||||
all_gt_parts.append(np.load(f_gt))
|
||||
all_probs_parts.append(np.load(f_probs))
|
||||
# 【修复】将 EOFError 添加到要捕获的异常列表中
|
||||
except (IOError, ValueError, EOFError):
|
||||
break # 文件读取完毕
|
||||
|
||||
if all_gt_parts:
|
||||
full_gt = np.concatenate(all_gt_parts)
|
||||
full_probs = np.concatenate(all_probs_parts)
|
||||
|
||||
generate_and_save_curves(
|
||||
y_true=full_gt, y_probs=full_probs,
|
||||
class_names=config.CLASSES, output_dir=analysis_results_dir
|
||||
)
|
||||
else:
|
||||
logging.warning("临时文件中没有数据,跳过 ROC/PR 曲线生成。")
|
||||
|
||||
finally:
|
||||
# 【新增】确保在最后删除临时文件
|
||||
import os
|
||||
os.remove(temp_gt_file.name)
|
||||
os.remove(temp_probs_file.name)
|
||||
logging.info("已清理临时文件。")
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 创建一个参数解析器
|
||||
parser = argparse.ArgumentParser(description="训练图像分割模型。")
|
||||
|
||||
# 添加 --architecture 参数
|
||||
# - `choices` 会自动从你的 config 文件中获取所有可用的模型名称
|
||||
# - `default` 设置了在不提供参数时的默认模型
|
||||
parser.add_argument(
|
||||
"-a", "--architecture",
|
||||
type=str,
|
||||
# default='Unet',
|
||||
choices=list(config.ALL_MODEL_CONFIGS.keys()),
|
||||
required=True,
|
||||
help="选择要训练的模型架构。"
|
||||
)
|
||||
|
||||
# 解析命令行传入的参数
|
||||
args = parser.parse_args()
|
||||
|
||||
# 将解析出的架构名称传入 main 函数并执行
|
||||
main(model_architecture=args.architecture)
|
||||
Reference in New Issue
Block a user