first commit
This commit is contained in:
564
Seg_All_In_One_SegModel/1_predict.py
Normal file
564
Seg_All_In_One_SegModel/1_predict.py
Normal file
@@ -0,0 +1,564 @@
|
||||
import logging, argparse, sys, utils
|
||||
import shutil, tempfile
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import albumentations as A
|
||||
import cv2
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import segmentation_models_pytorch as smp
|
||||
import torch
|
||||
from albumentations.pytorch import ToTensorV2
|
||||
from scipy.special import softmax
|
||||
from sklearn.metrics import (auc, average_precision_score,
|
||||
precision_recall_curve, roc_curve)
|
||||
from tqdm import tqdm
|
||||
|
||||
# 本地应用/库的导入
|
||||
import config
|
||||
from utils import log_metrics_to_csv, get_preprocessing_transform, setup_directories
|
||||
|
||||
# --- 日志设置 ---
|
||||
# 配置日志记录器
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s [%(levelname)s] %(message)s",
|
||||
handlers=[logging.StreamHandler()] # 输出日志到控制台
|
||||
)
|
||||
# 为 matplotlib 设置中文字体,请根据您的系统环境选择合适的字体
|
||||
plt.rcParams['font.sans-serif'] = ['Noto Sans CJK JP', 'SimHei', 'Microsoft YaHei'] # 例如:SimHei, Microsoft YaHei
|
||||
plt.rcParams['axes.unicode_minus'] = False # 解决负号显示问题
|
||||
|
||||
# --- 核心工具函数 ---
|
||||
|
||||
# 将单通道的类别索引掩码转换为 RGB 彩色掩码
|
||||
def colorize_mask(mask: np.ndarray, class_rgb_values: List[Tuple[int, int, int]]) -> np.ndarray:
|
||||
"""将单通道的类别索引掩码转换为 RGB 彩色掩码。"""
|
||||
color_mask = np.zeros((mask.shape[0], mask.shape[1], 3), dtype=np.uint8)
|
||||
# 一个通用的调色板,可以轻松扩展
|
||||
palette = [
|
||||
(0, 0, 0), (220, 20, 60), (0, 255, 0), (0, 0, 255),
|
||||
(255, 255, 0), (255, 0, 255), (0, 255, 255), (244, 164, 96),
|
||||
]
|
||||
for class_idx, color in enumerate(palette):
|
||||
if class_idx < len(class_rgb_values):
|
||||
# 假设掩码中的像素值 [0, 1, 2...] 对应类别的索引
|
||||
pixels_to_color = (mask == class_idx)
|
||||
color_mask[pixels_to_color] = color
|
||||
return color_mask
|
||||
|
||||
# --- 分析与可视化函数 ---
|
||||
|
||||
def generate_and_save_curves(y_true: np.ndarray, y_probs: np.ndarray, class_names: List[str], output_dir: Path) -> None:
|
||||
"""
|
||||
根据整个测试集的聚合预测结果,为每个类别生成并保存 ROC 和 PR 曲线。
|
||||
"""
|
||||
logging.info("正在生成 ROC 和 PR 曲线...")
|
||||
num_classes = len(class_names)
|
||||
plt.style.use('seaborn-v0_8-whitegrid')
|
||||
|
||||
# --- 准备画布 ---
|
||||
fig_roc, ax_roc = plt.subplots(figsize=(10, 8))
|
||||
ax_roc.plot([0, 1], [0, 1], 'k--', label='随机猜测')
|
||||
fig_pr, ax_pr = plt.subplots(figsize=(10, 8))
|
||||
|
||||
# --- 为每个类别(跳过背景)生成曲线 ---
|
||||
for i in range(1, num_classes):
|
||||
class_name = class_names[i]
|
||||
y_true_binary = (y_true == i).astype(int)
|
||||
y_class_probs = y_probs[:, i]
|
||||
|
||||
# ROC 曲线
|
||||
fpr, tpr, _ = roc_curve(y_true_binary, y_class_probs)
|
||||
roc_auc = auc(fpr, tpr)
|
||||
ax_roc.plot(fpr, tpr, label=f'类别 {class_name} (AUC = {roc_auc:.4f})')
|
||||
|
||||
# PR 曲线
|
||||
precision, recall, _ = precision_recall_curve(y_true_binary, y_class_probs)
|
||||
avg_precision = average_precision_score(y_true_binary, y_class_probs)
|
||||
ax_pr.plot(recall, precision, label=f'类别 {class_name} (AP = {avg_precision:.4f})')
|
||||
|
||||
# --- 美化并保存 ROC 曲线图 ---
|
||||
ax_roc.set(xlabel='False positive rate (FPR)', ylabel='True positive rate (TPR)', title='Receiver operating characteristic (ROC) curve')
|
||||
ax_roc.legend(loc='lower right')
|
||||
ax_roc.set_aspect('equal', adjustable='box')
|
||||
fig_roc.tight_layout()
|
||||
roc_save_path = output_dir / "ROC_Curves.png"
|
||||
fig_roc.savefig(roc_save_path)
|
||||
plt.close(fig_roc)
|
||||
logging.info(f"-> ROC 曲线图已保存至: {roc_save_path}")
|
||||
|
||||
# --- 美化并保存 PR 曲线图 ---
|
||||
ax_pr.set(xlabel='Recall', ylabel='Precision', title='Precision-Recall (PR) Curve')
|
||||
ax_pr.legend(loc='lower left')
|
||||
ax_pr.set_aspect('equal', adjustable='box')
|
||||
fig_pr.tight_layout()
|
||||
pr_save_path = output_dir / "Precision_Recall_Curves.png"
|
||||
fig_pr.savefig(pr_save_path)
|
||||
plt.close(fig_pr)
|
||||
logging.info(f"-> PR 曲线图已保存至: {pr_save_path}")
|
||||
|
||||
def save_visual_comparison(
|
||||
image_name: str, original_image: np.ndarray, pred_mask: np.ndarray,
|
||||
gt_mask: np.ndarray = None, stats: Dict = None, stats_text: str = "",
|
||||
save_dir: str = ""
|
||||
) -> None:
|
||||
"""保存原始图像、预测掩码以及可选的真实掩码的对比图。"""
|
||||
pred_color_mask = colorize_mask(pred_mask, config.CLASS_RGB_VALUES)
|
||||
|
||||
if gt_mask is not None and stats is not None:
|
||||
num_plots = 3
|
||||
figsize = (22, 8)
|
||||
gt_color_mask = colorize_mask(gt_mask, config.CLASS_RGB_VALUES)
|
||||
save_name = f"{Path(image_name).stem}-IOU_{stats['iou']:.4f}-ACC_{stats['acc']:.4f}.png"
|
||||
else:
|
||||
num_plots = 2
|
||||
figsize = (14, 6)
|
||||
save_name = f"{Path(image_name).stem}.png"
|
||||
|
||||
fig, axes = plt.subplots(1, num_plots, figsize=figsize)
|
||||
fig.suptitle(f"Predict: {image_name}", fontsize=16)
|
||||
|
||||
axes[0].imshow(original_image)
|
||||
axes[0].set_title("ORI_IMG")
|
||||
axes[0].axis('off')
|
||||
|
||||
pred_title = "Predicted Mask"
|
||||
if stats:
|
||||
pred_title += f"\nIoU: {stats['iou']:.4f} | Acc: {stats['acc']:.4f}"
|
||||
axes[1].imshow(pred_color_mask)
|
||||
axes[1].set_title(pred_title)
|
||||
axes[1].axis('off')
|
||||
|
||||
if num_plots == 3:
|
||||
axes[2].imshow(gt_color_mask)
|
||||
axes[2].set_title("Ground Truth")
|
||||
axes[2].axis('off')
|
||||
plt.figtext(0.5, 0.01, stats_text, ha="center", fontsize=10,
|
||||
bbox={"facecolor":"white", "alpha":0.7, "pad":5}, family='monospace')
|
||||
fig.subplots_adjust(bottom=0.3)
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig(save_dir / save_name)
|
||||
plt.close(fig)
|
||||
|
||||
# --- 主要逻辑函数 ---
|
||||
|
||||
def load_model(model_architecture: str, model_save_path: Path) -> torch.nn.Module:
|
||||
"""
|
||||
加载已训练的分割模型。
|
||||
此函数经过优化,可以智能处理由 torch.nn.DataParallel 保存的
|
||||
(带有 'module.' 前缀) 和常规保存的模型权重。
|
||||
"""
|
||||
logging.info(f"正在从以下路径加载模型: {model_save_path}")
|
||||
num_classes = len(config.CLASSES)
|
||||
|
||||
# 1. 【修改】根据 config 动态创建基础模型架构
|
||||
logging.info(f"正在重建模型架构: '{model_architecture}'")
|
||||
try:
|
||||
model_params = config.ALL_MODEL_CONFIGS[model_architecture].copy() #
|
||||
model_class = getattr(smp, model_architecture)
|
||||
except AttributeError:
|
||||
logging.error(f"模型 '{model_architecture}' 在 segmentation_models_pytorch 库中不存在!")
|
||||
raise
|
||||
|
||||
# 2. 【修改】准备参数,与训练时保持一致
|
||||
# 注意:加载权重时,我们不希望再次下载预训练权重,所以设为 None
|
||||
params = model_params
|
||||
params['in_channels'] = 3
|
||||
params['classes'] = num_classes
|
||||
params['encoder_weights'] = None # 非常重要!避免在预测时重新下载ImageNet权重
|
||||
|
||||
# ======================== 【新增代码段开始】 ======================== #
|
||||
# 自动检测 encoder_name 是否包含 'vit'
|
||||
encoder_name = model_params.get('encoder_name', '')
|
||||
if 'vit' in encoder_name.lower():
|
||||
params['dynamic_img_size'] = True
|
||||
logging.info(f"检测到 ViT 编码器 ('{encoder_name}')。自动设置 dynamic_img_size=True。")
|
||||
# ======================== 【新增代码段结束】 ======================== #
|
||||
|
||||
# 3. 【修改】使用 **kwargs 创建模型实例
|
||||
model = model_class(**params).to(config.DEVICE)
|
||||
|
||||
# 4. 从文件加载 state_dict
|
||||
state_dict = torch.load(model_save_path, map_location=torch.device(config.DEVICE))
|
||||
|
||||
# 3. 检查是否存在 'module.' 前缀 (判断是否为 DataParallel 保存的权重)
|
||||
# 我们通过检查字典中任何一个键是否以 'module.' 开头来判断
|
||||
is_dataparallel = any(key.startswith('module.') for key in state_dict.keys())
|
||||
|
||||
if is_dataparallel:
|
||||
logging.info("检测到模型由 DataParallel 保存,将移除 'module.' 前缀。")
|
||||
# 创建一个新的、不带前缀的 state_dict
|
||||
new_state_dict = {key.replace('module.', ''): value for key, value in state_dict.items()}
|
||||
model.load_state_dict(new_state_dict)
|
||||
else:
|
||||
# 如果没有 'module.' 前缀,直接加载
|
||||
logging.info("直接加载标准模型权重。")
|
||||
model.load_state_dict(state_dict)
|
||||
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
def calculate_and_save_final_metrics(stats_list: List[Dict], num_classes: int, save_dir: str) -> None:
|
||||
"""聚合每张图像的统计数据,计算总体指标,并保存到 CSV 文件。"""
|
||||
if not stats_list:
|
||||
logging.warning("未收集到统计数据,跳过最终指标计算。")
|
||||
return
|
||||
|
||||
logging.info("\n正在聚合结果以计算最终指标...")
|
||||
# 堆叠所有统计张量并按批次维度求和
|
||||
tp = torch.cat([s['tp'] for s in stats_list], dim=0).sum(dim=0) # (N×B, C) -> (C,)
|
||||
fp = torch.cat([s['fp'] for s in stats_list], dim=0).sum(dim=0) # (N×B, C) -> (C,)
|
||||
fn = torch.cat([s['fn'] for s in stats_list], dim=0).sum(dim=0) # (N×B, C) -> (C,)
|
||||
tn = torch.cat([s['tn'] for s in stats_list], dim=0).sum(dim=0) # (N×B, C) -> (C,)
|
||||
|
||||
# ======================== 【新增代码段开始】 ======================== #
|
||||
# 根据 config.py 中的设置,筛选出用于计算宏观指标的类别
|
||||
ignore_indices = [config.CLASSES.index(cls) for cls in config.EVALUATION_CLASSES_TO_IGNORE if cls in config.CLASSES]
|
||||
|
||||
keep_mask = torch.ones(num_classes, dtype=torch.bool, device=tp.device)
|
||||
if ignore_indices:
|
||||
keep_mask[ignore_indices] = False
|
||||
logging.info(f"最终评估时将忽略类别: {config.EVALUATION_CLASSES_TO_IGNORE}")
|
||||
|
||||
tp_filtered = tp[keep_mask]
|
||||
fp_filtered = fp[keep_mask]
|
||||
fn_filtered = fn[keep_mask]
|
||||
tn_filtered = tn[keep_mask]
|
||||
# ======================== 【新增代码段结束】 ======================== #
|
||||
|
||||
|
||||
# ======================== 【修改代码段开始】 ======================== #
|
||||
# 使用过滤后的统计数据计算总体指标
|
||||
metrics = {
|
||||
"iou_score": smp.metrics.iou_score(tp_filtered, fp_filtered, fn_filtered, tn_filtered, reduction='micro').item(),
|
||||
"f1_score": smp.metrics.f1_score(tp_filtered, fp_filtered, fn_filtered, tn_filtered, reduction='micro').item(),
|
||||
"accuracy": smp.metrics.accuracy(tp_filtered, fp_filtered, fn_filtered, tn_filtered, reduction='micro').item(),
|
||||
"recall": smp.metrics.recall(tp_filtered, fp_filtered, fn_filtered, tn_filtered, reduction='micro').item(),
|
||||
"precision": smp.metrics.precision(tp_filtered, fp_filtered, fn_filtered, tn_filtered, reduction='micro').item(),
|
||||
}
|
||||
# ======================== 【修改代码段结束】 ======================== #
|
||||
|
||||
# 添加每个类别的原始统计数据(使用未经过滤的数据)
|
||||
for i in range(num_classes):
|
||||
metrics[f'tp_class_{i}'] = tp[i].item()
|
||||
metrics[f'tn_class_{i}'] = tn[i].item()
|
||||
metrics[f'fp_class_{i}'] = fp[i].item()
|
||||
metrics[f'fn_class_{i}'] = fn[i].item()
|
||||
|
||||
csv_save_path = save_dir / "test_set_metrics.csv"
|
||||
log_metrics_to_csv(metrics, str(csv_save_path))
|
||||
logging.info(f"-> 整体测试集指标已保存至: {csv_save_path}")
|
||||
|
||||
# 2. 【新增】一个帮助函数,用于交互式选择模型目录
|
||||
def select_trained_model_path(base_dir: Path, model_architecture: str) -> Path:
|
||||
"""
|
||||
查找指定架构的所有训练运行目录,并让用户选择一个。
|
||||
"""
|
||||
logging.info(f"正在 '{base_dir}' 中搜索模型 '{model_architecture}' 的训练记录...")
|
||||
|
||||
# 查找所有以架构名开头的文件夹
|
||||
run_dirs = sorted([d for d in base_dir.iterdir() if d.is_dir() and d.name.startswith(model_architecture+'_')])
|
||||
|
||||
if not run_dirs:
|
||||
logging.error(f"未找到任何 '{model_architecture}' 模型的训练记录。请先运行 train.py。")
|
||||
return None
|
||||
|
||||
print("\n请选择要用于预测的模型:")
|
||||
for i, dir_path in enumerate(run_dirs):
|
||||
print(f" [{i+1}] {dir_path.name}")
|
||||
|
||||
while True:
|
||||
try:
|
||||
choice = input(f"请输入选项编号 (1-{len(run_dirs)}) 或按 Enter 退出: ")
|
||||
if not choice:
|
||||
return None
|
||||
choice_idx = int(choice) - 1
|
||||
if 0 <= choice_idx < len(run_dirs):
|
||||
selected_dir = run_dirs[choice_idx]
|
||||
logging.info(f"已选择模型: {selected_dir}")
|
||||
return selected_dir
|
||||
else:
|
||||
print("无效的选项,请重新输入。")
|
||||
except (ValueError, IndexError):
|
||||
print("无效的输入,请输入一个数字。")
|
||||
|
||||
# 【修改】函数签名和内部逻辑,不再接收一个列表,而是直接接收聚合后的统计数据
|
||||
def calculate_and_save_final_metrics(tp, fp, fn, tn, num_classes: int, save_dir: Path) -> None:
|
||||
"""根据聚合后的统计数据,计算总体指标,并保存到 CSV 文件。"""
|
||||
logging.info("\n正在计算最终的整体指标...")
|
||||
|
||||
# 根据 config.py 中的设置,筛选出用于计算宏观指标的类别
|
||||
ignore_indices = [config.CLASSES.index(cls) for cls in config.EVALUATION_CLASSES_TO_IGNORE if cls in config.CLASSES]
|
||||
|
||||
keep_mask = torch.ones(num_classes, dtype=torch.bool, device=tp.device)
|
||||
if ignore_indices:
|
||||
keep_mask[ignore_indices] = False
|
||||
logging.info(f"最终评估时将忽略类别: {config.EVALUATION_CLASSES_TO_IGNORE}")
|
||||
|
||||
tp_filtered = tp[keep_mask]
|
||||
fp_filtered = fp[keep_mask]
|
||||
fn_filtered = fn[keep_mask]
|
||||
tn_filtered = tn[keep_mask]
|
||||
|
||||
# 使用过滤后的统计数据计算总体指标
|
||||
metrics = {
|
||||
"iou_score": smp.metrics.iou_score(tp_filtered, fp_filtered, fn_filtered, tn_filtered, reduction='micro').item(),
|
||||
"f1_score": smp.metrics.f1_score(tp_filtered, fp_filtered, fn_filtered, tn_filtered, reduction='micro').item(),
|
||||
"accuracy": smp.metrics.accuracy(tp_filtered, fp_filtered, fn_filtered, tn_filtered, reduction='micro').item(),
|
||||
"recall": smp.metrics.recall(tp_filtered, fp_filtered, fn_filtered, tn_filtered, reduction='micro').item(),
|
||||
"precision": smp.metrics.precision(tp_filtered, fp_filtered, fn_filtered, tn_filtered, reduction='micro').item(),
|
||||
}
|
||||
|
||||
# 添加每个类别的原始统计数据(使用未经过滤的数据)
|
||||
for i in range(num_classes):
|
||||
metrics[f'tp_class_{i}'] = tp[i].item()
|
||||
metrics[f'tn_class_{i}'] = tn[i].item()
|
||||
metrics[f'fp_class_{i}'] = fp[i].item()
|
||||
metrics[f'fn_class_{i}'] = fn[i].item()
|
||||
|
||||
csv_save_path = save_dir / "test_set_metrics.csv"
|
||||
log_metrics_to_csv(metrics, str(csv_save_path))
|
||||
logging.info(f"-> 整体测试集指标已保存至: {csv_save_path}")
|
||||
|
||||
def main(model_architecture: str) -> None:
|
||||
"""
|
||||
主函数,执行模型加载、图像预测、评估,并生成分析结果。
|
||||
"""
|
||||
# --- 1. 【修改】选择模型并动态设置路径 ---
|
||||
selected_run_dir = select_trained_model_path(Path(config.PREDICT_BEST_MODEL_DIR), model_architecture)
|
||||
|
||||
if not selected_run_dir:
|
||||
logging.info("未选择任何模型,程序退出。")
|
||||
sys.exit()
|
||||
|
||||
# 【重要】动态覆写 config 中的路径,使其指向所选模型的目录
|
||||
# 这样,后续所有代码都会自动在正确的文件夹中加载模型和保存结果
|
||||
best_model_save_path = selected_run_dir / "best_model.pth"
|
||||
raw_mask_dir = selected_run_dir / "predicted_raw_masks"
|
||||
analysis_results_dir = selected_run_dir / "prediction_analysis"
|
||||
|
||||
if not best_model_save_path.exists():
|
||||
logging.error(f"错误: 在 '{selected_run_dir}' 中未找到 'best_model.pth' 文件!")
|
||||
sys.exit()
|
||||
|
||||
# setup_directories 函数现在用于创建预测结果的子目录
|
||||
utils.setup_directories(raw_mask_dir)
|
||||
utils.setup_directories(analysis_results_dir)
|
||||
|
||||
model = load_model(model_architecture, best_model_save_path) # 1. 加载模型
|
||||
|
||||
# ======================== 【新增/修改的代码段开始】 ======================== #
|
||||
# --- 2. 探测模型以确定正确的图像尺寸 ---
|
||||
logging.info("正在探测模型以确定预测时所需的输入尺寸...") #
|
||||
target_height, target_width = config.IMAGE_HEIGHT, config.IMAGE_WIDTH #
|
||||
|
||||
try:
|
||||
# 检查编码器是否需要固定输入尺寸
|
||||
if model.encoder.is_fixed_input_size:
|
||||
required_size = model.encoder.input_size
|
||||
target_height = required_size[1]
|
||||
target_width = required_size[2]
|
||||
logging.warning(
|
||||
f"模型 '{model_architecture}' 的编码器需要固定的输入尺寸 "
|
||||
f"({target_height}x{target_width})。"
|
||||
f"将忽略 config.py 中的尺寸设置进行预测。"
|
||||
)
|
||||
else:
|
||||
logging.info(f"模型 '{model_architecture}' 支持动态输入尺寸。将使用 config.py 中定义的尺寸 ({target_height}x{target_width})。")
|
||||
except AttributeError:
|
||||
logging.info("无法自动检测模型尺寸要求。将使用 config.py 中定义的尺寸。")
|
||||
|
||||
# --- 3. 获取用于预测的图像预处理流程 ---
|
||||
# 使用探测到的或配置中指定的目标尺寸
|
||||
transform = get_preprocessing_transform(target_height, target_width) #
|
||||
# ======================== 【新增/修改的代码段结束】 ======================== #
|
||||
|
||||
# --- 4. 获取图像列表并判断模式 ---
|
||||
image_filenames = sorted(list(config.TEST_IMAGE_DIR.iterdir()))
|
||||
evaluate_mode = config.TEST_MASK_DIR.is_dir()
|
||||
|
||||
if evaluate_mode:
|
||||
logging.info("正在以 [评估模式] 运行。将使用真实掩码进行评估。")
|
||||
else:
|
||||
logging.info("正在以 [仅预测模式] 运行。未找到真实掩码。")
|
||||
|
||||
# --- 3. 初始化用于增量计算的变量 ---
|
||||
# 【修改】不再使用 list 来累积,而是直接累加 TP, FP, FN, TN
|
||||
num_classes = len(config.CLASSES)
|
||||
total_tp = torch.zeros(num_classes, dtype=torch.long, device=config.DEVICE)
|
||||
total_fp = torch.zeros(num_classes, dtype=torch.long, device=config.DEVICE)
|
||||
total_fn = torch.zeros(num_classes, dtype=torch.long, device=config.DEVICE)
|
||||
total_tn = torch.zeros(num_classes, dtype=torch.long, device=config.DEVICE)
|
||||
|
||||
# 【新增】为 ROC/PR 曲线创建临时文件,用磁盘空间换内存
|
||||
# 使用 tempfile 模块来安全地创建临时文件
|
||||
temp_gt_file = tempfile.NamedTemporaryFile(delete=False, suffix='.npy')
|
||||
temp_probs_file = tempfile.NamedTemporaryFile(delete=False, suffix='.npy')
|
||||
logging.info(f"为ROC/PR曲线创建临时文件: {temp_gt_file.name}, {temp_probs_file.name}")
|
||||
|
||||
# --- 4. 主推理循环 - 预测图片 ---
|
||||
for image_path in tqdm(image_filenames, desc="正在处理测试图像"):
|
||||
# --- 4.1. 读取图片并预处理 ---
|
||||
original_image = cv2.cvtColor(cv2.imread(str(image_path)), cv2.COLOR_BGR2RGB)
|
||||
original_h, original_w = original_image.shape[:2]
|
||||
image_tensor = transform(image=original_image)['image'].unsqueeze(0).to(config.DEVICE)
|
||||
|
||||
# --- 4.2. 图片推理并保存 ---
|
||||
with torch.no_grad():
|
||||
pred_logits_tensor = model(image_tensor)
|
||||
pred_mask_tensor = torch.argmax(pred_logits_tensor, dim=1)
|
||||
pred_mask_numpy = pred_mask_tensor.squeeze().cpu().numpy().astype(np.uint8)
|
||||
pred_mask_resized_raw = cv2.resize(pred_mask_numpy, (original_w, original_h), interpolation=cv2.INTER_NEAREST)
|
||||
|
||||
# 保存原始(单通道)的预测掩码
|
||||
cv2.imwrite(str(raw_mask_dir / f"{image_path.stem}.png"), pred_mask_resized_raw)
|
||||
|
||||
# --- 4.3. 评估与可视化 ---
|
||||
gt_mask_numpy, per_image_stats, stats_text_display, gt_mask_raw = None, None, "", None
|
||||
if evaluate_mode:
|
||||
mask_path = config.TEST_MASK_DIR / image_path.name
|
||||
if mask_path.exists():
|
||||
gt_mask_raw = cv2.imread(str(mask_path), cv2.IMREAD_GRAYSCALE)
|
||||
gt_mask_numpy = cv2.resize(
|
||||
gt_mask_raw, (target_width, target_height),
|
||||
interpolation=cv2.INTER_NEAREST
|
||||
)
|
||||
gt_mask_tensor = torch.from_numpy(gt_mask_numpy).long().to(config.DEVICE).unsqueeze(0)
|
||||
|
||||
# ======================== 【修复代码段开始】 ======================== #
|
||||
|
||||
# --- 1. [修复] 计算单张图像的统计数据 ---
|
||||
# 这个代码块是必需的,它为下面的 per_image_stats 和 stats_text_display 定义了 tp, fp, fn, tn
|
||||
tp, fp, fn, tn = smp.metrics.get_stats(
|
||||
pred_mask_tensor, gt_mask_tensor,
|
||||
mode=config.SEG_MODE, num_classes=num_classes
|
||||
)
|
||||
|
||||
# --- 2. 累加到总数,用于最终的整体评估 ---
|
||||
total_tp += tp.squeeze(0).to(config.DEVICE)
|
||||
total_fp += fp.squeeze(0).to(config.DEVICE)
|
||||
total_fn += fn.squeeze(0).to(config.DEVICE)
|
||||
total_tn += tn.squeeze(0).to(config.DEVICE)
|
||||
|
||||
# --- 3. 将曲线数据写入临时文件 (这部分逻辑正确,无需修改) ---
|
||||
probs_tensor = torch.softmax(pred_logits_tensor, dim=1)
|
||||
gt_flat = gt_mask_tensor.cpu().numpy().flatten()
|
||||
probs_flat = probs_tensor.permute(0, 2, 3, 1).reshape(-1, num_classes).cpu().numpy()
|
||||
with open(temp_gt_file.name, 'ab') as f_gt, open(temp_probs_file.name, 'ab') as f_probs:
|
||||
np.save(f_gt, gt_flat)
|
||||
np.save(f_probs, probs_flat)
|
||||
|
||||
# --- 4. [修复] 现在可以安全地计算单图指标用于可视化 ---
|
||||
# 因为上面的代码块已经定义了 tp, fp, fn, tn,所以这里不再报错
|
||||
per_image_stats = {
|
||||
'iou': smp.metrics.iou_score(tp, fp, fn, tn, reduction='micro').item(),
|
||||
'acc': smp.metrics.accuracy(tp, fp, fn, tn, reduction='micro').item()
|
||||
}
|
||||
|
||||
# --- 5. [修复] 修正生成统计文本时的 Bug ---
|
||||
# `get_stats` 返回的 tp 张量形状是 (1, num_classes),
|
||||
# 因此我们应该遍历第1维度 (类别),并使用 [0, i] 进行索引
|
||||
stats_text_display = "Each Class TP / FP / FN / TN:\n" + "-"*35 + "\n"
|
||||
for i in range(1, tp.shape[1]): # 修正: 遍历类别维度 tp.shape[1]
|
||||
if i < len(config.CLASSES):
|
||||
class_name = config.CLASSES[i]
|
||||
stats_text_display += (
|
||||
f" {class_name:<10}: "
|
||||
# 修正: 使用正确的索引 tp[0, i]
|
||||
f"{tp[0, i]:>5} / {fp[0, i]:>5} / {fn[0, i]:>5} / {tn[0, i]:>7}\n"
|
||||
)
|
||||
|
||||
save_visual_comparison(
|
||||
image_name=image_path.name,
|
||||
original_image=original_image,
|
||||
pred_mask=pred_mask_resized_raw,
|
||||
gt_mask=gt_mask_raw,
|
||||
stats=per_image_stats,
|
||||
stats_text=stats_text_display,
|
||||
save_dir = analysis_results_dir
|
||||
)
|
||||
|
||||
logging.info("\n预测流程已完成!")
|
||||
logging.info(f"原始预测掩码已保存至: {raw_mask_dir}")
|
||||
logging.info(f"分析图像已保存至: {analysis_results_dir}")
|
||||
|
||||
# --- 5. 最终分析(循环结束后) ---
|
||||
if evaluate_mode:
|
||||
# calculate_and_save_final_metrics(all_stats_for_metrics, len(config.CLASSES), save_dir = analysis_results_dir)
|
||||
calculate_and_save_final_metrics(total_tp.cpu(), total_fp.cpu(), total_fn.cpu(), total_tn.cpu(), num_classes, save_dir=analysis_results_dir)
|
||||
|
||||
# if all_gt_for_curves:
|
||||
# logging.info("\n正在聚合结果以生成 ROC 和 PR 曲线...")
|
||||
# full_gt = torch.cat(all_gt_for_curves).numpy().flatten()
|
||||
# full_logits = torch.cat(all_logits_for_curves)
|
||||
|
||||
# # 调整维度并计算概率
|
||||
# num_classes = len(config.CLASSES)
|
||||
# full_probs = softmax(
|
||||
# full_logits.permute(0, 2, 3, 1).reshape(-1, num_classes).numpy(),
|
||||
# axis=1
|
||||
# )
|
||||
|
||||
# generate_and_save_curves(
|
||||
# y_true=full_gt, y_probs=full_probs,
|
||||
# class_names=config.CLASSES, output_dir=analysis_results_dir
|
||||
# )
|
||||
|
||||
# 【新增】从临时文件中分块读取数据来生成曲线
|
||||
try:
|
||||
logging.info("\n正在从临时文件加载数据以生成 ROC 和 PR 曲线...")
|
||||
temp_gt_file.close()
|
||||
temp_probs_file.close()
|
||||
|
||||
all_gt_parts, all_probs_parts = [], []
|
||||
with open(temp_gt_file.name, 'rb') as f_gt, open(temp_probs_file.name, 'rb') as f_probs:
|
||||
while True:
|
||||
try:
|
||||
all_gt_parts.append(np.load(f_gt))
|
||||
all_probs_parts.append(np.load(f_probs))
|
||||
# 【修复】将 EOFError 添加到要捕获的异常列表中
|
||||
except (IOError, ValueError, EOFError):
|
||||
break # 文件读取完毕
|
||||
|
||||
if all_gt_parts:
|
||||
full_gt = np.concatenate(all_gt_parts)
|
||||
full_probs = np.concatenate(all_probs_parts)
|
||||
|
||||
generate_and_save_curves(
|
||||
y_true=full_gt, y_probs=full_probs,
|
||||
class_names=config.CLASSES, output_dir=analysis_results_dir
|
||||
)
|
||||
else:
|
||||
logging.warning("临时文件中没有数据,跳过 ROC/PR 曲线生成。")
|
||||
|
||||
finally:
|
||||
# 【新增】确保在最后删除临时文件
|
||||
import os
|
||||
os.remove(temp_gt_file.name)
|
||||
os.remove(temp_probs_file.name)
|
||||
logging.info("已清理临时文件。")
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 创建一个参数解析器
|
||||
parser = argparse.ArgumentParser(description="训练图像分割模型。")
|
||||
|
||||
# 添加 --architecture 参数
|
||||
# - `choices` 会自动从你的 config 文件中获取所有可用的模型名称
|
||||
# - `default` 设置了在不提供参数时的默认模型
|
||||
parser.add_argument(
|
||||
"-a", "--architecture",
|
||||
type=str,
|
||||
# default='Unet',
|
||||
choices=list(config.ALL_MODEL_CONFIGS.keys()),
|
||||
required=True,
|
||||
help="选择要训练的模型架构。"
|
||||
)
|
||||
|
||||
# 解析命令行传入的参数
|
||||
args = parser.parse_args()
|
||||
|
||||
# 将解析出的架构名称传入 main 函数并执行
|
||||
main(model_architecture=args.architecture)
|
||||
137
Seg_All_In_One_SegModel/1_predict_raw_masks_check.py
Normal file
137
Seg_All_In_One_SegModel/1_predict_raw_masks_check.py
Normal file
@@ -0,0 +1,137 @@
|
||||
# predict_check.py
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Set, Tuple
|
||||
|
||||
try:
|
||||
import config
|
||||
except ImportError:
|
||||
print("错误:无法导入 'config.py'。请确保此脚本与 config.py 在同一目录下。")
|
||||
sys.exit(1)
|
||||
|
||||
def get_file_info(directory: Path) -> Tuple[Set[str], Set[str]]:
|
||||
"""
|
||||
安全地获取目录中所有文件的基本名(stem)和后缀(suffix)。
|
||||
|
||||
返回:
|
||||
一个包含 stems 集合和 suffixes 集合的元组。
|
||||
"""
|
||||
if not directory.is_dir():
|
||||
return set(), set()
|
||||
|
||||
stems = set()
|
||||
suffixes = set()
|
||||
for item in directory.iterdir():
|
||||
if item.is_file():
|
||||
stems.add(item.stem)
|
||||
suffixes.add(item.suffix)
|
||||
return stems, suffixes
|
||||
|
||||
def main():
|
||||
"""
|
||||
主函数,用于详细检查源测试数据集和所有预测输出目录之间的
|
||||
文件数量、文件名和后缀一致性。
|
||||
"""
|
||||
print("--- 开始详细检查预测输出 ---")
|
||||
|
||||
# 1. 从 config.py 获取源目录并解析文件信息
|
||||
source_dir = config.TEST_IMAGE_DIR
|
||||
if not source_dir.exists() or not source_dir.is_dir():
|
||||
print(f"错误:在 '{source_dir}' 找不到源测试图片目录。")
|
||||
sys.exit(1)
|
||||
|
||||
source_stems, _ = get_file_info(source_dir)
|
||||
source_file_count = len(source_stems)
|
||||
|
||||
if source_file_count == 0:
|
||||
print(f"警告:源目录 '{source_dir}' 为空或不包含任何文件。检查中止。")
|
||||
sys.exit(0)
|
||||
|
||||
print(f"源数据集: 在 '{source_dir}' 中找到 {source_file_count} 张图片。")
|
||||
print("-" * 50)
|
||||
|
||||
# 2. 从 config.py 获取预测结果的基础目录
|
||||
predictions_base_dir = Path(config.PREDICT_BEST_MODEL_DIR)
|
||||
if not predictions_base_dir.exists() or not predictions_base_dir.is_dir():
|
||||
print(f"错误:在 '{predictions_base_dir}' 找不到预测结果的基础目录。")
|
||||
sys.exit(1)
|
||||
|
||||
# 3. 遍历每个模型的运行目录并进行详细检查
|
||||
found_predictions = False
|
||||
global_mismatch_found = False
|
||||
|
||||
run_dirs = sorted([d for d in predictions_base_dir.iterdir() if d.is_dir()])
|
||||
|
||||
if not run_dirs:
|
||||
print("信息:在基础目录中没有找到任何模型的运行记录文件夹。")
|
||||
sys.exit(0)
|
||||
|
||||
for run_dir in run_dirs:
|
||||
target_dir = run_dir / "predicted_raw_masks"
|
||||
|
||||
# 如果目标目录不存在,则跳过
|
||||
if not (target_dir.exists() and target_dir.is_dir()):
|
||||
continue
|
||||
|
||||
print(f"正在检查: '{target_dir}'...")
|
||||
found_predictions = True
|
||||
is_current_dir_ok = True
|
||||
|
||||
predicted_stems, predicted_suffixes = get_file_info(target_dir)
|
||||
predicted_file_count = len(predicted_stems)
|
||||
|
||||
# --- 检查 1: 文件数量 ---
|
||||
if source_file_count != predicted_file_count:
|
||||
print(f" [✗ 数量不匹配] 源目录有 {source_file_count} 个文件,但此目录有 {predicted_file_count} 个。")
|
||||
is_current_dir_ok = False
|
||||
|
||||
# --- 检查 2: 文件名 ---
|
||||
missing_files = source_stems - predicted_stems
|
||||
if missing_files:
|
||||
# 只显示最多3个示例文件名以保持简洁
|
||||
examples = list(missing_files)[:3]
|
||||
print(f" [✗ 文件名缺失] 预测结果中缺少 {len(missing_files)} 个文件。例如: {examples}...")
|
||||
is_current_dir_ok = False
|
||||
|
||||
extra_files = predicted_stems - source_stems
|
||||
if extra_files:
|
||||
examples = list(extra_files)[:3]
|
||||
print(f" [✗ 文件名多余] 预测结果中多出 {len(extra_files)} 个文件。例如: {examples}...")
|
||||
is_current_dir_ok = False
|
||||
|
||||
# --- 检查 3: 后缀一致性 ---
|
||||
# 只有在目录不为空的情况下,后缀多于1个才算问题
|
||||
if predicted_file_count > 0 and len(predicted_suffixes) > 1:
|
||||
print(f" [✗ 后缀不一致] 预测目录中存在多种文件后缀: {sorted(list(predicted_suffixes))}")
|
||||
is_current_dir_ok = False
|
||||
|
||||
# --- 当前目录的检查总结 ---
|
||||
if is_current_dir_ok:
|
||||
# 再次检查目录为空的特殊情况
|
||||
if predicted_file_count == 0 and source_file_count > 0:
|
||||
print(" [✗ 目录为空] 预测目录是空的,但源目录包含文件。")
|
||||
global_mismatch_found = True
|
||||
else:
|
||||
suffix_info = list(predicted_suffixes)[0] if len(predicted_suffixes) == 1 else 'N/A'
|
||||
print(f" [✓ OK] 所有检查通过 (数量: {predicted_file_count}, 后缀: {suffix_info})")
|
||||
else:
|
||||
global_mismatch_found = True
|
||||
|
||||
print("-" * 25)
|
||||
|
||||
# 4. 输出最终的全局检查摘要
|
||||
print("\n--- 检查摘要 ---")
|
||||
if not found_predictions:
|
||||
print("结果: 没有找到任何 'predicted_raw_masks' 目录进行检查。")
|
||||
elif global_mismatch_found:
|
||||
print("结论: 检查不通过。发现至少一个预测目录未通过所有检查,请查看上方日志。")
|
||||
sys.exit(1)
|
||||
else:
|
||||
print("结论: 检查通过!所有已找到的预测目录均通过了数量、文件名和后缀一致性检查。")
|
||||
|
||||
print("--- 检查完成 ---")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
261
Seg_All_In_One_SegModel/2_predict_params_and_FLOPs_V1.py
Normal file
261
Seg_All_In_One_SegModel/2_predict_params_and_FLOPs_V1.py
Normal file
@@ -0,0 +1,261 @@
|
||||
import os
|
||||
import glob
|
||||
import logging
|
||||
import argparse
|
||||
import csv
|
||||
from typing import Dict, Optional, Tuple, List
|
||||
|
||||
import torch
|
||||
from fvcore.nn import FlopCountAnalysis
|
||||
|
||||
from train import initialize_components
|
||||
import config # Assumes config.py is in the same directory or accessible
|
||||
|
||||
import time
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
import sys
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
|
||||
# TODO 使用方法 TODO
|
||||
# 交互模式(推荐)
|
||||
# python 2_predict_params_and_FLOPs.py
|
||||
# 非交互模式(用于脚本)
|
||||
# python 2_predict_params_and_FLOPs.py --input_dir '../BestMode_Predict_Results_DataSet_Public'
|
||||
# python 2_predict_params_and_FLOPs.py --shape 512 512 # --shape 1080 1920
|
||||
|
||||
def get_flops_and_params_smp(architecture: str, shape: Tuple[int, int]) -> Optional[Dict[str, str]]:
|
||||
"""
|
||||
Calculates FLOPs and parameters for a given SMP model architecture using fvcore.
|
||||
|
||||
Args:
|
||||
architecture (str): The name of the model architecture.
|
||||
shape (Tuple[int, int]): The input image shape (H, W).
|
||||
|
||||
Returns:
|
||||
Optional[Dict[str, str]]: A dictionary with 'params' and 'flops', or None on failure.
|
||||
"""
|
||||
logging.info(f"Analyzing model: '{architecture}' with input shape {shape}...")
|
||||
|
||||
num_classes = len(config.CLASSES)
|
||||
model, _, _, _ = initialize_components(architecture, num_classes, config.SEG_MODE)
|
||||
|
||||
if model is None:
|
||||
return None
|
||||
|
||||
model.to(config.DEVICE)
|
||||
model.eval()
|
||||
|
||||
dummy_input = torch.randn(1, 3, shape[0], shape[1]).to(config.DEVICE)
|
||||
|
||||
try:
|
||||
# Calculate FLOPs using fvcore
|
||||
flops_analyzer = FlopCountAnalysis(model, dummy_input)
|
||||
total_flops = flops_analyzer.total()
|
||||
gflops = total_flops / 1e9
|
||||
|
||||
# Calculate Parameters
|
||||
total_params = sum(p.numel() for p in model.parameters())
|
||||
m_params = total_params / 1e6
|
||||
|
||||
logging.info(f"✅ Analysis successful for '{architecture}': {gflops:.2f} GFLOPs, {m_params:.2f} M Params")
|
||||
return {
|
||||
'flops': f"{gflops:.2f} G",
|
||||
'params': f"{m_params:.2f} M"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"❌ Failed to analyze FLOPs for '{architecture}'. Error: {e}")
|
||||
return None
|
||||
|
||||
def get_shape_from_path(path: str) -> Optional[Tuple[int, int]]:
|
||||
"""
|
||||
Extracts resolution (WxH) from a directory path using regex.
|
||||
|
||||
Args:
|
||||
path (str): The path to the dataset directory.
|
||||
|
||||
Returns:
|
||||
Optional[Tuple[int, int]]: A tuple of (Height, Width), or None if not found.
|
||||
"""
|
||||
# V1. 采用图片本身的尺寸作为输入
|
||||
# import re
|
||||
# match = re.search(r'(\d+)x(\d+)', os.path.basename(path))
|
||||
# if match:
|
||||
# width, height = int(match.group(1)), int(match.group(2))
|
||||
# return (height, width) # Return as H, W
|
||||
# return None
|
||||
|
||||
# V2. 采用文件夹config中的尺寸作为输入
|
||||
width, height = config.IMAGE_WIDTH, config.IMAGE_HEIGHT
|
||||
if width > 0 and height > 0:
|
||||
return (height, width) # Return as H, W
|
||||
else:
|
||||
return None
|
||||
|
||||
# --- 主函数 ---
|
||||
def main(args):
|
||||
"""
|
||||
Main script entry point for the automated analysis workflow.
|
||||
"""
|
||||
input_root = args.input_dir
|
||||
|
||||
if not os.path.isdir(input_root):
|
||||
logging.error(f"Input directory does not exist: {input_root}")
|
||||
return
|
||||
|
||||
# 1. Find all valid dataset directories (ending with '-SegModel')
|
||||
existing_dataset_dirs = sorted([
|
||||
d for d in glob.glob(os.path.join(input_root, '*_outputs-SegModel')) if os.path.isdir(d)
|
||||
])
|
||||
|
||||
if not existing_dataset_dirs:
|
||||
logging.error(f"No valid dataset folders (e.g., '*_outputs-SegModel') found in {input_root}.")
|
||||
return
|
||||
|
||||
# 2. First-level menu: Select dataset
|
||||
dataset_map = {str(i + 1): path for i, path in enumerate(existing_dataset_dirs)}
|
||||
|
||||
print("\n" + "="*50)
|
||||
print("--- Step 1: Please select the dataset to process ---")
|
||||
for key, path in dataset_map.items():
|
||||
print(f"{key}: {os.path.basename(path)}")
|
||||
print("="*50)
|
||||
|
||||
choice1 = input("Enter the dataset number and press Enter: ").strip()
|
||||
|
||||
model_archs_to_process = []
|
||||
selected_dataset_dir = None
|
||||
|
||||
if choice1 in dataset_map:
|
||||
selected_dataset_dir = dataset_map[choice1]
|
||||
logging.info(f"You have selected the dataset: [{os.path.basename(selected_dataset_dir)}]")
|
||||
|
||||
# 3. Find all algorithm subdirectories and extract base architecture names
|
||||
alg_dirs_full_path = sorted([d for d in glob.glob(os.path.join(selected_dataset_dir, '*')) if os.path.isdir(d)])
|
||||
|
||||
# CORRECTED LOGIC: Create a list of valid base architecture names found in the directory
|
||||
valid_alg_archs = []
|
||||
for path in alg_dirs_full_path:
|
||||
# Extract base name, e.g., 'Unet' from 'Unet_2025-09-24_14-24-05'
|
||||
base_name = os.path.basename(path).split('_')[0]
|
||||
if base_name in config.ALL_MODEL_CONFIGS and base_name not in valid_alg_archs:
|
||||
valid_alg_archs.append(base_name)
|
||||
|
||||
if not valid_alg_archs:
|
||||
logging.warning(f"No valid algorithm subfolders found in {os.path.basename(selected_dataset_dir)}. Please check folder names match keys in config.ALL_MODEL_CONFIGS.")
|
||||
else:
|
||||
# 4. Second-level menu: Select algorithm using the base names
|
||||
alg_map = {str(i + 1): name for i, name in enumerate(valid_alg_archs)}
|
||||
print("\n" + "="*50)
|
||||
print("--- Step 2: Please select the algorithm to process ---")
|
||||
print("0: Process [ALL] algorithms under the current dataset")
|
||||
for key, name in alg_map.items():
|
||||
print(f"{key}: {name}")
|
||||
print("="*50)
|
||||
|
||||
choice2 = input("Enter the algorithm number (or '0' for all) and press Enter: ").strip()
|
||||
|
||||
# 5. Determine the final list of architectures to process
|
||||
if choice2 == '0':
|
||||
model_archs_to_process = valid_alg_archs
|
||||
logging.info(f"You have chosen to process all {len(model_archs_to_process)} algorithms.")
|
||||
elif choice2 in alg_map:
|
||||
model_archs_to_process = [alg_map[choice2]]
|
||||
logging.info(f"You have chosen to process a single algorithm: {model_archs_to_process[0]}")
|
||||
else:
|
||||
logging.error("Invalid algorithm selection. Exiting.")
|
||||
else:
|
||||
logging.error("Invalid dataset selection. Exiting.")
|
||||
|
||||
if not model_archs_to_process or not selected_dataset_dir:
|
||||
return
|
||||
|
||||
# --- Processing logic ---
|
||||
results: List[Dict[str, str]] = []
|
||||
|
||||
# Determine input shape for analysis
|
||||
input_shape = None
|
||||
if args.shape:
|
||||
# 1. Prioritize user-provided shape from command line
|
||||
input_shape = (args.shape[0], args.shape[1])
|
||||
logging.info(f"Using user-provided shape (H, W): {input_shape} for calculation.")
|
||||
else:
|
||||
# 2. Fallback to detecting shape from folder name
|
||||
logging.info("No --shape argument provided. Attempting to detect from folder name...")
|
||||
input_shape = get_shape_from_path(selected_dataset_dir)
|
||||
if not input_shape:
|
||||
# 3. Fallback to config.py defaults if detection fails
|
||||
logging.warning(f"Could not automatically detect resolution from folder '{os.path.basename(selected_dataset_dir)}'.")
|
||||
logging.info(f"Using default shape from config.py: ({config.IMAGE_HEIGHT}, {config.IMAGE_WIDTH})")
|
||||
input_shape = (config.IMAGE_HEIGHT, config.IMAGE_WIDTH)
|
||||
else:
|
||||
logging.info(f"Detected input shape (H, W): {input_shape} from folder name.")
|
||||
|
||||
for architecture in model_archs_to_process:
|
||||
logging.info(f"--- Processing model: {architecture} ---")
|
||||
|
||||
# ==================== SOLUTION CODE BLOCK START ====================
|
||||
# For Transformer models like DPT, ensure input shape is divisible by the patch size.
|
||||
analysis_shape = input_shape
|
||||
if architecture == 'DPT':
|
||||
patch_size = 16 # From the encoder name '...patch16...'
|
||||
h, w = analysis_shape
|
||||
|
||||
if h % patch_size != 0 or w % patch_size != 0:
|
||||
# Calculate the new, compatible dimensions by rounding up to the nearest multiple of patch_size
|
||||
new_h = ((h + patch_size - 1) // patch_size) * patch_size
|
||||
new_w = ((w + patch_size - 1) // patch_size) * patch_size
|
||||
analysis_shape = (new_h, new_w)
|
||||
logging.warning(f"DPT requires dimensions divisible by {patch_size}. "
|
||||
f"Adjusting analysis shape from {input_shape} to {analysis_shape}.")
|
||||
# ===================== SOLUTION CODE BLOCK END =====================
|
||||
|
||||
stats = get_flops_and_params_smp(architecture, analysis_shape)
|
||||
if stats:
|
||||
results.append({
|
||||
'Model': architecture,
|
||||
'Params': stats['params'],
|
||||
'FLOPs': stats['flops'],
|
||||
'Input_Shape (HxW)': f"{analysis_shape[0]}x{analysis_shape[1]}"
|
||||
})
|
||||
else:
|
||||
logging.warning(f"Failed to get stats for model {architecture}.")
|
||||
|
||||
# --- Write results to CSV file ---
|
||||
if not results:
|
||||
logging.info("No statistics were successfully generated, CSV file will not be created.")
|
||||
return
|
||||
|
||||
dataset_name = os.path.basename(selected_dataset_dir).replace('_outputs-SegModel', '')
|
||||
output_csv_path = os.path.join(selected_dataset_dir, f'{dataset_name}_flops_params_summary.csv')
|
||||
|
||||
try:
|
||||
with open(output_csv_path, 'w', newline='', encoding='utf-8') as csvfile:
|
||||
fieldnames = ['Model', 'Params', 'FLOPs', 'Input_Shape (HxW)']
|
||||
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
|
||||
writer.writeheader()
|
||||
writer.writerows(results)
|
||||
|
||||
logging.info(f"=== All processing complete! Results saved to: {output_csv_path} ===")
|
||||
except IOError as e:
|
||||
logging.error(f"Could not write to CSV file: {output_csv_path}. Error: {e}")
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description="SMP Model Automation and Evaluation Script")
|
||||
parser.add_argument(
|
||||
'--input_dir',
|
||||
type=str,
|
||||
default=str(config.PREDICT_ALL_BEST_MODELS_DIR),
|
||||
help="Root directory containing trained model folders (e.g., 'DATASET_outputs-SegModel')."
|
||||
)
|
||||
parser.add_argument(
|
||||
'--shape',
|
||||
type=int,
|
||||
nargs=2,
|
||||
metavar=('HEIGHT', 'WIDTH'),
|
||||
help="Specify input shape (height width) for analysis, overriding automatic detection."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
325
Seg_All_In_One_SegModel/2_predict_params_and_FLOPs_V2.py
Normal file
325
Seg_All_In_One_SegModel/2_predict_params_and_FLOPs_V2.py
Normal file
@@ -0,0 +1,325 @@
|
||||
import os
|
||||
import glob
|
||||
import logging
|
||||
import argparse
|
||||
import csv
|
||||
from typing import Dict, Optional, Tuple, List
|
||||
import time
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
import sys
|
||||
|
||||
import torch
|
||||
from fvcore.nn import FlopCountAnalysis
|
||||
import segmentation_models_pytorch as smp
|
||||
import config # Assumes config.py is in the same directory or accessible
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
|
||||
# TODO 使用方法 TODO
|
||||
# 交互模式(推荐)
|
||||
# python 2_predict_params_and_FLOPs.py
|
||||
# 非交互模式(用于脚本)
|
||||
# python 2_predict_params_and_FLOPs_V2.py --input_dir '../BestMode_Predict_Results_DataSet_Public'
|
||||
# python 2_predict_params_and_FLOPs_V2.py --shape 512 512 # --shape 1080 1920
|
||||
|
||||
def load_model(model_architecture: str, checkpoint_path: Optional[Path]) -> Optional[torch.nn.Module]:
|
||||
"""
|
||||
根据 config.py 直接重建模型架构,并从 checkpoint 路径加载权重。
|
||||
此函数不依赖 train.py。
|
||||
"""
|
||||
# 1. 根据 config 动态创建基础模型架构
|
||||
logging.info(f"Rebuilding model architecture: '{model_architecture}'")
|
||||
try:
|
||||
model_params = config.ALL_MODEL_CONFIGS[model_architecture].copy()
|
||||
model_class = getattr(smp, model_architecture)
|
||||
except (AttributeError, KeyError):
|
||||
logging.error(f"Model '{model_architecture}' not found in smp library or config.py!")
|
||||
return None
|
||||
|
||||
# 2. 准备初始化参数
|
||||
num_classes = len(config.CLASSES)
|
||||
params = model_params
|
||||
params['classes'] = num_classes
|
||||
# 加载权重时,不应再次下载预训练权重
|
||||
params['encoder_weights'] = None
|
||||
|
||||
# 自动检测 ViT 编码器以设置特殊参数
|
||||
encoder_name = model_params.get('encoder_name', '')
|
||||
if 'vit' in encoder_name.lower():
|
||||
params['dynamic_img_size'] = True
|
||||
logging.info(f"ViT encoder ('{encoder_name}') detected. Setting dynamic_img_size=True.")
|
||||
|
||||
# 3. 创建模型实例
|
||||
model = model_class(**params).to(config.DEVICE)
|
||||
|
||||
# 4. 如果提供了有效的 checkpoint 路径,则加载权重
|
||||
if checkpoint_path and checkpoint_path.exists():
|
||||
logging.info(f"Loading checkpoint: {checkpoint_path}")
|
||||
try:
|
||||
state_dict = torch.load(checkpoint_path, map_location=torch.device(config.DEVICE))
|
||||
# 处理 DataParallel 保存的权重
|
||||
is_dataparallel = any(key.startswith('module.') for key in state_dict.keys())
|
||||
if is_dataparallel:
|
||||
logging.info("DataParallel weights detected, removing 'module.' prefix.")
|
||||
new_state_dict = {key.replace('module.', ''): value for key, value in state_dict.items()}
|
||||
model.load_state_dict(new_state_dict)
|
||||
else:
|
||||
model.load_state_dict(state_dict)
|
||||
logging.info("Checkpoint loaded successfully.")
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to load checkpoint. Using model with random weights. Error: {e}")
|
||||
else:
|
||||
logging.info("No valid checkpoint provided. Using model with random weights.")
|
||||
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
# --- 核心功能函数 ---
|
||||
def find_latest_model_run_path(base_dir: Path, model_architecture: str) -> Optional[Path]:
|
||||
"""
|
||||
非交互式地查找并返回给定架构的最新模型运行目录。
|
||||
"""
|
||||
logging.info(f"Searching for latest '{model_architecture}' model run in '{base_dir}'...")
|
||||
|
||||
# 查找所有以架构名开头的文件夹
|
||||
run_dirs = sorted([d for d in base_dir.iterdir() if d.is_dir() and d.name.startswith(model_architecture + '_')])
|
||||
|
||||
if not run_dirs:
|
||||
logging.info(f"No trained model folder found for '{model_architecture}'.")
|
||||
return None
|
||||
|
||||
# 排序后的最后一个即为最新的
|
||||
latest_run_dir = run_dirs[-1]
|
||||
logging.info(f"Automatically selected latest model run: {latest_run_dir.name}")
|
||||
return latest_run_dir
|
||||
|
||||
def calculate_flops_params(model: torch.nn.Module, shape: Tuple[int, int]) -> Optional[Dict[str, str]]:
|
||||
"""
|
||||
为已初始化的模型计算FLOPs和参数量。
|
||||
"""
|
||||
logging.info(f"Analyzing model Params/FLOPs with input shape {shape}...")
|
||||
model.to(config.DEVICE)
|
||||
model.eval()
|
||||
dummy_input = torch.randn(1, 3, shape[0], shape[1]).to(config.DEVICE)
|
||||
try:
|
||||
flops_analyzer = FlopCountAnalysis(model, dummy_input)
|
||||
gflops = flops_analyzer.total() / 1e9
|
||||
m_params = sum(p.numel() for p in model.parameters()) / 1e6
|
||||
logging.info(f"✅ Params/FLOPs Analysis successful: {gflops:.2f} GFLOPs, {m_params:.2f} M Params")
|
||||
return {'flops': f"{gflops:.2f} G", 'params': f"{m_params:.2f} M"}
|
||||
except Exception as e:
|
||||
logging.error(f"❌ Failed to analyze FLOPs/Params. Error: {e}")
|
||||
return None
|
||||
|
||||
def benchmark_model(model: torch.nn.Module, shape: Tuple[int, int]) -> Optional[Dict[str, float]]:
|
||||
"""
|
||||
对加载了权重的模型进行FPS基准测试。
|
||||
"""
|
||||
logging.info("Starting FPS benchmark...")
|
||||
model.to(config.DEVICE)
|
||||
model.eval()
|
||||
dummy_input = torch.randn(1, 3, shape[0], shape[1]).to(config.DEVICE)
|
||||
|
||||
overall_fps_list = []
|
||||
repeat_times = 3
|
||||
|
||||
with torch.no_grad():
|
||||
for run_index in range(repeat_times):
|
||||
num_warmup = 5
|
||||
pure_inf_time = 0
|
||||
total_iters = 100
|
||||
|
||||
for i in range(total_iters + num_warmup):
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.synchronize()
|
||||
start_time = time.perf_counter()
|
||||
_ = model(dummy_input)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.synchronize()
|
||||
elapsed = time.perf_counter() - start_time
|
||||
if i >= num_warmup:
|
||||
pure_inf_time += elapsed
|
||||
|
||||
if pure_inf_time == 0:
|
||||
pure_inf_time = 1e-6
|
||||
|
||||
overall_fps = total_iters / pure_inf_time
|
||||
overall_fps_list.append(overall_fps)
|
||||
print(f" Run {run_index + 1}/{repeat_times}, FPS: {overall_fps:.2f} img/s")
|
||||
|
||||
if not overall_fps_list:
|
||||
return None
|
||||
|
||||
avg_fps = np.mean(overall_fps_list)
|
||||
fps_variance = np.var(overall_fps_list)
|
||||
logging.info(f"✅ FPS Benchmark successful: Average FPS: {avg_fps:.2f} img/s")
|
||||
return {'avg_fps': avg_fps, 'fps_variance': fps_variance}
|
||||
|
||||
def get_shape_from_path(path: str) -> Optional[Tuple[int, int]]:
|
||||
"""
|
||||
从文件夹路径中提取分辨率。
|
||||
"""
|
||||
# V1. 采用图片本身的尺寸作为输入
|
||||
import re
|
||||
match = re.search(r'(\d+)x(\d+)', os.path.basename(path))
|
||||
if match:
|
||||
width, height = int(match.group(1)), int(match.group(2))
|
||||
return (height, width) # Return as H, W
|
||||
return None
|
||||
|
||||
# V2. 采用文件夹config中的尺寸作为输入
|
||||
# width, height = config.IMAGE_WIDTH, config.IMAGE_HEIGHT
|
||||
|
||||
if width > 0 and height > 0:
|
||||
return (height, width)
|
||||
else:
|
||||
return None
|
||||
|
||||
# --- 主函数 ---
|
||||
def main(args):
|
||||
"""
|
||||
脚本主入口。
|
||||
"""
|
||||
input_root = args.input_dir
|
||||
if not os.path.isdir(input_root):
|
||||
logging.error(f"Input directory does not exist: {input_root}")
|
||||
return
|
||||
|
||||
# 1. 交互式选择数据集和算法
|
||||
existing_dataset_dirs = sorted([d for d in glob.glob(os.path.join(input_root, '*_outputs-SegModel')) if os.path.isdir(d)])
|
||||
if not existing_dataset_dirs:
|
||||
logging.error(f"No valid dataset folders found in {input_root}.")
|
||||
return
|
||||
|
||||
dataset_map = {str(i + 1): path for i, path in enumerate(existing_dataset_dirs)}
|
||||
print("\n" + "="*50)
|
||||
print("--- Step 1: Please select the dataset to process ---")
|
||||
for key, path in dataset_map.items():
|
||||
print(f"{key}: {os.path.basename(path)}")
|
||||
print("="*50)
|
||||
choice1 = input("Enter the dataset number and press Enter: ").strip()
|
||||
|
||||
model_archs_to_process = []
|
||||
selected_dataset_dir = None
|
||||
if choice1 in dataset_map:
|
||||
selected_dataset_dir = dataset_map[choice1]
|
||||
logging.info(f"You have selected the dataset: [{os.path.basename(selected_dataset_dir)}]")
|
||||
|
||||
alg_dirs_full_path = sorted([d for d in glob.glob(os.path.join(selected_dataset_dir, '*')) if os.path.isdir(d)])
|
||||
valid_alg_archs = []
|
||||
for path in alg_dirs_full_path:
|
||||
base_name = os.path.basename(path).split('_')[0]
|
||||
if base_name in config.ALL_MODEL_CONFIGS and base_name not in valid_alg_archs:
|
||||
valid_alg_archs.append(base_name)
|
||||
|
||||
if valid_alg_archs:
|
||||
alg_map = {str(i + 1): name for i, name in enumerate(valid_alg_archs)}
|
||||
print("\n" + "="*50)
|
||||
print("--- Step 2: Please select the algorithm to process ---")
|
||||
print("0: Process [ALL] algorithms under the current dataset")
|
||||
for key, name in alg_map.items():
|
||||
print(f"{key}: {name}")
|
||||
print("="*50)
|
||||
choice2 = input("Enter the algorithm number (or '0' for all) and press Enter: ").strip()
|
||||
if choice2 == '0':
|
||||
model_archs_to_process = valid_alg_archs
|
||||
elif choice2 in alg_map:
|
||||
model_archs_to_process = [alg_map[choice2]]
|
||||
else:
|
||||
logging.warning(f"No valid algorithm subfolders found in {os.path.basename(selected_dataset_dir)}.")
|
||||
else:
|
||||
logging.error("Invalid dataset selection. Exiting.")
|
||||
return
|
||||
|
||||
if not model_archs_to_process or not selected_dataset_dir:
|
||||
return
|
||||
|
||||
# 2. 确定输入尺寸
|
||||
if args.shape:
|
||||
analysis_shape = (args.shape[0], args.shape[1])
|
||||
logging.info(f"Using user-provided shape (H, W): {analysis_shape} for calculation.")
|
||||
else:
|
||||
analysis_shape = get_shape_from_path(selected_dataset_dir) or (config.IMAGE_HEIGHT, config.IMAGE_WIDTH)
|
||||
logging.info(f"Using detected/default shape (H, W): {analysis_shape} for calculation.")
|
||||
|
||||
# 3. 主处理循环
|
||||
ori_analysis_shape = analysis_shape
|
||||
results: List[Dict[str, str]] = []
|
||||
for architecture in model_archs_to_process:
|
||||
logging.info(f"--- Processing model: {architecture} ---")
|
||||
|
||||
# 特殊处理 DPT 模型的输入尺寸要求
|
||||
analysis_shape = ori_analysis_shape
|
||||
patch_size = 16 # From the encoder name '...patch16...'
|
||||
h, w = analysis_shape
|
||||
if h % patch_size != 0 or w % patch_size != 0:
|
||||
# Calculate the new, compatible dimensions by rounding up to the nearest multiple of patch_size
|
||||
new_h = ((h + patch_size - 1) // patch_size) * patch_size
|
||||
new_w = ((w + patch_size - 1) // patch_size) * patch_size
|
||||
analysis_shape = (new_h, new_w)
|
||||
logging.warning(f"DPT requires dimensions divisible by {patch_size}. "
|
||||
f"Adjusting analysis shape from {ori_analysis_shape} to {analysis_shape}.")
|
||||
|
||||
# 步骤 3.1: 查找最新的模型运行路径和权重文件路径
|
||||
latest_run_dir = find_latest_model_run_path(Path(selected_dataset_dir), architecture)
|
||||
checkpoint_path = None
|
||||
if latest_run_dir:
|
||||
# 权重文件路径,即用户期望的 best_model_save_path
|
||||
checkpoint_path = latest_run_dir / config.BEST_MODEL_SAVE_NAME
|
||||
|
||||
# 步骤 3.2: 使用新函数加载模型(一行代码完成初始化和权重加载)
|
||||
model = load_model(architecture, checkpoint_path)
|
||||
|
||||
if model is None:
|
||||
logging.warning(f"Skipping all analysis for {architecture} due to model initialization failure.")
|
||||
continue
|
||||
|
||||
# 步骤 3.3: 在同一个模型实例上执行所有计算
|
||||
stats_flops_params = calculate_flops_params(model, analysis_shape)
|
||||
stats_fps = benchmark_model(model, analysis_shape)
|
||||
|
||||
# 3.4 整合结果
|
||||
results.append({
|
||||
'Model': architecture,
|
||||
'Params': stats_flops_params['params'] if stats_flops_params else "N/A",
|
||||
'FLOPs': stats_flops_params['flops'] if stats_flops_params else "N/A",
|
||||
'Input_Shape (HxW)': f"{ori_analysis_shape[0]}x{ori_analysis_shape[1]}",
|
||||
'Average_FPS': f"{stats_fps['avg_fps']:.2f}" if stats_fps else "N/A",
|
||||
'FPS_Variance': f"{stats_fps['fps_variance']:.4f}" if stats_fps else "N/A"
|
||||
})
|
||||
|
||||
# 4. 写入CSV文件
|
||||
if results:
|
||||
dataset_name = os.path.basename(selected_dataset_dir).replace('_outputs-SegModel', '')
|
||||
output_csv_path = os.path.join(selected_dataset_dir, f'{dataset_name}_flops_params_fps_summary.csv')
|
||||
|
||||
try:
|
||||
with open(output_csv_path, 'w', newline='', encoding='utf-8') as csvfile:
|
||||
fieldnames = ['Model', 'Params', 'FLOPs', 'Input_Shape (HxW)', 'Average_FPS', 'FPS_Variance']
|
||||
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
|
||||
writer.writeheader()
|
||||
writer.writerows(results)
|
||||
logging.info(f"=== All processing complete! Results saved to: {output_csv_path} ===")
|
||||
except IOError as e:
|
||||
logging.error(f"Could not write to CSV file: {output_csv_path}. Error: {e}")
|
||||
else:
|
||||
logging.info("No statistics were successfully generated, CSV file will not be created.")
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description="SMP Model Automation and Evaluation Script")
|
||||
parser.add_argument(
|
||||
'--input_dir',
|
||||
type=str,
|
||||
default=str(config.PREDICT_ALL_BEST_MODELS_DIR),
|
||||
help="Root directory containing trained model folders."
|
||||
)
|
||||
parser.add_argument(
|
||||
'--shape',
|
||||
type=int,
|
||||
nargs=2,
|
||||
metavar=('HEIGHT', 'WIDTH'),
|
||||
help="Specify input shape (height width), overriding automatic detection."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
271
Seg_All_In_One_SegModel/3_predict_matrics_from_log.py
Normal file
271
Seg_All_In_One_SegModel/3_predict_matrics_from_log.py
Normal file
@@ -0,0 +1,271 @@
|
||||
import os
|
||||
import glob
|
||||
import logging
|
||||
import argparse
|
||||
import re
|
||||
import csv
|
||||
import numpy as np
|
||||
from typing import Dict, Optional, List
|
||||
|
||||
# --- Configure Logging ---
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
|
||||
# --- Helper Functions (No changes in this section) ---
|
||||
|
||||
def find_smp_csv_file(algorithm_dir: str) -> Optional[str]:
|
||||
"""
|
||||
Finds the 'training_metrics.csv' file in a given algorithm directory.
|
||||
"""
|
||||
csv_path = os.path.join(algorithm_dir, 'training_metrics.csv')
|
||||
if os.path.exists(csv_path):
|
||||
return csv_path
|
||||
else:
|
||||
logging.warning(f"Could not find 'training_metrics.csv' in {algorithm_dir}")
|
||||
return None
|
||||
|
||||
def parse_smp_metrics(csv_path: str) -> Optional[Dict]:
|
||||
"""
|
||||
Parses 'training_metrics.csv' to find the best epoch based on the highest
|
||||
IoU score and extracts all relevant metrics.
|
||||
"""
|
||||
try:
|
||||
with open(csv_path, 'r', encoding='utf-8') as f:
|
||||
reader = csv.DictReader(f)
|
||||
rows = list(reader)
|
||||
except (IOError, StopIteration) as e:
|
||||
logging.error(f"Cannot read or parse CSV file: {csv_path}. Error: {e}")
|
||||
return None
|
||||
|
||||
if not rows:
|
||||
logging.warning(f"❌ CSV file is empty: {os.path.basename(csv_path)}")
|
||||
return None
|
||||
|
||||
try:
|
||||
best_row = max(rows, key=lambda row: float(row.get('iou_score', 0)))
|
||||
best_iou = float(best_row.get('iou_score', 0))
|
||||
if best_iou == 0:
|
||||
logging.warning(f"❌ Best 'iou_score' is 0 in {os.path.basename(csv_path)}. Check training.")
|
||||
return None
|
||||
except (ValueError, TypeError) as e:
|
||||
logging.error(f"Could not find a valid 'iou_score' in {csv_path}. Error: {e}")
|
||||
return None
|
||||
|
||||
results = {
|
||||
'epoch': best_row.get('epoch', 'N/A'),
|
||||
'summary': {
|
||||
'mIoU': best_row.get('iou_score', 'N/A'),
|
||||
'mAcc': 'N/A',
|
||||
'aAcc': best_row.get('accuracy', 'N/A')
|
||||
},
|
||||
'class_wise': []
|
||||
}
|
||||
|
||||
per_class_accuracies = []
|
||||
num_classes = 0
|
||||
class_keys = [key for key in best_row if key.startswith('tp_class_')]
|
||||
if class_keys:
|
||||
num_classes = max([int(k.split('_')[-1]) for k in class_keys]) + 1
|
||||
|
||||
for i in range(num_classes):
|
||||
try:
|
||||
tp = float(best_row.get(f'tp_class_{i}', 0))
|
||||
fp = float(best_row.get(f'fp_class_{i}', 0))
|
||||
fn = float(best_row.get(f'fn_class_{i}', 0))
|
||||
|
||||
class_iou = tp / (tp + fp + fn + 1e-6)
|
||||
class_acc_str = best_row.get(f'cpa_class_{i}', '0')
|
||||
per_class_accuracies.append(float(class_acc_str))
|
||||
|
||||
class_name = '背景' if str(i) == '0' else str(i)
|
||||
|
||||
results['class_wise'].append({
|
||||
'Class': class_name,
|
||||
'IoU': f"{class_iou:.4f}",
|
||||
'Acc': class_acc_str
|
||||
})
|
||||
except (ValueError, TypeError) as e:
|
||||
logging.warning(f"Could not process metrics for class {i}. Error: {e}")
|
||||
continue
|
||||
|
||||
if per_class_accuracies:
|
||||
results['summary']['mAcc'] = f"{np.mean(per_class_accuracies):.4f}"
|
||||
|
||||
if results['class_wise']:
|
||||
logging.info(f"✅ Successfully parsed metrics for epoch '{results['epoch']}' from {os.path.basename(csv_path)}.")
|
||||
return results
|
||||
else:
|
||||
logging.warning(f"❌ Could not parse any per-class metrics from {os.path.basename(csv_path)}.")
|
||||
return None
|
||||
|
||||
|
||||
# --- Main Function ---
|
||||
def main(args):
|
||||
"""
|
||||
Main script entry point to orchestrate the analysis workflow.
|
||||
"""
|
||||
input_root = args.input_dir
|
||||
output_root = args.output_dir
|
||||
|
||||
if not os.path.isdir(input_root):
|
||||
logging.error(f"Input directory does not exist: {input_root}")
|
||||
return
|
||||
|
||||
# --- Interactive Menu ---
|
||||
all_dataset_dirs = sorted([
|
||||
d for d in glob.glob(os.path.join(input_root, '*_outputs-SegModel')) if os.path.isdir(d)
|
||||
])
|
||||
|
||||
if not all_dataset_dirs:
|
||||
logging.error(f"No valid dataset folders ending in '_outputs-SegModel' found in {input_root}.")
|
||||
return
|
||||
|
||||
dataset_map = {str(i + 1): path for i, path in enumerate(all_dataset_dirs)}
|
||||
|
||||
print("\n" + "="*50)
|
||||
print("--- Step 1: Please select a dataset to process ---")
|
||||
for key, path in dataset_map.items():
|
||||
print(f"{key}: {os.path.basename(path)}")
|
||||
print("="*50)
|
||||
|
||||
choice1 = input("Enter the dataset number and press Enter: ").strip()
|
||||
|
||||
model_dirs = []
|
||||
selected_dataset_dir = None
|
||||
|
||||
if choice1 in dataset_map:
|
||||
selected_dataset_dir = dataset_map[choice1]
|
||||
logging.info(f"You have selected the dataset: [{os.path.basename(selected_dataset_dir)}]")
|
||||
|
||||
alg_dirs = sorted([
|
||||
d for d in glob.glob(os.path.join(selected_dataset_dir, '*')) if os.path.isdir(d)
|
||||
])
|
||||
|
||||
if not alg_dirs:
|
||||
logging.warning(f"No algorithm subfolders found in {os.path.basename(selected_dataset_dir)}.")
|
||||
return
|
||||
|
||||
alg_map = {str(i + 1): path for i, path in enumerate(alg_dirs)}
|
||||
print("\n" + "="*50)
|
||||
print("--- Step 2: Please select the algorithm(s) to process ---")
|
||||
print("0: Process [ALL] algorithms under the current dataset")
|
||||
for key, path in alg_map.items():
|
||||
print(f"{key}: {os.path.basename(path)}")
|
||||
print("="*50)
|
||||
|
||||
choice2 = input("Enter the algorithm number (or '0' for all) and press Enter: ").strip()
|
||||
|
||||
if choice2 == '0':
|
||||
model_dirs = alg_dirs
|
||||
logging.info(f"You have chosen to batch process all {len(model_dirs)} algorithms.")
|
||||
elif choice2 in alg_map:
|
||||
model_dirs = [alg_map[choice2]]
|
||||
logging.info(f"You have chosen to process a single algorithm: {os.path.basename(model_dirs[0])}")
|
||||
else:
|
||||
logging.error("Invalid algorithm selection. The program will exit.")
|
||||
return
|
||||
else:
|
||||
logging.error("Invalid dataset selection. The program will exit.")
|
||||
return
|
||||
|
||||
# --- Start processing the selected algorithms ---
|
||||
csv_rows = []
|
||||
|
||||
for model_dir in model_dirs:
|
||||
model_name_full = os.path.basename(model_dir)
|
||||
logging.info(f"\n--- Processing algorithm: {model_name_full} ---")
|
||||
|
||||
csv_file_path = find_smp_csv_file(model_dir)
|
||||
|
||||
if not csv_file_path:
|
||||
logging.warning(f"Skipping {model_name_full} as no 'training_metrics.csv' was found.")
|
||||
continue
|
||||
|
||||
metrics = parse_smp_metrics(csv_file_path)
|
||||
|
||||
if not metrics:
|
||||
logging.warning(f"❌❌❌ Skipping {model_name_full} due to a failure in parsing metrics. ❌❌❌")
|
||||
continue
|
||||
|
||||
summary = metrics['summary']
|
||||
short_model_name = model_name_full.split('_')[0]
|
||||
|
||||
row_data = {
|
||||
'Algorithm': short_model_name,
|
||||
'Epoch': metrics['epoch'],
|
||||
'mIoU': f"{round(float(summary['mIoU']), 4)*100:.2f}",
|
||||
'mAcc': f"{round(float(summary['mAcc']), 4)*100:.2f}",
|
||||
'aAcc': f"{round(float(summary['aAcc']), 4)*100:.2f}"
|
||||
}
|
||||
|
||||
for class_data in metrics['class_wise']:
|
||||
class_name = class_data['Class']
|
||||
row_data[f'{class_name}_IoU'] = f"{round(float(class_data['IoU']), 4)*100:.2f}"
|
||||
row_data[f'{class_name}_Acc'] = f"{round(float(class_data['Acc']), 4)*100:.2f}"
|
||||
|
||||
csv_rows.append(row_data)
|
||||
|
||||
# --- Write results to the final CSV file ---
|
||||
if not csv_rows:
|
||||
logging.info("No data was successfully collected. No CSV file will be generated.")
|
||||
return
|
||||
|
||||
# --- NEW: DYNAMICALLY GENERATE FIELDNAMES ---
|
||||
base_fieldnames = ['Algorithm', 'Epoch', 'mIoU', 'mAcc', 'aAcc']
|
||||
|
||||
# Collect all unique per-class fieldnames from the data
|
||||
extra_fieldnames = set()
|
||||
for row in csv_rows:
|
||||
for key in row.keys():
|
||||
if key not in base_fieldnames:
|
||||
extra_fieldnames.add(key)
|
||||
|
||||
# Define a sort key for natural sorting (e.g., '1', '2', '10' instead of '1', '10', '2')
|
||||
def natural_sort_key(key_name):
|
||||
parts = key_name.split('_')
|
||||
class_part = parts[0]
|
||||
metric_part = parts[1] if len(parts) > 1 else ''
|
||||
if class_part == '背景':
|
||||
return (-1, metric_part) # Put '背景' (background) first
|
||||
try:
|
||||
return (int(class_part), metric_part)
|
||||
except ValueError:
|
||||
return (float('inf'), key_name) # Put any non-numeric names last
|
||||
|
||||
# Combine base fields with sorted extra fields
|
||||
final_fieldnames = base_fieldnames + sorted(list(extra_fieldnames), key=natural_sort_key)
|
||||
|
||||
# Build the output path
|
||||
output_dataset_folder = os.path.basename(selected_dataset_dir)
|
||||
final_output_dir = os.path.join(output_root, output_dataset_folder)
|
||||
os.makedirs(final_output_dir, exist_ok=True)
|
||||
|
||||
dataset_name = output_dataset_folder.replace('_outputs-SegModel', '')
|
||||
output_csv_path = os.path.join(final_output_dir, f'{dataset_name}_metrics_summary_wide.csv')
|
||||
|
||||
try:
|
||||
with open(output_csv_path, 'w', newline='', encoding='utf-8-sig') as csvfile:
|
||||
writer = csv.DictWriter(csvfile, fieldnames=final_fieldnames)
|
||||
writer.writeheader()
|
||||
writer.writerows(csv_rows)
|
||||
|
||||
logging.info(f"\n=== All processing is complete! Results have been saved to: {output_csv_path} ===")
|
||||
except IOError as e:
|
||||
logging.error(f"Failed to write to CSV file: {output_csv_path}. Error: {e}")
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description="SMP (Segmentation Models Pytorch) Final Metrics Extraction Script")
|
||||
parser.add_argument(
|
||||
'--input_dir',
|
||||
type=str,
|
||||
default='../Hardisk',
|
||||
help="The root directory containing dataset output folders (e.g., '..._outputs-SegModel')."
|
||||
)
|
||||
parser.add_argument(
|
||||
'--output_dir',
|
||||
type=str,
|
||||
default='../BestMode_Predict_Results_DataSet_Public',
|
||||
help="The root directory where analysis results will be stored."
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
120
Seg_All_In_One_SegModel/Tool_Copy_Best_Model.sh
Normal file
120
Seg_All_In_One_SegModel/Tool_Copy_Best_Model.sh
Normal file
@@ -0,0 +1,120 @@
|
||||
#!/bin/bash
|
||||
|
||||
# +----------------------------------------------------------------------------+
|
||||
# | 脚本: 将 best_model.pth 从 Hardisk 复制到 BestMode_Predict_Results_DataSet_Public |
|
||||
# | 说明: 此版本优化了复制逻辑,仅在文件不存在或内容不一致时才复制。 |
|
||||
# | 同时包含一个可视化的进度条。 (已修正语法错误) |
|
||||
# | 用法: ./copy_models.sh |
|
||||
# +----------------------------------------------------------------------------+
|
||||
|
||||
# 设置源目录和目标目录的基础路径
|
||||
# realpath确保我们获得的是绝对路径,避免相对路径可能带来的问题
|
||||
SOURCE_BASE_DIR=$(realpath "../Hardisk")
|
||||
DEST_BASE_DIR=$(realpath "../BestMode_Predict_Results_DataSet_Public")
|
||||
|
||||
# 检查源目录是否存在
|
||||
if [ ! -d "$SOURCE_BASE_DIR" ]; then
|
||||
echo "错误: 源目录 '$SOURCE_BASE_DIR' 不存在。"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# 确保目标基础目录存在
|
||||
mkdir -p "$DEST_BASE_DIR"
|
||||
|
||||
echo "正在准备操作..."
|
||||
|
||||
# --- 第1步: 查找所有文件并存储到数组中 ---
|
||||
# 使用 find -print0 和 readarray -d '' 的组合,可以安全处理任何包含特殊字符的文件名
|
||||
# 这样只需查找一次,提高了效率和健壮性
|
||||
echo "正在查找所有 'best_model.pth' 文件..."
|
||||
readarray -d '' files_to_process < <(find "$SOURCE_BASE_DIR" -path "*-SegModel/*" -name "best_model.pth" -type f -print0 2>/dev/null)
|
||||
TOTAL_FILES=${#files_to_process[@]}
|
||||
|
||||
|
||||
if [ "$TOTAL_FILES" -eq 0 ]; then
|
||||
echo "在源目录中没有找到任何 'best_model.pth' 文件。脚本退出。"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
echo "总共找到 $TOTAL_FILES 个模型文件需要处理。"
|
||||
echo "--------------------------------------------------"
|
||||
|
||||
# --- 进度条函数 ---
|
||||
# 参数1: 当前文件数
|
||||
# 参数2: 文件总数
|
||||
print_progress() {
|
||||
local current=$1
|
||||
local total=$2
|
||||
# 减去更多的空间以容纳状态信息
|
||||
local term_width=$(tput cols)
|
||||
local bar_width=$((term_width - 30))
|
||||
|
||||
# 确保 bar_width 不为负
|
||||
if [ "$bar_width" -lt 10 ]; then
|
||||
bar_width=10
|
||||
fi
|
||||
|
||||
# 计算百分比
|
||||
local percent=$((current * 100 / total))
|
||||
|
||||
# 计算已完成部分的长度
|
||||
local filled_len=$((bar_width * percent / 100))
|
||||
|
||||
# 准备进度条的显示字符
|
||||
local bar=""
|
||||
for ((i=0; i<filled_len; i++)); do bar+="="; done
|
||||
for ((i=filled_len; i<bar_width; i++)); do bar+=" "; done
|
||||
|
||||
# 使用 \r (回车符) 让光标回到行首,实现单行刷新效果
|
||||
printf "\r处理进度: %3d%% [%s] %d/%d" $percent "$bar" $current $total
|
||||
}
|
||||
|
||||
|
||||
# --- 第2步: 执行处理并显示进度条 ---
|
||||
CURRENT_FILE=0
|
||||
COPIED_COUNT=0
|
||||
SKIPPED_COUNT=0
|
||||
|
||||
# 遍历之前查找到的文件列表
|
||||
for source_file_path in "${files_to_process[@]}"; do
|
||||
# 【安全检查】如果数组中因意外情况包含空元素,则跳过此次循环
|
||||
if [[ -z "$source_file_path" ]]; then
|
||||
continue
|
||||
fi
|
||||
|
||||
# 增加当前文件计数器
|
||||
((CURRENT_FILE++))
|
||||
|
||||
# 【修正部分】构建目标文件路径,请确保这一区域的代码和下面完全一致
|
||||
# 从完整源路径中移除基础目录部分,得到相对路径
|
||||
relative_path="${source_file_path#"$SOURCE_BASE_DIR"}"
|
||||
# 组合目标基础目录和相对路径,得到完整目标路径
|
||||
dest_file_path="${DEST_BASE_DIR}${relative_path}"
|
||||
dest_dir_path=$(dirname "$dest_file_path")
|
||||
|
||||
# 优化点: 检查目标文件是否存在以及内容是否一致
|
||||
# 1. [ ! -f "$dest_file_path" ]: 目标文件不存在,则需要复制
|
||||
# 2. ! cmp -s "$source_file_path" "$dest_file_path": 目标文件存在,但内容不一致 (cmp -s 在文件相同时返回0, 不同时返回非0), 也需要复制
|
||||
if [ ! -f "$dest_file_path" ] || ! cmp -s "$source_file_path" "$dest_file_path"; then
|
||||
# 创建目标目录 (如果不存在)
|
||||
mkdir -p "$dest_dir_path"
|
||||
|
||||
# 执行复制操作
|
||||
cp "$source_file_path" "$dest_file_path"
|
||||
((COPIED_COUNT++))
|
||||
else
|
||||
# 文件已存在且内容一致,跳过
|
||||
((SKIPPED_COUNT++))
|
||||
fi
|
||||
|
||||
# 调用函数,更新屏幕上的进度条
|
||||
print_progress $CURRENT_FILE $TOTAL_FILES
|
||||
done
|
||||
|
||||
# 循环结束后,打印一个换行符,使光标移到下一行
|
||||
echo ""
|
||||
echo "--------------------------------------------------"
|
||||
echo "所有操作已完成!"
|
||||
echo " - 总共处理文件: $TOTAL_FILES"
|
||||
echo " - 复制或更新的文件: $COPIED_COUNT"
|
||||
echo " - 已存在且内容一致而跳过的文件: $SKIPPED_COUNT"
|
||||
189
Seg_All_In_One_SegModel/Tool_benchmark_smp.py
Normal file
189
Seg_All_In_One_SegModel/Tool_benchmark_smp.py
Normal file
@@ -0,0 +1,189 @@
|
||||
# benchmark_smp.py
|
||||
import torch
|
||||
import time
|
||||
import argparse
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from mmengine.model.utils import revert_sync_batchnorm
|
||||
|
||||
# 假设 train.py 在同一目录下,以便导入其函数
|
||||
# 如果不在,请确保路径正确
|
||||
from train import initialize_components
|
||||
import config
|
||||
|
||||
# TODO 使用方法 TODO
|
||||
# 交互模式(推荐)
|
||||
# python Tool_benchmark_smp.py -a Unet --shape 512 512
|
||||
# 非交互模式(用于脚本)
|
||||
# python Tool_benchmark_smp.py -a Unet -c "outputs/Unet_2025-10-13_19-30-00/best_model.pth" --shape 512 512
|
||||
|
||||
# --- Function copied from predict.py ---
|
||||
def select_trained_model_path(base_dir: Path, model_architecture: str) -> Path:
|
||||
"""
|
||||
查找指定架构的所有训练运行目录,并让用户选择一个。
|
||||
"""
|
||||
print(f"\nSearching for '{model_architecture}' model runs in '{base_dir}'...")
|
||||
|
||||
# 查找所有以架构名开头的文件夹
|
||||
run_dirs = sorted([d for d in base_dir.iterdir() if d.is_dir() and d.name.startswith(model_architecture+'_')])
|
||||
|
||||
if not run_dirs:
|
||||
print(f"ERROR: No trained models found for architecture '{model_architecture}'. Please run train_smp.py first.")
|
||||
return None
|
||||
|
||||
print("\nPlease select a trained model to benchmark:")
|
||||
for i, dir_path in enumerate(run_dirs):
|
||||
print(f" [{i+1}] {dir_path.name}")
|
||||
|
||||
while True:
|
||||
try:
|
||||
choice = input(f"Enter selection (1-{len(run_dirs)}) or press Enter to cancel: ")
|
||||
if not choice:
|
||||
return None
|
||||
choice_idx = int(choice) - 1
|
||||
if 0 <= choice_idx < len(run_dirs):
|
||||
selected_dir = run_dirs[choice_idx]
|
||||
print(f"Selected model: {selected_dir.name}")
|
||||
return selected_dir
|
||||
else:
|
||||
print("Invalid selection. Please try again.")
|
||||
except (ValueError, IndexError):
|
||||
print("Invalid input. Please enter a number.")
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description='Benchmark a segmentation model from SMP')
|
||||
parser.add_argument(
|
||||
'-a', '--architecture',
|
||||
type=str,
|
||||
required=True,
|
||||
choices=list(config.ALL_MODEL_CONFIGS.keys()),
|
||||
help="The model architecture to benchmark."
|
||||
)
|
||||
parser.add_argument(
|
||||
'-c', '--checkpoint',
|
||||
type=str,
|
||||
required=False, # <-- Changed to False
|
||||
help='(Optional) Path to the checkpoint file. If omitted, an interactive selection will be shown.'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--shape',
|
||||
type=int,
|
||||
nargs='+',
|
||||
default=[512, 512],
|
||||
help='Input image size for benchmarking, e.g., --shape 512 512'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--log-interval',
|
||||
type=int,
|
||||
default=50,
|
||||
help='Interval of logging.'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--repeat-times',
|
||||
type=int,
|
||||
default=3,
|
||||
help='Number of times to repeat the benchmark for averaging.'
|
||||
)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
# --- New logic to select checkpoint ---
|
||||
checkpoint_path = None
|
||||
if args.checkpoint:
|
||||
checkpoint_path = Path(args.checkpoint)
|
||||
else:
|
||||
# Assumes your outputs are saved in the directory specified in config.py
|
||||
# This is consistent with train_smp.py and predict.py
|
||||
selected_run_dir = select_trained_model_path(Path(config.PREDICT_BEST_MODEL_DIR), args.architecture)
|
||||
if not selected_run_dir:
|
||||
print("No model selected. Exiting.")
|
||||
sys.exit()
|
||||
|
||||
checkpoint_path = selected_run_dir / config.BEST_MODEL_SAVE_NAME
|
||||
|
||||
if not checkpoint_path.exists():
|
||||
print(f"ERROR: Checkpoint file not found at '{checkpoint_path}'")
|
||||
sys.exit()
|
||||
# --- End of new logic ---
|
||||
|
||||
if len(args.shape) == 1:
|
||||
h, w = args.shape[0], args.shape[0]
|
||||
elif len(args.shape) == 2:
|
||||
h, w = args.shape[0], args.shape[1]
|
||||
else:
|
||||
raise ValueError('Invalid input shape. Use one or two integers.')
|
||||
|
||||
print(f"\n--- Model Benchmarking ---")
|
||||
print(f"Architecture: {args.architecture}")
|
||||
print(f"Checkpoint: {checkpoint_path}")
|
||||
print(f"Input Shape: (1, 3, {h}, {w})")
|
||||
print(f"Device: {config.DEVICE}")
|
||||
print("-" * 28)
|
||||
|
||||
# 1. 初始化模型
|
||||
num_classes = len(config.CLASSES)
|
||||
model, _, _, _ = initialize_components(args.architecture, num_classes, config.SEG_MODE)
|
||||
|
||||
# 2. 加载权重
|
||||
print("Loading checkpoint...")
|
||||
state_dict = torch.load(checkpoint_path, map_location=config.DEVICE)
|
||||
|
||||
# Handle DataParallel saved models by removing 'module.' prefix
|
||||
is_dataparallel = any(key.startswith('module.') for key in state_dict.keys())
|
||||
if is_dataparallel:
|
||||
print("DataParallel model detected, removing 'module.' prefix.")
|
||||
new_state_dict = {key.replace('module.', ''): value for key, value in state_dict.items()}
|
||||
model.load_state_dict(new_state_dict)
|
||||
else:
|
||||
model.load_state_dict(state_dict)
|
||||
|
||||
model.to(config.DEVICE)
|
||||
model.eval()
|
||||
|
||||
# 3. 准备伪造输入数据
|
||||
dummy_input = torch.randn(1, 3, h, w).to(config.DEVICE)
|
||||
|
||||
# 4. 开始基准测试
|
||||
overall_fps_list = []
|
||||
for run_index in range(args.repeat_times):
|
||||
print(f"\n>>> Running benchmark iteration {run_index + 1}/{args.repeat_times}...")
|
||||
|
||||
num_warmup = 5
|
||||
pure_inf_time = 0
|
||||
total_iters = 200
|
||||
|
||||
with torch.no_grad():
|
||||
for i in range(total_iters + num_warmup):
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.synchronize()
|
||||
start_time = time.perf_counter()
|
||||
|
||||
_ = model(dummy_input)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.synchronize()
|
||||
elapsed = time.perf_counter() - start_time
|
||||
|
||||
if i >= num_warmup:
|
||||
pure_inf_time += elapsed
|
||||
if (i + 1) % args.log_interval == 0:
|
||||
fps = (i + 1 - num_warmup) / pure_inf_time
|
||||
print(f'Done image [{i + 1 - num_warmup:<4}/ {total_iters}], '
|
||||
f'current FPS: {fps:.2f} img/s')
|
||||
|
||||
overall_fps = total_iters / pure_inf_time
|
||||
overall_fps_list.append(overall_fps)
|
||||
print(f'Overall FPS for this run: {overall_fps:.2f} img/s')
|
||||
|
||||
# 5. 总结结果
|
||||
print("\n--- Benchmark Summary ---")
|
||||
avg_fps = np.mean(overall_fps_list)
|
||||
var_fps = np.var(overall_fps_list)
|
||||
print(f'Average FPS of {args.repeat_times} runs: {avg_fps:.2f} img/s')
|
||||
print(f'FPS Variance of {args.repeat_times} runs: {var_fps:.4f}')
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
89
Seg_All_In_One_SegModel/Tool_get_params_and_FLOPs.py
Normal file
89
Seg_All_In_One_SegModel/Tool_get_params_and_FLOPs.py
Normal file
@@ -0,0 +1,89 @@
|
||||
# get_flops_smp.py
|
||||
import torch
|
||||
import argparse
|
||||
from fvcore.nn import FlopCountAnalysis, flop_count_table
|
||||
|
||||
# 同样,需要能够导入 train_smp.py 中的函数
|
||||
from train import initialize_components
|
||||
import config
|
||||
|
||||
# TODO 使用方法 TODO
|
||||
# python Tool_get_params_and_FLOPs.py -a Unet --shape 512 512
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description='Get FLOPs and Params of an SMP model')
|
||||
parser.add_argument(
|
||||
'-a', '--architecture',
|
||||
type=str,
|
||||
required=True,
|
||||
choices=list(config.ALL_MODEL_CONFIGS.keys()),
|
||||
help="The model architecture to analyze."
|
||||
)
|
||||
parser.add_argument(
|
||||
'--shape',
|
||||
type=int,
|
||||
nargs='+',
|
||||
default=[512, 512],
|
||||
help='Input image size for analysis, e.g., --shape 512 512'
|
||||
)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
if len(args.shape) == 1:
|
||||
h, w = args.shape[0], args.shape[0]
|
||||
elif len(args.shape) == 2:
|
||||
h, w = args.shape[0], args.shape[1]
|
||||
else:
|
||||
raise ValueError('Invalid input shape. Use one or two integers.')
|
||||
|
||||
# 1. 初始化模型
|
||||
print(f"Initializing model: '{args.architecture}'...")
|
||||
num_classes = len(config.CLASSES)
|
||||
model, _, _, _ = initialize_components(args.architecture, num_classes, config.SEG_MODE)
|
||||
|
||||
# 如果模型在训练时使用了 DataParallel,需要移除
|
||||
if isinstance(model, torch.nn.DataParallel):
|
||||
model = model.module
|
||||
|
||||
model.to(config.DEVICE)
|
||||
model.eval()
|
||||
|
||||
# 2. 创建伪造输入
|
||||
dummy_input = torch.randn(1, 3, h, w).to(config.DEVICE)
|
||||
|
||||
# 3. 计算参数量
|
||||
# 分别计算总参数和可训练参数
|
||||
total_params = sum(p.numel() for p in model.parameters())
|
||||
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
|
||||
# 4. 使用 fvcore 计算 FLOPs
|
||||
print("Analyzing model FLOPs... (This may take a moment)")
|
||||
flops = FlopCountAnalysis(model, dummy_input)
|
||||
|
||||
# 5. 打印结果
|
||||
split_line = '=' * 30
|
||||
print(f"\n{split_line}")
|
||||
print(f"Model: {args.architecture}")
|
||||
print(f"Input shape: (1, 3, {h}, {w})")
|
||||
print(f"{split_line}")
|
||||
|
||||
# 使用 fvcore 的表格打印工具,非常清晰
|
||||
print(flop_count_table(flops, max_depth=4))
|
||||
|
||||
# 打印总结
|
||||
gflops = flops.total() / 1e9
|
||||
total_params_m = total_params / 1e6
|
||||
trainable_params_m = trainable_params / 1e6
|
||||
|
||||
print(f"{split_line}")
|
||||
print(f"Total FLOPs: {gflops:.2f} GFLOPs")
|
||||
print(f"Total Parameters: {total_params_m:.2f} M")
|
||||
print(f"Trainable Parameters: {trainable_params_m:.2f} M")
|
||||
print(f"{split_line}")
|
||||
print('!!! Please be cautious: FLOPs computation may not be perfectly accurate for all custom ops.')
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
153
Seg_All_In_One_SegModel/config.py
Normal file
153
Seg_All_In_One_SegModel/config.py
Normal file
@@ -0,0 +1,153 @@
|
||||
import torch
|
||||
from pathlib import Path
|
||||
|
||||
# --- 1. 核心目录设置 (Core Directories) ---
|
||||
# 使用 pathlib 进行路径管理,更现代、更健壮
|
||||
HARDISK_DIR = Path('../Hardisk')
|
||||
DATA_SETS_DIR = Path('../DataSet_Public_outputs') # 模型保存位置
|
||||
PREDICT_ALL_BEST_MODELS_DIR = Path('../BestMode_Predict_Results_DataSet_Public') # 预测模型存储位置
|
||||
|
||||
# V1: 1_CholecSeg8k-13Type-1920x1080
|
||||
# DATA_DIR = Path('../DataSet_Public/1_CholecSeg8k-13Type-1920x1080') # Path(__file__).parent.parent【有中文放弃了】 # 项目根目录 (假设 config.py 在 src/ 下)
|
||||
# OUTPUTS_DIR = DATA_SETS_DIR / "1_CholecSeg8k-13Type-1920x1080_outputs-SegModel" # 所有输出文件的根目录 # train 中 Path(config.OUTPUTS_DIR / architecture)
|
||||
# PREDICT_BEST_MODEL_DIR = PREDICT_ALL_BEST_MODELS_DIR / "1_CholecSeg8k-13Type-1920x1080_outputs-SegModel" # 最优模型位置 # train 中 Path(config.PREDICT_BEST_MODEL_DIR / architecture)
|
||||
# # V2:2_AutoLaparo-10Type-1920x1080
|
||||
# DATA_DIR = Path('../DataSet_Public/2_AutoLaparo-10Type-1920x1080') # Path(__file__).parent.parent【有中文放弃了】 # 项目根目录 (假设 config.py 在 src/ 下)
|
||||
# OUTPUTS_DIR = DATA_SETS_DIR / "2_AutoLaparo-10Type-1920x1080_outputs-SegModel" # 所有输出文件的根目录 # train 中 Path(config.OUTPUTS_DIR / architecture)
|
||||
# PREDICT_BEST_MODEL_DIR = PREDICT_ALL_BEST_MODELS_DIR / "2_AutoLaparo-10Type-1920x1080_outputs-SegModel" # 最优模型位置 # train 中 Path(config.PREDICT_BEST_MODEL_DIR / architecture)
|
||||
# # V3:3_1_Endovis_2017-8Type-512x512
|
||||
# DATA_DIR = Path('../DataSet_Public/3_1_Endovis_2017-8Type-512x512') # Path(__file__).parent.parent【有中文放弃了】 # 项目根目录 (假设 config.py 在 src/ 下)
|
||||
# OUTPUTS_DIR = DATA_SETS_DIR / "3_1_Endovis_2017-8Type-512x512_outputs-SegModel" # 所有输出文件的根目录 # train 中 Path(config.OUTPUTS_DIR / architecture)
|
||||
# PREDICT_BEST_MODEL_DIR = PREDICT_ALL_BEST_MODELS_DIR / "3_1_Endovis_2017-8Type-512x512_outputs-SegModel" # 最优模型位置 # train 中 Path(config.PREDICT_BEST_MODEL_DIR / architecture)
|
||||
# # V4:3_2_Endovis_2018-8Type-512x512
|
||||
# DATA_DIR = Path('../DataSet_Public/3_2_Endovis_2018-8Type-512x512') # Path(__file__).parent.parent【有中文放弃了】 # 项目根目录 (假设 config.py 在 src/ 下)
|
||||
# OUTPUTS_DIR = DATA_SETS_DIR / "3_2_Endovis_2018-8Type-512x512_outputs-SegModel" # 所有输出文件的根目录 # train 中 Path(config.OUTPUTS_DIR / architecture)
|
||||
# PREDICT_BEST_MODEL_DIR = PREDICT_ALL_BEST_MODELS_DIR / "3_2_Endovis_2018-8Type-512x512_outputs-SegModel" # 最优模型位置 # train 中 Path(config.PREDICT_BEST_MODEL_DIR / architecture)
|
||||
# # V5:4_Dresden-11Type-512x512
|
||||
# DATA_DIR = Path('../DataSet_Public/4_Dresden-11Type-512x512') # Path(__file__).parent.parent【有中文放弃了】 # 项目根目录 (假设 config.py 在 src/ 下)
|
||||
# OUTPUTS_DIR = DATA_SETS_DIR / "4_Dresden-11Type-512x512_outputs-SegModel" # 所有输出文件的根目录 # train 中 Path(config.OUTPUTS_DIR / architecture)
|
||||
# PREDICT_BEST_MODEL_DIR = PREDICT_ALL_BEST_MODELS_DIR / "4_Dresden-11Type-512x512_outputs-SegModel" # 最优模型位置 # train 中 Path(config.PREDICT_BEST_MODEL_DIR / architecture)
|
||||
# # Test_V1:5_Predict_Video
|
||||
DATA_DIR = Path('../DataSet_Public/5_Predict_Video/LC_Video_1') # Path(__file__).parent.parent【有中文放弃了】 # 项目根目录 (假设 config.py 在 src/ 下)
|
||||
OUTPUTS_DIR = None # 所有输出文件的根目录 # train 中 Path(config.OUTPUTS_DIR / architecture)
|
||||
PREDICT_BEST_MODEL_DIR = PREDICT_ALL_BEST_MODELS_DIR / "LC_Video_1_outputs-SegModel" # 最优模型位置 # train 中 Path(config.PREDICT_BEST_MODEL_DIR / architecture)
|
||||
|
||||
# --- 2. 训练与验证数据路径 (Training & Validation Paths for train.py) ---
|
||||
# 在 train.py 中已使用
|
||||
TRAIN_IMAGE_DIR = DATA_DIR / "images" / "train"
|
||||
TRAIN_MASK_DIR = DATA_DIR / "labels_GT" / "train"
|
||||
VAL_IMAGE_DIR = DATA_DIR / "images" / "val" # TODO "val_images"
|
||||
VAL_MASK_DIR = DATA_DIR / "labels_GT" / "val" # TODO "val_masks"
|
||||
|
||||
# --- 3. 预测数据路径 (Prediction Paths for predict.py) ---
|
||||
TEST_IMAGE_DIR = DATA_DIR / "images" / "val" # 测试图像目录 # TODO "test_images"
|
||||
TEST_MASK_DIR = DATA_DIR / "labels_GT" / "val" # 测试掩码目录 (用于评估) # TODO "test_masks"
|
||||
|
||||
# --- 4. 输出文件与目录路径 (Output Files & Directories) ---
|
||||
RAW_MASK_FOLDER = "predicted_raw_masks" # 存放预测出的单通道原始掩码
|
||||
ANALYSIS_RESULTS_FOLDER = "prediction_analysis" # 存放对比图、曲线和指标CSV
|
||||
|
||||
# 训练过程中的输出文件
|
||||
BEST_MODEL_SAVE_NAME = "best_model.pth"
|
||||
METRICS_CSV_NAME = "training_metrics.csv"
|
||||
|
||||
# --- 5. 模型与数据参数 (Model & Data Parameters) ---
|
||||
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
# --- 5.1. 【新增】模型选择与参数化配置 ---
|
||||
# ※ 不定义训练用模型,在train、predict中定义 ※
|
||||
|
||||
# 这是一个“配置库”,存放了所有模型的参数设置 # 具体请参考: https://smp.readthedocs.io/en/latest/models.html
|
||||
ALL_MODEL_CONFIGS = {
|
||||
'Unet': {
|
||||
'encoder_name': 'resnet34',
|
||||
'encoder_weights': 'imagenet',
|
||||
'decoder_channels': (256, 128, 64, 32, 16),
|
||||
'decoder_attention_type': None, # 可选 'scse'
|
||||
},
|
||||
'UnetPlusPlus': {
|
||||
'encoder_name': 'resnet34',
|
||||
'encoder_weights': 'imagenet',
|
||||
'decoder_channels': (256, 128, 64, 32, 16),
|
||||
'decoder_attention_type': None, # 可选 'scse'
|
||||
},
|
||||
'FPN': {
|
||||
'encoder_name': 'resnet34',
|
||||
'encoder_weights': 'imagenet',
|
||||
},
|
||||
'PSPNet': {
|
||||
'encoder_name': 'resnet34',
|
||||
'encoder_weights': 'imagenet',
|
||||
},
|
||||
'DeepLabV3': {
|
||||
'encoder_name': 'resnet34',
|
||||
'encoder_weights': 'imagenet',
|
||||
},
|
||||
'DeepLabV3Plus': {
|
||||
'encoder_name': 'resnet34',
|
||||
'encoder_weights': 'imagenet',
|
||||
},
|
||||
|
||||
'Linknet': {
|
||||
'encoder_name': 'resnet34',
|
||||
'encoder_weights': 'imagenet',
|
||||
},
|
||||
'MAnet': {
|
||||
'encoder_name': 'resnet34',
|
||||
'encoder_weights': 'imagenet',
|
||||
},
|
||||
'PAN': {
|
||||
'encoder_name': 'resnet34',
|
||||
'encoder_weights': 'imagenet',
|
||||
},
|
||||
'UPerNet': {
|
||||
'encoder_name': 'resnet34',
|
||||
'encoder_weights': 'imagenet',
|
||||
},
|
||||
'Segformer': {
|
||||
'encoder_name': 'resnet34',
|
||||
'encoder_weights': 'imagenet',
|
||||
},
|
||||
'DPT': {
|
||||
'encoder_name': 'tu-vit_base_patch16_224.augreg_in21k',
|
||||
'encoder_weights': 'imagenet',
|
||||
}
|
||||
}
|
||||
|
||||
SEG_MODE = "multiclass" # 分割模式:'multiclass' 或 'multilabel'
|
||||
|
||||
# TODO 评估参数排除项:在计算评估指标 (如 IoU, F1-score) 时要忽略的类别列表。 TODO
|
||||
EVALUATION_CLASSES_TO_IGNORE = [] # 如果列表为空 [], 则评估所有类别。
|
||||
# EVALUATION_CLASSES_TO_IGNORE = ['background'] # 例如,设置为 ['background'] 将在计算总体指标时排除背景类。
|
||||
IGNORE_INDEX = -100 # 当不包含背景时,掩码中背景像素的值,损失函数会忽略它
|
||||
|
||||
# V1:1_CholecSeg8k-13Type-1920x1080
|
||||
CLASSES = ['background', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12'] # 根据您的数据集修改
|
||||
CLASS_RGB_VALUES = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12] # 根据您的数据集修改
|
||||
# V2:2_AutoLaparo-10Type-1920x1080
|
||||
# CLASSES = ['background', '1', '2', '3', '4', '5', '6', '7', '8', '9'] # 根据您的数据集修改
|
||||
# CLASS_RGB_VALUES = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] # 根据您的数据集修改
|
||||
# V3:3_1_Endovis_2017-8Type-512x512
|
||||
# CLASSES = ['background', '1', '2', '3', '4', '5', '6', '7'] # 根据您的数据集修改
|
||||
# CLASS_RGB_VALUES = [0, 1, 2, 3, 4, 5, 6, 7] # 根据您的数据集修改
|
||||
# V4:3_2_Endovis_2018-8Type-512x512
|
||||
# CLASSES = ['background', '1', '2', '3', '4', '5', '6', '7'] # 根据您的数据集修改
|
||||
# CLASS_RGB_VALUES = [0, 1, 2, 3, 4, 5, 6, 7] # 根据您的数据集修改
|
||||
# V5:4_Dresden-11Type-512x512
|
||||
# CLASSES = ['background', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10'] # 根据您的数据集修改
|
||||
# CLASS_RGB_VALUES = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] # 根据您的数据集修改
|
||||
|
||||
# 图像尺寸
|
||||
IMAGE_HEIGHT = 256 # 512
|
||||
IMAGE_WIDTH = 256 # 512
|
||||
|
||||
# --- 6. 训练超参数 (Training Hyperparameters) ---
|
||||
BATCH_SIZE = 16
|
||||
NUM_WORKERS = 8
|
||||
EPOCHS = 300
|
||||
LEARNING_RATE = 1e-4 # 1e-3 # MANet_降低学习率至 1e-4
|
||||
WEIGHT_DECAY = 1e-4 # L2 正则化权重衰减
|
||||
PIN_MEMORY = True # 是否将数据加载到锁页内存中以加速传输到 GPU
|
||||
FBETA_BETA = 1.0 # F-beta score 的 beta 值。beta=1 等同于 F1-score
|
||||
EARLY_STOPPING_PATIENCE = 100 # 早停机制的耐心值
|
||||
CKPT_SAVE_INTERVAL = 10 # 每隔多少轮保存一次检查点模型
|
||||
99
Seg_All_In_One_SegModel/dataset.py
Normal file
99
Seg_All_In_One_SegModel/dataset.py
Normal file
@@ -0,0 +1,99 @@
|
||||
import os
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
class SegmentationDataset(Dataset):
|
||||
"""
|
||||
自定义图像分割数据集。
|
||||
根据指定的 seg_mode,可以生成适用于 'multiclass' 或 'multilabel' 任务的掩码。
|
||||
"""
|
||||
def __init__(self, image_dir, mask_dir, classes, class_rgb_values, seg_mode, augmentation=None):
|
||||
"""
|
||||
Args:
|
||||
image_dir (str): 图像文件目录。
|
||||
mask_dir (str): 掩码文件目录。
|
||||
classes (list): 类别名称列表。
|
||||
class_rgb_values (list): 每个类别在灰度掩码中对应的像素值。
|
||||
augmentation (albumentations.Compose, optional): 数据增强流程。
|
||||
seg_mode (str): 分割模式, 'multiclass' 或 'multilabel'。
|
||||
"""
|
||||
self.image_dir = image_dir
|
||||
self.mask_dir = mask_dir
|
||||
self.image_filenames = sorted(os.listdir(image_dir))
|
||||
self.augmentation = augmentation
|
||||
self.classes = classes
|
||||
self.class_rgb_values = class_rgb_values
|
||||
|
||||
# 【新增】存储分割模式并进行验证
|
||||
self.seg_mode = seg_mode
|
||||
if self.seg_mode not in ['multiclass', 'multilabel']:
|
||||
raise ValueError(f"seg_mode must be 'multiclass' or 'multilabel', but got {self.seg_mode}")
|
||||
|
||||
print(f"Found {len(self.image_filenames)} images in {image_dir}. Dataset mode: '{self.seg_mode}'")
|
||||
|
||||
def __len__(self):
|
||||
return len(self.image_filenames)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
img_path = os.path.join(self.image_dir, self.image_filenames[idx])
|
||||
image = cv2.imread(img_path)
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||
|
||||
mask_path = os.path.join(self.mask_dir, self.image_filenames[idx])
|
||||
# 读取灰度图
|
||||
mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
|
||||
|
||||
if mask is None:
|
||||
raise FileNotFoundError(f"Mask file not found or could not be read: {mask_path}")
|
||||
|
||||
# 【核心修改】根据 seg_mode 选择掩码的处理方式
|
||||
if self.seg_mode == 'multilabel':
|
||||
# 为 'multilabel' 模式创建 one-hot 编码掩码
|
||||
processed_mask = self._create_one_hot_mask(mask)
|
||||
else: # 'multiclass'
|
||||
# 为 'multiclass' 模式创建类索引掩码
|
||||
processed_mask = self._create_class_index_mask(mask)
|
||||
|
||||
if self.augmentation:
|
||||
# Albumentations 可以同时处理 (H,W,C) 和 (H,W) 格式的掩码
|
||||
sample = self.augmentation(image=image, mask=processed_mask)
|
||||
image = sample['image']
|
||||
mask = sample['mask']
|
||||
else:
|
||||
mask = processed_mask
|
||||
|
||||
# 【核心修改】根据 seg_mode 对最终的掩码张量进行处理
|
||||
if self.seg_mode == 'multilabel':
|
||||
# 调整维度顺序 (H, W, C) -> (C, H, W) 并转换为 float
|
||||
mask = mask.permute(2, 0, 1).float()
|
||||
else: # 'multiclass'
|
||||
# 直接转换为 long 类型,不需要调整维度
|
||||
mask = mask.long()
|
||||
|
||||
return image, mask
|
||||
|
||||
def _create_one_hot_mask(self, mask):
|
||||
"""
|
||||
将单通道的灰度掩码 (H, W) 转换为 one-hot 编码的掩码 (H, W, C)。
|
||||
这是为 'multilabel' 模式准备的。
|
||||
返回一个 NumPy 数组。
|
||||
"""
|
||||
semantic_map = np.zeros((mask.shape[0], mask.shape[1], len(self.class_rgb_values)), dtype=np.uint8)
|
||||
for i, value in enumerate(self.class_rgb_values):
|
||||
semantic_map[:, :, i] = (mask == value).astype(np.uint8)
|
||||
return semantic_map
|
||||
|
||||
def _create_class_index_mask(self, mask):
|
||||
"""
|
||||
【新增函数】
|
||||
将单通道的灰度掩码 (H, W) 转换为类索引掩码 (H, W)。
|
||||
这是为 'multiclass' 模式准备的。
|
||||
返回一个 NumPy 数组。
|
||||
"""
|
||||
class_index_map = np.zeros(mask.shape, dtype=np.uint8)
|
||||
for i, value in enumerate(self.class_rgb_values):
|
||||
# 将灰度值为 value 的像素,其类别索引设置为 i
|
||||
class_index_map[mask == value] = i
|
||||
return class_index_map
|
||||
86
Seg_All_In_One_SegModel/loss.py
Normal file
86
Seg_All_In_One_SegModel/loss.py
Normal file
@@ -0,0 +1,86 @@
|
||||
# In utils.py or a new losses.py
|
||||
import config
|
||||
import torch.nn as nn
|
||||
import segmentation_models_pytorch as smp
|
||||
|
||||
class UNetPlusPlusLoss(nn.Module):
|
||||
"""
|
||||
仿照 UNet++ 官方实现,结合了 BCE 和 Dice Loss。
|
||||
这个损失函数适用于 segmentation-models-pytorch 库的输出。
|
||||
|
||||
Args:
|
||||
mode (str): DiceLoss 的模式, e.g., 'multilabel', 'multiclass'.
|
||||
bce_weight (float): BCE Loss 在总损失中所占的权重。
|
||||
dice_weight (float): Dice Loss 在总损失中所占的权重。
|
||||
"""
|
||||
def __init__(self, mode=config.SEG_MODE, bce_weight=0.5, dice_weight=0.5):
|
||||
super(UNetPlusPlusLoss, self).__init__()
|
||||
|
||||
# 使用数值更稳定的 BCEWithLogitsLoss
|
||||
self.bce_loss = nn.BCEWithLogitsLoss()
|
||||
self.dice_loss = smp.losses.DiceLoss(mode=mode)
|
||||
self.bce_weight = bce_weight
|
||||
self.dice_weight = dice_weight
|
||||
|
||||
def forward(self, y_pred, y_true):
|
||||
"""
|
||||
计算组合损失。
|
||||
|
||||
Args:
|
||||
y_pred: 模型的预测输出 (logits)。
|
||||
y_true: 真实的标签 (mask)。
|
||||
|
||||
Returns:
|
||||
组合后的损失值。
|
||||
"""
|
||||
bce = self.bce_loss(y_pred, y_true)
|
||||
dice = self.dice_loss(y_pred, y_true)
|
||||
|
||||
# 按照权重将两种损失相加
|
||||
total_loss = self.bce_weight * bce + self.dice_weight * dice
|
||||
return total_loss
|
||||
|
||||
class MultiClassLoss(nn.Module):
|
||||
"""
|
||||
A combined loss function for multi-class segmentation.
|
||||
This loss combines CrossEntropyLoss and DiceLoss, which is a common and effective
|
||||
practice for semantic segmentation tasks.
|
||||
|
||||
Args:
|
||||
mode (str): DiceLoss mode, should be 'multiclass'.
|
||||
ce_weight (float): Weight for the CrossEntropyLoss component.
|
||||
dice_weight (float): Weight for the DiceLoss component.
|
||||
ignore_index (int): Specifies a target value that is ignored and does not contribute to the input gradient.
|
||||
"""
|
||||
def __init__(self, mode='multiclass', ce_weight=0.5, dice_weight=0.5, ignore_index=-100):
|
||||
super(MultiClassLoss, self).__init__()
|
||||
|
||||
# CrossEntropyLoss is the standard for multi-class classification.
|
||||
# It combines LogSoftmax and NLLLoss in one single class.
|
||||
self.ce_loss = nn.CrossEntropyLoss(ignore_index=ignore_index)
|
||||
|
||||
# DiceLoss in 'multiclass' mode works correctly with class indices.
|
||||
self.dice_loss = smp.losses.DiceLoss(mode=mode)
|
||||
|
||||
self.ce_weight = ce_weight
|
||||
self.dice_weight = dice_weight
|
||||
|
||||
def forward(self, y_pred, y_true):
|
||||
"""
|
||||
Calculates the combined loss.
|
||||
|
||||
Args:
|
||||
y_pred: Model predictions (logits), shape (N, C, H, W).
|
||||
y_true: Ground truth labels (class indices), shape (N, H, W) and dtype=torch.long.
|
||||
|
||||
Returns:
|
||||
The combined loss value.
|
||||
"""
|
||||
# Ensure target tensor is of type long for CrossEntropyLoss
|
||||
y_true = y_true.long()
|
||||
|
||||
ce = self.ce_loss(y_pred, y_true)
|
||||
dice = self.dice_loss(y_pred, y_true)
|
||||
|
||||
total_loss = self.ce_weight * ce + self.dice_weight * dice
|
||||
return total_loss
|
||||
109
Seg_All_In_One_SegModel/predict.sh
Normal file
109
Seg_All_In_One_SegModel/predict.sh
Normal file
@@ -0,0 +1,109 @@
|
||||
#!/bin/bash
|
||||
# --- 1. Conda 环境设置 ---
|
||||
CONDA_BASE_PATH="/home/wkmgc/miniconda3" # <--- 在这里修改为您自己的路径
|
||||
CONDA_ENV_NAME="${SEG_CONDA_ENV:-seg_smp}" # 可用 SEG_CONDA_ENV=SMP bash predict.sh 临时覆盖
|
||||
|
||||
# 初始化并激活 Conda 环境
|
||||
if [ -f "${CONDA_BASE_PATH}/etc/profile.d/conda.sh" ]; then
|
||||
source "${CONDA_BASE_PATH}/etc/profile.d/conda.sh"
|
||||
conda activate "${CONDA_ENV_NAME}"
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "错误: 激活 Conda 环境 '${CONDA_ENV_NAME}' 失败!"
|
||||
exit 1
|
||||
fi
|
||||
echo "Conda 环境 '${CONDA_ENV_NAME}' 已成功激活。"
|
||||
else
|
||||
echo "错误: 找不到 conda.sh 脚本。请检查您的 CONDA_BASE_PATH 设置是否正确。"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# --- 2. 模型与 GPU 配置 ---
|
||||
GPUS_GROUP_1="2"
|
||||
GPUS_GROUP_2="3"
|
||||
GPUS_GROUP_3="4"
|
||||
GPUS_GROUP_4="5"
|
||||
|
||||
# 注意:这里的模型架构列表应与 train.sh 保持一致,以确保能找到对应的训练好的模型
|
||||
GROUP_1_ARCHS=("PSPNet" "Unet" "UnetPlusPlus" )
|
||||
GROUP_2_ARCHS=("Linknet" "MAnet" "DeepLabV3" )
|
||||
GROUP_3_ARCHS=("UPerNet" "Segformer" "DPT" )
|
||||
GROUP_4_ARCHS=("FPN" "DeepLabV3Plus" "PAN")
|
||||
|
||||
# 1. 从 config.py 中读取 PREDICT_BEST_MODEL_DIR 的值
|
||||
PREDICT_BEST_MODEL_DIR=$(python -c "from config import PREDICT_BEST_MODEL_DIR; print(PREDICT_BEST_MODEL_DIR)")
|
||||
# 检查是否成功获取了 PREDICT_BEST_MODEL_DIR
|
||||
if [ -z "$PREDICT_BEST_MODEL_DIR" ]; then
|
||||
echo "PREDICT_BEST_MODEL_DIR: $PREDICT_BEST_MODEL_DIR"
|
||||
echo "Error: Could not read PREDICT_BEST_MODEL_DIR from yolo_config.py. Exiting."
|
||||
echo "Error 2: Or the directory specified by PREDICT_BEST_MODEL_DIR does not exist. Please create it first."
|
||||
exit 1
|
||||
fi
|
||||
# 2. 定义带有时间戳的日志目录名
|
||||
LOG_DIR_NAME="predict_logs_parallel_$(date +%Y-%m-%d_%H-%M-%S)"
|
||||
# 3. 拼接成最终的完整路径
|
||||
LOG_DIR="$PREDICT_BEST_MODEL_DIR/$LOG_DIR_NAME"
|
||||
mkdir -p "${LOG_DIR}"
|
||||
echo "所有模型的预测日志将保存在 ./${LOG_DIR}/ 目录中。"
|
||||
echo "----------------------------------------------------"
|
||||
|
||||
|
||||
# --- 3. 依次启动所有预测任务 ---
|
||||
# 脚本将按顺序逐一执行每个模型的预测,等待上一个完成后再开始下一个。
|
||||
|
||||
echo ">>> 准备启动第一组预测任务 (后台运行)..."
|
||||
for arch in "${GROUP_1_ARCHS[@]}"; do
|
||||
echo " -> 正在后台启动模型: ${arch} on GPUs: ${GPUS_GROUP_1}"
|
||||
# 【修改点】: 使用 'echo "1" |' 来自动回答 1_predict.py 中的交互式提问,选择第一个找到的模型。
|
||||
echo "1" | CUDA_VISIBLE_DEVICES=${GPUS_GROUP_1} python 1_predict.py -a "${arch}" > "${LOG_DIR}/${arch}.log" 2>&1 &
|
||||
echo " - 模型 ${arch} 预测已在后台启动。日志文件: ${LOG_DIR}/${arch}.log"
|
||||
echo " - 等待 60 秒..."
|
||||
sleep 60
|
||||
done
|
||||
echo ">>> 第一组所有模型均已启动。"
|
||||
echo "----------------------------------------------------"
|
||||
|
||||
echo ">>> 准备启动第二组预测任务 (后台运行)..."
|
||||
for arch in "${GROUP_2_ARCHS[@]}"; do
|
||||
echo " -> 正在后台启动模型: ${arch} on GPUs: ${GPUS_GROUP_2}"
|
||||
# 【修改点】: 自动选择第一个模型
|
||||
echo "1" | CUDA_VISIBLE_DEVICES=${GPUS_GROUP_2} python 1_predict.py -a "${arch}" > "${LOG_DIR}/${arch}.log" 2>&1 &
|
||||
echo " - 模型 ${arch} 预测已在后台启动。日志文件: ${LOG_DIR}/${arch}.log"
|
||||
echo " - 等待 50 秒..."
|
||||
sleep 60
|
||||
done
|
||||
echo ">>> 第二组所有模型均已启动。"
|
||||
echo "----------------------------------------------------"
|
||||
|
||||
echo ">>> 准备启动第三组预测任务 (后台运行)..."
|
||||
for arch in "${GROUP_3_ARCHS[@]}"; do
|
||||
echo " -> 正在后台启动模型: ${arch} on GPUs: ${GPUS_GROUP_3}"
|
||||
# 【修改点】: 自动选择第一个模型
|
||||
echo "1" | CUDA_VISIBLE_DEVICES=${GPUS_GROUP_3} python 1_predict.py -a "${arch}" > "${LOG_DIR}/${arch}.log" 2>&1 &
|
||||
echo " - 模型 ${arch} 预测已在后台启动。日志文件: ${LOG_DIR}/${arch}.log"
|
||||
echo " - 等待 50 秒..."
|
||||
sleep 60
|
||||
done
|
||||
echo ">>> 第三组所有模型均已启动。"
|
||||
echo "----------------------------------------------------"
|
||||
|
||||
echo ">>> 准备启动第四组预测任务 (后台运行)..."
|
||||
for arch in "${GROUP_4_ARCHS[@]}"; do
|
||||
echo " -> 正在后台启动模型: ${arch} on GPUs: ${GPUS_GROUP_4}"
|
||||
# 【修改点】: 自动选择第一个模型
|
||||
echo "1" | CUDA_VISIBLE_DEVICES=${GPUS_GROUP_4} python 1_predict.py -a "${arch}" > "${LOG_DIR}/${arch}.log" 2>&1 &
|
||||
echo " - 模型 ${arch} 预测已在后台启动。日志文件: ${LOG_DIR}/${arch}.log"
|
||||
echo " - 等待 50 秒..."
|
||||
sleep 60
|
||||
done
|
||||
echo ">>> 第四组所有模型均已启动。"
|
||||
echo "----------------------------------------------------"
|
||||
|
||||
# --- 4. 等待所有后台任务完成 ---
|
||||
echo ""
|
||||
echo "--- 所有模型均已在后台启动。现在等待所有预测任务完成... ---"
|
||||
# 'wait' 命令会暂停脚本,直到所有由此脚本启动的后台任务全部执行完毕
|
||||
wait
|
||||
echo "--- 所有后台预测任务已全部完成! ---"
|
||||
|
||||
# 退出前取消激活环境
|
||||
conda deactivate
|
||||
10
Seg_All_In_One_SegModel/requirements.txt
Normal file
10
Seg_All_In_One_SegModel/requirements.txt
Normal file
@@ -0,0 +1,10 @@
|
||||
opencv-python
|
||||
albumentations
|
||||
matplotlib
|
||||
numpy
|
||||
pandas
|
||||
scikit_learn
|
||||
scipy
|
||||
segmentation_models_pytorch==0.5.1.dev0
|
||||
torch
|
||||
tqdm
|
||||
312
Seg_All_In_One_SegModel/train.py
Normal file
312
Seg_All_In_One_SegModel/train.py
Normal file
@@ -0,0 +1,312 @@
|
||||
import logging, argparse, shutil
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Dict, Tuple
|
||||
import gc
|
||||
|
||||
import numpy as np
|
||||
import segmentation_models_pytorch as smp
|
||||
import torch
|
||||
import torch.optim as optim
|
||||
from torch.optim import lr_scheduler # 学习率调度器
|
||||
from torch.amp import GradScaler
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
# 本地应用/库的导入
|
||||
import config, os
|
||||
import utils
|
||||
from dataset import SegmentationDataset
|
||||
from loss import MultiClassLoss, UNetPlusPlusLoss
|
||||
|
||||
# --- 日志设置 ---
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s [%(levelname)s] %(message)s",
|
||||
handlers=[
|
||||
logging.StreamHandler() # 输出日志到控制台
|
||||
]
|
||||
)
|
||||
|
||||
# --- 辅助函数:移动训练结果文件夹从源输出目录移动到配置的硬盘目录 ---
|
||||
def move_results_to_hardisk(project_folder: str):
|
||||
"""
|
||||
将指定的训练结果文件夹从源输出目录移动到配置的硬盘目录。
|
||||
移动操作包括复制整个文件夹树,然后在复制成功后删除原始文件夹。
|
||||
|
||||
参数:
|
||||
project_folder (str): 要移动的项目文件夹的名称。
|
||||
该名称通常由模型架构和时间戳构成。
|
||||
"""
|
||||
source_dir = Path(config.OUTPUTS_DIR) / project_folder
|
||||
# 确保目标硬盘目录存在
|
||||
outputs_folder_name = Path(config.OUTPUTS_DIR).name
|
||||
destination_dir = Path(config.HARDISK_DIR) / outputs_folder_name / project_folder
|
||||
destination_dir.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
logging.info(f"准备将结果从 {source_dir} 移动到 {destination_dir}...")
|
||||
|
||||
try:
|
||||
# 步骤 1: 复制文件夹
|
||||
shutil.copytree(source_dir, destination_dir)
|
||||
logging.info(f"成功复制结果到: {destination_dir}")
|
||||
|
||||
# 步骤 2: 复制成功后,删除原文件夹
|
||||
logging.info(f"正在删除原文件夹: {source_dir}")
|
||||
shutil.rmtree(source_dir)
|
||||
logging.info("成功删除原文件夹。")
|
||||
logging.info(f"任务完成,最终结果已保存至: {destination_dir}")
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"移动文件夹时发生错误: {e}")
|
||||
logging.warning(f"原始训练结果仍保留在: {source_dir}")
|
||||
|
||||
# 1. 根据配置文件初始化模型、损失函数和优化器。
|
||||
def initialize_components(model_architecture:str, num_classes: int, seg_mode: str) -> Tuple:
|
||||
"""
|
||||
根据配置文件动态初始化模型、损失函数和优化器。
|
||||
"""
|
||||
# --- 初始化模型 ---
|
||||
logging.info(f"正在初始化模型: '{model_architecture}'...")
|
||||
|
||||
# 使用传入的 model_architecture 作为 key 来获取参数
|
||||
try:
|
||||
model_params = config.ALL_MODEL_CONFIGS[model_architecture]
|
||||
model_class = getattr(smp, model_architecture)
|
||||
except KeyError:
|
||||
logging.error(f"模型 '{model_architecture}' 的配置未在 config.py 的 ALL_MODEL_CONFIGS 中定义!")
|
||||
raise
|
||||
except AttributeError:
|
||||
logging.error(f"模型 '{model_architecture}' 在 segmentation_models_pytorch 库中不存在!")
|
||||
raise
|
||||
|
||||
# 2. 准备参数字典
|
||||
# 首先复制 config 中的参数,然后添加固定的 `in_channels` 和 `classes`
|
||||
params = model_params.copy()
|
||||
params['in_channels'] = 3
|
||||
params['classes'] = num_classes
|
||||
|
||||
# ======================== 【新增代码段开始】 ======================== #
|
||||
# 自动检测 encoder_name 是否包含 'vit'
|
||||
encoder_name = model_params.get('encoder_name', '')
|
||||
if 'vit' in encoder_name.lower():
|
||||
params['dynamic_img_size'] = True
|
||||
logging.info(f"检测到 ViT 编码器 ('{encoder_name}')。自动设置 dynamic_img_size=True。")
|
||||
# ======================== 【新增代码段结束】 ======================== #
|
||||
|
||||
# 3. 使用字典解包 (**) 将所有参数传递给模型构造函数
|
||||
# 这种方式非常灵活,无论你在 config 中定义了多少参数,都能正确传递
|
||||
model = model_class(**params).to(config.DEVICE)
|
||||
|
||||
# --- 设置损失函数 (保持不变) ---
|
||||
logging.info(f"为 '{seg_mode}' 模式设置损失函数。")
|
||||
if seg_mode == 'multiclass':
|
||||
loss_fn = MultiClassLoss(mode=seg_mode)
|
||||
elif seg_mode == 'multilabel':
|
||||
loss_fn = UNetPlusPlusLoss(mode=seg_mode)
|
||||
else:
|
||||
raise ValueError(f"无效的 SEG_MODE: '{seg_mode}'。必须是 'multiclass' 或 'multilabel'。")
|
||||
|
||||
# --- 优化器和梯度缩放器 (保持不变) ---
|
||||
optimizer = optim.Adam(model.parameters(), lr=config.LEARNING_RATE, weight_decay=config.WEIGHT_DECAY)
|
||||
# --- TODO 添加学习率调度器 TODO ---
|
||||
# T_max 是调度器周期的最大迭代次数,通常设置为总的 epoch 数量
|
||||
scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=config.EPOCHS, eta_min=1e-6)
|
||||
# -----------------------
|
||||
scaler = GradScaler('cuda')
|
||||
|
||||
return model, loss_fn, optimizer, scaler
|
||||
|
||||
def main(model_architecture: str) -> None:
|
||||
"""
|
||||
编排图像分割模型的主要训练流程。
|
||||
|
||||
该函数执行以下步骤:
|
||||
1. 设置输出目录。
|
||||
2. 准备数据增强,并为训练集和验证集创建 DataLoader 实例。
|
||||
3. 初始化 U-Net++ 模型、损失函数、优化器和梯度缩放器。
|
||||
4. 运行主训练循环,其中包括:
|
||||
- 训练一个 epoch。
|
||||
- 在验证集上评估模型。
|
||||
- 将各项指标记录到 CSV 文件。
|
||||
- 基于验证损失实现早停机制。
|
||||
- 保存性能最佳的模型和定期的断点。
|
||||
- 绘制训练进度曲线图。
|
||||
返回:
|
||||
str: 本次训练运行的文件夹名称 (run_name)。
|
||||
"""
|
||||
logging.info(f"使用设备: {config.DEVICE}")
|
||||
|
||||
# --- 1.1. 使用 pathlib 进行现代化的路径管理 ---
|
||||
timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
||||
run_name = f"{model_architecture}_{timestamp}"
|
||||
output_dir = Path(config.OUTPUTS_DIR) / run_name # <-- 使用 / 操作符
|
||||
metrics_csv_path = output_dir / config.METRICS_CSV_NAME
|
||||
best_model_path = output_dir / config.BEST_MODEL_SAVE_NAME
|
||||
|
||||
utils.setup_directories(output_dir)
|
||||
|
||||
# --- 1.2. 探测模型以确定最佳图像尺寸 ---
|
||||
logging.info("正在探测模型以确定输入尺寸要求...") #
|
||||
num_classes = len(config.CLASSES) #
|
||||
|
||||
# 创建一个临时模型实例,仅用于检查其属性
|
||||
# 注意:这里的调用现在更简单,不需要传递 dynamic_img_size
|
||||
probe_model, _, _, _ = initialize_components(
|
||||
model_architecture, num_classes, config.SEG_MODE #
|
||||
)
|
||||
|
||||
target_height, target_width = config.IMAGE_HEIGHT, config.IMAGE_WIDTH #
|
||||
|
||||
try:
|
||||
# 检查编码器是否需要固定输入尺寸
|
||||
if probe_model.encoder.is_fixed_input_size:
|
||||
required_size = probe_model.encoder.input_size
|
||||
target_height = required_size[1]
|
||||
target_width = required_size[2]
|
||||
logging.warning(
|
||||
f"模型 '{model_architecture}' 的编码器需要固定的输入尺寸 "
|
||||
f"({target_height}x{target_width})。"
|
||||
f"将忽略 config.py 中的尺寸设置。"
|
||||
)
|
||||
else:
|
||||
logging.info(f"模型 '{model_architecture}' 支持动态输入尺寸。将使用 config.py 中定义的尺寸 ({target_height}x{target_width})。")
|
||||
except AttributeError:
|
||||
logging.info("无法自动检测模型尺寸要求。将使用 config.py 中定义的尺寸。")
|
||||
|
||||
del probe_model # 删除临时模型,释放内存
|
||||
|
||||
# --- 1.3. 数据加载 ---
|
||||
logging.info(f"正在设置数据增强和加载器,目标尺寸为 {target_height}x{target_width}...") #
|
||||
# 使用探测到的或配置中指定的目标尺寸
|
||||
train_transform = utils.get_training_augmentation(target_height, target_width) #
|
||||
val_transform = utils.get_validation_augmentation(target_height, target_width) #
|
||||
|
||||
train_dataset = SegmentationDataset(
|
||||
image_dir=Path(config.TRAIN_IMAGE_DIR),
|
||||
mask_dir=Path(config.TRAIN_MASK_DIR),
|
||||
classes=config.CLASSES,
|
||||
class_rgb_values=config.CLASS_RGB_VALUES,
|
||||
augmentation=train_transform,
|
||||
seg_mode=config.SEG_MODE
|
||||
)
|
||||
val_dataset = SegmentationDataset(
|
||||
# 已修正路径,使用配置文件中的验证数据目录
|
||||
image_dir=Path(config.VAL_IMAGE_DIR),
|
||||
mask_dir=Path(config.VAL_MASK_DIR),
|
||||
classes=config.CLASSES,
|
||||
class_rgb_values=config.CLASS_RGB_VALUES,
|
||||
augmentation=val_transform,
|
||||
seg_mode=config.SEG_MODE
|
||||
)
|
||||
|
||||
train_loader = DataLoader(train_dataset, batch_size=config.BATCH_SIZE, shuffle=True, pin_memory=config.PIN_MEMORY, num_workers=config.NUM_WORKERS)
|
||||
val_loader = DataLoader(val_dataset, batch_size=config.BATCH_SIZE, shuffle=False, pin_memory=config.PIN_MEMORY, num_workers=config.NUM_WORKERS)
|
||||
|
||||
# --- 2. 模型、损失函数、优化器、梯度缩放器 ---
|
||||
num_classes = len(config.CLASSES)
|
||||
model, loss_fn, optimizer, scaler = initialize_components(model_architecture, num_classes, config.SEG_MODE)
|
||||
|
||||
# --- 启用多 GPU 训练 ---
|
||||
if torch.cuda.device_count() >= 1:
|
||||
logging.info(f"正在使用 {torch.cuda.device_count()} 块 GPU 进行并行训练。")
|
||||
model = torch.nn.DataParallel(model)
|
||||
|
||||
# --- 3. 训练循环 ---
|
||||
best_val_loss = float('inf')
|
||||
epochs_no_improve = 0
|
||||
|
||||
for epoch in range(config.EPOCHS):
|
||||
logging.info(f"\n--- 第 {epoch+1}/{config.EPOCHS} 轮 ---")
|
||||
# --- 3.1. 训练集训练数据 ---
|
||||
train_loss = utils.train_fn(train_loader, model, optimizer, loss_fn, scaler, config.DEVICE, model_architecture)
|
||||
# --- 3.2. 验证集测试数据【获得指标】 ---
|
||||
val_metrics_dict = utils.evaluate_fn(val_loader, model, loss_fn, config.DEVICE, num_classes, model_architecture, beta=config.FBETA_BETA)
|
||||
|
||||
# --- 3.3. 日志记录 及 显示 ---
|
||||
log_data = {'epoch': epoch + 1, 'train_loss': train_loss}
|
||||
log_data.update(val_metrics_dict)
|
||||
utils.log_metrics_to_csv(log_data, metrics_csv_path)
|
||||
|
||||
val_loss = val_metrics_dict.get('val_loss')
|
||||
iou_score = val_metrics_dict.get('iou_score', 0)
|
||||
fbeta_score = val_metrics_dict.get('fbeta_score', 0)
|
||||
|
||||
logging.info(f"训练集Loss: {train_loss:.4f} | 验证集Loss: {val_loss:.4f} | IoU: {iou_score:.4f} | F-beta: {fbeta_score:.4f}")
|
||||
|
||||
# --- 3.4. 早停与模型检查点 ---
|
||||
if val_loss < best_val_loss:
|
||||
best_val_loss = val_loss
|
||||
epochs_no_improve = 0
|
||||
if isinstance(model, torch.nn.DataParallel):
|
||||
torch.save(model.module.state_dict(), best_model_path)
|
||||
else:
|
||||
torch.save(model.state_dict(), best_model_path)
|
||||
logging.info(f"验证损失改善至 {best_val_loss:.4f}。正在保存最佳模型至 '{best_model_path}'。")
|
||||
else:
|
||||
epochs_no_improve += 1
|
||||
logging.info(f"验证损失已连续 {epochs_no_improve} 轮未改善。")
|
||||
|
||||
if isinstance(config.EARLY_STOPPING_PATIENCE, int) and epochs_no_improve >= config.EARLY_STOPPING_PATIENCE:
|
||||
logging.info(f"\n训练在 {epoch + 1} 轮后触发早停。")
|
||||
logging.warning(f"验证损失已连续 {config.EARLY_STOPPING_PATIENCE} 轮未改善。")
|
||||
break
|
||||
|
||||
# --- 3.5. 检查点及曲线图保存 ---
|
||||
checkpoint_path = output_dir / f"epoch_{epoch+1}.pth"
|
||||
if isinstance(model, torch.nn.DataParallel):
|
||||
torch.save(model.module.state_dict(), checkpoint_path)
|
||||
else:
|
||||
torch.save(model.state_dict(), checkpoint_path)
|
||||
logging.info(f"第 {epoch+1} 轮的检查点已保存至 '{checkpoint_path}'。")
|
||||
|
||||
# ---------- 清理旧检查点 ----------
|
||||
# 只在 epoch > 0 时才开始删,避免第一轮就误删
|
||||
if epoch > 0:
|
||||
# 要尝试删除的 epoch 编号 = 上一轮
|
||||
last_epoch = epoch
|
||||
# 如果上一轮不是“10 的整数倍”就删掉
|
||||
if last_epoch % config.CKPT_SAVE_INTERVAL != 0:
|
||||
old_ckpt = output_dir / f"epoch_{last_epoch}.pth"
|
||||
if old_ckpt.exists():
|
||||
old_ckpt.unlink() # 文件删除
|
||||
|
||||
utils.plot_training_progress(metrics_csv_path, output_dir)
|
||||
logging.info("所有训练曲线图已更新并保存。")
|
||||
|
||||
return run_name
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 创建一个参数解析器
|
||||
parser = argparse.ArgumentParser(description="训练图像分割模型。")
|
||||
|
||||
# 添加 --architecture 参数
|
||||
# - `choices` 会自动从你的 config 文件中获取所有可用的模型名称
|
||||
# - `default` 设置了在不提供参数时的默认模型
|
||||
parser.add_argument(
|
||||
"-a", "--architecture",
|
||||
type=str,
|
||||
choices=list(config.ALL_MODEL_CONFIGS.keys()),
|
||||
required=True,
|
||||
help="选择要训练的模型架构。"
|
||||
)
|
||||
|
||||
# 解析命令行传入的参数
|
||||
args = parser.parse_args()
|
||||
|
||||
# 将解析出的架构名称传入 main 函数并执行
|
||||
project_folder = None
|
||||
try:
|
||||
project_folder = main(model_architecture=args.architecture)
|
||||
finally:
|
||||
# 强制进行垃圾回收,释放 Python 对象占用的系统内存 (RAM)
|
||||
gc.collect()
|
||||
logging.info("已执行垃圾回收 (gc.collect())。")
|
||||
# 清理 PyTorch 在 CUDA 上缓存的显存
|
||||
torch.cuda.empty_cache()
|
||||
logging.info("训练结束,已清理 CUDA 缓存。")
|
||||
|
||||
# 训练和清理完成后,移动结果到硬盘
|
||||
if project_folder:
|
||||
move_results_to_hardisk(project_folder)
|
||||
else:
|
||||
logging.error("由于训练失败或未生成项目文件夹,未执行结果移动操作。")
|
||||
105
Seg_All_In_One_SegModel/train.sh
Normal file
105
Seg_All_In_One_SegModel/train.sh
Normal file
@@ -0,0 +1,105 @@
|
||||
#!/bin/bash
|
||||
# --- 1. Conda 环境设置 ---
|
||||
CONDA_BASE_PATH="/home/wkmgc/miniconda3" # <--- 在这里修改为您自己的路径
|
||||
CONDA_ENV_NAME="${SEG_CONDA_ENV:-seg_smp}" # 可用 SEG_CONDA_ENV=SMP bash train.sh 临时覆盖
|
||||
|
||||
# 初始化并激活 Conda 环境
|
||||
if [ -f "${CONDA_BASE_PATH}/etc/profile.d/conda.sh" ]; then
|
||||
source "${CONDA_BASE_PATH}/etc/profile.d/conda.sh"
|
||||
conda activate "${CONDA_ENV_NAME}"
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "错误: 激活 Conda 环境 '${CONDA_ENV_NAME}' 失败!"
|
||||
exit 1
|
||||
fi
|
||||
echo "Conda 环境 '${CONDA_ENV_NAME}' 已成功激活。"
|
||||
else
|
||||
echo "错误: 找不到 conda.sh 脚本。请检查您的 CONDA_BASE_PATH 设置是否正确。"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# --- 2. 模型与 GPU 配置 ---
|
||||
GPUS_GROUP_1="3"
|
||||
GPUS_GROUP_2="2"
|
||||
GPUS_GROUP_3="1"
|
||||
GPUS_GROUP_4="4"
|
||||
|
||||
GROUP_1_ARCHS=("PSPNet" "Unet" "UnetPlusPlus" ) # G3 # "PSPNet" # "Unet" "UnetPlusPlus" "FPN"
|
||||
GROUP_2_ARCHS=("Linknet" "MAnet" "DeepLabV3" ) # G3 "Linknet" "MAnet" # "DeepLabV3" # "DeepLabV3Plus"
|
||||
GROUP_3_ARCHS=("UPerNet" "Segformer" "DPT" ) # G3 # "UPerNet" "Segformer" "DPT" # "PAN"
|
||||
GROUP_4_ARCHS=("FPN" "DeepLabV3Plus" "PAN")
|
||||
|
||||
# 1. 从 config.py 中读取 OUTPUTS_DIR 的值
|
||||
OUTPUTS_DIR=$(python -c "from config import OUTPUTS_DIR; print(OUTPUTS_DIR)")
|
||||
# 检查是否成功获取了 OUTPUTS_DIR
|
||||
if [ -z "$OUTPUTS_DIR" ]; then
|
||||
echo "OUTPUTS_DIR: $OUTPUTS_DIR"
|
||||
echo "Error 1: Could not read OUTPUTS_DIR from config.py. Exiting."
|
||||
echo "Error 2: Or the directory specified by OUTPUTS_DIR does not exist. Please create it first."
|
||||
exit 1
|
||||
fi
|
||||
# 2. 定义带有时间戳的日志目录名
|
||||
LOG_DIR_NAME="predict_logs_parallel_$(date +%Y-%m-%d_%H-%M-%S)"
|
||||
# 3. 拼接成最终的完整路径
|
||||
LOG_DIR="$OUTPUTS_DIR/$LOG_DIR_NAME"
|
||||
mkdir -p "${LOG_DIR}"
|
||||
echo "所有模型的日志将保存在 ./${LOG_DIR}/ 目录中。"
|
||||
echo "----------------------------------------------------"
|
||||
|
||||
|
||||
# --- 3. 依次启动所有训练任务 ---
|
||||
# 脚本将按顺序逐一执行每个模型的训练,等待上一个完成后再开始下一个。
|
||||
|
||||
echo ">>> 准备启动第一组训练任务 (后台运行)..."
|
||||
for arch in "${GROUP_1_ARCHS[@]}"; do
|
||||
echo " -> 正在后台启动模型: ${arch} on GPUs: ${GPUS_GROUP_1}"
|
||||
# 使用 '&' 将命令放入后台运行
|
||||
CUDA_VISIBLE_DEVICES=${GPUS_GROUP_1} python train.py -a "${arch}" > "${LOG_DIR}/${arch}.log" 2>&1 &
|
||||
echo " - 模型 ${arch} 已在后台启动。日志文件: ${LOG_DIR}/${arch}.log"
|
||||
echo " - 等待 50 秒..."
|
||||
sleep 50
|
||||
done
|
||||
echo ">>> 第一组所有模型均已启动。"
|
||||
echo "----------------------------------------------------"
|
||||
|
||||
echo ">>> 准备启动第二组训练任务 (后台运行)..."
|
||||
for arch in "${GROUP_2_ARCHS[@]}"; do
|
||||
echo " -> 正在后台启动模型: ${arch} on GPUs: ${GPUS_GROUP_2}"
|
||||
CUDA_VISIBLE_DEVICES=${GPUS_GROUP_2} python train.py -a "${arch}" > "${LOG_DIR}/${arch}.log" 2>&1 &
|
||||
echo " - 模型 ${arch} 已在后台启动。日志文件: ${LOG_DIR}/${arch}.log"
|
||||
echo " - 等待 50 秒..."
|
||||
sleep 50
|
||||
done
|
||||
echo ">>> 第二组所有模型均已启动。"
|
||||
echo "----------------------------------------------------"
|
||||
|
||||
echo ">>> 准备启动第三组训练任务 (后台运行)..."
|
||||
for arch in "${GROUP_3_ARCHS[@]}"; do
|
||||
echo " -> 正在后台启动模型: ${arch} on GPUs: ${GPUS_GROUP_3}"
|
||||
CUDA_VISIBLE_DEVICES=${GPUS_GROUP_3} python train.py -a "${arch}" > "${LOG_DIR}/${arch}.log" 2>&1 &
|
||||
echo " - 模型 ${arch} 已在后台启动。日志文件: ${LOG_DIR}/${arch}.log"
|
||||
echo " - 等待 50 秒..."
|
||||
sleep 50
|
||||
done
|
||||
echo ">>> 第三组所有模型均已启动。"
|
||||
echo "----------------------------------------------------"
|
||||
|
||||
echo ">>> 准备启动第四组训练任务 (后台运行)..."
|
||||
for arch in "${GROUP_4_ARCHS[@]}"; do
|
||||
echo " -> 正在后台启动模型: ${arch} on GPUs: ${GPUS_GROUP_4}"
|
||||
CUDA_VISIBLE_DEVICES=${GPUS_GROUP_4} python train.py -a "${arch}" > "${LOG_DIR}/${arch}.log" 2>&1 &
|
||||
echo " - 模型 ${arch} 已在后台启动。日志文件: ${LOG_DIR}/${arch}.log"
|
||||
echo " - 等待 50 秒..."
|
||||
sleep 50
|
||||
done
|
||||
# echo ">>> 第四组所有模型均已启动。"
|
||||
# echo "----------------------------------------------------"
|
||||
|
||||
# --- 4. 等待所有后台任务完成 ---
|
||||
echo ""
|
||||
echo "--- 所有模型均已在后台启动。现在等待所有训练任务完成... ---"
|
||||
# 'wait' 命令会暂停脚本,直到所有由此脚本启动的后台任务全部执行完毕
|
||||
wait
|
||||
echo "--- 所有后台训练任务已全部完成! ---"
|
||||
|
||||
# 退出前取消激活环境
|
||||
conda deactivate
|
||||
361
Seg_All_In_One_SegModel/utils.py
Normal file
361
Seg_All_In_One_SegModel/utils.py
Normal file
@@ -0,0 +1,361 @@
|
||||
import os, logging, shutil
|
||||
import torch
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from tqdm import tqdm
|
||||
from torch.amp import autocast
|
||||
from typing import Dict, Union, List
|
||||
import albumentations as A
|
||||
from albumentations.pytorch import ToTensorV2
|
||||
import segmentation_models_pytorch as smp
|
||||
import matplotlib.pyplot as plt
|
||||
from pathlib import Path
|
||||
|
||||
import config
|
||||
|
||||
# ------------------- 绘图函数 (全新升级) ------------------- #
|
||||
def log_metrics_to_csv(metrics_dict: Dict[str, Union[int, float]], csv_path: str) -> None:
|
||||
"""
|
||||
使用 pandas 将一个指标字典记录到 CSV 文件中。
|
||||
(此函数保持不变)
|
||||
"""
|
||||
df_new_row = pd.DataFrame([metrics_dict])
|
||||
if not os.path.isfile(csv_path):
|
||||
df_new_row.to_csv(csv_path, index=False)
|
||||
else:
|
||||
df_new_row.to_csv(csv_path, mode='a', header=False, index=False)
|
||||
|
||||
|
||||
def _plot_metric_group(ax, df: pd.DataFrame, group_metrics: List[str], title: str):
|
||||
"""
|
||||
内部辅助函数,用于在给定的坐标轴上绘制一组指标。
|
||||
(此函数保持不变)
|
||||
"""
|
||||
epochs_range = df['epoch']
|
||||
for name in group_metrics:
|
||||
if f"val_{name}" in df.columns:
|
||||
metric_col = f"val_{name}"
|
||||
label = name.replace("_", " ").title()
|
||||
ax.plot(epochs_range, df[metric_col], 'o-', label=label)
|
||||
elif f"{name}" in df.columns:
|
||||
metric_col = f"{name}"
|
||||
label = name.replace("_", " ").title()
|
||||
ax.plot(epochs_range, df[metric_col], 'o-', label=label)
|
||||
else:
|
||||
print("没有搜索到:", f"val_{name}")
|
||||
|
||||
ax.set_title(title)
|
||||
ax.set_xlabel('Epoch')
|
||||
ax.set_ylabel('Score')
|
||||
ax.legend(loc='best', fontsize='small')
|
||||
ax.grid(True)
|
||||
|
||||
|
||||
def plot_training_progress(csv_path: str, output_dir: str):
|
||||
"""
|
||||
从 metrics.csv 文件读取训练历史数据,并生成所有分组和总览图表。
|
||||
(此函数已修正错误)
|
||||
"""
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
try:
|
||||
df_metrics = pd.read_csv(csv_path)
|
||||
except FileNotFoundError:
|
||||
print(f"警告:在路径 {csv_path} 未找到 metrics.csv 文件。跳过绘图。")
|
||||
return
|
||||
|
||||
if len(df_metrics) < 2:
|
||||
print("警告:数据点不足 (少于2个),无法生成趋势图。")
|
||||
return
|
||||
|
||||
epochs_range = df_metrics['epoch']
|
||||
train_loss = df_metrics['train_loss']
|
||||
val_loss = df_metrics['val_loss']
|
||||
|
||||
core_metrics = ["iou_score", "fbeta_score", "accuracy"]
|
||||
pr_metrics = ["precision", "recall", "specificity", "negative_predictive_value"]
|
||||
error_metrics = ["false_positive_rate", "false_negative_rate", "false_discovery_rate", "false_omission_rate"]
|
||||
likelihood_metrics = ["positive_likelihood_ratio", "negative_likelihood_ratio"]
|
||||
|
||||
# --- 1. 单独绘制并保存每个指标组的图表 ---
|
||||
# (这部分逻辑没有变化)
|
||||
fig_loss, ax_loss = plt.subplots(figsize=(10, 6))
|
||||
ax_loss.plot(epochs_range, train_loss, 'o-', label='Training Loss')
|
||||
ax_loss.plot(epochs_range, val_loss, 'o-', label='Validation Loss')
|
||||
ax_loss.set_title('Loss Curves')
|
||||
ax_loss.set_xlabel('Epoch')
|
||||
ax_loss.set_ylabel('Loss')
|
||||
ax_loss.legend()
|
||||
ax_loss.grid(True)
|
||||
plt.tight_layout()
|
||||
plt.savefig(os.path.join(output_dir, "loss_curves.png"))
|
||||
plt.close(fig_loss)
|
||||
|
||||
metric_groups = {
|
||||
"core_performance": core_metrics,
|
||||
"precision_recall": pr_metrics,
|
||||
"error_rates": error_metrics,
|
||||
"likelihood_ratios": likelihood_metrics
|
||||
}
|
||||
for group_name, group_list in metric_groups.items():
|
||||
fig_group, ax_group = plt.subplots(figsize=(10, 7))
|
||||
_plot_metric_group(ax_group, df_metrics, group_list, f'{group_name.replace("_", " ").title()} Metrics')
|
||||
plt.tight_layout()
|
||||
plt.savefig(os.path.join(output_dir, f"metrics_{group_name}.png"))
|
||||
plt.close(fig_group)
|
||||
|
||||
# --- 2. 绘制并保存 3x2 综合概览图 ---
|
||||
fig_overview, axes = plt.subplots(3, 2, figsize=(20, 22))
|
||||
fig_overview.suptitle('Training Progress Overview', fontsize=20)
|
||||
ax_list = axes.ravel()
|
||||
|
||||
# 图 1: 损失曲线 (*** 已修正 ***)
|
||||
# 使用 ax 而不是 ax
|
||||
ax_list[0].plot(epochs_range, train_loss, 'o-', label='Training Loss')
|
||||
ax_list[0].plot(epochs_range, val_loss, 'o-', label='Validation Loss')
|
||||
ax_list[0].set_title('Loss Curves')
|
||||
ax_list[0].set_xlabel('Epoch')
|
||||
ax_list[0].set_ylabel('Loss')
|
||||
ax_list[0].legend()
|
||||
ax_list[0].grid(True)
|
||||
|
||||
# 图 2: 核心性能指标
|
||||
# 这里使用 ax 是正确的
|
||||
_plot_metric_group(ax_list[1], df_metrics, core_metrics, 'Core Performance Metrics')
|
||||
|
||||
# 图 3: 精确率/召回率族
|
||||
# 这里使用 ax 是正确的
|
||||
_plot_metric_group(ax_list[2], df_metrics, pr_metrics, 'Precision-Recall Family')
|
||||
|
||||
# 图 4: 错误率分析
|
||||
# 这里使用 ax 是正确的
|
||||
_plot_metric_group(ax_list[3], df_metrics, error_metrics, 'Error Rate Analysis')
|
||||
|
||||
# 图 5: 似然比
|
||||
# 这里使用 ax 是正确的
|
||||
_plot_metric_group(ax_list[4], df_metrics, likelihood_metrics, 'Likelihood Ratios')
|
||||
|
||||
# 隐藏最后一个未使用的子图 (*** 已修正 ***)
|
||||
# 使用 ax 而不是 ax
|
||||
ax_list[5].axis('off')
|
||||
|
||||
plt.tight_layout(rect=[0, 0, 1, 0.96])
|
||||
plt.savefig(os.path.join(output_dir, "training_progress_overview.png"))
|
||||
plt.close(fig_overview)
|
||||
|
||||
print(f"所有图表已更新并保存至目录: {output_dir}")
|
||||
|
||||
# ------------------- 训练与评估函数 ('vit'模型不适用amp) ------------------- #
|
||||
def train_fn(loader, model, optimizer, loss_fn, scaler, device, model_architecture: str):
|
||||
"""一个 epoch 的训练逻辑。"""
|
||||
loop = tqdm(loader, desc="Training")
|
||||
total_loss = 0.0
|
||||
model.train()
|
||||
|
||||
# 检查是否为ViT模型,决定是否启用AMP
|
||||
encoder_name = config.ALL_MODEL_CONFIGS[model_architecture].get('encoder_name', '')
|
||||
use_amp = 'vit' not in encoder_name.lower()
|
||||
if not use_amp and loop.n == 0: # 只在第一个batch打印一次警告
|
||||
logging.warning(f"检测到 ViT 编码器 ('{encoder_name}')。本次训练将禁用自动混合精度 (AMP) 以确保稳定性。")
|
||||
|
||||
for batch_idx, (data, targets) in enumerate(loop):
|
||||
data = data.to(device)
|
||||
targets = targets.to(device)
|
||||
|
||||
# 根据 use_amp 标志执行不同代码路径
|
||||
if use_amp:
|
||||
# 【AMP 路径】适用于 CNN 等稳定模型
|
||||
with autocast(device_type='cuda'):
|
||||
predictions = model(data)
|
||||
loss = loss_fn(predictions, targets)
|
||||
|
||||
optimizer.zero_grad()
|
||||
scaler.scale(loss).backward()
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
else:
|
||||
# 【Float32 路径】适用于 ViT/DPT 等模型
|
||||
predictions = model(data)
|
||||
loss = loss_fn(predictions, targets)
|
||||
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
total_loss += loss.item()
|
||||
loop.set_postfix(loss=loss.item())
|
||||
|
||||
return total_loss / len(loader)
|
||||
|
||||
# --- 【修改】 ---
|
||||
def evaluate_fn(loader, model, loss_fn, device, num_classes, model_architecture: str, beta=1.0):
|
||||
"""在验证集上评估模型性能,并计算包括原始TP/TN/FP/FN在内的多种指标。"""
|
||||
model.eval()
|
||||
loop = tqdm(loader, desc="Validation")
|
||||
|
||||
total_loss = 0
|
||||
all_tp, all_fp, all_fn, all_tn = [], [], [], []
|
||||
|
||||
# 检查是否为ViT模型,决定是否启用AMP
|
||||
encoder_name = config.ALL_MODEL_CONFIGS[model_architecture].get('encoder_name', '')
|
||||
use_amp = 'vit' not in encoder_name.lower()
|
||||
|
||||
with torch.no_grad():
|
||||
for batch_idx, (data, targets) in enumerate(loop):
|
||||
data = data.to(device)
|
||||
targets = targets.to(device)
|
||||
|
||||
# 根据 use_amp 标志执行不同代码路径
|
||||
if use_amp:
|
||||
# 【AMP 路径】
|
||||
with autocast(device_type='cuda'):
|
||||
predictions = model(data)
|
||||
if config.SEG_MODE == 'multilabel':
|
||||
loss = loss_fn(predictions, targets.float())
|
||||
else:
|
||||
loss = loss_fn(predictions, targets.long())
|
||||
else:
|
||||
# 【Float32 路径】
|
||||
predictions = model(data)
|
||||
if config.SEG_MODE == 'multilabel':
|
||||
loss = loss_fn(predictions, targets.float())
|
||||
else:
|
||||
loss = loss_fn(predictions, targets.long())
|
||||
|
||||
total_loss += loss.item()
|
||||
|
||||
# --- 【核心修正】根据 SEG_MODE 条件处理 targets ---
|
||||
# 1. 获取模型预测的类别索引
|
||||
preds_indices = torch.argmax(predictions, dim=1)
|
||||
|
||||
# 2. 根据模式确定真实标签的类别索引
|
||||
if config.SEG_MODE == 'multilabel':
|
||||
# 在 multilabel 模式下, targets 是 one-hot, 需要 argmax
|
||||
targets_indices = torch.argmax(targets, dim=1)
|
||||
elif config.SEG_MODE == 'multiclass':
|
||||
# 在 multiclass 模式下, targets 已经是类别索引, 无需处理
|
||||
targets_indices = targets
|
||||
|
||||
# 现在 preds_indices 和 targets_indices 都有正确的形状 [N, H, W]
|
||||
tp, fp, fn, tn = smp.metrics.get_stats(
|
||||
preds_indices,
|
||||
targets_indices.long(), # Ensure it's long type for get_stats
|
||||
mode=config.SEG_MODE, # Metrics are always calculated in multiclass fashion after argmax
|
||||
num_classes=num_classes
|
||||
)
|
||||
|
||||
# get_stats in smp 0.3.3 might return per-image stats, let's sum over batch dimension if needed.
|
||||
# Assuming get_stats returns (B-size, C,) tensor per image in batch, we stack and sum. -> (B-num, B-size, C,)
|
||||
all_tp.append(tp)
|
||||
all_fp.append(fp)
|
||||
all_fn.append(fn)
|
||||
all_tn.append(tn)
|
||||
|
||||
# Sum up stats from all batches (Batch-num, Batch-size, C,) -> (C,)
|
||||
tp = torch.cat(all_tp, dim=0).sum(dim=0) # (N, B, C) -> (C,)
|
||||
fp = torch.cat(all_fp, dim=0).sum(dim=0) # (N, B, C) -> (C,)
|
||||
fn = torch.cat(all_fn, dim=0).sum(dim=0) # (N, B, C) -> (C,)
|
||||
tn = torch.cat(all_tn, dim=0).sum(dim=0) # (N, B, C) -> (C,)
|
||||
|
||||
# ======================== 【新增代码段开始】 ======================== #
|
||||
# 根据 config.py 中的设置,筛选出用于计算宏观指标的类别
|
||||
|
||||
# 1. 获取需要忽略的类别的索引
|
||||
ignore_indices = [config.CLASSES.index(cls) for cls in config.EVALUATION_CLASSES_TO_IGNORE if cls in config.CLASSES]
|
||||
|
||||
# 2. 创建一个布尔掩码,标记哪些类别需要被保留
|
||||
# 默认保留所有类别
|
||||
keep_mask = torch.ones(num_classes, dtype=torch.bool, device=tp.device)
|
||||
if ignore_indices:
|
||||
# 将需要忽略的类别的标记设置为 False
|
||||
keep_mask[ignore_indices] = False
|
||||
logging.info(f"评估时将忽略类别: {config.EVALUATION_CLASSES_TO_IGNORE}")
|
||||
|
||||
# 3. 使用掩码过滤统计数据
|
||||
tp_filtered = tp[keep_mask]
|
||||
fp_filtered = fp[keep_mask]
|
||||
fn_filtered = fn[keep_mask]
|
||||
tn_filtered = tn[keep_mask]
|
||||
# ======================== 【新增代码段结束】 ======================== #
|
||||
|
||||
|
||||
# ======================== 【修改代码段开始】 ======================== #
|
||||
# 使用过滤后的 tp_filtered, fp_filtered 等张量来计算所有宏观指标
|
||||
metrics = {
|
||||
"val_loss": total_loss / len(loader),
|
||||
"iou_score": smp.metrics.iou_score(tp_filtered, fp_filtered, fn_filtered, tn_filtered, reduction='micro').item(),
|
||||
"f1_score": smp.metrics.f1_score(tp_filtered, fp_filtered, fn_filtered, tn_filtered, reduction='micro').item(),
|
||||
"fbeta_score": smp.metrics.fbeta_score(tp_filtered, fp_filtered, fn_filtered, tn_filtered, beta=beta, reduction='micro').item(),
|
||||
"accuracy": smp.metrics.accuracy(tp_filtered, fp_filtered, fn_filtered, tn_filtered, reduction='micro').item(),
|
||||
"recall": smp.metrics.recall(tp_filtered, fp_filtered, fn_filtered, tn_filtered, reduction='micro').item(),
|
||||
"precision": smp.metrics.precision(tp_filtered, fp_filtered, fn_filtered, tn_filtered, reduction='micro').item(),
|
||||
# ... 对所有使用 smp.metrics 的宏观指标进行同样修改 ...
|
||||
"specificity": smp.metrics.specificity(tp_filtered, fp_filtered, fn_filtered, tn_filtered, reduction='micro').item(),
|
||||
"negative_predictive_value": smp.metrics.negative_predictive_value(tp_filtered, fp_filtered, fn_filtered, tn_filtered, reduction='micro').item(),
|
||||
"false_negative_rate": smp.metrics.false_negative_rate(tp_filtered, fp_filtered, fn_filtered, tn_filtered, reduction='micro').item(),
|
||||
"false_positive_rate": smp.metrics.false_positive_rate(tp_filtered, fp_filtered, fn_filtered, tn_filtered, reduction='micro').item(),
|
||||
"false_discovery_rate": smp.metrics.false_discovery_rate(tp_filtered, fp_filtered, fn_filtered, tn_filtered, reduction='micro').item(),
|
||||
"false_omission_rate": smp.metrics.false_omission_rate(tp_filtered, fp_filtered, fn_filtered, tn_filtered, reduction='micro').item(),
|
||||
"positive_likelihood_ratio": smp.metrics.positive_likelihood_ratio(tp_filtered, fp_filtered, fn_filtered, tn_filtered, reduction='micro').item(),
|
||||
"negative_likelihood_ratio": smp.metrics.negative_likelihood_ratio(tp_filtered, fp_filtered, fn_filtered, tn_filtered, reduction='micro').item(),
|
||||
}
|
||||
# ======================== 【修改代码段结束】 ======================== #
|
||||
|
||||
cpa_per_class = smp.metrics.recall(tp, fp, fn, tn, reduction='none').cpu().numpy().flatten()
|
||||
if len(cpa_per_class) > 1:
|
||||
metrics["mpa"] = np.mean(cpa_per_class[1:]).item()
|
||||
for i, cpa in enumerate(cpa_per_class):
|
||||
metrics[f"cpa_class_{i}"] = cpa.item()
|
||||
else:
|
||||
metrics["mpa"] = cpa_per_class[0].item()
|
||||
metrics[f"cpa_class_0"] = cpa_per_class[0].item()
|
||||
|
||||
for i in range(num_classes):
|
||||
metrics[f'tp_class_{i}'] = tp[i].item()
|
||||
metrics[f'tn_class_{i}'] = tn[i].item()
|
||||
metrics[f'fp_class_{i}'] = fp[i].item()
|
||||
metrics[f'fn_class_{i}'] = fn[i].item()
|
||||
|
||||
model.train()
|
||||
return metrics
|
||||
|
||||
# ------------------- 数据增强 (保持不变) ------------------- #
|
||||
def get_training_augmentation(height, width):
|
||||
return A.Compose([
|
||||
A.Resize(height, width),
|
||||
A.HorizontalFlip(p=0.5),
|
||||
A.VerticalFlip(p=0.5),
|
||||
A.RandomRotate90(p=0.5),
|
||||
A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
||||
ToTensorV2(),
|
||||
])
|
||||
|
||||
def get_validation_augmentation(height, width):
|
||||
return A.Compose([
|
||||
A.Resize(height, width),
|
||||
A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
||||
ToTensorV2(),
|
||||
])
|
||||
|
||||
# TOOL - 获取用于预测的图像预处理流程
|
||||
def get_preprocessing_transform(height: int, width: int) -> A.Compose:
|
||||
"""获取用于预测的图像预处理流程。"""
|
||||
return A.Compose([
|
||||
A.Resize(height, width),
|
||||
A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
||||
ToTensorV2(),
|
||||
])
|
||||
|
||||
# Tool - 为一个全新的训练任务清理并重新创建输出目录。
|
||||
def setup_directories(output_dir: Path) -> None:
|
||||
"""
|
||||
为一个全新的训练任务清理并重新创建输出目录。
|
||||
|
||||
Args:
|
||||
output_dir (Path): 指向主输出目录的路径对象。
|
||||
"""
|
||||
if output_dir.exists():
|
||||
logging.info(f"输出目录 '{output_dir}' 已存在。正在删除...")
|
||||
shutil.rmtree(output_dir)
|
||||
logging.info(f"正在 '{output_dir}' 创建新的空目录。")
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
62
Seg_All_In_One_SegModel/※2025_9_21_使用手册
Normal file
62
Seg_All_In_One_SegModel/※2025_9_21_使用手册
Normal file
@@ -0,0 +1,62 @@
|
||||
############## A. 创建conda环境 ##############
|
||||
sudo apt install git
|
||||
1.
|
||||
conda create -n SMP python==3.9
|
||||
conda activate SMP
|
||||
2.
|
||||
# pytorch安装网址:https://pytorch.org/get-started/locally/
|
||||
e.g. pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cu129
|
||||
3.
|
||||
proxychains pip install git+https://github.com/qubvel/segmentation_models.pytorch
|
||||
pip show segmentation-models-pytorch
|
||||
4.
|
||||
python3 -m pip install -r requirements.txt
|
||||
|
||||
############## B. Train 训练程序 ##############
|
||||
# A. 第一次训练开一个梯子【需要下载内容】
|
||||
export https_proxy=http://127.0.0.1:1080 http_proxy=http://127.0.0.1:1080
|
||||
CUDA_VISIBLE_DEVICES=4,5,6,7 python train.py -a Unet
|
||||
{Unet,UnetPlusPlus,FPN,PSPNet,DeepLabV3,DeepLabV3Plus,Linknet,MAnet,PAN,UPerNet,Segformer,DPT}
|
||||
|
||||
# B. 运行单个训练程序
|
||||
1. 在config.py 中修改 训练数据集;CUDA_VISIBLE_DEVICES修改使用显卡;-a 修改算法;
|
||||
2. CUDA_VISIBLE_DEVICES=0 python ./train.py -a Unet
|
||||
|
||||
# C. 批量运行训练程序
|
||||
1. train.sh 修改想让它使用的显卡、算法;
|
||||
2. bash train.sh
|
||||
3. 会生成 ./logs_parallel_DATE 的终端记录文件;结果先存储在 ../DataSet_Public_outputs/DATASET_outputs-SegModel 后自动移动到 ../Hardisk/DATASET_outputs-SegModel 中
|
||||
|
||||
############## C. Predict 推理程序 ##############
|
||||
# A1. 预先步骤:将模型同步到 Nas_BackUp_Seg 或 ./Hardisk 文件夹中【cd .. && bash ./Back_Up.sh】
|
||||
# A2. 如果需要处理自定义数据集,请保证数据仍然是 DATASET/images/[val/test] 格式
|
||||
python ../Seg_Predict_Own_Video_V2/1_Save_Frame_V2.py --video ./Video_Name.mp4 --resize "1920x1080" --output_dir "../DataSet_Public/5_Predict_Video" --interval 0.5
|
||||
# B1. 将最优模型文件从 ./Nas_BackUp_Seg 移动到 ./BestMode_Predict_Results_DataSet_Public 指定文件夹中
|
||||
bash ./Tool_Copy_Best_Model.sh # 修改里面的路径
|
||||
# B2. 如果需要处理自定义数据集,请将模型文件夹手动复制为 ./BestMode_Predict_Results_DataSet_Public/DATASET-Yolo 中
|
||||
可以先对原有 ./BestMode_Predict_Results_DataSet_Public/ORI_DATASET_outputs-SegModel 改名
|
||||
运行 bash ./Tool_Yolo_Copy_Best_Model.sh
|
||||
在将生成的 ./BestMode_Predict_Results_DataSet_Public/ORI_DATASET_outputs-SegModel 改为 ./BestMode_Predict_Results_DataSet_Public/DATASET_outputs-SegModel
|
||||
# C. 运行单个推理程序
|
||||
1. 在config.py 中修改 预测数据集 及 val/test 、 模型保存路径PREDICT_ALL_BEST_MODELS_DIR/PREDICT_BEST_MODEL_DIR;CUDA_VISIBLE_DEVICES修改使用显卡;-a 修改算法;
|
||||
2. CUDA_VISIBLE_DEVICES=0 python 1_predict.py -a Unet/UnetPlusPlus/FPN/PSPNet/DeepLabV3/DeepLabV3Plus/Linknet/MAnet/PAN/UPerNet/Segformer/DPT
|
||||
# D. 批量运行推理程序
|
||||
1. predict.sh 修改想让它使用的显卡、算法;
|
||||
2. bash 1_predict.sh
|
||||
3. 会生成 ./predict_logs_parallel_DATE 的终端记录文件;结果存储在 ../BestMode_Predict_Results_DataSet_Public/DATASET_outputs-SegModel 中
|
||||
# E. 预测模型 参数量、FLOPs、FPS
|
||||
CUDA_VISIBLE_DEVICES=5 python 2_predict_params_and_FLOPs_V2.py # --shape 512 512
|
||||
|
||||
############## D. Predict_raw_img_Check 检测推理输出图片是否齐全 ##############
|
||||
# 检测 config.TEST_IMAGE_DIR 和 config.PREDICT_BEST_MODEL_DIR/****/predicted_raw_masks 中图片是否匹配
|
||||
1. 在config.py 中修改 config.DATA_DIR(即config.TEST_IMAGE_DIR) 及 config.PREDICT_BEST_MODEL_DIR;CUDA_VISIBLE_DEVICES修改使用显卡;
|
||||
2. python 1_predict_raw_masks_check.py
|
||||
|
||||
############## E. 其他辅助脚本 ##############
|
||||
1. 解决--resume 不能运行的问题
|
||||
您可以直接修改mmengine源文件,使其torch.load可以加载完整的检查点
|
||||
vim /home/wkmgc/miniconda3/envs/SMP/lib/python3.9/site-packages/mmengine/runner/checkpoint.py
|
||||
转到第 347 行。您将看到以下代码:
|
||||
checkpoint = torch.load(filename, map_location=map_location)
|
||||
更改该行以添加weights_only=False参数
|
||||
checkpoint = torch.load(filename, map_location=map_location, weights_only=False)
|
||||
Reference in New Issue
Block a user