first commit
This commit is contained in:
361
Seg_All_In_One_SegModel/utils.py
Normal file
361
Seg_All_In_One_SegModel/utils.py
Normal file
@@ -0,0 +1,361 @@
|
||||
import os, logging, shutil
|
||||
import torch
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from tqdm import tqdm
|
||||
from torch.amp import autocast
|
||||
from typing import Dict, Union, List
|
||||
import albumentations as A
|
||||
from albumentations.pytorch import ToTensorV2
|
||||
import segmentation_models_pytorch as smp
|
||||
import matplotlib.pyplot as plt
|
||||
from pathlib import Path
|
||||
|
||||
import config
|
||||
|
||||
# ------------------- 绘图函数 (全新升级) ------------------- #
|
||||
def log_metrics_to_csv(metrics_dict: Dict[str, Union[int, float]], csv_path: str) -> None:
|
||||
"""
|
||||
使用 pandas 将一个指标字典记录到 CSV 文件中。
|
||||
(此函数保持不变)
|
||||
"""
|
||||
df_new_row = pd.DataFrame([metrics_dict])
|
||||
if not os.path.isfile(csv_path):
|
||||
df_new_row.to_csv(csv_path, index=False)
|
||||
else:
|
||||
df_new_row.to_csv(csv_path, mode='a', header=False, index=False)
|
||||
|
||||
|
||||
def _plot_metric_group(ax, df: pd.DataFrame, group_metrics: List[str], title: str):
|
||||
"""
|
||||
内部辅助函数,用于在给定的坐标轴上绘制一组指标。
|
||||
(此函数保持不变)
|
||||
"""
|
||||
epochs_range = df['epoch']
|
||||
for name in group_metrics:
|
||||
if f"val_{name}" in df.columns:
|
||||
metric_col = f"val_{name}"
|
||||
label = name.replace("_", " ").title()
|
||||
ax.plot(epochs_range, df[metric_col], 'o-', label=label)
|
||||
elif f"{name}" in df.columns:
|
||||
metric_col = f"{name}"
|
||||
label = name.replace("_", " ").title()
|
||||
ax.plot(epochs_range, df[metric_col], 'o-', label=label)
|
||||
else:
|
||||
print("没有搜索到:", f"val_{name}")
|
||||
|
||||
ax.set_title(title)
|
||||
ax.set_xlabel('Epoch')
|
||||
ax.set_ylabel('Score')
|
||||
ax.legend(loc='best', fontsize='small')
|
||||
ax.grid(True)
|
||||
|
||||
|
||||
def plot_training_progress(csv_path: str, output_dir: str):
|
||||
"""
|
||||
从 metrics.csv 文件读取训练历史数据,并生成所有分组和总览图表。
|
||||
(此函数已修正错误)
|
||||
"""
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
try:
|
||||
df_metrics = pd.read_csv(csv_path)
|
||||
except FileNotFoundError:
|
||||
print(f"警告:在路径 {csv_path} 未找到 metrics.csv 文件。跳过绘图。")
|
||||
return
|
||||
|
||||
if len(df_metrics) < 2:
|
||||
print("警告:数据点不足 (少于2个),无法生成趋势图。")
|
||||
return
|
||||
|
||||
epochs_range = df_metrics['epoch']
|
||||
train_loss = df_metrics['train_loss']
|
||||
val_loss = df_metrics['val_loss']
|
||||
|
||||
core_metrics = ["iou_score", "fbeta_score", "accuracy"]
|
||||
pr_metrics = ["precision", "recall", "specificity", "negative_predictive_value"]
|
||||
error_metrics = ["false_positive_rate", "false_negative_rate", "false_discovery_rate", "false_omission_rate"]
|
||||
likelihood_metrics = ["positive_likelihood_ratio", "negative_likelihood_ratio"]
|
||||
|
||||
# --- 1. 单独绘制并保存每个指标组的图表 ---
|
||||
# (这部分逻辑没有变化)
|
||||
fig_loss, ax_loss = plt.subplots(figsize=(10, 6))
|
||||
ax_loss.plot(epochs_range, train_loss, 'o-', label='Training Loss')
|
||||
ax_loss.plot(epochs_range, val_loss, 'o-', label='Validation Loss')
|
||||
ax_loss.set_title('Loss Curves')
|
||||
ax_loss.set_xlabel('Epoch')
|
||||
ax_loss.set_ylabel('Loss')
|
||||
ax_loss.legend()
|
||||
ax_loss.grid(True)
|
||||
plt.tight_layout()
|
||||
plt.savefig(os.path.join(output_dir, "loss_curves.png"))
|
||||
plt.close(fig_loss)
|
||||
|
||||
metric_groups = {
|
||||
"core_performance": core_metrics,
|
||||
"precision_recall": pr_metrics,
|
||||
"error_rates": error_metrics,
|
||||
"likelihood_ratios": likelihood_metrics
|
||||
}
|
||||
for group_name, group_list in metric_groups.items():
|
||||
fig_group, ax_group = plt.subplots(figsize=(10, 7))
|
||||
_plot_metric_group(ax_group, df_metrics, group_list, f'{group_name.replace("_", " ").title()} Metrics')
|
||||
plt.tight_layout()
|
||||
plt.savefig(os.path.join(output_dir, f"metrics_{group_name}.png"))
|
||||
plt.close(fig_group)
|
||||
|
||||
# --- 2. 绘制并保存 3x2 综合概览图 ---
|
||||
fig_overview, axes = plt.subplots(3, 2, figsize=(20, 22))
|
||||
fig_overview.suptitle('Training Progress Overview', fontsize=20)
|
||||
ax_list = axes.ravel()
|
||||
|
||||
# 图 1: 损失曲线 (*** 已修正 ***)
|
||||
# 使用 ax 而不是 ax
|
||||
ax_list[0].plot(epochs_range, train_loss, 'o-', label='Training Loss')
|
||||
ax_list[0].plot(epochs_range, val_loss, 'o-', label='Validation Loss')
|
||||
ax_list[0].set_title('Loss Curves')
|
||||
ax_list[0].set_xlabel('Epoch')
|
||||
ax_list[0].set_ylabel('Loss')
|
||||
ax_list[0].legend()
|
||||
ax_list[0].grid(True)
|
||||
|
||||
# 图 2: 核心性能指标
|
||||
# 这里使用 ax 是正确的
|
||||
_plot_metric_group(ax_list[1], df_metrics, core_metrics, 'Core Performance Metrics')
|
||||
|
||||
# 图 3: 精确率/召回率族
|
||||
# 这里使用 ax 是正确的
|
||||
_plot_metric_group(ax_list[2], df_metrics, pr_metrics, 'Precision-Recall Family')
|
||||
|
||||
# 图 4: 错误率分析
|
||||
# 这里使用 ax 是正确的
|
||||
_plot_metric_group(ax_list[3], df_metrics, error_metrics, 'Error Rate Analysis')
|
||||
|
||||
# 图 5: 似然比
|
||||
# 这里使用 ax 是正确的
|
||||
_plot_metric_group(ax_list[4], df_metrics, likelihood_metrics, 'Likelihood Ratios')
|
||||
|
||||
# 隐藏最后一个未使用的子图 (*** 已修正 ***)
|
||||
# 使用 ax 而不是 ax
|
||||
ax_list[5].axis('off')
|
||||
|
||||
plt.tight_layout(rect=[0, 0, 1, 0.96])
|
||||
plt.savefig(os.path.join(output_dir, "training_progress_overview.png"))
|
||||
plt.close(fig_overview)
|
||||
|
||||
print(f"所有图表已更新并保存至目录: {output_dir}")
|
||||
|
||||
# ------------------- 训练与评估函数 ('vit'模型不适用amp) ------------------- #
|
||||
def train_fn(loader, model, optimizer, loss_fn, scaler, device, model_architecture: str):
|
||||
"""一个 epoch 的训练逻辑。"""
|
||||
loop = tqdm(loader, desc="Training")
|
||||
total_loss = 0.0
|
||||
model.train()
|
||||
|
||||
# 检查是否为ViT模型,决定是否启用AMP
|
||||
encoder_name = config.ALL_MODEL_CONFIGS[model_architecture].get('encoder_name', '')
|
||||
use_amp = 'vit' not in encoder_name.lower()
|
||||
if not use_amp and loop.n == 0: # 只在第一个batch打印一次警告
|
||||
logging.warning(f"检测到 ViT 编码器 ('{encoder_name}')。本次训练将禁用自动混合精度 (AMP) 以确保稳定性。")
|
||||
|
||||
for batch_idx, (data, targets) in enumerate(loop):
|
||||
data = data.to(device)
|
||||
targets = targets.to(device)
|
||||
|
||||
# 根据 use_amp 标志执行不同代码路径
|
||||
if use_amp:
|
||||
# 【AMP 路径】适用于 CNN 等稳定模型
|
||||
with autocast(device_type='cuda'):
|
||||
predictions = model(data)
|
||||
loss = loss_fn(predictions, targets)
|
||||
|
||||
optimizer.zero_grad()
|
||||
scaler.scale(loss).backward()
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
else:
|
||||
# 【Float32 路径】适用于 ViT/DPT 等模型
|
||||
predictions = model(data)
|
||||
loss = loss_fn(predictions, targets)
|
||||
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
total_loss += loss.item()
|
||||
loop.set_postfix(loss=loss.item())
|
||||
|
||||
return total_loss / len(loader)
|
||||
|
||||
# --- 【修改】 ---
|
||||
def evaluate_fn(loader, model, loss_fn, device, num_classes, model_architecture: str, beta=1.0):
|
||||
"""在验证集上评估模型性能,并计算包括原始TP/TN/FP/FN在内的多种指标。"""
|
||||
model.eval()
|
||||
loop = tqdm(loader, desc="Validation")
|
||||
|
||||
total_loss = 0
|
||||
all_tp, all_fp, all_fn, all_tn = [], [], [], []
|
||||
|
||||
# 检查是否为ViT模型,决定是否启用AMP
|
||||
encoder_name = config.ALL_MODEL_CONFIGS[model_architecture].get('encoder_name', '')
|
||||
use_amp = 'vit' not in encoder_name.lower()
|
||||
|
||||
with torch.no_grad():
|
||||
for batch_idx, (data, targets) in enumerate(loop):
|
||||
data = data.to(device)
|
||||
targets = targets.to(device)
|
||||
|
||||
# 根据 use_amp 标志执行不同代码路径
|
||||
if use_amp:
|
||||
# 【AMP 路径】
|
||||
with autocast(device_type='cuda'):
|
||||
predictions = model(data)
|
||||
if config.SEG_MODE == 'multilabel':
|
||||
loss = loss_fn(predictions, targets.float())
|
||||
else:
|
||||
loss = loss_fn(predictions, targets.long())
|
||||
else:
|
||||
# 【Float32 路径】
|
||||
predictions = model(data)
|
||||
if config.SEG_MODE == 'multilabel':
|
||||
loss = loss_fn(predictions, targets.float())
|
||||
else:
|
||||
loss = loss_fn(predictions, targets.long())
|
||||
|
||||
total_loss += loss.item()
|
||||
|
||||
# --- 【核心修正】根据 SEG_MODE 条件处理 targets ---
|
||||
# 1. 获取模型预测的类别索引
|
||||
preds_indices = torch.argmax(predictions, dim=1)
|
||||
|
||||
# 2. 根据模式确定真实标签的类别索引
|
||||
if config.SEG_MODE == 'multilabel':
|
||||
# 在 multilabel 模式下, targets 是 one-hot, 需要 argmax
|
||||
targets_indices = torch.argmax(targets, dim=1)
|
||||
elif config.SEG_MODE == 'multiclass':
|
||||
# 在 multiclass 模式下, targets 已经是类别索引, 无需处理
|
||||
targets_indices = targets
|
||||
|
||||
# 现在 preds_indices 和 targets_indices 都有正确的形状 [N, H, W]
|
||||
tp, fp, fn, tn = smp.metrics.get_stats(
|
||||
preds_indices,
|
||||
targets_indices.long(), # Ensure it's long type for get_stats
|
||||
mode=config.SEG_MODE, # Metrics are always calculated in multiclass fashion after argmax
|
||||
num_classes=num_classes
|
||||
)
|
||||
|
||||
# get_stats in smp 0.3.3 might return per-image stats, let's sum over batch dimension if needed.
|
||||
# Assuming get_stats returns (B-size, C,) tensor per image in batch, we stack and sum. -> (B-num, B-size, C,)
|
||||
all_tp.append(tp)
|
||||
all_fp.append(fp)
|
||||
all_fn.append(fn)
|
||||
all_tn.append(tn)
|
||||
|
||||
# Sum up stats from all batches (Batch-num, Batch-size, C,) -> (C,)
|
||||
tp = torch.cat(all_tp, dim=0).sum(dim=0) # (N, B, C) -> (C,)
|
||||
fp = torch.cat(all_fp, dim=0).sum(dim=0) # (N, B, C) -> (C,)
|
||||
fn = torch.cat(all_fn, dim=0).sum(dim=0) # (N, B, C) -> (C,)
|
||||
tn = torch.cat(all_tn, dim=0).sum(dim=0) # (N, B, C) -> (C,)
|
||||
|
||||
# ======================== 【新增代码段开始】 ======================== #
|
||||
# 根据 config.py 中的设置,筛选出用于计算宏观指标的类别
|
||||
|
||||
# 1. 获取需要忽略的类别的索引
|
||||
ignore_indices = [config.CLASSES.index(cls) for cls in config.EVALUATION_CLASSES_TO_IGNORE if cls in config.CLASSES]
|
||||
|
||||
# 2. 创建一个布尔掩码,标记哪些类别需要被保留
|
||||
# 默认保留所有类别
|
||||
keep_mask = torch.ones(num_classes, dtype=torch.bool, device=tp.device)
|
||||
if ignore_indices:
|
||||
# 将需要忽略的类别的标记设置为 False
|
||||
keep_mask[ignore_indices] = False
|
||||
logging.info(f"评估时将忽略类别: {config.EVALUATION_CLASSES_TO_IGNORE}")
|
||||
|
||||
# 3. 使用掩码过滤统计数据
|
||||
tp_filtered = tp[keep_mask]
|
||||
fp_filtered = fp[keep_mask]
|
||||
fn_filtered = fn[keep_mask]
|
||||
tn_filtered = tn[keep_mask]
|
||||
# ======================== 【新增代码段结束】 ======================== #
|
||||
|
||||
|
||||
# ======================== 【修改代码段开始】 ======================== #
|
||||
# 使用过滤后的 tp_filtered, fp_filtered 等张量来计算所有宏观指标
|
||||
metrics = {
|
||||
"val_loss": total_loss / len(loader),
|
||||
"iou_score": smp.metrics.iou_score(tp_filtered, fp_filtered, fn_filtered, tn_filtered, reduction='micro').item(),
|
||||
"f1_score": smp.metrics.f1_score(tp_filtered, fp_filtered, fn_filtered, tn_filtered, reduction='micro').item(),
|
||||
"fbeta_score": smp.metrics.fbeta_score(tp_filtered, fp_filtered, fn_filtered, tn_filtered, beta=beta, reduction='micro').item(),
|
||||
"accuracy": smp.metrics.accuracy(tp_filtered, fp_filtered, fn_filtered, tn_filtered, reduction='micro').item(),
|
||||
"recall": smp.metrics.recall(tp_filtered, fp_filtered, fn_filtered, tn_filtered, reduction='micro').item(),
|
||||
"precision": smp.metrics.precision(tp_filtered, fp_filtered, fn_filtered, tn_filtered, reduction='micro').item(),
|
||||
# ... 对所有使用 smp.metrics 的宏观指标进行同样修改 ...
|
||||
"specificity": smp.metrics.specificity(tp_filtered, fp_filtered, fn_filtered, tn_filtered, reduction='micro').item(),
|
||||
"negative_predictive_value": smp.metrics.negative_predictive_value(tp_filtered, fp_filtered, fn_filtered, tn_filtered, reduction='micro').item(),
|
||||
"false_negative_rate": smp.metrics.false_negative_rate(tp_filtered, fp_filtered, fn_filtered, tn_filtered, reduction='micro').item(),
|
||||
"false_positive_rate": smp.metrics.false_positive_rate(tp_filtered, fp_filtered, fn_filtered, tn_filtered, reduction='micro').item(),
|
||||
"false_discovery_rate": smp.metrics.false_discovery_rate(tp_filtered, fp_filtered, fn_filtered, tn_filtered, reduction='micro').item(),
|
||||
"false_omission_rate": smp.metrics.false_omission_rate(tp_filtered, fp_filtered, fn_filtered, tn_filtered, reduction='micro').item(),
|
||||
"positive_likelihood_ratio": smp.metrics.positive_likelihood_ratio(tp_filtered, fp_filtered, fn_filtered, tn_filtered, reduction='micro').item(),
|
||||
"negative_likelihood_ratio": smp.metrics.negative_likelihood_ratio(tp_filtered, fp_filtered, fn_filtered, tn_filtered, reduction='micro').item(),
|
||||
}
|
||||
# ======================== 【修改代码段结束】 ======================== #
|
||||
|
||||
cpa_per_class = smp.metrics.recall(tp, fp, fn, tn, reduction='none').cpu().numpy().flatten()
|
||||
if len(cpa_per_class) > 1:
|
||||
metrics["mpa"] = np.mean(cpa_per_class[1:]).item()
|
||||
for i, cpa in enumerate(cpa_per_class):
|
||||
metrics[f"cpa_class_{i}"] = cpa.item()
|
||||
else:
|
||||
metrics["mpa"] = cpa_per_class[0].item()
|
||||
metrics[f"cpa_class_0"] = cpa_per_class[0].item()
|
||||
|
||||
for i in range(num_classes):
|
||||
metrics[f'tp_class_{i}'] = tp[i].item()
|
||||
metrics[f'tn_class_{i}'] = tn[i].item()
|
||||
metrics[f'fp_class_{i}'] = fp[i].item()
|
||||
metrics[f'fn_class_{i}'] = fn[i].item()
|
||||
|
||||
model.train()
|
||||
return metrics
|
||||
|
||||
# ------------------- 数据增强 (保持不变) ------------------- #
|
||||
def get_training_augmentation(height, width):
|
||||
return A.Compose([
|
||||
A.Resize(height, width),
|
||||
A.HorizontalFlip(p=0.5),
|
||||
A.VerticalFlip(p=0.5),
|
||||
A.RandomRotate90(p=0.5),
|
||||
A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
||||
ToTensorV2(),
|
||||
])
|
||||
|
||||
def get_validation_augmentation(height, width):
|
||||
return A.Compose([
|
||||
A.Resize(height, width),
|
||||
A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
||||
ToTensorV2(),
|
||||
])
|
||||
|
||||
# TOOL - 获取用于预测的图像预处理流程
|
||||
def get_preprocessing_transform(height: int, width: int) -> A.Compose:
|
||||
"""获取用于预测的图像预处理流程。"""
|
||||
return A.Compose([
|
||||
A.Resize(height, width),
|
||||
A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
||||
ToTensorV2(),
|
||||
])
|
||||
|
||||
# Tool - 为一个全新的训练任务清理并重新创建输出目录。
|
||||
def setup_directories(output_dir: Path) -> None:
|
||||
"""
|
||||
为一个全新的训练任务清理并重新创建输出目录。
|
||||
|
||||
Args:
|
||||
output_dir (Path): 指向主输出目录的路径对象。
|
||||
"""
|
||||
if output_dir.exists():
|
||||
logging.info(f"输出目录 '{output_dir}' 已存在。正在删除...")
|
||||
shutil.rmtree(output_dir)
|
||||
logging.info(f"正在 '{output_dir}' 创建新的空目录。")
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
Reference in New Issue
Block a user