first commit

This commit is contained in:
admin
2026-05-20 15:05:35 +08:00
commit ac09b26253
2048 changed files with 189478 additions and 0 deletions

View File

@@ -0,0 +1,564 @@
import logging, argparse, sys, utils
import shutil, tempfile
from pathlib import Path
from typing import Dict, List, Tuple
import albumentations as A
import cv2
import matplotlib.pyplot as plt
import numpy as np
import segmentation_models_pytorch as smp
import torch
from albumentations.pytorch import ToTensorV2
from scipy.special import softmax
from sklearn.metrics import (auc, average_precision_score,
precision_recall_curve, roc_curve)
from tqdm import tqdm
# 本地应用/库的导入
import config
from utils import log_metrics_to_csv, get_preprocessing_transform, setup_directories
# --- 日志设置 ---
# 配置日志记录器
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(message)s",
handlers=[logging.StreamHandler()] # 输出日志到控制台
)
# 为 matplotlib 设置中文字体,请根据您的系统环境选择合适的字体
plt.rcParams['font.sans-serif'] = ['Noto Sans CJK JP', 'SimHei', 'Microsoft YaHei'] # 例如SimHei, Microsoft YaHei
plt.rcParams['axes.unicode_minus'] = False # 解决负号显示问题
# --- 核心工具函数 ---
# 将单通道的类别索引掩码转换为 RGB 彩色掩码
def colorize_mask(mask: np.ndarray, class_rgb_values: List[Tuple[int, int, int]]) -> np.ndarray:
"""将单通道的类别索引掩码转换为 RGB 彩色掩码。"""
color_mask = np.zeros((mask.shape[0], mask.shape[1], 3), dtype=np.uint8)
# 一个通用的调色板,可以轻松扩展
palette = [
(0, 0, 0), (220, 20, 60), (0, 255, 0), (0, 0, 255),
(255, 255, 0), (255, 0, 255), (0, 255, 255), (244, 164, 96),
]
for class_idx, color in enumerate(palette):
if class_idx < len(class_rgb_values):
# 假设掩码中的像素值 [0, 1, 2...] 对应类别的索引
pixels_to_color = (mask == class_idx)
color_mask[pixels_to_color] = color
return color_mask
# --- 分析与可视化函数 ---
def generate_and_save_curves(y_true: np.ndarray, y_probs: np.ndarray, class_names: List[str], output_dir: Path) -> None:
"""
根据整个测试集的聚合预测结果,为每个类别生成并保存 ROC 和 PR 曲线。
"""
logging.info("正在生成 ROC 和 PR 曲线...")
num_classes = len(class_names)
plt.style.use('seaborn-v0_8-whitegrid')
# --- 准备画布 ---
fig_roc, ax_roc = plt.subplots(figsize=(10, 8))
ax_roc.plot([0, 1], [0, 1], 'k--', label='随机猜测')
fig_pr, ax_pr = plt.subplots(figsize=(10, 8))
# --- 为每个类别(跳过背景)生成曲线 ---
for i in range(1, num_classes):
class_name = class_names[i]
y_true_binary = (y_true == i).astype(int)
y_class_probs = y_probs[:, i]
# ROC 曲线
fpr, tpr, _ = roc_curve(y_true_binary, y_class_probs)
roc_auc = auc(fpr, tpr)
ax_roc.plot(fpr, tpr, label=f'类别 {class_name} (AUC = {roc_auc:.4f})')
# PR 曲线
precision, recall, _ = precision_recall_curve(y_true_binary, y_class_probs)
avg_precision = average_precision_score(y_true_binary, y_class_probs)
ax_pr.plot(recall, precision, label=f'类别 {class_name} (AP = {avg_precision:.4f})')
# --- 美化并保存 ROC 曲线图 ---
ax_roc.set(xlabel='False positive rate (FPR)', ylabel='True positive rate (TPR)', title='Receiver operating characteristic (ROC) curve')
ax_roc.legend(loc='lower right')
ax_roc.set_aspect('equal', adjustable='box')
fig_roc.tight_layout()
roc_save_path = output_dir / "ROC_Curves.png"
fig_roc.savefig(roc_save_path)
plt.close(fig_roc)
logging.info(f"-> ROC 曲线图已保存至: {roc_save_path}")
# --- 美化并保存 PR 曲线图 ---
ax_pr.set(xlabel='Recall', ylabel='Precision', title='Precision-Recall (PR) Curve')
ax_pr.legend(loc='lower left')
ax_pr.set_aspect('equal', adjustable='box')
fig_pr.tight_layout()
pr_save_path = output_dir / "Precision_Recall_Curves.png"
fig_pr.savefig(pr_save_path)
plt.close(fig_pr)
logging.info(f"-> PR 曲线图已保存至: {pr_save_path}")
def save_visual_comparison(
image_name: str, original_image: np.ndarray, pred_mask: np.ndarray,
gt_mask: np.ndarray = None, stats: Dict = None, stats_text: str = "",
save_dir: str = ""
) -> None:
"""保存原始图像、预测掩码以及可选的真实掩码的对比图。"""
pred_color_mask = colorize_mask(pred_mask, config.CLASS_RGB_VALUES)
if gt_mask is not None and stats is not None:
num_plots = 3
figsize = (22, 8)
gt_color_mask = colorize_mask(gt_mask, config.CLASS_RGB_VALUES)
save_name = f"{Path(image_name).stem}-IOU_{stats['iou']:.4f}-ACC_{stats['acc']:.4f}.png"
else:
num_plots = 2
figsize = (14, 6)
save_name = f"{Path(image_name).stem}.png"
fig, axes = plt.subplots(1, num_plots, figsize=figsize)
fig.suptitle(f"Predict: {image_name}", fontsize=16)
axes[0].imshow(original_image)
axes[0].set_title("ORI_IMG")
axes[0].axis('off')
pred_title = "Predicted Mask"
if stats:
pred_title += f"\nIoU: {stats['iou']:.4f} | Acc: {stats['acc']:.4f}"
axes[1].imshow(pred_color_mask)
axes[1].set_title(pred_title)
axes[1].axis('off')
if num_plots == 3:
axes[2].imshow(gt_color_mask)
axes[2].set_title("Ground Truth")
axes[2].axis('off')
plt.figtext(0.5, 0.01, stats_text, ha="center", fontsize=10,
bbox={"facecolor":"white", "alpha":0.7, "pad":5}, family='monospace')
fig.subplots_adjust(bottom=0.3)
plt.tight_layout()
plt.savefig(save_dir / save_name)
plt.close(fig)
# --- 主要逻辑函数 ---
def load_model(model_architecture: str, model_save_path: Path) -> torch.nn.Module:
"""
加载已训练的分割模型。
此函数经过优化,可以智能处理由 torch.nn.DataParallel 保存的
(带有 'module.' 前缀) 和常规保存的模型权重。
"""
logging.info(f"正在从以下路径加载模型: {model_save_path}")
num_classes = len(config.CLASSES)
# 1. 【修改】根据 config 动态创建基础模型架构
logging.info(f"正在重建模型架构: '{model_architecture}'")
try:
model_params = config.ALL_MODEL_CONFIGS[model_architecture].copy() #
model_class = getattr(smp, model_architecture)
except AttributeError:
logging.error(f"模型 '{model_architecture}' 在 segmentation_models_pytorch 库中不存在!")
raise
# 2. 【修改】准备参数,与训练时保持一致
# 注意:加载权重时,我们不希望再次下载预训练权重,所以设为 None
params = model_params
params['in_channels'] = 3
params['classes'] = num_classes
params['encoder_weights'] = None # 非常重要避免在预测时重新下载ImageNet权重
# ======================== 【新增代码段开始】 ======================== #
# 自动检测 encoder_name 是否包含 'vit'
encoder_name = model_params.get('encoder_name', '')
if 'vit' in encoder_name.lower():
params['dynamic_img_size'] = True
logging.info(f"检测到 ViT 编码器 ('{encoder_name}')。自动设置 dynamic_img_size=True。")
# ======================== 【新增代码段结束】 ======================== #
# 3. 【修改】使用 **kwargs 创建模型实例
model = model_class(**params).to(config.DEVICE)
# 4. 从文件加载 state_dict
state_dict = torch.load(model_save_path, map_location=torch.device(config.DEVICE))
# 3. 检查是否存在 'module.' 前缀 (判断是否为 DataParallel 保存的权重)
# 我们通过检查字典中任何一个键是否以 'module.' 开头来判断
is_dataparallel = any(key.startswith('module.') for key in state_dict.keys())
if is_dataparallel:
logging.info("检测到模型由 DataParallel 保存,将移除 'module.' 前缀。")
# 创建一个新的、不带前缀的 state_dict
new_state_dict = {key.replace('module.', ''): value for key, value in state_dict.items()}
model.load_state_dict(new_state_dict)
else:
# 如果没有 'module.' 前缀,直接加载
logging.info("直接加载标准模型权重。")
model.load_state_dict(state_dict)
model.eval()
return model
def calculate_and_save_final_metrics(stats_list: List[Dict], num_classes: int, save_dir: str) -> None:
"""聚合每张图像的统计数据,计算总体指标,并保存到 CSV 文件。"""
if not stats_list:
logging.warning("未收集到统计数据,跳过最终指标计算。")
return
logging.info("\n正在聚合结果以计算最终指标...")
# 堆叠所有统计张量并按批次维度求和
tp = torch.cat([s['tp'] for s in stats_list], dim=0).sum(dim=0) # (N×B, C) -> (C,)
fp = torch.cat([s['fp'] for s in stats_list], dim=0).sum(dim=0) # (N×B, C) -> (C,)
fn = torch.cat([s['fn'] for s in stats_list], dim=0).sum(dim=0) # (N×B, C) -> (C,)
tn = torch.cat([s['tn'] for s in stats_list], dim=0).sum(dim=0) # (N×B, C) -> (C,)
# ======================== 【新增代码段开始】 ======================== #
# 根据 config.py 中的设置,筛选出用于计算宏观指标的类别
ignore_indices = [config.CLASSES.index(cls) for cls in config.EVALUATION_CLASSES_TO_IGNORE if cls in config.CLASSES]
keep_mask = torch.ones(num_classes, dtype=torch.bool, device=tp.device)
if ignore_indices:
keep_mask[ignore_indices] = False
logging.info(f"最终评估时将忽略类别: {config.EVALUATION_CLASSES_TO_IGNORE}")
tp_filtered = tp[keep_mask]
fp_filtered = fp[keep_mask]
fn_filtered = fn[keep_mask]
tn_filtered = tn[keep_mask]
# ======================== 【新增代码段结束】 ======================== #
# ======================== 【修改代码段开始】 ======================== #
# 使用过滤后的统计数据计算总体指标
metrics = {
"iou_score": smp.metrics.iou_score(tp_filtered, fp_filtered, fn_filtered, tn_filtered, reduction='micro').item(),
"f1_score": smp.metrics.f1_score(tp_filtered, fp_filtered, fn_filtered, tn_filtered, reduction='micro').item(),
"accuracy": smp.metrics.accuracy(tp_filtered, fp_filtered, fn_filtered, tn_filtered, reduction='micro').item(),
"recall": smp.metrics.recall(tp_filtered, fp_filtered, fn_filtered, tn_filtered, reduction='micro').item(),
"precision": smp.metrics.precision(tp_filtered, fp_filtered, fn_filtered, tn_filtered, reduction='micro').item(),
}
# ======================== 【修改代码段结束】 ======================== #
# 添加每个类别的原始统计数据(使用未经过滤的数据)
for i in range(num_classes):
metrics[f'tp_class_{i}'] = tp[i].item()
metrics[f'tn_class_{i}'] = tn[i].item()
metrics[f'fp_class_{i}'] = fp[i].item()
metrics[f'fn_class_{i}'] = fn[i].item()
csv_save_path = save_dir / "test_set_metrics.csv"
log_metrics_to_csv(metrics, str(csv_save_path))
logging.info(f"-> 整体测试集指标已保存至: {csv_save_path}")
# 2. 【新增】一个帮助函数,用于交互式选择模型目录
def select_trained_model_path(base_dir: Path, model_architecture: str) -> Path:
"""
查找指定架构的所有训练运行目录,并让用户选择一个。
"""
logging.info(f"正在 '{base_dir}' 中搜索模型 '{model_architecture}' 的训练记录...")
# 查找所有以架构名开头的文件夹
run_dirs = sorted([d for d in base_dir.iterdir() if d.is_dir() and d.name.startswith(model_architecture+'_')])
if not run_dirs:
logging.error(f"未找到任何 '{model_architecture}' 模型的训练记录。请先运行 train.py。")
return None
print("\n请选择要用于预测的模型:")
for i, dir_path in enumerate(run_dirs):
print(f" [{i+1}] {dir_path.name}")
while True:
try:
choice = input(f"请输入选项编号 (1-{len(run_dirs)}) 或按 Enter 退出: ")
if not choice:
return None
choice_idx = int(choice) - 1
if 0 <= choice_idx < len(run_dirs):
selected_dir = run_dirs[choice_idx]
logging.info(f"已选择模型: {selected_dir}")
return selected_dir
else:
print("无效的选项,请重新输入。")
except (ValueError, IndexError):
print("无效的输入,请输入一个数字。")
# 【修改】函数签名和内部逻辑,不再接收一个列表,而是直接接收聚合后的统计数据
def calculate_and_save_final_metrics(tp, fp, fn, tn, num_classes: int, save_dir: Path) -> None:
"""根据聚合后的统计数据,计算总体指标,并保存到 CSV 文件。"""
logging.info("\n正在计算最终的整体指标...")
# 根据 config.py 中的设置,筛选出用于计算宏观指标的类别
ignore_indices = [config.CLASSES.index(cls) for cls in config.EVALUATION_CLASSES_TO_IGNORE if cls in config.CLASSES]
keep_mask = torch.ones(num_classes, dtype=torch.bool, device=tp.device)
if ignore_indices:
keep_mask[ignore_indices] = False
logging.info(f"最终评估时将忽略类别: {config.EVALUATION_CLASSES_TO_IGNORE}")
tp_filtered = tp[keep_mask]
fp_filtered = fp[keep_mask]
fn_filtered = fn[keep_mask]
tn_filtered = tn[keep_mask]
# 使用过滤后的统计数据计算总体指标
metrics = {
"iou_score": smp.metrics.iou_score(tp_filtered, fp_filtered, fn_filtered, tn_filtered, reduction='micro').item(),
"f1_score": smp.metrics.f1_score(tp_filtered, fp_filtered, fn_filtered, tn_filtered, reduction='micro').item(),
"accuracy": smp.metrics.accuracy(tp_filtered, fp_filtered, fn_filtered, tn_filtered, reduction='micro').item(),
"recall": smp.metrics.recall(tp_filtered, fp_filtered, fn_filtered, tn_filtered, reduction='micro').item(),
"precision": smp.metrics.precision(tp_filtered, fp_filtered, fn_filtered, tn_filtered, reduction='micro').item(),
}
# 添加每个类别的原始统计数据(使用未经过滤的数据)
for i in range(num_classes):
metrics[f'tp_class_{i}'] = tp[i].item()
metrics[f'tn_class_{i}'] = tn[i].item()
metrics[f'fp_class_{i}'] = fp[i].item()
metrics[f'fn_class_{i}'] = fn[i].item()
csv_save_path = save_dir / "test_set_metrics.csv"
log_metrics_to_csv(metrics, str(csv_save_path))
logging.info(f"-> 整体测试集指标已保存至: {csv_save_path}")
def main(model_architecture: str) -> None:
"""
主函数,执行模型加载、图像预测、评估,并生成分析结果。
"""
# --- 1. 【修改】选择模型并动态设置路径 ---
selected_run_dir = select_trained_model_path(Path(config.PREDICT_BEST_MODEL_DIR), model_architecture)
if not selected_run_dir:
logging.info("未选择任何模型,程序退出。")
sys.exit()
# 【重要】动态覆写 config 中的路径,使其指向所选模型的目录
# 这样,后续所有代码都会自动在正确的文件夹中加载模型和保存结果
best_model_save_path = selected_run_dir / "best_model.pth"
raw_mask_dir = selected_run_dir / "predicted_raw_masks"
analysis_results_dir = selected_run_dir / "prediction_analysis"
if not best_model_save_path.exists():
logging.error(f"错误: 在 '{selected_run_dir}' 中未找到 'best_model.pth' 文件!")
sys.exit()
# setup_directories 函数现在用于创建预测结果的子目录
utils.setup_directories(raw_mask_dir)
utils.setup_directories(analysis_results_dir)
model = load_model(model_architecture, best_model_save_path) # 1. 加载模型
# ======================== 【新增/修改的代码段开始】 ======================== #
# --- 2. 探测模型以确定正确的图像尺寸 ---
logging.info("正在探测模型以确定预测时所需的输入尺寸...") #
target_height, target_width = config.IMAGE_HEIGHT, config.IMAGE_WIDTH #
try:
# 检查编码器是否需要固定输入尺寸
if model.encoder.is_fixed_input_size:
required_size = model.encoder.input_size
target_height = required_size[1]
target_width = required_size[2]
logging.warning(
f"模型 '{model_architecture}' 的编码器需要固定的输入尺寸 "
f"({target_height}x{target_width})。"
f"将忽略 config.py 中的尺寸设置进行预测。"
)
else:
logging.info(f"模型 '{model_architecture}' 支持动态输入尺寸。将使用 config.py 中定义的尺寸 ({target_height}x{target_width})。")
except AttributeError:
logging.info("无法自动检测模型尺寸要求。将使用 config.py 中定义的尺寸。")
# --- 3. 获取用于预测的图像预处理流程 ---
# 使用探测到的或配置中指定的目标尺寸
transform = get_preprocessing_transform(target_height, target_width) #
# ======================== 【新增/修改的代码段结束】 ======================== #
# --- 4. 获取图像列表并判断模式 ---
image_filenames = sorted(list(config.TEST_IMAGE_DIR.iterdir()))
evaluate_mode = config.TEST_MASK_DIR.is_dir()
if evaluate_mode:
logging.info("正在以 [评估模式] 运行。将使用真实掩码进行评估。")
else:
logging.info("正在以 [仅预测模式] 运行。未找到真实掩码。")
# --- 3. 初始化用于增量计算的变量 ---
# 【修改】不再使用 list 来累积,而是直接累加 TP, FP, FN, TN
num_classes = len(config.CLASSES)
total_tp = torch.zeros(num_classes, dtype=torch.long, device=config.DEVICE)
total_fp = torch.zeros(num_classes, dtype=torch.long, device=config.DEVICE)
total_fn = torch.zeros(num_classes, dtype=torch.long, device=config.DEVICE)
total_tn = torch.zeros(num_classes, dtype=torch.long, device=config.DEVICE)
# 【新增】为 ROC/PR 曲线创建临时文件,用磁盘空间换内存
# 使用 tempfile 模块来安全地创建临时文件
temp_gt_file = tempfile.NamedTemporaryFile(delete=False, suffix='.npy')
temp_probs_file = tempfile.NamedTemporaryFile(delete=False, suffix='.npy')
logging.info(f"为ROC/PR曲线创建临时文件: {temp_gt_file.name}, {temp_probs_file.name}")
# --- 4. 主推理循环 - 预测图片 ---
for image_path in tqdm(image_filenames, desc="正在处理测试图像"):
# --- 4.1. 读取图片并预处理 ---
original_image = cv2.cvtColor(cv2.imread(str(image_path)), cv2.COLOR_BGR2RGB)
original_h, original_w = original_image.shape[:2]
image_tensor = transform(image=original_image)['image'].unsqueeze(0).to(config.DEVICE)
# --- 4.2. 图片推理并保存 ---
with torch.no_grad():
pred_logits_tensor = model(image_tensor)
pred_mask_tensor = torch.argmax(pred_logits_tensor, dim=1)
pred_mask_numpy = pred_mask_tensor.squeeze().cpu().numpy().astype(np.uint8)
pred_mask_resized_raw = cv2.resize(pred_mask_numpy, (original_w, original_h), interpolation=cv2.INTER_NEAREST)
# 保存原始(单通道)的预测掩码
cv2.imwrite(str(raw_mask_dir / f"{image_path.stem}.png"), pred_mask_resized_raw)
# --- 4.3. 评估与可视化 ---
gt_mask_numpy, per_image_stats, stats_text_display, gt_mask_raw = None, None, "", None
if evaluate_mode:
mask_path = config.TEST_MASK_DIR / image_path.name
if mask_path.exists():
gt_mask_raw = cv2.imread(str(mask_path), cv2.IMREAD_GRAYSCALE)
gt_mask_numpy = cv2.resize(
gt_mask_raw, (target_width, target_height),
interpolation=cv2.INTER_NEAREST
)
gt_mask_tensor = torch.from_numpy(gt_mask_numpy).long().to(config.DEVICE).unsqueeze(0)
# ======================== 【修复代码段开始】 ======================== #
# --- 1. [修复] 计算单张图像的统计数据 ---
# 这个代码块是必需的,它为下面的 per_image_stats 和 stats_text_display 定义了 tp, fp, fn, tn
tp, fp, fn, tn = smp.metrics.get_stats(
pred_mask_tensor, gt_mask_tensor,
mode=config.SEG_MODE, num_classes=num_classes
)
# --- 2. 累加到总数,用于最终的整体评估 ---
total_tp += tp.squeeze(0).to(config.DEVICE)
total_fp += fp.squeeze(0).to(config.DEVICE)
total_fn += fn.squeeze(0).to(config.DEVICE)
total_tn += tn.squeeze(0).to(config.DEVICE)
# --- 3. 将曲线数据写入临时文件 (这部分逻辑正确,无需修改) ---
probs_tensor = torch.softmax(pred_logits_tensor, dim=1)
gt_flat = gt_mask_tensor.cpu().numpy().flatten()
probs_flat = probs_tensor.permute(0, 2, 3, 1).reshape(-1, num_classes).cpu().numpy()
with open(temp_gt_file.name, 'ab') as f_gt, open(temp_probs_file.name, 'ab') as f_probs:
np.save(f_gt, gt_flat)
np.save(f_probs, probs_flat)
# --- 4. [修复] 现在可以安全地计算单图指标用于可视化 ---
# 因为上面的代码块已经定义了 tp, fp, fn, tn所以这里不再报错
per_image_stats = {
'iou': smp.metrics.iou_score(tp, fp, fn, tn, reduction='micro').item(),
'acc': smp.metrics.accuracy(tp, fp, fn, tn, reduction='micro').item()
}
# --- 5. [修复] 修正生成统计文本时的 Bug ---
# `get_stats` 返回的 tp 张量形状是 (1, num_classes),
# 因此我们应该遍历第1维度 (类别),并使用 [0, i] 进行索引
stats_text_display = "Each Class TP / FP / FN / TN:\n" + "-"*35 + "\n"
for i in range(1, tp.shape[1]): # 修正: 遍历类别维度 tp.shape[1]
if i < len(config.CLASSES):
class_name = config.CLASSES[i]
stats_text_display += (
f" {class_name:<10}: "
# 修正: 使用正确的索引 tp[0, i]
f"{tp[0, i]:>5} / {fp[0, i]:>5} / {fn[0, i]:>5} / {tn[0, i]:>7}\n"
)
save_visual_comparison(
image_name=image_path.name,
original_image=original_image,
pred_mask=pred_mask_resized_raw,
gt_mask=gt_mask_raw,
stats=per_image_stats,
stats_text=stats_text_display,
save_dir = analysis_results_dir
)
logging.info("\n预测流程已完成!")
logging.info(f"原始预测掩码已保存至: {raw_mask_dir}")
logging.info(f"分析图像已保存至: {analysis_results_dir}")
# --- 5. 最终分析(循环结束后) ---
if evaluate_mode:
# calculate_and_save_final_metrics(all_stats_for_metrics, len(config.CLASSES), save_dir = analysis_results_dir)
calculate_and_save_final_metrics(total_tp.cpu(), total_fp.cpu(), total_fn.cpu(), total_tn.cpu(), num_classes, save_dir=analysis_results_dir)
# if all_gt_for_curves:
# logging.info("\n正在聚合结果以生成 ROC 和 PR 曲线...")
# full_gt = torch.cat(all_gt_for_curves).numpy().flatten()
# full_logits = torch.cat(all_logits_for_curves)
# # 调整维度并计算概率
# num_classes = len(config.CLASSES)
# full_probs = softmax(
# full_logits.permute(0, 2, 3, 1).reshape(-1, num_classes).numpy(),
# axis=1
# )
# generate_and_save_curves(
# y_true=full_gt, y_probs=full_probs,
# class_names=config.CLASSES, output_dir=analysis_results_dir
# )
# 【新增】从临时文件中分块读取数据来生成曲线
try:
logging.info("\n正在从临时文件加载数据以生成 ROC 和 PR 曲线...")
temp_gt_file.close()
temp_probs_file.close()
all_gt_parts, all_probs_parts = [], []
with open(temp_gt_file.name, 'rb') as f_gt, open(temp_probs_file.name, 'rb') as f_probs:
while True:
try:
all_gt_parts.append(np.load(f_gt))
all_probs_parts.append(np.load(f_probs))
# 【修复】将 EOFError 添加到要捕获的异常列表中
except (IOError, ValueError, EOFError):
break # 文件读取完毕
if all_gt_parts:
full_gt = np.concatenate(all_gt_parts)
full_probs = np.concatenate(all_probs_parts)
generate_and_save_curves(
y_true=full_gt, y_probs=full_probs,
class_names=config.CLASSES, output_dir=analysis_results_dir
)
else:
logging.warning("临时文件中没有数据,跳过 ROC/PR 曲线生成。")
finally:
# 【新增】确保在最后删除临时文件
import os
os.remove(temp_gt_file.name)
os.remove(temp_probs_file.name)
logging.info("已清理临时文件。")
if __name__ == "__main__":
# 创建一个参数解析器
parser = argparse.ArgumentParser(description="训练图像分割模型。")
# 添加 --architecture 参数
# - `choices` 会自动从你的 config 文件中获取所有可用的模型名称
# - `default` 设置了在不提供参数时的默认模型
parser.add_argument(
"-a", "--architecture",
type=str,
# default='Unet',
choices=list(config.ALL_MODEL_CONFIGS.keys()),
required=True,
help="选择要训练的模型架构。"
)
# 解析命令行传入的参数
args = parser.parse_args()
# 将解析出的架构名称传入 main 函数并执行
main(model_architecture=args.architecture)

View File

@@ -0,0 +1,137 @@
# predict_check.py
import sys
from pathlib import Path
from typing import Set, Tuple
try:
import config
except ImportError:
print("错误:无法导入 'config.py'。请确保此脚本与 config.py 在同一目录下。")
sys.exit(1)
def get_file_info(directory: Path) -> Tuple[Set[str], Set[str]]:
"""
安全地获取目录中所有文件的基本名stem和后缀suffix
返回:
一个包含 stems 集合和 suffixes 集合的元组。
"""
if not directory.is_dir():
return set(), set()
stems = set()
suffixes = set()
for item in directory.iterdir():
if item.is_file():
stems.add(item.stem)
suffixes.add(item.suffix)
return stems, suffixes
def main():
"""
主函数,用于详细检查源测试数据集和所有预测输出目录之间的
文件数量、文件名和后缀一致性。
"""
print("--- 开始详细检查预测输出 ---")
# 1. 从 config.py 获取源目录并解析文件信息
source_dir = config.TEST_IMAGE_DIR
if not source_dir.exists() or not source_dir.is_dir():
print(f"错误:在 '{source_dir}' 找不到源测试图片目录。")
sys.exit(1)
source_stems, _ = get_file_info(source_dir)
source_file_count = len(source_stems)
if source_file_count == 0:
print(f"警告:源目录 '{source_dir}' 为空或不包含任何文件。检查中止。")
sys.exit(0)
print(f"源数据集: 在 '{source_dir}' 中找到 {source_file_count} 张图片。")
print("-" * 50)
# 2. 从 config.py 获取预测结果的基础目录
predictions_base_dir = Path(config.PREDICT_BEST_MODEL_DIR)
if not predictions_base_dir.exists() or not predictions_base_dir.is_dir():
print(f"错误:在 '{predictions_base_dir}' 找不到预测结果的基础目录。")
sys.exit(1)
# 3. 遍历每个模型的运行目录并进行详细检查
found_predictions = False
global_mismatch_found = False
run_dirs = sorted([d for d in predictions_base_dir.iterdir() if d.is_dir()])
if not run_dirs:
print("信息:在基础目录中没有找到任何模型的运行记录文件夹。")
sys.exit(0)
for run_dir in run_dirs:
target_dir = run_dir / "predicted_raw_masks"
# 如果目标目录不存在,则跳过
if not (target_dir.exists() and target_dir.is_dir()):
continue
print(f"正在检查: '{target_dir}'...")
found_predictions = True
is_current_dir_ok = True
predicted_stems, predicted_suffixes = get_file_info(target_dir)
predicted_file_count = len(predicted_stems)
# --- 检查 1: 文件数量 ---
if source_file_count != predicted_file_count:
print(f" [✗ 数量不匹配] 源目录有 {source_file_count} 个文件,但此目录有 {predicted_file_count} 个。")
is_current_dir_ok = False
# --- 检查 2: 文件名 ---
missing_files = source_stems - predicted_stems
if missing_files:
# 只显示最多3个示例文件名以保持简洁
examples = list(missing_files)[:3]
print(f" [✗ 文件名缺失] 预测结果中缺少 {len(missing_files)} 个文件。例如: {examples}...")
is_current_dir_ok = False
extra_files = predicted_stems - source_stems
if extra_files:
examples = list(extra_files)[:3]
print(f" [✗ 文件名多余] 预测结果中多出 {len(extra_files)} 个文件。例如: {examples}...")
is_current_dir_ok = False
# --- 检查 3: 后缀一致性 ---
# 只有在目录不为空的情况下后缀多于1个才算问题
if predicted_file_count > 0 and len(predicted_suffixes) > 1:
print(f" [✗ 后缀不一致] 预测目录中存在多种文件后缀: {sorted(list(predicted_suffixes))}")
is_current_dir_ok = False
# --- 当前目录的检查总结 ---
if is_current_dir_ok:
# 再次检查目录为空的特殊情况
if predicted_file_count == 0 and source_file_count > 0:
print(" [✗ 目录为空] 预测目录是空的,但源目录包含文件。")
global_mismatch_found = True
else:
suffix_info = list(predicted_suffixes)[0] if len(predicted_suffixes) == 1 else 'N/A'
print(f" [✓ OK] 所有检查通过 (数量: {predicted_file_count}, 后缀: {suffix_info})")
else:
global_mismatch_found = True
print("-" * 25)
# 4. 输出最终的全局检查摘要
print("\n--- 检查摘要 ---")
if not found_predictions:
print("结果: 没有找到任何 'predicted_raw_masks' 目录进行检查。")
elif global_mismatch_found:
print("结论: 检查不通过。发现至少一个预测目录未通过所有检查,请查看上方日志。")
sys.exit(1)
else:
print("结论: 检查通过!所有已找到的预测目录均通过了数量、文件名和后缀一致性检查。")
print("--- 检查完成 ---")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,261 @@
import os
import glob
import logging
import argparse
import csv
from typing import Dict, Optional, Tuple, List
import torch
from fvcore.nn import FlopCountAnalysis
from train import initialize_components
import config # Assumes config.py is in the same directory or accessible
import time
import numpy as np
from pathlib import Path
import sys
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
# TODO 使用方法 TODO
# 交互模式(推荐)
# python 2_predict_params_and_FLOPs.py
# 非交互模式(用于脚本)
# python 2_predict_params_and_FLOPs.py --input_dir '../BestMode_Predict_Results_DataSet_Public'
# python 2_predict_params_and_FLOPs.py --shape 512 512 # --shape 1080 1920
def get_flops_and_params_smp(architecture: str, shape: Tuple[int, int]) -> Optional[Dict[str, str]]:
"""
Calculates FLOPs and parameters for a given SMP model architecture using fvcore.
Args:
architecture (str): The name of the model architecture.
shape (Tuple[int, int]): The input image shape (H, W).
Returns:
Optional[Dict[str, str]]: A dictionary with 'params' and 'flops', or None on failure.
"""
logging.info(f"Analyzing model: '{architecture}' with input shape {shape}...")
num_classes = len(config.CLASSES)
model, _, _, _ = initialize_components(architecture, num_classes, config.SEG_MODE)
if model is None:
return None
model.to(config.DEVICE)
model.eval()
dummy_input = torch.randn(1, 3, shape[0], shape[1]).to(config.DEVICE)
try:
# Calculate FLOPs using fvcore
flops_analyzer = FlopCountAnalysis(model, dummy_input)
total_flops = flops_analyzer.total()
gflops = total_flops / 1e9
# Calculate Parameters
total_params = sum(p.numel() for p in model.parameters())
m_params = total_params / 1e6
logging.info(f"✅ Analysis successful for '{architecture}': {gflops:.2f} GFLOPs, {m_params:.2f} M Params")
return {
'flops': f"{gflops:.2f} G",
'params': f"{m_params:.2f} M"
}
except Exception as e:
logging.error(f"❌ Failed to analyze FLOPs for '{architecture}'. Error: {e}")
return None
def get_shape_from_path(path: str) -> Optional[Tuple[int, int]]:
"""
Extracts resolution (WxH) from a directory path using regex.
Args:
path (str): The path to the dataset directory.
Returns:
Optional[Tuple[int, int]]: A tuple of (Height, Width), or None if not found.
"""
# V1. 采用图片本身的尺寸作为输入
# import re
# match = re.search(r'(\d+)x(\d+)', os.path.basename(path))
# if match:
# width, height = int(match.group(1)), int(match.group(2))
# return (height, width) # Return as H, W
# return None
# V2. 采用文件夹config中的尺寸作为输入
width, height = config.IMAGE_WIDTH, config.IMAGE_HEIGHT
if width > 0 and height > 0:
return (height, width) # Return as H, W
else:
return None
# --- 主函数 ---
def main(args):
"""
Main script entry point for the automated analysis workflow.
"""
input_root = args.input_dir
if not os.path.isdir(input_root):
logging.error(f"Input directory does not exist: {input_root}")
return
# 1. Find all valid dataset directories (ending with '-SegModel')
existing_dataset_dirs = sorted([
d for d in glob.glob(os.path.join(input_root, '*_outputs-SegModel')) if os.path.isdir(d)
])
if not existing_dataset_dirs:
logging.error(f"No valid dataset folders (e.g., '*_outputs-SegModel') found in {input_root}.")
return
# 2. First-level menu: Select dataset
dataset_map = {str(i + 1): path for i, path in enumerate(existing_dataset_dirs)}
print("\n" + "="*50)
print("--- Step 1: Please select the dataset to process ---")
for key, path in dataset_map.items():
print(f"{key}: {os.path.basename(path)}")
print("="*50)
choice1 = input("Enter the dataset number and press Enter: ").strip()
model_archs_to_process = []
selected_dataset_dir = None
if choice1 in dataset_map:
selected_dataset_dir = dataset_map[choice1]
logging.info(f"You have selected the dataset: [{os.path.basename(selected_dataset_dir)}]")
# 3. Find all algorithm subdirectories and extract base architecture names
alg_dirs_full_path = sorted([d for d in glob.glob(os.path.join(selected_dataset_dir, '*')) if os.path.isdir(d)])
# CORRECTED LOGIC: Create a list of valid base architecture names found in the directory
valid_alg_archs = []
for path in alg_dirs_full_path:
# Extract base name, e.g., 'Unet' from 'Unet_2025-09-24_14-24-05'
base_name = os.path.basename(path).split('_')[0]
if base_name in config.ALL_MODEL_CONFIGS and base_name not in valid_alg_archs:
valid_alg_archs.append(base_name)
if not valid_alg_archs:
logging.warning(f"No valid algorithm subfolders found in {os.path.basename(selected_dataset_dir)}. Please check folder names match keys in config.ALL_MODEL_CONFIGS.")
else:
# 4. Second-level menu: Select algorithm using the base names
alg_map = {str(i + 1): name for i, name in enumerate(valid_alg_archs)}
print("\n" + "="*50)
print("--- Step 2: Please select the algorithm to process ---")
print("0: Process [ALL] algorithms under the current dataset")
for key, name in alg_map.items():
print(f"{key}: {name}")
print("="*50)
choice2 = input("Enter the algorithm number (or '0' for all) and press Enter: ").strip()
# 5. Determine the final list of architectures to process
if choice2 == '0':
model_archs_to_process = valid_alg_archs
logging.info(f"You have chosen to process all {len(model_archs_to_process)} algorithms.")
elif choice2 in alg_map:
model_archs_to_process = [alg_map[choice2]]
logging.info(f"You have chosen to process a single algorithm: {model_archs_to_process[0]}")
else:
logging.error("Invalid algorithm selection. Exiting.")
else:
logging.error("Invalid dataset selection. Exiting.")
if not model_archs_to_process or not selected_dataset_dir:
return
# --- Processing logic ---
results: List[Dict[str, str]] = []
# Determine input shape for analysis
input_shape = None
if args.shape:
# 1. Prioritize user-provided shape from command line
input_shape = (args.shape[0], args.shape[1])
logging.info(f"Using user-provided shape (H, W): {input_shape} for calculation.")
else:
# 2. Fallback to detecting shape from folder name
logging.info("No --shape argument provided. Attempting to detect from folder name...")
input_shape = get_shape_from_path(selected_dataset_dir)
if not input_shape:
# 3. Fallback to config.py defaults if detection fails
logging.warning(f"Could not automatically detect resolution from folder '{os.path.basename(selected_dataset_dir)}'.")
logging.info(f"Using default shape from config.py: ({config.IMAGE_HEIGHT}, {config.IMAGE_WIDTH})")
input_shape = (config.IMAGE_HEIGHT, config.IMAGE_WIDTH)
else:
logging.info(f"Detected input shape (H, W): {input_shape} from folder name.")
for architecture in model_archs_to_process:
logging.info(f"--- Processing model: {architecture} ---")
# ==================== SOLUTION CODE BLOCK START ====================
# For Transformer models like DPT, ensure input shape is divisible by the patch size.
analysis_shape = input_shape
if architecture == 'DPT':
patch_size = 16 # From the encoder name '...patch16...'
h, w = analysis_shape
if h % patch_size != 0 or w % patch_size != 0:
# Calculate the new, compatible dimensions by rounding up to the nearest multiple of patch_size
new_h = ((h + patch_size - 1) // patch_size) * patch_size
new_w = ((w + patch_size - 1) // patch_size) * patch_size
analysis_shape = (new_h, new_w)
logging.warning(f"DPT requires dimensions divisible by {patch_size}. "
f"Adjusting analysis shape from {input_shape} to {analysis_shape}.")
# ===================== SOLUTION CODE BLOCK END =====================
stats = get_flops_and_params_smp(architecture, analysis_shape)
if stats:
results.append({
'Model': architecture,
'Params': stats['params'],
'FLOPs': stats['flops'],
'Input_Shape (HxW)': f"{analysis_shape[0]}x{analysis_shape[1]}"
})
else:
logging.warning(f"Failed to get stats for model {architecture}.")
# --- Write results to CSV file ---
if not results:
logging.info("No statistics were successfully generated, CSV file will not be created.")
return
dataset_name = os.path.basename(selected_dataset_dir).replace('_outputs-SegModel', '')
output_csv_path = os.path.join(selected_dataset_dir, f'{dataset_name}_flops_params_summary.csv')
try:
with open(output_csv_path, 'w', newline='', encoding='utf-8') as csvfile:
fieldnames = ['Model', 'Params', 'FLOPs', 'Input_Shape (HxW)']
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
writer.writeheader()
writer.writerows(results)
logging.info(f"=== All processing complete! Results saved to: {output_csv_path} ===")
except IOError as e:
logging.error(f"Could not write to CSV file: {output_csv_path}. Error: {e}")
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="SMP Model Automation and Evaluation Script")
parser.add_argument(
'--input_dir',
type=str,
default=str(config.PREDICT_ALL_BEST_MODELS_DIR),
help="Root directory containing trained model folders (e.g., 'DATASET_outputs-SegModel')."
)
parser.add_argument(
'--shape',
type=int,
nargs=2,
metavar=('HEIGHT', 'WIDTH'),
help="Specify input shape (height width) for analysis, overriding automatic detection."
)
args = parser.parse_args()
main(args)

View File

@@ -0,0 +1,325 @@
import os
import glob
import logging
import argparse
import csv
from typing import Dict, Optional, Tuple, List
import time
import numpy as np
from pathlib import Path
import sys
import torch
from fvcore.nn import FlopCountAnalysis
import segmentation_models_pytorch as smp
import config # Assumes config.py is in the same directory or accessible
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
# TODO 使用方法 TODO
# 交互模式(推荐)
# python 2_predict_params_and_FLOPs.py
# 非交互模式(用于脚本)
# python 2_predict_params_and_FLOPs_V2.py --input_dir '../BestMode_Predict_Results_DataSet_Public'
# python 2_predict_params_and_FLOPs_V2.py --shape 512 512 # --shape 1080 1920
def load_model(model_architecture: str, checkpoint_path: Optional[Path]) -> Optional[torch.nn.Module]:
"""
根据 config.py 直接重建模型架构,并从 checkpoint 路径加载权重。
此函数不依赖 train.py。
"""
# 1. 根据 config 动态创建基础模型架构
logging.info(f"Rebuilding model architecture: '{model_architecture}'")
try:
model_params = config.ALL_MODEL_CONFIGS[model_architecture].copy()
model_class = getattr(smp, model_architecture)
except (AttributeError, KeyError):
logging.error(f"Model '{model_architecture}' not found in smp library or config.py!")
return None
# 2. 准备初始化参数
num_classes = len(config.CLASSES)
params = model_params
params['classes'] = num_classes
# 加载权重时,不应再次下载预训练权重
params['encoder_weights'] = None
# 自动检测 ViT 编码器以设置特殊参数
encoder_name = model_params.get('encoder_name', '')
if 'vit' in encoder_name.lower():
params['dynamic_img_size'] = True
logging.info(f"ViT encoder ('{encoder_name}') detected. Setting dynamic_img_size=True.")
# 3. 创建模型实例
model = model_class(**params).to(config.DEVICE)
# 4. 如果提供了有效的 checkpoint 路径,则加载权重
if checkpoint_path and checkpoint_path.exists():
logging.info(f"Loading checkpoint: {checkpoint_path}")
try:
state_dict = torch.load(checkpoint_path, map_location=torch.device(config.DEVICE))
# 处理 DataParallel 保存的权重
is_dataparallel = any(key.startswith('module.') for key in state_dict.keys())
if is_dataparallel:
logging.info("DataParallel weights detected, removing 'module.' prefix.")
new_state_dict = {key.replace('module.', ''): value for key, value in state_dict.items()}
model.load_state_dict(new_state_dict)
else:
model.load_state_dict(state_dict)
logging.info("Checkpoint loaded successfully.")
except Exception as e:
logging.error(f"Failed to load checkpoint. Using model with random weights. Error: {e}")
else:
logging.info("No valid checkpoint provided. Using model with random weights.")
model.eval()
return model
# --- 核心功能函数 ---
def find_latest_model_run_path(base_dir: Path, model_architecture: str) -> Optional[Path]:
"""
非交互式地查找并返回给定架构的最新模型运行目录。
"""
logging.info(f"Searching for latest '{model_architecture}' model run in '{base_dir}'...")
# 查找所有以架构名开头的文件夹
run_dirs = sorted([d for d in base_dir.iterdir() if d.is_dir() and d.name.startswith(model_architecture + '_')])
if not run_dirs:
logging.info(f"No trained model folder found for '{model_architecture}'.")
return None
# 排序后的最后一个即为最新的
latest_run_dir = run_dirs[-1]
logging.info(f"Automatically selected latest model run: {latest_run_dir.name}")
return latest_run_dir
def calculate_flops_params(model: torch.nn.Module, shape: Tuple[int, int]) -> Optional[Dict[str, str]]:
"""
为已初始化的模型计算FLOPs和参数量。
"""
logging.info(f"Analyzing model Params/FLOPs with input shape {shape}...")
model.to(config.DEVICE)
model.eval()
dummy_input = torch.randn(1, 3, shape[0], shape[1]).to(config.DEVICE)
try:
flops_analyzer = FlopCountAnalysis(model, dummy_input)
gflops = flops_analyzer.total() / 1e9
m_params = sum(p.numel() for p in model.parameters()) / 1e6
logging.info(f"✅ Params/FLOPs Analysis successful: {gflops:.2f} GFLOPs, {m_params:.2f} M Params")
return {'flops': f"{gflops:.2f} G", 'params': f"{m_params:.2f} M"}
except Exception as e:
logging.error(f"❌ Failed to analyze FLOPs/Params. Error: {e}")
return None
def benchmark_model(model: torch.nn.Module, shape: Tuple[int, int]) -> Optional[Dict[str, float]]:
"""
对加载了权重的模型进行FPS基准测试。
"""
logging.info("Starting FPS benchmark...")
model.to(config.DEVICE)
model.eval()
dummy_input = torch.randn(1, 3, shape[0], shape[1]).to(config.DEVICE)
overall_fps_list = []
repeat_times = 3
with torch.no_grad():
for run_index in range(repeat_times):
num_warmup = 5
pure_inf_time = 0
total_iters = 100
for i in range(total_iters + num_warmup):
if torch.cuda.is_available():
torch.cuda.synchronize()
start_time = time.perf_counter()
_ = model(dummy_input)
if torch.cuda.is_available():
torch.cuda.synchronize()
elapsed = time.perf_counter() - start_time
if i >= num_warmup:
pure_inf_time += elapsed
if pure_inf_time == 0:
pure_inf_time = 1e-6
overall_fps = total_iters / pure_inf_time
overall_fps_list.append(overall_fps)
print(f" Run {run_index + 1}/{repeat_times}, FPS: {overall_fps:.2f} img/s")
if not overall_fps_list:
return None
avg_fps = np.mean(overall_fps_list)
fps_variance = np.var(overall_fps_list)
logging.info(f"✅ FPS Benchmark successful: Average FPS: {avg_fps:.2f} img/s")
return {'avg_fps': avg_fps, 'fps_variance': fps_variance}
def get_shape_from_path(path: str) -> Optional[Tuple[int, int]]:
"""
从文件夹路径中提取分辨率。
"""
# V1. 采用图片本身的尺寸作为输入
import re
match = re.search(r'(\d+)x(\d+)', os.path.basename(path))
if match:
width, height = int(match.group(1)), int(match.group(2))
return (height, width) # Return as H, W
return None
# V2. 采用文件夹config中的尺寸作为输入
# width, height = config.IMAGE_WIDTH, config.IMAGE_HEIGHT
if width > 0 and height > 0:
return (height, width)
else:
return None
# --- 主函数 ---
def main(args):
"""
脚本主入口。
"""
input_root = args.input_dir
if not os.path.isdir(input_root):
logging.error(f"Input directory does not exist: {input_root}")
return
# 1. 交互式选择数据集和算法
existing_dataset_dirs = sorted([d for d in glob.glob(os.path.join(input_root, '*_outputs-SegModel')) if os.path.isdir(d)])
if not existing_dataset_dirs:
logging.error(f"No valid dataset folders found in {input_root}.")
return
dataset_map = {str(i + 1): path for i, path in enumerate(existing_dataset_dirs)}
print("\n" + "="*50)
print("--- Step 1: Please select the dataset to process ---")
for key, path in dataset_map.items():
print(f"{key}: {os.path.basename(path)}")
print("="*50)
choice1 = input("Enter the dataset number and press Enter: ").strip()
model_archs_to_process = []
selected_dataset_dir = None
if choice1 in dataset_map:
selected_dataset_dir = dataset_map[choice1]
logging.info(f"You have selected the dataset: [{os.path.basename(selected_dataset_dir)}]")
alg_dirs_full_path = sorted([d for d in glob.glob(os.path.join(selected_dataset_dir, '*')) if os.path.isdir(d)])
valid_alg_archs = []
for path in alg_dirs_full_path:
base_name = os.path.basename(path).split('_')[0]
if base_name in config.ALL_MODEL_CONFIGS and base_name not in valid_alg_archs:
valid_alg_archs.append(base_name)
if valid_alg_archs:
alg_map = {str(i + 1): name for i, name in enumerate(valid_alg_archs)}
print("\n" + "="*50)
print("--- Step 2: Please select the algorithm to process ---")
print("0: Process [ALL] algorithms under the current dataset")
for key, name in alg_map.items():
print(f"{key}: {name}")
print("="*50)
choice2 = input("Enter the algorithm number (or '0' for all) and press Enter: ").strip()
if choice2 == '0':
model_archs_to_process = valid_alg_archs
elif choice2 in alg_map:
model_archs_to_process = [alg_map[choice2]]
else:
logging.warning(f"No valid algorithm subfolders found in {os.path.basename(selected_dataset_dir)}.")
else:
logging.error("Invalid dataset selection. Exiting.")
return
if not model_archs_to_process or not selected_dataset_dir:
return
# 2. 确定输入尺寸
if args.shape:
analysis_shape = (args.shape[0], args.shape[1])
logging.info(f"Using user-provided shape (H, W): {analysis_shape} for calculation.")
else:
analysis_shape = get_shape_from_path(selected_dataset_dir) or (config.IMAGE_HEIGHT, config.IMAGE_WIDTH)
logging.info(f"Using detected/default shape (H, W): {analysis_shape} for calculation.")
# 3. 主处理循环
ori_analysis_shape = analysis_shape
results: List[Dict[str, str]] = []
for architecture in model_archs_to_process:
logging.info(f"--- Processing model: {architecture} ---")
# 特殊处理 DPT 模型的输入尺寸要求
analysis_shape = ori_analysis_shape
patch_size = 16 # From the encoder name '...patch16...'
h, w = analysis_shape
if h % patch_size != 0 or w % patch_size != 0:
# Calculate the new, compatible dimensions by rounding up to the nearest multiple of patch_size
new_h = ((h + patch_size - 1) // patch_size) * patch_size
new_w = ((w + patch_size - 1) // patch_size) * patch_size
analysis_shape = (new_h, new_w)
logging.warning(f"DPT requires dimensions divisible by {patch_size}. "
f"Adjusting analysis shape from {ori_analysis_shape} to {analysis_shape}.")
# 步骤 3.1: 查找最新的模型运行路径和权重文件路径
latest_run_dir = find_latest_model_run_path(Path(selected_dataset_dir), architecture)
checkpoint_path = None
if latest_run_dir:
# 权重文件路径,即用户期望的 best_model_save_path
checkpoint_path = latest_run_dir / config.BEST_MODEL_SAVE_NAME
# 步骤 3.2: 使用新函数加载模型(一行代码完成初始化和权重加载)
model = load_model(architecture, checkpoint_path)
if model is None:
logging.warning(f"Skipping all analysis for {architecture} due to model initialization failure.")
continue
# 步骤 3.3: 在同一个模型实例上执行所有计算
stats_flops_params = calculate_flops_params(model, analysis_shape)
stats_fps = benchmark_model(model, analysis_shape)
# 3.4 整合结果
results.append({
'Model': architecture,
'Params': stats_flops_params['params'] if stats_flops_params else "N/A",
'FLOPs': stats_flops_params['flops'] if stats_flops_params else "N/A",
'Input_Shape (HxW)': f"{ori_analysis_shape[0]}x{ori_analysis_shape[1]}",
'Average_FPS': f"{stats_fps['avg_fps']:.2f}" if stats_fps else "N/A",
'FPS_Variance': f"{stats_fps['fps_variance']:.4f}" if stats_fps else "N/A"
})
# 4. 写入CSV文件
if results:
dataset_name = os.path.basename(selected_dataset_dir).replace('_outputs-SegModel', '')
output_csv_path = os.path.join(selected_dataset_dir, f'{dataset_name}_flops_params_fps_summary.csv')
try:
with open(output_csv_path, 'w', newline='', encoding='utf-8') as csvfile:
fieldnames = ['Model', 'Params', 'FLOPs', 'Input_Shape (HxW)', 'Average_FPS', 'FPS_Variance']
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
writer.writeheader()
writer.writerows(results)
logging.info(f"=== All processing complete! Results saved to: {output_csv_path} ===")
except IOError as e:
logging.error(f"Could not write to CSV file: {output_csv_path}. Error: {e}")
else:
logging.info("No statistics were successfully generated, CSV file will not be created.")
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="SMP Model Automation and Evaluation Script")
parser.add_argument(
'--input_dir',
type=str,
default=str(config.PREDICT_ALL_BEST_MODELS_DIR),
help="Root directory containing trained model folders."
)
parser.add_argument(
'--shape',
type=int,
nargs=2,
metavar=('HEIGHT', 'WIDTH'),
help="Specify input shape (height width), overriding automatic detection."
)
args = parser.parse_args()
main(args)

View File

@@ -0,0 +1,271 @@
import os
import glob
import logging
import argparse
import re
import csv
import numpy as np
from typing import Dict, Optional, List
# --- Configure Logging ---
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
# --- Helper Functions (No changes in this section) ---
def find_smp_csv_file(algorithm_dir: str) -> Optional[str]:
"""
Finds the 'training_metrics.csv' file in a given algorithm directory.
"""
csv_path = os.path.join(algorithm_dir, 'training_metrics.csv')
if os.path.exists(csv_path):
return csv_path
else:
logging.warning(f"Could not find 'training_metrics.csv' in {algorithm_dir}")
return None
def parse_smp_metrics(csv_path: str) -> Optional[Dict]:
"""
Parses 'training_metrics.csv' to find the best epoch based on the highest
IoU score and extracts all relevant metrics.
"""
try:
with open(csv_path, 'r', encoding='utf-8') as f:
reader = csv.DictReader(f)
rows = list(reader)
except (IOError, StopIteration) as e:
logging.error(f"Cannot read or parse CSV file: {csv_path}. Error: {e}")
return None
if not rows:
logging.warning(f"❌ CSV file is empty: {os.path.basename(csv_path)}")
return None
try:
best_row = max(rows, key=lambda row: float(row.get('iou_score', 0)))
best_iou = float(best_row.get('iou_score', 0))
if best_iou == 0:
logging.warning(f"❌ Best 'iou_score' is 0 in {os.path.basename(csv_path)}. Check training.")
return None
except (ValueError, TypeError) as e:
logging.error(f"Could not find a valid 'iou_score' in {csv_path}. Error: {e}")
return None
results = {
'epoch': best_row.get('epoch', 'N/A'),
'summary': {
'mIoU': best_row.get('iou_score', 'N/A'),
'mAcc': 'N/A',
'aAcc': best_row.get('accuracy', 'N/A')
},
'class_wise': []
}
per_class_accuracies = []
num_classes = 0
class_keys = [key for key in best_row if key.startswith('tp_class_')]
if class_keys:
num_classes = max([int(k.split('_')[-1]) for k in class_keys]) + 1
for i in range(num_classes):
try:
tp = float(best_row.get(f'tp_class_{i}', 0))
fp = float(best_row.get(f'fp_class_{i}', 0))
fn = float(best_row.get(f'fn_class_{i}', 0))
class_iou = tp / (tp + fp + fn + 1e-6)
class_acc_str = best_row.get(f'cpa_class_{i}', '0')
per_class_accuracies.append(float(class_acc_str))
class_name = '背景' if str(i) == '0' else str(i)
results['class_wise'].append({
'Class': class_name,
'IoU': f"{class_iou:.4f}",
'Acc': class_acc_str
})
except (ValueError, TypeError) as e:
logging.warning(f"Could not process metrics for class {i}. Error: {e}")
continue
if per_class_accuracies:
results['summary']['mAcc'] = f"{np.mean(per_class_accuracies):.4f}"
if results['class_wise']:
logging.info(f"✅ Successfully parsed metrics for epoch '{results['epoch']}' from {os.path.basename(csv_path)}.")
return results
else:
logging.warning(f"❌ Could not parse any per-class metrics from {os.path.basename(csv_path)}.")
return None
# --- Main Function ---
def main(args):
"""
Main script entry point to orchestrate the analysis workflow.
"""
input_root = args.input_dir
output_root = args.output_dir
if not os.path.isdir(input_root):
logging.error(f"Input directory does not exist: {input_root}")
return
# --- Interactive Menu ---
all_dataset_dirs = sorted([
d for d in glob.glob(os.path.join(input_root, '*_outputs-SegModel')) if os.path.isdir(d)
])
if not all_dataset_dirs:
logging.error(f"No valid dataset folders ending in '_outputs-SegModel' found in {input_root}.")
return
dataset_map = {str(i + 1): path for i, path in enumerate(all_dataset_dirs)}
print("\n" + "="*50)
print("--- Step 1: Please select a dataset to process ---")
for key, path in dataset_map.items():
print(f"{key}: {os.path.basename(path)}")
print("="*50)
choice1 = input("Enter the dataset number and press Enter: ").strip()
model_dirs = []
selected_dataset_dir = None
if choice1 in dataset_map:
selected_dataset_dir = dataset_map[choice1]
logging.info(f"You have selected the dataset: [{os.path.basename(selected_dataset_dir)}]")
alg_dirs = sorted([
d for d in glob.glob(os.path.join(selected_dataset_dir, '*')) if os.path.isdir(d)
])
if not alg_dirs:
logging.warning(f"No algorithm subfolders found in {os.path.basename(selected_dataset_dir)}.")
return
alg_map = {str(i + 1): path for i, path in enumerate(alg_dirs)}
print("\n" + "="*50)
print("--- Step 2: Please select the algorithm(s) to process ---")
print("0: Process [ALL] algorithms under the current dataset")
for key, path in alg_map.items():
print(f"{key}: {os.path.basename(path)}")
print("="*50)
choice2 = input("Enter the algorithm number (or '0' for all) and press Enter: ").strip()
if choice2 == '0':
model_dirs = alg_dirs
logging.info(f"You have chosen to batch process all {len(model_dirs)} algorithms.")
elif choice2 in alg_map:
model_dirs = [alg_map[choice2]]
logging.info(f"You have chosen to process a single algorithm: {os.path.basename(model_dirs[0])}")
else:
logging.error("Invalid algorithm selection. The program will exit.")
return
else:
logging.error("Invalid dataset selection. The program will exit.")
return
# --- Start processing the selected algorithms ---
csv_rows = []
for model_dir in model_dirs:
model_name_full = os.path.basename(model_dir)
logging.info(f"\n--- Processing algorithm: {model_name_full} ---")
csv_file_path = find_smp_csv_file(model_dir)
if not csv_file_path:
logging.warning(f"Skipping {model_name_full} as no 'training_metrics.csv' was found.")
continue
metrics = parse_smp_metrics(csv_file_path)
if not metrics:
logging.warning(f"❌❌❌ Skipping {model_name_full} due to a failure in parsing metrics. ❌❌❌")
continue
summary = metrics['summary']
short_model_name = model_name_full.split('_')[0]
row_data = {
'Algorithm': short_model_name,
'Epoch': metrics['epoch'],
'mIoU': f"{round(float(summary['mIoU']), 4)*100:.2f}",
'mAcc': f"{round(float(summary['mAcc']), 4)*100:.2f}",
'aAcc': f"{round(float(summary['aAcc']), 4)*100:.2f}"
}
for class_data in metrics['class_wise']:
class_name = class_data['Class']
row_data[f'{class_name}_IoU'] = f"{round(float(class_data['IoU']), 4)*100:.2f}"
row_data[f'{class_name}_Acc'] = f"{round(float(class_data['Acc']), 4)*100:.2f}"
csv_rows.append(row_data)
# --- Write results to the final CSV file ---
if not csv_rows:
logging.info("No data was successfully collected. No CSV file will be generated.")
return
# --- NEW: DYNAMICALLY GENERATE FIELDNAMES ---
base_fieldnames = ['Algorithm', 'Epoch', 'mIoU', 'mAcc', 'aAcc']
# Collect all unique per-class fieldnames from the data
extra_fieldnames = set()
for row in csv_rows:
for key in row.keys():
if key not in base_fieldnames:
extra_fieldnames.add(key)
# Define a sort key for natural sorting (e.g., '1', '2', '10' instead of '1', '10', '2')
def natural_sort_key(key_name):
parts = key_name.split('_')
class_part = parts[0]
metric_part = parts[1] if len(parts) > 1 else ''
if class_part == '背景':
return (-1, metric_part) # Put '背景' (background) first
try:
return (int(class_part), metric_part)
except ValueError:
return (float('inf'), key_name) # Put any non-numeric names last
# Combine base fields with sorted extra fields
final_fieldnames = base_fieldnames + sorted(list(extra_fieldnames), key=natural_sort_key)
# Build the output path
output_dataset_folder = os.path.basename(selected_dataset_dir)
final_output_dir = os.path.join(output_root, output_dataset_folder)
os.makedirs(final_output_dir, exist_ok=True)
dataset_name = output_dataset_folder.replace('_outputs-SegModel', '')
output_csv_path = os.path.join(final_output_dir, f'{dataset_name}_metrics_summary_wide.csv')
try:
with open(output_csv_path, 'w', newline='', encoding='utf-8-sig') as csvfile:
writer = csv.DictWriter(csvfile, fieldnames=final_fieldnames)
writer.writeheader()
writer.writerows(csv_rows)
logging.info(f"\n=== All processing is complete! Results have been saved to: {output_csv_path} ===")
except IOError as e:
logging.error(f"Failed to write to CSV file: {output_csv_path}. Error: {e}")
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="SMP (Segmentation Models Pytorch) Final Metrics Extraction Script")
parser.add_argument(
'--input_dir',
type=str,
default='../Hardisk',
help="The root directory containing dataset output folders (e.g., '..._outputs-SegModel')."
)
parser.add_argument(
'--output_dir',
type=str,
default='../BestMode_Predict_Results_DataSet_Public',
help="The root directory where analysis results will be stored."
)
args = parser.parse_args()
main(args)

View File

@@ -0,0 +1,120 @@
#!/bin/bash
# +----------------------------------------------------------------------------+
# | 脚本: 将 best_model.pth 从 Hardisk 复制到 BestMode_Predict_Results_DataSet_Public |
# | 说明: 此版本优化了复制逻辑,仅在文件不存在或内容不一致时才复制。 |
# | 同时包含一个可视化的进度条。 (已修正语法错误) |
# | 用法: ./copy_models.sh |
# +----------------------------------------------------------------------------+
# 设置源目录和目标目录的基础路径
# realpath确保我们获得的是绝对路径避免相对路径可能带来的问题
SOURCE_BASE_DIR=$(realpath "../Hardisk")
DEST_BASE_DIR=$(realpath "../BestMode_Predict_Results_DataSet_Public")
# 检查源目录是否存在
if [ ! -d "$SOURCE_BASE_DIR" ]; then
echo "错误: 源目录 '$SOURCE_BASE_DIR' 不存在。"
exit 1
fi
# 确保目标基础目录存在
mkdir -p "$DEST_BASE_DIR"
echo "正在准备操作..."
# --- 第1步: 查找所有文件并存储到数组中 ---
# 使用 find -print0 和 readarray -d '' 的组合,可以安全处理任何包含特殊字符的文件名
# 这样只需查找一次,提高了效率和健壮性
echo "正在查找所有 'best_model.pth' 文件..."
readarray -d '' files_to_process < <(find "$SOURCE_BASE_DIR" -path "*-SegModel/*" -name "best_model.pth" -type f -print0 2>/dev/null)
TOTAL_FILES=${#files_to_process[@]}
if [ "$TOTAL_FILES" -eq 0 ]; then
echo "在源目录中没有找到任何 'best_model.pth' 文件。脚本退出。"
exit 0
fi
echo "总共找到 $TOTAL_FILES 个模型文件需要处理。"
echo "--------------------------------------------------"
# --- 进度条函数 ---
# 参数1: 当前文件数
# 参数2: 文件总数
print_progress() {
local current=$1
local total=$2
# 减去更多的空间以容纳状态信息
local term_width=$(tput cols)
local bar_width=$((term_width - 30))
# 确保 bar_width 不为负
if [ "$bar_width" -lt 10 ]; then
bar_width=10
fi
# 计算百分比
local percent=$((current * 100 / total))
# 计算已完成部分的长度
local filled_len=$((bar_width * percent / 100))
# 准备进度条的显示字符
local bar=""
for ((i=0; i<filled_len; i++)); do bar+="="; done
for ((i=filled_len; i<bar_width; i++)); do bar+=" "; done
# 使用 \r (回车符) 让光标回到行首,实现单行刷新效果
printf "\r处理进度: %3d%% [%s] %d/%d" $percent "$bar" $current $total
}
# --- 第2步: 执行处理并显示进度条 ---
CURRENT_FILE=0
COPIED_COUNT=0
SKIPPED_COUNT=0
# 遍历之前查找到的文件列表
for source_file_path in "${files_to_process[@]}"; do
# 【安全检查】如果数组中因意外情况包含空元素,则跳过此次循环
if [[ -z "$source_file_path" ]]; then
continue
fi
# 增加当前文件计数器
((CURRENT_FILE++))
# 【修正部分】构建目标文件路径,请确保这一区域的代码和下面完全一致
# 从完整源路径中移除基础目录部分,得到相对路径
relative_path="${source_file_path#"$SOURCE_BASE_DIR"}"
# 组合目标基础目录和相对路径,得到完整目标路径
dest_file_path="${DEST_BASE_DIR}${relative_path}"
dest_dir_path=$(dirname "$dest_file_path")
# 优化点: 检查目标文件是否存在以及内容是否一致
# 1. [ ! -f "$dest_file_path" ]: 目标文件不存在,则需要复制
# 2. ! cmp -s "$source_file_path" "$dest_file_path": 目标文件存在,但内容不一致 (cmp -s 在文件相同时返回0, 不同时返回非0), 也需要复制
if [ ! -f "$dest_file_path" ] || ! cmp -s "$source_file_path" "$dest_file_path"; then
# 创建目标目录 (如果不存在)
mkdir -p "$dest_dir_path"
# 执行复制操作
cp "$source_file_path" "$dest_file_path"
((COPIED_COUNT++))
else
# 文件已存在且内容一致,跳过
((SKIPPED_COUNT++))
fi
# 调用函数,更新屏幕上的进度条
print_progress $CURRENT_FILE $TOTAL_FILES
done
# 循环结束后,打印一个换行符,使光标移到下一行
echo ""
echo "--------------------------------------------------"
echo "所有操作已完成!"
echo " - 总共处理文件: $TOTAL_FILES"
echo " - 复制或更新的文件: $COPIED_COUNT"
echo " - 已存在且内容一致而跳过的文件: $SKIPPED_COUNT"

View File

@@ -0,0 +1,189 @@
# benchmark_smp.py
import torch
import time
import argparse
import numpy as np
from pathlib import Path
from mmengine.model.utils import revert_sync_batchnorm
# 假设 train.py 在同一目录下,以便导入其函数
# 如果不在,请确保路径正确
from train import initialize_components
import config
# TODO 使用方法 TODO
# 交互模式(推荐)
# python Tool_benchmark_smp.py -a Unet --shape 512 512
# 非交互模式(用于脚本)
# python Tool_benchmark_smp.py -a Unet -c "outputs/Unet_2025-10-13_19-30-00/best_model.pth" --shape 512 512
# --- Function copied from predict.py ---
def select_trained_model_path(base_dir: Path, model_architecture: str) -> Path:
"""
查找指定架构的所有训练运行目录,并让用户选择一个。
"""
print(f"\nSearching for '{model_architecture}' model runs in '{base_dir}'...")
# 查找所有以架构名开头的文件夹
run_dirs = sorted([d for d in base_dir.iterdir() if d.is_dir() and d.name.startswith(model_architecture+'_')])
if not run_dirs:
print(f"ERROR: No trained models found for architecture '{model_architecture}'. Please run train_smp.py first.")
return None
print("\nPlease select a trained model to benchmark:")
for i, dir_path in enumerate(run_dirs):
print(f" [{i+1}] {dir_path.name}")
while True:
try:
choice = input(f"Enter selection (1-{len(run_dirs)}) or press Enter to cancel: ")
if not choice:
return None
choice_idx = int(choice) - 1
if 0 <= choice_idx < len(run_dirs):
selected_dir = run_dirs[choice_idx]
print(f"Selected model: {selected_dir.name}")
return selected_dir
else:
print("Invalid selection. Please try again.")
except (ValueError, IndexError):
print("Invalid input. Please enter a number.")
def parse_args():
parser = argparse.ArgumentParser(description='Benchmark a segmentation model from SMP')
parser.add_argument(
'-a', '--architecture',
type=str,
required=True,
choices=list(config.ALL_MODEL_CONFIGS.keys()),
help="The model architecture to benchmark."
)
parser.add_argument(
'-c', '--checkpoint',
type=str,
required=False, # <-- Changed to False
help='(Optional) Path to the checkpoint file. If omitted, an interactive selection will be shown.'
)
parser.add_argument(
'--shape',
type=int,
nargs='+',
default=[512, 512],
help='Input image size for benchmarking, e.g., --shape 512 512'
)
parser.add_argument(
'--log-interval',
type=int,
default=50,
help='Interval of logging.'
)
parser.add_argument(
'--repeat-times',
type=int,
default=3,
help='Number of times to repeat the benchmark for averaging.'
)
args = parser.parse_args()
return args
def main():
args = parse_args()
# --- New logic to select checkpoint ---
checkpoint_path = None
if args.checkpoint:
checkpoint_path = Path(args.checkpoint)
else:
# Assumes your outputs are saved in the directory specified in config.py
# This is consistent with train_smp.py and predict.py
selected_run_dir = select_trained_model_path(Path(config.PREDICT_BEST_MODEL_DIR), args.architecture)
if not selected_run_dir:
print("No model selected. Exiting.")
sys.exit()
checkpoint_path = selected_run_dir / config.BEST_MODEL_SAVE_NAME
if not checkpoint_path.exists():
print(f"ERROR: Checkpoint file not found at '{checkpoint_path}'")
sys.exit()
# --- End of new logic ---
if len(args.shape) == 1:
h, w = args.shape[0], args.shape[0]
elif len(args.shape) == 2:
h, w = args.shape[0], args.shape[1]
else:
raise ValueError('Invalid input shape. Use one or two integers.')
print(f"\n--- Model Benchmarking ---")
print(f"Architecture: {args.architecture}")
print(f"Checkpoint: {checkpoint_path}")
print(f"Input Shape: (1, 3, {h}, {w})")
print(f"Device: {config.DEVICE}")
print("-" * 28)
# 1. 初始化模型
num_classes = len(config.CLASSES)
model, _, _, _ = initialize_components(args.architecture, num_classes, config.SEG_MODE)
# 2. 加载权重
print("Loading checkpoint...")
state_dict = torch.load(checkpoint_path, map_location=config.DEVICE)
# Handle DataParallel saved models by removing 'module.' prefix
is_dataparallel = any(key.startswith('module.') for key in state_dict.keys())
if is_dataparallel:
print("DataParallel model detected, removing 'module.' prefix.")
new_state_dict = {key.replace('module.', ''): value for key, value in state_dict.items()}
model.load_state_dict(new_state_dict)
else:
model.load_state_dict(state_dict)
model.to(config.DEVICE)
model.eval()
# 3. 准备伪造输入数据
dummy_input = torch.randn(1, 3, h, w).to(config.DEVICE)
# 4. 开始基准测试
overall_fps_list = []
for run_index in range(args.repeat_times):
print(f"\n>>> Running benchmark iteration {run_index + 1}/{args.repeat_times}...")
num_warmup = 5
pure_inf_time = 0
total_iters = 200
with torch.no_grad():
for i in range(total_iters + num_warmup):
if torch.cuda.is_available():
torch.cuda.synchronize()
start_time = time.perf_counter()
_ = model(dummy_input)
if torch.cuda.is_available():
torch.cuda.synchronize()
elapsed = time.perf_counter() - start_time
if i >= num_warmup:
pure_inf_time += elapsed
if (i + 1) % args.log_interval == 0:
fps = (i + 1 - num_warmup) / pure_inf_time
print(f'Done image [{i + 1 - num_warmup:<4}/ {total_iters}], '
f'current FPS: {fps:.2f} img/s')
overall_fps = total_iters / pure_inf_time
overall_fps_list.append(overall_fps)
print(f'Overall FPS for this run: {overall_fps:.2f} img/s')
# 5. 总结结果
print("\n--- Benchmark Summary ---")
avg_fps = np.mean(overall_fps_list)
var_fps = np.var(overall_fps_list)
print(f'Average FPS of {args.repeat_times} runs: {avg_fps:.2f} img/s')
print(f'FPS Variance of {args.repeat_times} runs: {var_fps:.4f}')
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,89 @@
# get_flops_smp.py
import torch
import argparse
from fvcore.nn import FlopCountAnalysis, flop_count_table
# 同样,需要能够导入 train_smp.py 中的函数
from train import initialize_components
import config
# TODO 使用方法 TODO
# python Tool_get_params_and_FLOPs.py -a Unet --shape 512 512
def parse_args():
parser = argparse.ArgumentParser(description='Get FLOPs and Params of an SMP model')
parser.add_argument(
'-a', '--architecture',
type=str,
required=True,
choices=list(config.ALL_MODEL_CONFIGS.keys()),
help="The model architecture to analyze."
)
parser.add_argument(
'--shape',
type=int,
nargs='+',
default=[512, 512],
help='Input image size for analysis, e.g., --shape 512 512'
)
args = parser.parse_args()
return args
def main():
args = parse_args()
if len(args.shape) == 1:
h, w = args.shape[0], args.shape[0]
elif len(args.shape) == 2:
h, w = args.shape[0], args.shape[1]
else:
raise ValueError('Invalid input shape. Use one or two integers.')
# 1. 初始化模型
print(f"Initializing model: '{args.architecture}'...")
num_classes = len(config.CLASSES)
model, _, _, _ = initialize_components(args.architecture, num_classes, config.SEG_MODE)
# 如果模型在训练时使用了 DataParallel需要移除
if isinstance(model, torch.nn.DataParallel):
model = model.module
model.to(config.DEVICE)
model.eval()
# 2. 创建伪造输入
dummy_input = torch.randn(1, 3, h, w).to(config.DEVICE)
# 3. 计算参数量
# 分别计算总参数和可训练参数
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
# 4. 使用 fvcore 计算 FLOPs
print("Analyzing model FLOPs... (This may take a moment)")
flops = FlopCountAnalysis(model, dummy_input)
# 5. 打印结果
split_line = '=' * 30
print(f"\n{split_line}")
print(f"Model: {args.architecture}")
print(f"Input shape: (1, 3, {h}, {w})")
print(f"{split_line}")
# 使用 fvcore 的表格打印工具,非常清晰
print(flop_count_table(flops, max_depth=4))
# 打印总结
gflops = flops.total() / 1e9
total_params_m = total_params / 1e6
trainable_params_m = trainable_params / 1e6
print(f"{split_line}")
print(f"Total FLOPs: {gflops:.2f} GFLOPs")
print(f"Total Parameters: {total_params_m:.2f} M")
print(f"Trainable Parameters: {trainable_params_m:.2f} M")
print(f"{split_line}")
print('!!! Please be cautious: FLOPs computation may not be perfectly accurate for all custom ops.')
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,153 @@
import torch
from pathlib import Path
# --- 1. 核心目录设置 (Core Directories) ---
# 使用 pathlib 进行路径管理,更现代、更健壮
HARDISK_DIR = Path('../Hardisk')
DATA_SETS_DIR = Path('../DataSet_Public_outputs') # 模型保存位置
PREDICT_ALL_BEST_MODELS_DIR = Path('../BestMode_Predict_Results_DataSet_Public') # 预测模型存储位置
# V1: 1_CholecSeg8k-13Type-1920x1080
# DATA_DIR = Path('../DataSet_Public/1_CholecSeg8k-13Type-1920x1080') # Path(__file__).parent.parent【有中文放弃了】 # 项目根目录 (假设 config.py 在 src/ 下)
# OUTPUTS_DIR = DATA_SETS_DIR / "1_CholecSeg8k-13Type-1920x1080_outputs-SegModel" # 所有输出文件的根目录 # train 中 Path(config.OUTPUTS_DIR / architecture)
# PREDICT_BEST_MODEL_DIR = PREDICT_ALL_BEST_MODELS_DIR / "1_CholecSeg8k-13Type-1920x1080_outputs-SegModel" # 最优模型位置 # train 中 Path(config.PREDICT_BEST_MODEL_DIR / architecture)
# # V22_AutoLaparo-10Type-1920x1080
# DATA_DIR = Path('../DataSet_Public/2_AutoLaparo-10Type-1920x1080') # Path(__file__).parent.parent【有中文放弃了】 # 项目根目录 (假设 config.py 在 src/ 下)
# OUTPUTS_DIR = DATA_SETS_DIR / "2_AutoLaparo-10Type-1920x1080_outputs-SegModel" # 所有输出文件的根目录 # train 中 Path(config.OUTPUTS_DIR / architecture)
# PREDICT_BEST_MODEL_DIR = PREDICT_ALL_BEST_MODELS_DIR / "2_AutoLaparo-10Type-1920x1080_outputs-SegModel" # 最优模型位置 # train 中 Path(config.PREDICT_BEST_MODEL_DIR / architecture)
# # V33_1_Endovis_2017-8Type-512x512
# DATA_DIR = Path('../DataSet_Public/3_1_Endovis_2017-8Type-512x512') # Path(__file__).parent.parent【有中文放弃了】 # 项目根目录 (假设 config.py 在 src/ 下)
# OUTPUTS_DIR = DATA_SETS_DIR / "3_1_Endovis_2017-8Type-512x512_outputs-SegModel" # 所有输出文件的根目录 # train 中 Path(config.OUTPUTS_DIR / architecture)
# PREDICT_BEST_MODEL_DIR = PREDICT_ALL_BEST_MODELS_DIR / "3_1_Endovis_2017-8Type-512x512_outputs-SegModel" # 最优模型位置 # train 中 Path(config.PREDICT_BEST_MODEL_DIR / architecture)
# # V43_2_Endovis_2018-8Type-512x512
# DATA_DIR = Path('../DataSet_Public/3_2_Endovis_2018-8Type-512x512') # Path(__file__).parent.parent【有中文放弃了】 # 项目根目录 (假设 config.py 在 src/ 下)
# OUTPUTS_DIR = DATA_SETS_DIR / "3_2_Endovis_2018-8Type-512x512_outputs-SegModel" # 所有输出文件的根目录 # train 中 Path(config.OUTPUTS_DIR / architecture)
# PREDICT_BEST_MODEL_DIR = PREDICT_ALL_BEST_MODELS_DIR / "3_2_Endovis_2018-8Type-512x512_outputs-SegModel" # 最优模型位置 # train 中 Path(config.PREDICT_BEST_MODEL_DIR / architecture)
# # V54_Dresden-11Type-512x512
# DATA_DIR = Path('../DataSet_Public/4_Dresden-11Type-512x512') # Path(__file__).parent.parent【有中文放弃了】 # 项目根目录 (假设 config.py 在 src/ 下)
# OUTPUTS_DIR = DATA_SETS_DIR / "4_Dresden-11Type-512x512_outputs-SegModel" # 所有输出文件的根目录 # train 中 Path(config.OUTPUTS_DIR / architecture)
# PREDICT_BEST_MODEL_DIR = PREDICT_ALL_BEST_MODELS_DIR / "4_Dresden-11Type-512x512_outputs-SegModel" # 最优模型位置 # train 中 Path(config.PREDICT_BEST_MODEL_DIR / architecture)
# # Test_V15_Predict_Video
DATA_DIR = Path('../DataSet_Public/5_Predict_Video/LC_Video_1') # Path(__file__).parent.parent【有中文放弃了】 # 项目根目录 (假设 config.py 在 src/ 下)
OUTPUTS_DIR = None # 所有输出文件的根目录 # train 中 Path(config.OUTPUTS_DIR / architecture)
PREDICT_BEST_MODEL_DIR = PREDICT_ALL_BEST_MODELS_DIR / "LC_Video_1_outputs-SegModel" # 最优模型位置 # train 中 Path(config.PREDICT_BEST_MODEL_DIR / architecture)
# --- 2. 训练与验证数据路径 (Training & Validation Paths for train.py) ---
# 在 train.py 中已使用
TRAIN_IMAGE_DIR = DATA_DIR / "images" / "train"
TRAIN_MASK_DIR = DATA_DIR / "labels_GT" / "train"
VAL_IMAGE_DIR = DATA_DIR / "images" / "val" # TODO "val_images"
VAL_MASK_DIR = DATA_DIR / "labels_GT" / "val" # TODO "val_masks"
# --- 3. 预测数据路径 (Prediction Paths for predict.py) ---
TEST_IMAGE_DIR = DATA_DIR / "images" / "val" # 测试图像目录 # TODO "test_images"
TEST_MASK_DIR = DATA_DIR / "labels_GT" / "val" # 测试掩码目录 (用于评估) # TODO "test_masks"
# --- 4. 输出文件与目录路径 (Output Files & Directories) ---
RAW_MASK_FOLDER = "predicted_raw_masks" # 存放预测出的单通道原始掩码
ANALYSIS_RESULTS_FOLDER = "prediction_analysis" # 存放对比图、曲线和指标CSV
# 训练过程中的输出文件
BEST_MODEL_SAVE_NAME = "best_model.pth"
METRICS_CSV_NAME = "training_metrics.csv"
# --- 5. 模型与数据参数 (Model & Data Parameters) ---
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# --- 5.1. 【新增】模型选择与参数化配置 ---
# ※ 不定义训练用模型在train、predict中定义 ※
# 这是一个“配置库”,存放了所有模型的参数设置 # 具体请参考: https://smp.readthedocs.io/en/latest/models.html
ALL_MODEL_CONFIGS = {
'Unet': {
'encoder_name': 'resnet34',
'encoder_weights': 'imagenet',
'decoder_channels': (256, 128, 64, 32, 16),
'decoder_attention_type': None, # 可选 'scse'
},
'UnetPlusPlus': {
'encoder_name': 'resnet34',
'encoder_weights': 'imagenet',
'decoder_channels': (256, 128, 64, 32, 16),
'decoder_attention_type': None, # 可选 'scse'
},
'FPN': {
'encoder_name': 'resnet34',
'encoder_weights': 'imagenet',
},
'PSPNet': {
'encoder_name': 'resnet34',
'encoder_weights': 'imagenet',
},
'DeepLabV3': {
'encoder_name': 'resnet34',
'encoder_weights': 'imagenet',
},
'DeepLabV3Plus': {
'encoder_name': 'resnet34',
'encoder_weights': 'imagenet',
},
'Linknet': {
'encoder_name': 'resnet34',
'encoder_weights': 'imagenet',
},
'MAnet': {
'encoder_name': 'resnet34',
'encoder_weights': 'imagenet',
},
'PAN': {
'encoder_name': 'resnet34',
'encoder_weights': 'imagenet',
},
'UPerNet': {
'encoder_name': 'resnet34',
'encoder_weights': 'imagenet',
},
'Segformer': {
'encoder_name': 'resnet34',
'encoder_weights': 'imagenet',
},
'DPT': {
'encoder_name': 'tu-vit_base_patch16_224.augreg_in21k',
'encoder_weights': 'imagenet',
}
}
SEG_MODE = "multiclass" # 分割模式:'multiclass' 或 'multilabel'
# TODO 评估参数排除项:在计算评估指标 (如 IoU, F1-score) 时要忽略的类别列表。 TODO
EVALUATION_CLASSES_TO_IGNORE = [] # 如果列表为空 [], 则评估所有类别。
# EVALUATION_CLASSES_TO_IGNORE = ['background'] # 例如,设置为 ['background'] 将在计算总体指标时排除背景类。
IGNORE_INDEX = -100 # 当不包含背景时,掩码中背景像素的值,损失函数会忽略它
# V11_CholecSeg8k-13Type-1920x1080
CLASSES = ['background', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12'] # 根据您的数据集修改
CLASS_RGB_VALUES = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12] # 根据您的数据集修改
# V22_AutoLaparo-10Type-1920x1080
# CLASSES = ['background', '1', '2', '3', '4', '5', '6', '7', '8', '9'] # 根据您的数据集修改
# CLASS_RGB_VALUES = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] # 根据您的数据集修改
# V33_1_Endovis_2017-8Type-512x512
# CLASSES = ['background', '1', '2', '3', '4', '5', '6', '7'] # 根据您的数据集修改
# CLASS_RGB_VALUES = [0, 1, 2, 3, 4, 5, 6, 7] # 根据您的数据集修改
# V43_2_Endovis_2018-8Type-512x512
# CLASSES = ['background', '1', '2', '3', '4', '5', '6', '7'] # 根据您的数据集修改
# CLASS_RGB_VALUES = [0, 1, 2, 3, 4, 5, 6, 7] # 根据您的数据集修改
# V54_Dresden-11Type-512x512
# CLASSES = ['background', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10'] # 根据您的数据集修改
# CLASS_RGB_VALUES = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] # 根据您的数据集修改
# 图像尺寸
IMAGE_HEIGHT = 256 # 512
IMAGE_WIDTH = 256 # 512
# --- 6. 训练超参数 (Training Hyperparameters) ---
BATCH_SIZE = 16
NUM_WORKERS = 8
EPOCHS = 300
LEARNING_RATE = 1e-4 # 1e-3 # MANet_降低学习率至 1e-4
WEIGHT_DECAY = 1e-4 # L2 正则化权重衰减
PIN_MEMORY = True # 是否将数据加载到锁页内存中以加速传输到 GPU
FBETA_BETA = 1.0 # F-beta score 的 beta 值。beta=1 等同于 F1-score
EARLY_STOPPING_PATIENCE = 100 # 早停机制的耐心值
CKPT_SAVE_INTERVAL = 10 # 每隔多少轮保存一次检查点模型

View File

@@ -0,0 +1,99 @@
import os
import cv2
import numpy as np
import torch
from torch.utils.data import Dataset
class SegmentationDataset(Dataset):
"""
自定义图像分割数据集。
根据指定的 seg_mode可以生成适用于 'multiclass''multilabel' 任务的掩码。
"""
def __init__(self, image_dir, mask_dir, classes, class_rgb_values, seg_mode, augmentation=None):
"""
Args:
image_dir (str): 图像文件目录。
mask_dir (str): 掩码文件目录。
classes (list): 类别名称列表。
class_rgb_values (list): 每个类别在灰度掩码中对应的像素值。
augmentation (albumentations.Compose, optional): 数据增强流程。
seg_mode (str): 分割模式, 'multiclass''multilabel'
"""
self.image_dir = image_dir
self.mask_dir = mask_dir
self.image_filenames = sorted(os.listdir(image_dir))
self.augmentation = augmentation
self.classes = classes
self.class_rgb_values = class_rgb_values
# 【新增】存储分割模式并进行验证
self.seg_mode = seg_mode
if self.seg_mode not in ['multiclass', 'multilabel']:
raise ValueError(f"seg_mode must be 'multiclass' or 'multilabel', but got {self.seg_mode}")
print(f"Found {len(self.image_filenames)} images in {image_dir}. Dataset mode: '{self.seg_mode}'")
def __len__(self):
return len(self.image_filenames)
def __getitem__(self, idx):
img_path = os.path.join(self.image_dir, self.image_filenames[idx])
image = cv2.imread(img_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
mask_path = os.path.join(self.mask_dir, self.image_filenames[idx])
# 读取灰度图
mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
if mask is None:
raise FileNotFoundError(f"Mask file not found or could not be read: {mask_path}")
# 【核心修改】根据 seg_mode 选择掩码的处理方式
if self.seg_mode == 'multilabel':
# 为 'multilabel' 模式创建 one-hot 编码掩码
processed_mask = self._create_one_hot_mask(mask)
else: # 'multiclass'
# 为 'multiclass' 模式创建类索引掩码
processed_mask = self._create_class_index_mask(mask)
if self.augmentation:
# Albumentations 可以同时处理 (H,W,C) 和 (H,W) 格式的掩码
sample = self.augmentation(image=image, mask=processed_mask)
image = sample['image']
mask = sample['mask']
else:
mask = processed_mask
# 【核心修改】根据 seg_mode 对最终的掩码张量进行处理
if self.seg_mode == 'multilabel':
# 调整维度顺序 (H, W, C) -> (C, H, W) 并转换为 float
mask = mask.permute(2, 0, 1).float()
else: # 'multiclass'
# 直接转换为 long 类型,不需要调整维度
mask = mask.long()
return image, mask
def _create_one_hot_mask(self, mask):
"""
将单通道的灰度掩码 (H, W) 转换为 one-hot 编码的掩码 (H, W, C)。
这是为 'multilabel' 模式准备的。
返回一个 NumPy 数组。
"""
semantic_map = np.zeros((mask.shape[0], mask.shape[1], len(self.class_rgb_values)), dtype=np.uint8)
for i, value in enumerate(self.class_rgb_values):
semantic_map[:, :, i] = (mask == value).astype(np.uint8)
return semantic_map
def _create_class_index_mask(self, mask):
"""
【新增函数】
将单通道的灰度掩码 (H, W) 转换为类索引掩码 (H, W)。
这是为 'multiclass' 模式准备的。
返回一个 NumPy 数组。
"""
class_index_map = np.zeros(mask.shape, dtype=np.uint8)
for i, value in enumerate(self.class_rgb_values):
# 将灰度值为 value 的像素,其类别索引设置为 i
class_index_map[mask == value] = i
return class_index_map

View File

@@ -0,0 +1,86 @@
# In utils.py or a new losses.py
import config
import torch.nn as nn
import segmentation_models_pytorch as smp
class UNetPlusPlusLoss(nn.Module):
"""
仿照 UNet++ 官方实现,结合了 BCE 和 Dice Loss。
这个损失函数适用于 segmentation-models-pytorch 库的输出。
Args:
mode (str): DiceLoss 的模式, e.g., 'multilabel', 'multiclass'.
bce_weight (float): BCE Loss 在总损失中所占的权重。
dice_weight (float): Dice Loss 在总损失中所占的权重。
"""
def __init__(self, mode=config.SEG_MODE, bce_weight=0.5, dice_weight=0.5):
super(UNetPlusPlusLoss, self).__init__()
# 使用数值更稳定的 BCEWithLogitsLoss
self.bce_loss = nn.BCEWithLogitsLoss()
self.dice_loss = smp.losses.DiceLoss(mode=mode)
self.bce_weight = bce_weight
self.dice_weight = dice_weight
def forward(self, y_pred, y_true):
"""
计算组合损失。
Args:
y_pred: 模型的预测输出 (logits)。
y_true: 真实的标签 (mask)。
Returns:
组合后的损失值。
"""
bce = self.bce_loss(y_pred, y_true)
dice = self.dice_loss(y_pred, y_true)
# 按照权重将两种损失相加
total_loss = self.bce_weight * bce + self.dice_weight * dice
return total_loss
class MultiClassLoss(nn.Module):
"""
A combined loss function for multi-class segmentation.
This loss combines CrossEntropyLoss and DiceLoss, which is a common and effective
practice for semantic segmentation tasks.
Args:
mode (str): DiceLoss mode, should be 'multiclass'.
ce_weight (float): Weight for the CrossEntropyLoss component.
dice_weight (float): Weight for the DiceLoss component.
ignore_index (int): Specifies a target value that is ignored and does not contribute to the input gradient.
"""
def __init__(self, mode='multiclass', ce_weight=0.5, dice_weight=0.5, ignore_index=-100):
super(MultiClassLoss, self).__init__()
# CrossEntropyLoss is the standard for multi-class classification.
# It combines LogSoftmax and NLLLoss in one single class.
self.ce_loss = nn.CrossEntropyLoss(ignore_index=ignore_index)
# DiceLoss in 'multiclass' mode works correctly with class indices.
self.dice_loss = smp.losses.DiceLoss(mode=mode)
self.ce_weight = ce_weight
self.dice_weight = dice_weight
def forward(self, y_pred, y_true):
"""
Calculates the combined loss.
Args:
y_pred: Model predictions (logits), shape (N, C, H, W).
y_true: Ground truth labels (class indices), shape (N, H, W) and dtype=torch.long.
Returns:
The combined loss value.
"""
# Ensure target tensor is of type long for CrossEntropyLoss
y_true = y_true.long()
ce = self.ce_loss(y_pred, y_true)
dice = self.dice_loss(y_pred, y_true)
total_loss = self.ce_weight * ce + self.dice_weight * dice
return total_loss

View File

@@ -0,0 +1,109 @@
#!/bin/bash
# --- 1. Conda 环境设置 ---
CONDA_BASE_PATH="/home/wkmgc/miniconda3" # <--- 在这里修改为您自己的路径
CONDA_ENV_NAME="${SEG_CONDA_ENV:-seg_smp}" # 可用 SEG_CONDA_ENV=SMP bash predict.sh 临时覆盖
# 初始化并激活 Conda 环境
if [ -f "${CONDA_BASE_PATH}/etc/profile.d/conda.sh" ]; then
source "${CONDA_BASE_PATH}/etc/profile.d/conda.sh"
conda activate "${CONDA_ENV_NAME}"
if [ $? -ne 0 ]; then
echo "错误: 激活 Conda 环境 '${CONDA_ENV_NAME}' 失败!"
exit 1
fi
echo "Conda 环境 '${CONDA_ENV_NAME}' 已成功激活。"
else
echo "错误: 找不到 conda.sh 脚本。请检查您的 CONDA_BASE_PATH 设置是否正确。"
exit 1
fi
# --- 2. 模型与 GPU 配置 ---
GPUS_GROUP_1="2"
GPUS_GROUP_2="3"
GPUS_GROUP_3="4"
GPUS_GROUP_4="5"
# 注意:这里的模型架构列表应与 train.sh 保持一致,以确保能找到对应的训练好的模型
GROUP_1_ARCHS=("PSPNet" "Unet" "UnetPlusPlus" )
GROUP_2_ARCHS=("Linknet" "MAnet" "DeepLabV3" )
GROUP_3_ARCHS=("UPerNet" "Segformer" "DPT" )
GROUP_4_ARCHS=("FPN" "DeepLabV3Plus" "PAN")
# 1. 从 config.py 中读取 PREDICT_BEST_MODEL_DIR 的值
PREDICT_BEST_MODEL_DIR=$(python -c "from config import PREDICT_BEST_MODEL_DIR; print(PREDICT_BEST_MODEL_DIR)")
# 检查是否成功获取了 PREDICT_BEST_MODEL_DIR
if [ -z "$PREDICT_BEST_MODEL_DIR" ]; then
echo "PREDICT_BEST_MODEL_DIR: $PREDICT_BEST_MODEL_DIR"
echo "Error: Could not read PREDICT_BEST_MODEL_DIR from yolo_config.py. Exiting."
echo "Error 2: Or the directory specified by PREDICT_BEST_MODEL_DIR does not exist. Please create it first."
exit 1
fi
# 2. 定义带有时间戳的日志目录名
LOG_DIR_NAME="predict_logs_parallel_$(date +%Y-%m-%d_%H-%M-%S)"
# 3. 拼接成最终的完整路径
LOG_DIR="$PREDICT_BEST_MODEL_DIR/$LOG_DIR_NAME"
mkdir -p "${LOG_DIR}"
echo "所有模型的预测日志将保存在 ./${LOG_DIR}/ 目录中。"
echo "----------------------------------------------------"
# --- 3. 依次启动所有预测任务 ---
# 脚本将按顺序逐一执行每个模型的预测,等待上一个完成后再开始下一个。
echo ">>> 准备启动第一组预测任务 (后台运行)..."
for arch in "${GROUP_1_ARCHS[@]}"; do
echo " -> 正在后台启动模型: ${arch} on GPUs: ${GPUS_GROUP_1}"
# 【修改点】: 使用 'echo "1" |' 来自动回答 1_predict.py 中的交互式提问,选择第一个找到的模型。
echo "1" | CUDA_VISIBLE_DEVICES=${GPUS_GROUP_1} python 1_predict.py -a "${arch}" > "${LOG_DIR}/${arch}.log" 2>&1 &
echo " - 模型 ${arch} 预测已在后台启动。日志文件: ${LOG_DIR}/${arch}.log"
echo " - 等待 60 秒..."
sleep 60
done
echo ">>> 第一组所有模型均已启动。"
echo "----------------------------------------------------"
echo ">>> 准备启动第二组预测任务 (后台运行)..."
for arch in "${GROUP_2_ARCHS[@]}"; do
echo " -> 正在后台启动模型: ${arch} on GPUs: ${GPUS_GROUP_2}"
# 【修改点】: 自动选择第一个模型
echo "1" | CUDA_VISIBLE_DEVICES=${GPUS_GROUP_2} python 1_predict.py -a "${arch}" > "${LOG_DIR}/${arch}.log" 2>&1 &
echo " - 模型 ${arch} 预测已在后台启动。日志文件: ${LOG_DIR}/${arch}.log"
echo " - 等待 50 秒..."
sleep 60
done
echo ">>> 第二组所有模型均已启动。"
echo "----------------------------------------------------"
echo ">>> 准备启动第三组预测任务 (后台运行)..."
for arch in "${GROUP_3_ARCHS[@]}"; do
echo " -> 正在后台启动模型: ${arch} on GPUs: ${GPUS_GROUP_3}"
# 【修改点】: 自动选择第一个模型
echo "1" | CUDA_VISIBLE_DEVICES=${GPUS_GROUP_3} python 1_predict.py -a "${arch}" > "${LOG_DIR}/${arch}.log" 2>&1 &
echo " - 模型 ${arch} 预测已在后台启动。日志文件: ${LOG_DIR}/${arch}.log"
echo " - 等待 50 秒..."
sleep 60
done
echo ">>> 第三组所有模型均已启动。"
echo "----------------------------------------------------"
echo ">>> 准备启动第四组预测任务 (后台运行)..."
for arch in "${GROUP_4_ARCHS[@]}"; do
echo " -> 正在后台启动模型: ${arch} on GPUs: ${GPUS_GROUP_4}"
# 【修改点】: 自动选择第一个模型
echo "1" | CUDA_VISIBLE_DEVICES=${GPUS_GROUP_4} python 1_predict.py -a "${arch}" > "${LOG_DIR}/${arch}.log" 2>&1 &
echo " - 模型 ${arch} 预测已在后台启动。日志文件: ${LOG_DIR}/${arch}.log"
echo " - 等待 50 秒..."
sleep 60
done
echo ">>> 第四组所有模型均已启动。"
echo "----------------------------------------------------"
# --- 4. 等待所有后台任务完成 ---
echo ""
echo "--- 所有模型均已在后台启动。现在等待所有预测任务完成... ---"
# 'wait' 命令会暂停脚本,直到所有由此脚本启动的后台任务全部执行完毕
wait
echo "--- 所有后台预测任务已全部完成! ---"
# 退出前取消激活环境
conda deactivate

View File

@@ -0,0 +1,10 @@
opencv-python
albumentations
matplotlib
numpy
pandas
scikit_learn
scipy
segmentation_models_pytorch==0.5.1.dev0
torch
tqdm

View File

@@ -0,0 +1,312 @@
import logging, argparse, shutil
from datetime import datetime
from pathlib import Path
from typing import Dict, Tuple
import gc
import numpy as np
import segmentation_models_pytorch as smp
import torch
import torch.optim as optim
from torch.optim import lr_scheduler # 学习率调度器
from torch.amp import GradScaler
from torch.utils.data import DataLoader
# 本地应用/库的导入
import config, os
import utils
from dataset import SegmentationDataset
from loss import MultiClassLoss, UNetPlusPlusLoss
# --- 日志设置 ---
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(message)s",
handlers=[
logging.StreamHandler() # 输出日志到控制台
]
)
# --- 辅助函数:移动训练结果文件夹从源输出目录移动到配置的硬盘目录 ---
def move_results_to_hardisk(project_folder: str):
"""
将指定的训练结果文件夹从源输出目录移动到配置的硬盘目录。
移动操作包括复制整个文件夹树,然后在复制成功后删除原始文件夹。
参数:
project_folder (str): 要移动的项目文件夹的名称。
该名称通常由模型架构和时间戳构成。
"""
source_dir = Path(config.OUTPUTS_DIR) / project_folder
# 确保目标硬盘目录存在
outputs_folder_name = Path(config.OUTPUTS_DIR).name
destination_dir = Path(config.HARDISK_DIR) / outputs_folder_name / project_folder
destination_dir.parent.mkdir(parents=True, exist_ok=True)
logging.info(f"准备将结果从 {source_dir} 移动到 {destination_dir}...")
try:
# 步骤 1: 复制文件夹
shutil.copytree(source_dir, destination_dir)
logging.info(f"成功复制结果到: {destination_dir}")
# 步骤 2: 复制成功后,删除原文件夹
logging.info(f"正在删除原文件夹: {source_dir}")
shutil.rmtree(source_dir)
logging.info("成功删除原文件夹。")
logging.info(f"任务完成,最终结果已保存至: {destination_dir}")
except Exception as e:
logging.error(f"移动文件夹时发生错误: {e}")
logging.warning(f"原始训练结果仍保留在: {source_dir}")
# 1. 根据配置文件初始化模型、损失函数和优化器。
def initialize_components(model_architecture:str, num_classes: int, seg_mode: str) -> Tuple:
"""
根据配置文件动态初始化模型、损失函数和优化器。
"""
# --- 初始化模型 ---
logging.info(f"正在初始化模型: '{model_architecture}'...")
# 使用传入的 model_architecture 作为 key 来获取参数
try:
model_params = config.ALL_MODEL_CONFIGS[model_architecture]
model_class = getattr(smp, model_architecture)
except KeyError:
logging.error(f"模型 '{model_architecture}' 的配置未在 config.py 的 ALL_MODEL_CONFIGS 中定义!")
raise
except AttributeError:
logging.error(f"模型 '{model_architecture}' 在 segmentation_models_pytorch 库中不存在!")
raise
# 2. 准备参数字典
# 首先复制 config 中的参数,然后添加固定的 `in_channels` 和 `classes`
params = model_params.copy()
params['in_channels'] = 3
params['classes'] = num_classes
# ======================== 【新增代码段开始】 ======================== #
# 自动检测 encoder_name 是否包含 'vit'
encoder_name = model_params.get('encoder_name', '')
if 'vit' in encoder_name.lower():
params['dynamic_img_size'] = True
logging.info(f"检测到 ViT 编码器 ('{encoder_name}')。自动设置 dynamic_img_size=True。")
# ======================== 【新增代码段结束】 ======================== #
# 3. 使用字典解包 (**) 将所有参数传递给模型构造函数
# 这种方式非常灵活,无论你在 config 中定义了多少参数,都能正确传递
model = model_class(**params).to(config.DEVICE)
# --- 设置损失函数 (保持不变) ---
logging.info(f"'{seg_mode}' 模式设置损失函数。")
if seg_mode == 'multiclass':
loss_fn = MultiClassLoss(mode=seg_mode)
elif seg_mode == 'multilabel':
loss_fn = UNetPlusPlusLoss(mode=seg_mode)
else:
raise ValueError(f"无效的 SEG_MODE: '{seg_mode}'。必须是 'multiclass''multilabel'")
# --- 优化器和梯度缩放器 (保持不变) ---
optimizer = optim.Adam(model.parameters(), lr=config.LEARNING_RATE, weight_decay=config.WEIGHT_DECAY)
# --- TODO 添加学习率调度器 TODO ---
# T_max 是调度器周期的最大迭代次数,通常设置为总的 epoch 数量
scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=config.EPOCHS, eta_min=1e-6)
# -----------------------
scaler = GradScaler('cuda')
return model, loss_fn, optimizer, scaler
def main(model_architecture: str) -> None:
"""
编排图像分割模型的主要训练流程。
该函数执行以下步骤:
1. 设置输出目录。
2. 准备数据增强,并为训练集和验证集创建 DataLoader 实例。
3. 初始化 U-Net++ 模型、损失函数、优化器和梯度缩放器。
4. 运行主训练循环,其中包括:
- 训练一个 epoch。
- 在验证集上评估模型。
- 将各项指标记录到 CSV 文件。
- 基于验证损失实现早停机制。
- 保存性能最佳的模型和定期的断点。
- 绘制训练进度曲线图。
返回:
str: 本次训练运行的文件夹名称 (run_name)。
"""
logging.info(f"使用设备: {config.DEVICE}")
# --- 1.1. 使用 pathlib 进行现代化的路径管理 ---
timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
run_name = f"{model_architecture}_{timestamp}"
output_dir = Path(config.OUTPUTS_DIR) / run_name # <-- 使用 / 操作符
metrics_csv_path = output_dir / config.METRICS_CSV_NAME
best_model_path = output_dir / config.BEST_MODEL_SAVE_NAME
utils.setup_directories(output_dir)
# --- 1.2. 探测模型以确定最佳图像尺寸 ---
logging.info("正在探测模型以确定输入尺寸要求...") #
num_classes = len(config.CLASSES) #
# 创建一个临时模型实例,仅用于检查其属性
# 注意:这里的调用现在更简单,不需要传递 dynamic_img_size
probe_model, _, _, _ = initialize_components(
model_architecture, num_classes, config.SEG_MODE #
)
target_height, target_width = config.IMAGE_HEIGHT, config.IMAGE_WIDTH #
try:
# 检查编码器是否需要固定输入尺寸
if probe_model.encoder.is_fixed_input_size:
required_size = probe_model.encoder.input_size
target_height = required_size[1]
target_width = required_size[2]
logging.warning(
f"模型 '{model_architecture}' 的编码器需要固定的输入尺寸 "
f"({target_height}x{target_width})。"
f"将忽略 config.py 中的尺寸设置。"
)
else:
logging.info(f"模型 '{model_architecture}' 支持动态输入尺寸。将使用 config.py 中定义的尺寸 ({target_height}x{target_width})。")
except AttributeError:
logging.info("无法自动检测模型尺寸要求。将使用 config.py 中定义的尺寸。")
del probe_model # 删除临时模型,释放内存
# --- 1.3. 数据加载 ---
logging.info(f"正在设置数据增强和加载器,目标尺寸为 {target_height}x{target_width}...") #
# 使用探测到的或配置中指定的目标尺寸
train_transform = utils.get_training_augmentation(target_height, target_width) #
val_transform = utils.get_validation_augmentation(target_height, target_width) #
train_dataset = SegmentationDataset(
image_dir=Path(config.TRAIN_IMAGE_DIR),
mask_dir=Path(config.TRAIN_MASK_DIR),
classes=config.CLASSES,
class_rgb_values=config.CLASS_RGB_VALUES,
augmentation=train_transform,
seg_mode=config.SEG_MODE
)
val_dataset = SegmentationDataset(
# 已修正路径,使用配置文件中的验证数据目录
image_dir=Path(config.VAL_IMAGE_DIR),
mask_dir=Path(config.VAL_MASK_DIR),
classes=config.CLASSES,
class_rgb_values=config.CLASS_RGB_VALUES,
augmentation=val_transform,
seg_mode=config.SEG_MODE
)
train_loader = DataLoader(train_dataset, batch_size=config.BATCH_SIZE, shuffle=True, pin_memory=config.PIN_MEMORY, num_workers=config.NUM_WORKERS)
val_loader = DataLoader(val_dataset, batch_size=config.BATCH_SIZE, shuffle=False, pin_memory=config.PIN_MEMORY, num_workers=config.NUM_WORKERS)
# --- 2. 模型、损失函数、优化器、梯度缩放器 ---
num_classes = len(config.CLASSES)
model, loss_fn, optimizer, scaler = initialize_components(model_architecture, num_classes, config.SEG_MODE)
# --- 启用多 GPU 训练 ---
if torch.cuda.device_count() >= 1:
logging.info(f"正在使用 {torch.cuda.device_count()} 块 GPU 进行并行训练。")
model = torch.nn.DataParallel(model)
# --- 3. 训练循环 ---
best_val_loss = float('inf')
epochs_no_improve = 0
for epoch in range(config.EPOCHS):
logging.info(f"\n--- 第 {epoch+1}/{config.EPOCHS} 轮 ---")
# --- 3.1. 训练集训练数据 ---
train_loss = utils.train_fn(train_loader, model, optimizer, loss_fn, scaler, config.DEVICE, model_architecture)
# --- 3.2. 验证集测试数据【获得指标】 ---
val_metrics_dict = utils.evaluate_fn(val_loader, model, loss_fn, config.DEVICE, num_classes, model_architecture, beta=config.FBETA_BETA)
# --- 3.3. 日志记录 及 显示 ---
log_data = {'epoch': epoch + 1, 'train_loss': train_loss}
log_data.update(val_metrics_dict)
utils.log_metrics_to_csv(log_data, metrics_csv_path)
val_loss = val_metrics_dict.get('val_loss')
iou_score = val_metrics_dict.get('iou_score', 0)
fbeta_score = val_metrics_dict.get('fbeta_score', 0)
logging.info(f"训练集Loss: {train_loss:.4f} | 验证集Loss: {val_loss:.4f} | IoU: {iou_score:.4f} | F-beta: {fbeta_score:.4f}")
# --- 3.4. 早停与模型检查点 ---
if val_loss < best_val_loss:
best_val_loss = val_loss
epochs_no_improve = 0
if isinstance(model, torch.nn.DataParallel):
torch.save(model.module.state_dict(), best_model_path)
else:
torch.save(model.state_dict(), best_model_path)
logging.info(f"验证损失改善至 {best_val_loss:.4f}。正在保存最佳模型至 '{best_model_path}'")
else:
epochs_no_improve += 1
logging.info(f"验证损失已连续 {epochs_no_improve} 轮未改善。")
if isinstance(config.EARLY_STOPPING_PATIENCE, int) and epochs_no_improve >= config.EARLY_STOPPING_PATIENCE:
logging.info(f"\n训练在 {epoch + 1} 轮后触发早停。")
logging.warning(f"验证损失已连续 {config.EARLY_STOPPING_PATIENCE} 轮未改善。")
break
# --- 3.5. 检查点及曲线图保存 ---
checkpoint_path = output_dir / f"epoch_{epoch+1}.pth"
if isinstance(model, torch.nn.DataParallel):
torch.save(model.module.state_dict(), checkpoint_path)
else:
torch.save(model.state_dict(), checkpoint_path)
logging.info(f"{epoch+1} 轮的检查点已保存至 '{checkpoint_path}'")
# ---------- 清理旧检查点 ----------
# 只在 epoch > 0 时才开始删,避免第一轮就误删
if epoch > 0:
# 要尝试删除的 epoch 编号 = 上一轮
last_epoch = epoch
# 如果上一轮不是“10 的整数倍”就删掉
if last_epoch % config.CKPT_SAVE_INTERVAL != 0:
old_ckpt = output_dir / f"epoch_{last_epoch}.pth"
if old_ckpt.exists():
old_ckpt.unlink() # 文件删除
utils.plot_training_progress(metrics_csv_path, output_dir)
logging.info("所有训练曲线图已更新并保存。")
return run_name
if __name__ == "__main__":
# 创建一个参数解析器
parser = argparse.ArgumentParser(description="训练图像分割模型。")
# 添加 --architecture 参数
# - `choices` 会自动从你的 config 文件中获取所有可用的模型名称
# - `default` 设置了在不提供参数时的默认模型
parser.add_argument(
"-a", "--architecture",
type=str,
choices=list(config.ALL_MODEL_CONFIGS.keys()),
required=True,
help="选择要训练的模型架构。"
)
# 解析命令行传入的参数
args = parser.parse_args()
# 将解析出的架构名称传入 main 函数并执行
project_folder = None
try:
project_folder = main(model_architecture=args.architecture)
finally:
# 强制进行垃圾回收,释放 Python 对象占用的系统内存 (RAM)
gc.collect()
logging.info("已执行垃圾回收 (gc.collect())。")
# 清理 PyTorch 在 CUDA 上缓存的显存
torch.cuda.empty_cache()
logging.info("训练结束,已清理 CUDA 缓存。")
# 训练和清理完成后,移动结果到硬盘
if project_folder:
move_results_to_hardisk(project_folder)
else:
logging.error("由于训练失败或未生成项目文件夹,未执行结果移动操作。")

View File

@@ -0,0 +1,105 @@
#!/bin/bash
# --- 1. Conda 环境设置 ---
CONDA_BASE_PATH="/home/wkmgc/miniconda3" # <--- 在这里修改为您自己的路径
CONDA_ENV_NAME="${SEG_CONDA_ENV:-seg_smp}" # 可用 SEG_CONDA_ENV=SMP bash train.sh 临时覆盖
# 初始化并激活 Conda 环境
if [ -f "${CONDA_BASE_PATH}/etc/profile.d/conda.sh" ]; then
source "${CONDA_BASE_PATH}/etc/profile.d/conda.sh"
conda activate "${CONDA_ENV_NAME}"
if [ $? -ne 0 ]; then
echo "错误: 激活 Conda 环境 '${CONDA_ENV_NAME}' 失败!"
exit 1
fi
echo "Conda 环境 '${CONDA_ENV_NAME}' 已成功激活。"
else
echo "错误: 找不到 conda.sh 脚本。请检查您的 CONDA_BASE_PATH 设置是否正确。"
exit 1
fi
# --- 2. 模型与 GPU 配置 ---
GPUS_GROUP_1="3"
GPUS_GROUP_2="2"
GPUS_GROUP_3="1"
GPUS_GROUP_4="4"
GROUP_1_ARCHS=("PSPNet" "Unet" "UnetPlusPlus" ) # G3 # "PSPNet" # "Unet" "UnetPlusPlus" "FPN"
GROUP_2_ARCHS=("Linknet" "MAnet" "DeepLabV3" ) # G3 "Linknet" "MAnet" # "DeepLabV3" # "DeepLabV3Plus"
GROUP_3_ARCHS=("UPerNet" "Segformer" "DPT" ) # G3 # "UPerNet" "Segformer" "DPT" # "PAN"
GROUP_4_ARCHS=("FPN" "DeepLabV3Plus" "PAN")
# 1. 从 config.py 中读取 OUTPUTS_DIR 的值
OUTPUTS_DIR=$(python -c "from config import OUTPUTS_DIR; print(OUTPUTS_DIR)")
# 检查是否成功获取了 OUTPUTS_DIR
if [ -z "$OUTPUTS_DIR" ]; then
echo "OUTPUTS_DIR: $OUTPUTS_DIR"
echo "Error 1: Could not read OUTPUTS_DIR from config.py. Exiting."
echo "Error 2: Or the directory specified by OUTPUTS_DIR does not exist. Please create it first."
exit 1
fi
# 2. 定义带有时间戳的日志目录名
LOG_DIR_NAME="predict_logs_parallel_$(date +%Y-%m-%d_%H-%M-%S)"
# 3. 拼接成最终的完整路径
LOG_DIR="$OUTPUTS_DIR/$LOG_DIR_NAME"
mkdir -p "${LOG_DIR}"
echo "所有模型的日志将保存在 ./${LOG_DIR}/ 目录中。"
echo "----------------------------------------------------"
# --- 3. 依次启动所有训练任务 ---
# 脚本将按顺序逐一执行每个模型的训练,等待上一个完成后再开始下一个。
echo ">>> 准备启动第一组训练任务 (后台运行)..."
for arch in "${GROUP_1_ARCHS[@]}"; do
echo " -> 正在后台启动模型: ${arch} on GPUs: ${GPUS_GROUP_1}"
# 使用 '&' 将命令放入后台运行
CUDA_VISIBLE_DEVICES=${GPUS_GROUP_1} python train.py -a "${arch}" > "${LOG_DIR}/${arch}.log" 2>&1 &
echo " - 模型 ${arch} 已在后台启动。日志文件: ${LOG_DIR}/${arch}.log"
echo " - 等待 50 秒..."
sleep 50
done
echo ">>> 第一组所有模型均已启动。"
echo "----------------------------------------------------"
echo ">>> 准备启动第二组训练任务 (后台运行)..."
for arch in "${GROUP_2_ARCHS[@]}"; do
echo " -> 正在后台启动模型: ${arch} on GPUs: ${GPUS_GROUP_2}"
CUDA_VISIBLE_DEVICES=${GPUS_GROUP_2} python train.py -a "${arch}" > "${LOG_DIR}/${arch}.log" 2>&1 &
echo " - 模型 ${arch} 已在后台启动。日志文件: ${LOG_DIR}/${arch}.log"
echo " - 等待 50 秒..."
sleep 50
done
echo ">>> 第二组所有模型均已启动。"
echo "----------------------------------------------------"
echo ">>> 准备启动第三组训练任务 (后台运行)..."
for arch in "${GROUP_3_ARCHS[@]}"; do
echo " -> 正在后台启动模型: ${arch} on GPUs: ${GPUS_GROUP_3}"
CUDA_VISIBLE_DEVICES=${GPUS_GROUP_3} python train.py -a "${arch}" > "${LOG_DIR}/${arch}.log" 2>&1 &
echo " - 模型 ${arch} 已在后台启动。日志文件: ${LOG_DIR}/${arch}.log"
echo " - 等待 50 秒..."
sleep 50
done
echo ">>> 第三组所有模型均已启动。"
echo "----------------------------------------------------"
echo ">>> 准备启动第四组训练任务 (后台运行)..."
for arch in "${GROUP_4_ARCHS[@]}"; do
echo " -> 正在后台启动模型: ${arch} on GPUs: ${GPUS_GROUP_4}"
CUDA_VISIBLE_DEVICES=${GPUS_GROUP_4} python train.py -a "${arch}" > "${LOG_DIR}/${arch}.log" 2>&1 &
echo " - 模型 ${arch} 已在后台启动。日志文件: ${LOG_DIR}/${arch}.log"
echo " - 等待 50 秒..."
sleep 50
done
# echo ">>> 第四组所有模型均已启动。"
# echo "----------------------------------------------------"
# --- 4. 等待所有后台任务完成 ---
echo ""
echo "--- 所有模型均已在后台启动。现在等待所有训练任务完成... ---"
# 'wait' 命令会暂停脚本,直到所有由此脚本启动的后台任务全部执行完毕
wait
echo "--- 所有后台训练任务已全部完成! ---"
# 退出前取消激活环境
conda deactivate

View File

@@ -0,0 +1,361 @@
import os, logging, shutil
import torch
import numpy as np
import pandas as pd
from tqdm import tqdm
from torch.amp import autocast
from typing import Dict, Union, List
import albumentations as A
from albumentations.pytorch import ToTensorV2
import segmentation_models_pytorch as smp
import matplotlib.pyplot as plt
from pathlib import Path
import config
# ------------------- 绘图函数 (全新升级) ------------------- #
def log_metrics_to_csv(metrics_dict: Dict[str, Union[int, float]], csv_path: str) -> None:
"""
使用 pandas 将一个指标字典记录到 CSV 文件中。
(此函数保持不变)
"""
df_new_row = pd.DataFrame([metrics_dict])
if not os.path.isfile(csv_path):
df_new_row.to_csv(csv_path, index=False)
else:
df_new_row.to_csv(csv_path, mode='a', header=False, index=False)
def _plot_metric_group(ax, df: pd.DataFrame, group_metrics: List[str], title: str):
"""
内部辅助函数,用于在给定的坐标轴上绘制一组指标。
(此函数保持不变)
"""
epochs_range = df['epoch']
for name in group_metrics:
if f"val_{name}" in df.columns:
metric_col = f"val_{name}"
label = name.replace("_", " ").title()
ax.plot(epochs_range, df[metric_col], 'o-', label=label)
elif f"{name}" in df.columns:
metric_col = f"{name}"
label = name.replace("_", " ").title()
ax.plot(epochs_range, df[metric_col], 'o-', label=label)
else:
print("没有搜索到:", f"val_{name}")
ax.set_title(title)
ax.set_xlabel('Epoch')
ax.set_ylabel('Score')
ax.legend(loc='best', fontsize='small')
ax.grid(True)
def plot_training_progress(csv_path: str, output_dir: str):
"""
从 metrics.csv 文件读取训练历史数据,并生成所有分组和总览图表。
(此函数已修正错误)
"""
os.makedirs(output_dir, exist_ok=True)
try:
df_metrics = pd.read_csv(csv_path)
except FileNotFoundError:
print(f"警告:在路径 {csv_path} 未找到 metrics.csv 文件。跳过绘图。")
return
if len(df_metrics) < 2:
print("警告:数据点不足 (少于2个),无法生成趋势图。")
return
epochs_range = df_metrics['epoch']
train_loss = df_metrics['train_loss']
val_loss = df_metrics['val_loss']
core_metrics = ["iou_score", "fbeta_score", "accuracy"]
pr_metrics = ["precision", "recall", "specificity", "negative_predictive_value"]
error_metrics = ["false_positive_rate", "false_negative_rate", "false_discovery_rate", "false_omission_rate"]
likelihood_metrics = ["positive_likelihood_ratio", "negative_likelihood_ratio"]
# --- 1. 单独绘制并保存每个指标组的图表 ---
# (这部分逻辑没有变化)
fig_loss, ax_loss = plt.subplots(figsize=(10, 6))
ax_loss.plot(epochs_range, train_loss, 'o-', label='Training Loss')
ax_loss.plot(epochs_range, val_loss, 'o-', label='Validation Loss')
ax_loss.set_title('Loss Curves')
ax_loss.set_xlabel('Epoch')
ax_loss.set_ylabel('Loss')
ax_loss.legend()
ax_loss.grid(True)
plt.tight_layout()
plt.savefig(os.path.join(output_dir, "loss_curves.png"))
plt.close(fig_loss)
metric_groups = {
"core_performance": core_metrics,
"precision_recall": pr_metrics,
"error_rates": error_metrics,
"likelihood_ratios": likelihood_metrics
}
for group_name, group_list in metric_groups.items():
fig_group, ax_group = plt.subplots(figsize=(10, 7))
_plot_metric_group(ax_group, df_metrics, group_list, f'{group_name.replace("_", " ").title()} Metrics')
plt.tight_layout()
plt.savefig(os.path.join(output_dir, f"metrics_{group_name}.png"))
plt.close(fig_group)
# --- 2. 绘制并保存 3x2 综合概览图 ---
fig_overview, axes = plt.subplots(3, 2, figsize=(20, 22))
fig_overview.suptitle('Training Progress Overview', fontsize=20)
ax_list = axes.ravel()
# 图 1: 损失曲线 (*** 已修正 ***)
# 使用 ax 而不是 ax
ax_list[0].plot(epochs_range, train_loss, 'o-', label='Training Loss')
ax_list[0].plot(epochs_range, val_loss, 'o-', label='Validation Loss')
ax_list[0].set_title('Loss Curves')
ax_list[0].set_xlabel('Epoch')
ax_list[0].set_ylabel('Loss')
ax_list[0].legend()
ax_list[0].grid(True)
# 图 2: 核心性能指标
# 这里使用 ax 是正确的
_plot_metric_group(ax_list[1], df_metrics, core_metrics, 'Core Performance Metrics')
# 图 3: 精确率/召回率族
# 这里使用 ax 是正确的
_plot_metric_group(ax_list[2], df_metrics, pr_metrics, 'Precision-Recall Family')
# 图 4: 错误率分析
# 这里使用 ax 是正确的
_plot_metric_group(ax_list[3], df_metrics, error_metrics, 'Error Rate Analysis')
# 图 5: 似然比
# 这里使用 ax 是正确的
_plot_metric_group(ax_list[4], df_metrics, likelihood_metrics, 'Likelihood Ratios')
# 隐藏最后一个未使用的子图 (*** 已修正 ***)
# 使用 ax 而不是 ax
ax_list[5].axis('off')
plt.tight_layout(rect=[0, 0, 1, 0.96])
plt.savefig(os.path.join(output_dir, "training_progress_overview.png"))
plt.close(fig_overview)
print(f"所有图表已更新并保存至目录: {output_dir}")
# ------------------- 训练与评估函数 ('vit'模型不适用amp) ------------------- #
def train_fn(loader, model, optimizer, loss_fn, scaler, device, model_architecture: str):
"""一个 epoch 的训练逻辑。"""
loop = tqdm(loader, desc="Training")
total_loss = 0.0
model.train()
# 检查是否为ViT模型决定是否启用AMP
encoder_name = config.ALL_MODEL_CONFIGS[model_architecture].get('encoder_name', '')
use_amp = 'vit' not in encoder_name.lower()
if not use_amp and loop.n == 0: # 只在第一个batch打印一次警告
logging.warning(f"检测到 ViT 编码器 ('{encoder_name}')。本次训练将禁用自动混合精度 (AMP) 以确保稳定性。")
for batch_idx, (data, targets) in enumerate(loop):
data = data.to(device)
targets = targets.to(device)
# 根据 use_amp 标志执行不同代码路径
if use_amp:
# 【AMP 路径】适用于 CNN 等稳定模型
with autocast(device_type='cuda'):
predictions = model(data)
loss = loss_fn(predictions, targets)
optimizer.zero_grad()
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
else:
# 【Float32 路径】适用于 ViT/DPT 等模型
predictions = model(data)
loss = loss_fn(predictions, targets)
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
loop.set_postfix(loss=loss.item())
return total_loss / len(loader)
# --- 【修改】 ---
def evaluate_fn(loader, model, loss_fn, device, num_classes, model_architecture: str, beta=1.0):
"""在验证集上评估模型性能并计算包括原始TP/TN/FP/FN在内的多种指标。"""
model.eval()
loop = tqdm(loader, desc="Validation")
total_loss = 0
all_tp, all_fp, all_fn, all_tn = [], [], [], []
# 检查是否为ViT模型决定是否启用AMP
encoder_name = config.ALL_MODEL_CONFIGS[model_architecture].get('encoder_name', '')
use_amp = 'vit' not in encoder_name.lower()
with torch.no_grad():
for batch_idx, (data, targets) in enumerate(loop):
data = data.to(device)
targets = targets.to(device)
# 根据 use_amp 标志执行不同代码路径
if use_amp:
# 【AMP 路径】
with autocast(device_type='cuda'):
predictions = model(data)
if config.SEG_MODE == 'multilabel':
loss = loss_fn(predictions, targets.float())
else:
loss = loss_fn(predictions, targets.long())
else:
# 【Float32 路径】
predictions = model(data)
if config.SEG_MODE == 'multilabel':
loss = loss_fn(predictions, targets.float())
else:
loss = loss_fn(predictions, targets.long())
total_loss += loss.item()
# --- 【核心修正】根据 SEG_MODE 条件处理 targets ---
# 1. 获取模型预测的类别索引
preds_indices = torch.argmax(predictions, dim=1)
# 2. 根据模式确定真实标签的类别索引
if config.SEG_MODE == 'multilabel':
# 在 multilabel 模式下, targets 是 one-hot, 需要 argmax
targets_indices = torch.argmax(targets, dim=1)
elif config.SEG_MODE == 'multiclass':
# 在 multiclass 模式下, targets 已经是类别索引, 无需处理
targets_indices = targets
# 现在 preds_indices 和 targets_indices 都有正确的形状 [N, H, W]
tp, fp, fn, tn = smp.metrics.get_stats(
preds_indices,
targets_indices.long(), # Ensure it's long type for get_stats
mode=config.SEG_MODE, # Metrics are always calculated in multiclass fashion after argmax
num_classes=num_classes
)
# get_stats in smp 0.3.3 might return per-image stats, let's sum over batch dimension if needed.
# Assuming get_stats returns (B-size, C,) tensor per image in batch, we stack and sum. -> (B-num, B-size, C,)
all_tp.append(tp)
all_fp.append(fp)
all_fn.append(fn)
all_tn.append(tn)
# Sum up stats from all batches (Batch-num, Batch-size, C,) -> (C,)
tp = torch.cat(all_tp, dim=0).sum(dim=0) # (N, B, C) -> (C,)
fp = torch.cat(all_fp, dim=0).sum(dim=0) # (N, B, C) -> (C,)
fn = torch.cat(all_fn, dim=0).sum(dim=0) # (N, B, C) -> (C,)
tn = torch.cat(all_tn, dim=0).sum(dim=0) # (N, B, C) -> (C,)
# ======================== 【新增代码段开始】 ======================== #
# 根据 config.py 中的设置,筛选出用于计算宏观指标的类别
# 1. 获取需要忽略的类别的索引
ignore_indices = [config.CLASSES.index(cls) for cls in config.EVALUATION_CLASSES_TO_IGNORE if cls in config.CLASSES]
# 2. 创建一个布尔掩码,标记哪些类别需要被保留
# 默认保留所有类别
keep_mask = torch.ones(num_classes, dtype=torch.bool, device=tp.device)
if ignore_indices:
# 将需要忽略的类别的标记设置为 False
keep_mask[ignore_indices] = False
logging.info(f"评估时将忽略类别: {config.EVALUATION_CLASSES_TO_IGNORE}")
# 3. 使用掩码过滤统计数据
tp_filtered = tp[keep_mask]
fp_filtered = fp[keep_mask]
fn_filtered = fn[keep_mask]
tn_filtered = tn[keep_mask]
# ======================== 【新增代码段结束】 ======================== #
# ======================== 【修改代码段开始】 ======================== #
# 使用过滤后的 tp_filtered, fp_filtered 等张量来计算所有宏观指标
metrics = {
"val_loss": total_loss / len(loader),
"iou_score": smp.metrics.iou_score(tp_filtered, fp_filtered, fn_filtered, tn_filtered, reduction='micro').item(),
"f1_score": smp.metrics.f1_score(tp_filtered, fp_filtered, fn_filtered, tn_filtered, reduction='micro').item(),
"fbeta_score": smp.metrics.fbeta_score(tp_filtered, fp_filtered, fn_filtered, tn_filtered, beta=beta, reduction='micro').item(),
"accuracy": smp.metrics.accuracy(tp_filtered, fp_filtered, fn_filtered, tn_filtered, reduction='micro').item(),
"recall": smp.metrics.recall(tp_filtered, fp_filtered, fn_filtered, tn_filtered, reduction='micro').item(),
"precision": smp.metrics.precision(tp_filtered, fp_filtered, fn_filtered, tn_filtered, reduction='micro').item(),
# ... 对所有使用 smp.metrics 的宏观指标进行同样修改 ...
"specificity": smp.metrics.specificity(tp_filtered, fp_filtered, fn_filtered, tn_filtered, reduction='micro').item(),
"negative_predictive_value": smp.metrics.negative_predictive_value(tp_filtered, fp_filtered, fn_filtered, tn_filtered, reduction='micro').item(),
"false_negative_rate": smp.metrics.false_negative_rate(tp_filtered, fp_filtered, fn_filtered, tn_filtered, reduction='micro').item(),
"false_positive_rate": smp.metrics.false_positive_rate(tp_filtered, fp_filtered, fn_filtered, tn_filtered, reduction='micro').item(),
"false_discovery_rate": smp.metrics.false_discovery_rate(tp_filtered, fp_filtered, fn_filtered, tn_filtered, reduction='micro').item(),
"false_omission_rate": smp.metrics.false_omission_rate(tp_filtered, fp_filtered, fn_filtered, tn_filtered, reduction='micro').item(),
"positive_likelihood_ratio": smp.metrics.positive_likelihood_ratio(tp_filtered, fp_filtered, fn_filtered, tn_filtered, reduction='micro').item(),
"negative_likelihood_ratio": smp.metrics.negative_likelihood_ratio(tp_filtered, fp_filtered, fn_filtered, tn_filtered, reduction='micro').item(),
}
# ======================== 【修改代码段结束】 ======================== #
cpa_per_class = smp.metrics.recall(tp, fp, fn, tn, reduction='none').cpu().numpy().flatten()
if len(cpa_per_class) > 1:
metrics["mpa"] = np.mean(cpa_per_class[1:]).item()
for i, cpa in enumerate(cpa_per_class):
metrics[f"cpa_class_{i}"] = cpa.item()
else:
metrics["mpa"] = cpa_per_class[0].item()
metrics[f"cpa_class_0"] = cpa_per_class[0].item()
for i in range(num_classes):
metrics[f'tp_class_{i}'] = tp[i].item()
metrics[f'tn_class_{i}'] = tn[i].item()
metrics[f'fp_class_{i}'] = fp[i].item()
metrics[f'fn_class_{i}'] = fn[i].item()
model.train()
return metrics
# ------------------- 数据增强 (保持不变) ------------------- #
def get_training_augmentation(height, width):
return A.Compose([
A.Resize(height, width),
A.HorizontalFlip(p=0.5),
A.VerticalFlip(p=0.5),
A.RandomRotate90(p=0.5),
A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
ToTensorV2(),
])
def get_validation_augmentation(height, width):
return A.Compose([
A.Resize(height, width),
A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
ToTensorV2(),
])
# TOOL - 获取用于预测的图像预处理流程
def get_preprocessing_transform(height: int, width: int) -> A.Compose:
"""获取用于预测的图像预处理流程。"""
return A.Compose([
A.Resize(height, width),
A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
ToTensorV2(),
])
# Tool - 为一个全新的训练任务清理并重新创建输出目录。
def setup_directories(output_dir: Path) -> None:
"""
为一个全新的训练任务清理并重新创建输出目录。
Args:
output_dir (Path): 指向主输出目录的路径对象。
"""
if output_dir.exists():
logging.info(f"输出目录 '{output_dir}' 已存在。正在删除...")
shutil.rmtree(output_dir)
logging.info(f"正在 '{output_dir}' 创建新的空目录。")
output_dir.mkdir(parents=True, exist_ok=True)

View File

@@ -0,0 +1,62 @@
############## A. 创建conda环境 ##############
sudo apt install git
1.
conda create -n SMP python==3.9
conda activate SMP
2.
# pytorch安装网址https://pytorch.org/get-started/locally/
e.g. pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cu129
3.
proxychains pip install git+https://github.com/qubvel/segmentation_models.pytorch
pip show segmentation-models-pytorch
4.
python3 -m pip install -r requirements.txt
############## B. Train 训练程序 ##############
# A. 第一次训练开一个梯子【需要下载内容】
export https_proxy=http://127.0.0.1:1080 http_proxy=http://127.0.0.1:1080
CUDA_VISIBLE_DEVICES=4,5,6,7 python train.py -a Unet
{Unet,UnetPlusPlus,FPN,PSPNet,DeepLabV3,DeepLabV3Plus,Linknet,MAnet,PAN,UPerNet,Segformer,DPT}
# B. 运行单个训练程序
1. 在config.py 中修改 训练数据集CUDA_VISIBLE_DEVICES修改使用显卡-a 修改算法;
2. CUDA_VISIBLE_DEVICES=0 python ./train.py -a Unet
# C. 批量运行训练程序
1. train.sh 修改想让它使用的显卡、算法;
2. bash train.sh
3. 会生成 ./logs_parallel_DATE 的终端记录文件;结果先存储在 ../DataSet_Public_outputs/DATASET_outputs-SegModel 后自动移动到 ../Hardisk/DATASET_outputs-SegModel 中
############## C. Predict 推理程序 ##############
# A1. 预先步骤:将模型同步到 Nas_BackUp_Seg 或 ./Hardisk 文件夹中【cd .. && bash ./Back_Up.sh】
# A2. 如果需要处理自定义数据集,请保证数据仍然是 DATASET/images/[val/test] 格式
python ../Seg_Predict_Own_Video_V2/1_Save_Frame_V2.py --video ./Video_Name.mp4 --resize "1920x1080" --output_dir "../DataSet_Public/5_Predict_Video" --interval 0.5
# B1. 将最优模型文件从 ./Nas_BackUp_Seg 移动到 ./BestMode_Predict_Results_DataSet_Public 指定文件夹中
bash ./Tool_Copy_Best_Model.sh # 修改里面的路径
# B2. 如果需要处理自定义数据集,请将模型文件夹手动复制为 ./BestMode_Predict_Results_DataSet_Public/DATASET-Yolo 中
可以先对原有 ./BestMode_Predict_Results_DataSet_Public/ORI_DATASET_outputs-SegModel 改名
运行 bash ./Tool_Yolo_Copy_Best_Model.sh
在将生成的 ./BestMode_Predict_Results_DataSet_Public/ORI_DATASET_outputs-SegModel 改为 ./BestMode_Predict_Results_DataSet_Public/DATASET_outputs-SegModel
# C. 运行单个推理程序
1. 在config.py 中修改 预测数据集 及 val/test 、 模型保存路径PREDICT_ALL_BEST_MODELS_DIR/PREDICT_BEST_MODEL_DIRCUDA_VISIBLE_DEVICES修改使用显卡-a 修改算法;
2. CUDA_VISIBLE_DEVICES=0 python 1_predict.py -a Unet/UnetPlusPlus/FPN/PSPNet/DeepLabV3/DeepLabV3Plus/Linknet/MAnet/PAN/UPerNet/Segformer/DPT
# D. 批量运行推理程序
1. predict.sh 修改想让它使用的显卡、算法;
2. bash 1_predict.sh
3. 会生成 ./predict_logs_parallel_DATE 的终端记录文件;结果存储在 ../BestMode_Predict_Results_DataSet_Public/DATASET_outputs-SegModel 中
# E. 预测模型 参数量、FLOPs、FPS
CUDA_VISIBLE_DEVICES=5 python 2_predict_params_and_FLOPs_V2.py # --shape 512 512
############## D. Predict_raw_img_Check 检测推理输出图片是否齐全 ##############
# 检测 config.TEST_IMAGE_DIR 和 config.PREDICT_BEST_MODEL_DIR/****/predicted_raw_masks 中图片是否匹配
1. 在config.py 中修改 config.DATA_DIR即config.TEST_IMAGE_DIR 及 config.PREDICT_BEST_MODEL_DIRCUDA_VISIBLE_DEVICES修改使用显卡
2. python 1_predict_raw_masks_check.py
############## E. 其他辅助脚本 ##############
1. 解决--resume 不能运行的问题
您可以直接修改mmengine源文件使其torch.load可以加载完整的检查点
vim /home/wkmgc/miniconda3/envs/SMP/lib/python3.9/site-packages/mmengine/runner/checkpoint.py
转到第 347 行。您将看到以下代码:
checkpoint = torch.load(filename, map_location=map_location)
更改该行以添加weights_only=False参数
checkpoint = torch.load(filename, map_location=map_location, weights_only=False)