361 lines
16 KiB
Python
361 lines
16 KiB
Python
import os, logging, shutil
|
||
import torch
|
||
import numpy as np
|
||
import pandas as pd
|
||
from tqdm import tqdm
|
||
from torch.amp import autocast
|
||
from typing import Dict, Union, List
|
||
import albumentations as A
|
||
from albumentations.pytorch import ToTensorV2
|
||
import segmentation_models_pytorch as smp
|
||
import matplotlib.pyplot as plt
|
||
from pathlib import Path
|
||
|
||
import config
|
||
|
||
# ------------------- 绘图函数 (全新升级) ------------------- #
|
||
def log_metrics_to_csv(metrics_dict: Dict[str, Union[int, float]], csv_path: str) -> None:
|
||
"""
|
||
使用 pandas 将一个指标字典记录到 CSV 文件中。
|
||
(此函数保持不变)
|
||
"""
|
||
df_new_row = pd.DataFrame([metrics_dict])
|
||
if not os.path.isfile(csv_path):
|
||
df_new_row.to_csv(csv_path, index=False)
|
||
else:
|
||
df_new_row.to_csv(csv_path, mode='a', header=False, index=False)
|
||
|
||
|
||
def _plot_metric_group(ax, df: pd.DataFrame, group_metrics: List[str], title: str):
|
||
"""
|
||
内部辅助函数,用于在给定的坐标轴上绘制一组指标。
|
||
(此函数保持不变)
|
||
"""
|
||
epochs_range = df['epoch']
|
||
for name in group_metrics:
|
||
if f"val_{name}" in df.columns:
|
||
metric_col = f"val_{name}"
|
||
label = name.replace("_", " ").title()
|
||
ax.plot(epochs_range, df[metric_col], 'o-', label=label)
|
||
elif f"{name}" in df.columns:
|
||
metric_col = f"{name}"
|
||
label = name.replace("_", " ").title()
|
||
ax.plot(epochs_range, df[metric_col], 'o-', label=label)
|
||
else:
|
||
print("没有搜索到:", f"val_{name}")
|
||
|
||
ax.set_title(title)
|
||
ax.set_xlabel('Epoch')
|
||
ax.set_ylabel('Score')
|
||
ax.legend(loc='best', fontsize='small')
|
||
ax.grid(True)
|
||
|
||
|
||
def plot_training_progress(csv_path: str, output_dir: str):
|
||
"""
|
||
从 metrics.csv 文件读取训练历史数据,并生成所有分组和总览图表。
|
||
(此函数已修正错误)
|
||
"""
|
||
os.makedirs(output_dir, exist_ok=True)
|
||
|
||
try:
|
||
df_metrics = pd.read_csv(csv_path)
|
||
except FileNotFoundError:
|
||
print(f"警告:在路径 {csv_path} 未找到 metrics.csv 文件。跳过绘图。")
|
||
return
|
||
|
||
if len(df_metrics) < 2:
|
||
print("警告:数据点不足 (少于2个),无法生成趋势图。")
|
||
return
|
||
|
||
epochs_range = df_metrics['epoch']
|
||
train_loss = df_metrics['train_loss']
|
||
val_loss = df_metrics['val_loss']
|
||
|
||
core_metrics = ["iou_score", "fbeta_score", "accuracy"]
|
||
pr_metrics = ["precision", "recall", "specificity", "negative_predictive_value"]
|
||
error_metrics = ["false_positive_rate", "false_negative_rate", "false_discovery_rate", "false_omission_rate"]
|
||
likelihood_metrics = ["positive_likelihood_ratio", "negative_likelihood_ratio"]
|
||
|
||
# --- 1. 单独绘制并保存每个指标组的图表 ---
|
||
# (这部分逻辑没有变化)
|
||
fig_loss, ax_loss = plt.subplots(figsize=(10, 6))
|
||
ax_loss.plot(epochs_range, train_loss, 'o-', label='Training Loss')
|
||
ax_loss.plot(epochs_range, val_loss, 'o-', label='Validation Loss')
|
||
ax_loss.set_title('Loss Curves')
|
||
ax_loss.set_xlabel('Epoch')
|
||
ax_loss.set_ylabel('Loss')
|
||
ax_loss.legend()
|
||
ax_loss.grid(True)
|
||
plt.tight_layout()
|
||
plt.savefig(os.path.join(output_dir, "loss_curves.png"))
|
||
plt.close(fig_loss)
|
||
|
||
metric_groups = {
|
||
"core_performance": core_metrics,
|
||
"precision_recall": pr_metrics,
|
||
"error_rates": error_metrics,
|
||
"likelihood_ratios": likelihood_metrics
|
||
}
|
||
for group_name, group_list in metric_groups.items():
|
||
fig_group, ax_group = plt.subplots(figsize=(10, 7))
|
||
_plot_metric_group(ax_group, df_metrics, group_list, f'{group_name.replace("_", " ").title()} Metrics')
|
||
plt.tight_layout()
|
||
plt.savefig(os.path.join(output_dir, f"metrics_{group_name}.png"))
|
||
plt.close(fig_group)
|
||
|
||
# --- 2. 绘制并保存 3x2 综合概览图 ---
|
||
fig_overview, axes = plt.subplots(3, 2, figsize=(20, 22))
|
||
fig_overview.suptitle('Training Progress Overview', fontsize=20)
|
||
ax_list = axes.ravel()
|
||
|
||
# 图 1: 损失曲线 (*** 已修正 ***)
|
||
# 使用 ax 而不是 ax
|
||
ax_list[0].plot(epochs_range, train_loss, 'o-', label='Training Loss')
|
||
ax_list[0].plot(epochs_range, val_loss, 'o-', label='Validation Loss')
|
||
ax_list[0].set_title('Loss Curves')
|
||
ax_list[0].set_xlabel('Epoch')
|
||
ax_list[0].set_ylabel('Loss')
|
||
ax_list[0].legend()
|
||
ax_list[0].grid(True)
|
||
|
||
# 图 2: 核心性能指标
|
||
# 这里使用 ax 是正确的
|
||
_plot_metric_group(ax_list[1], df_metrics, core_metrics, 'Core Performance Metrics')
|
||
|
||
# 图 3: 精确率/召回率族
|
||
# 这里使用 ax 是正确的
|
||
_plot_metric_group(ax_list[2], df_metrics, pr_metrics, 'Precision-Recall Family')
|
||
|
||
# 图 4: 错误率分析
|
||
# 这里使用 ax 是正确的
|
||
_plot_metric_group(ax_list[3], df_metrics, error_metrics, 'Error Rate Analysis')
|
||
|
||
# 图 5: 似然比
|
||
# 这里使用 ax 是正确的
|
||
_plot_metric_group(ax_list[4], df_metrics, likelihood_metrics, 'Likelihood Ratios')
|
||
|
||
# 隐藏最后一个未使用的子图 (*** 已修正 ***)
|
||
# 使用 ax 而不是 ax
|
||
ax_list[5].axis('off')
|
||
|
||
plt.tight_layout(rect=[0, 0, 1, 0.96])
|
||
plt.savefig(os.path.join(output_dir, "training_progress_overview.png"))
|
||
plt.close(fig_overview)
|
||
|
||
print(f"所有图表已更新并保存至目录: {output_dir}")
|
||
|
||
# ------------------- 训练与评估函数 ('vit'模型不适用amp) ------------------- #
|
||
def train_fn(loader, model, optimizer, loss_fn, scaler, device, model_architecture: str):
|
||
"""一个 epoch 的训练逻辑。"""
|
||
loop = tqdm(loader, desc="Training")
|
||
total_loss = 0.0
|
||
model.train()
|
||
|
||
# 检查是否为ViT模型,决定是否启用AMP
|
||
encoder_name = config.ALL_MODEL_CONFIGS[model_architecture].get('encoder_name', '')
|
||
use_amp = 'vit' not in encoder_name.lower()
|
||
if not use_amp and loop.n == 0: # 只在第一个batch打印一次警告
|
||
logging.warning(f"检测到 ViT 编码器 ('{encoder_name}')。本次训练将禁用自动混合精度 (AMP) 以确保稳定性。")
|
||
|
||
for batch_idx, (data, targets) in enumerate(loop):
|
||
data = data.to(device)
|
||
targets = targets.to(device)
|
||
|
||
# 根据 use_amp 标志执行不同代码路径
|
||
if use_amp:
|
||
# 【AMP 路径】适用于 CNN 等稳定模型
|
||
with autocast(device_type='cuda'):
|
||
predictions = model(data)
|
||
loss = loss_fn(predictions, targets)
|
||
|
||
optimizer.zero_grad()
|
||
scaler.scale(loss).backward()
|
||
scaler.step(optimizer)
|
||
scaler.update()
|
||
else:
|
||
# 【Float32 路径】适用于 ViT/DPT 等模型
|
||
predictions = model(data)
|
||
loss = loss_fn(predictions, targets)
|
||
|
||
optimizer.zero_grad()
|
||
loss.backward()
|
||
optimizer.step()
|
||
|
||
total_loss += loss.item()
|
||
loop.set_postfix(loss=loss.item())
|
||
|
||
return total_loss / len(loader)
|
||
|
||
# --- 【修改】 ---
|
||
def evaluate_fn(loader, model, loss_fn, device, num_classes, model_architecture: str, beta=1.0):
|
||
"""在验证集上评估模型性能,并计算包括原始TP/TN/FP/FN在内的多种指标。"""
|
||
model.eval()
|
||
loop = tqdm(loader, desc="Validation")
|
||
|
||
total_loss = 0
|
||
all_tp, all_fp, all_fn, all_tn = [], [], [], []
|
||
|
||
# 检查是否为ViT模型,决定是否启用AMP
|
||
encoder_name = config.ALL_MODEL_CONFIGS[model_architecture].get('encoder_name', '')
|
||
use_amp = 'vit' not in encoder_name.lower()
|
||
|
||
with torch.no_grad():
|
||
for batch_idx, (data, targets) in enumerate(loop):
|
||
data = data.to(device)
|
||
targets = targets.to(device)
|
||
|
||
# 根据 use_amp 标志执行不同代码路径
|
||
if use_amp:
|
||
# 【AMP 路径】
|
||
with autocast(device_type='cuda'):
|
||
predictions = model(data)
|
||
if config.SEG_MODE == 'multilabel':
|
||
loss = loss_fn(predictions, targets.float())
|
||
else:
|
||
loss = loss_fn(predictions, targets.long())
|
||
else:
|
||
# 【Float32 路径】
|
||
predictions = model(data)
|
||
if config.SEG_MODE == 'multilabel':
|
||
loss = loss_fn(predictions, targets.float())
|
||
else:
|
||
loss = loss_fn(predictions, targets.long())
|
||
|
||
total_loss += loss.item()
|
||
|
||
# --- 【核心修正】根据 SEG_MODE 条件处理 targets ---
|
||
# 1. 获取模型预测的类别索引
|
||
preds_indices = torch.argmax(predictions, dim=1)
|
||
|
||
# 2. 根据模式确定真实标签的类别索引
|
||
if config.SEG_MODE == 'multilabel':
|
||
# 在 multilabel 模式下, targets 是 one-hot, 需要 argmax
|
||
targets_indices = torch.argmax(targets, dim=1)
|
||
elif config.SEG_MODE == 'multiclass':
|
||
# 在 multiclass 模式下, targets 已经是类别索引, 无需处理
|
||
targets_indices = targets
|
||
|
||
# 现在 preds_indices 和 targets_indices 都有正确的形状 [N, H, W]
|
||
tp, fp, fn, tn = smp.metrics.get_stats(
|
||
preds_indices,
|
||
targets_indices.long(), # Ensure it's long type for get_stats
|
||
mode=config.SEG_MODE, # Metrics are always calculated in multiclass fashion after argmax
|
||
num_classes=num_classes
|
||
)
|
||
|
||
# get_stats in smp 0.3.3 might return per-image stats, let's sum over batch dimension if needed.
|
||
# Assuming get_stats returns (B-size, C,) tensor per image in batch, we stack and sum. -> (B-num, B-size, C,)
|
||
all_tp.append(tp)
|
||
all_fp.append(fp)
|
||
all_fn.append(fn)
|
||
all_tn.append(tn)
|
||
|
||
# Sum up stats from all batches (Batch-num, Batch-size, C,) -> (C,)
|
||
tp = torch.cat(all_tp, dim=0).sum(dim=0) # (N, B, C) -> (C,)
|
||
fp = torch.cat(all_fp, dim=0).sum(dim=0) # (N, B, C) -> (C,)
|
||
fn = torch.cat(all_fn, dim=0).sum(dim=0) # (N, B, C) -> (C,)
|
||
tn = torch.cat(all_tn, dim=0).sum(dim=0) # (N, B, C) -> (C,)
|
||
|
||
# ======================== 【新增代码段开始】 ======================== #
|
||
# 根据 config.py 中的设置,筛选出用于计算宏观指标的类别
|
||
|
||
# 1. 获取需要忽略的类别的索引
|
||
ignore_indices = [config.CLASSES.index(cls) for cls in config.EVALUATION_CLASSES_TO_IGNORE if cls in config.CLASSES]
|
||
|
||
# 2. 创建一个布尔掩码,标记哪些类别需要被保留
|
||
# 默认保留所有类别
|
||
keep_mask = torch.ones(num_classes, dtype=torch.bool, device=tp.device)
|
||
if ignore_indices:
|
||
# 将需要忽略的类别的标记设置为 False
|
||
keep_mask[ignore_indices] = False
|
||
logging.info(f"评估时将忽略类别: {config.EVALUATION_CLASSES_TO_IGNORE}")
|
||
|
||
# 3. 使用掩码过滤统计数据
|
||
tp_filtered = tp[keep_mask]
|
||
fp_filtered = fp[keep_mask]
|
||
fn_filtered = fn[keep_mask]
|
||
tn_filtered = tn[keep_mask]
|
||
# ======================== 【新增代码段结束】 ======================== #
|
||
|
||
|
||
# ======================== 【修改代码段开始】 ======================== #
|
||
# 使用过滤后的 tp_filtered, fp_filtered 等张量来计算所有宏观指标
|
||
metrics = {
|
||
"val_loss": total_loss / len(loader),
|
||
"iou_score": smp.metrics.iou_score(tp_filtered, fp_filtered, fn_filtered, tn_filtered, reduction='micro').item(),
|
||
"f1_score": smp.metrics.f1_score(tp_filtered, fp_filtered, fn_filtered, tn_filtered, reduction='micro').item(),
|
||
"fbeta_score": smp.metrics.fbeta_score(tp_filtered, fp_filtered, fn_filtered, tn_filtered, beta=beta, reduction='micro').item(),
|
||
"accuracy": smp.metrics.accuracy(tp_filtered, fp_filtered, fn_filtered, tn_filtered, reduction='micro').item(),
|
||
"recall": smp.metrics.recall(tp_filtered, fp_filtered, fn_filtered, tn_filtered, reduction='micro').item(),
|
||
"precision": smp.metrics.precision(tp_filtered, fp_filtered, fn_filtered, tn_filtered, reduction='micro').item(),
|
||
# ... 对所有使用 smp.metrics 的宏观指标进行同样修改 ...
|
||
"specificity": smp.metrics.specificity(tp_filtered, fp_filtered, fn_filtered, tn_filtered, reduction='micro').item(),
|
||
"negative_predictive_value": smp.metrics.negative_predictive_value(tp_filtered, fp_filtered, fn_filtered, tn_filtered, reduction='micro').item(),
|
||
"false_negative_rate": smp.metrics.false_negative_rate(tp_filtered, fp_filtered, fn_filtered, tn_filtered, reduction='micro').item(),
|
||
"false_positive_rate": smp.metrics.false_positive_rate(tp_filtered, fp_filtered, fn_filtered, tn_filtered, reduction='micro').item(),
|
||
"false_discovery_rate": smp.metrics.false_discovery_rate(tp_filtered, fp_filtered, fn_filtered, tn_filtered, reduction='micro').item(),
|
||
"false_omission_rate": smp.metrics.false_omission_rate(tp_filtered, fp_filtered, fn_filtered, tn_filtered, reduction='micro').item(),
|
||
"positive_likelihood_ratio": smp.metrics.positive_likelihood_ratio(tp_filtered, fp_filtered, fn_filtered, tn_filtered, reduction='micro').item(),
|
||
"negative_likelihood_ratio": smp.metrics.negative_likelihood_ratio(tp_filtered, fp_filtered, fn_filtered, tn_filtered, reduction='micro').item(),
|
||
}
|
||
# ======================== 【修改代码段结束】 ======================== #
|
||
|
||
cpa_per_class = smp.metrics.recall(tp, fp, fn, tn, reduction='none').cpu().numpy().flatten()
|
||
if len(cpa_per_class) > 1:
|
||
metrics["mpa"] = np.mean(cpa_per_class[1:]).item()
|
||
for i, cpa in enumerate(cpa_per_class):
|
||
metrics[f"cpa_class_{i}"] = cpa.item()
|
||
else:
|
||
metrics["mpa"] = cpa_per_class[0].item()
|
||
metrics[f"cpa_class_0"] = cpa_per_class[0].item()
|
||
|
||
for i in range(num_classes):
|
||
metrics[f'tp_class_{i}'] = tp[i].item()
|
||
metrics[f'tn_class_{i}'] = tn[i].item()
|
||
metrics[f'fp_class_{i}'] = fp[i].item()
|
||
metrics[f'fn_class_{i}'] = fn[i].item()
|
||
|
||
model.train()
|
||
return metrics
|
||
|
||
# ------------------- 数据增强 (保持不变) ------------------- #
|
||
def get_training_augmentation(height, width):
|
||
return A.Compose([
|
||
A.Resize(height, width),
|
||
A.HorizontalFlip(p=0.5),
|
||
A.VerticalFlip(p=0.5),
|
||
A.RandomRotate90(p=0.5),
|
||
A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
||
ToTensorV2(),
|
||
])
|
||
|
||
def get_validation_augmentation(height, width):
|
||
return A.Compose([
|
||
A.Resize(height, width),
|
||
A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
||
ToTensorV2(),
|
||
])
|
||
|
||
# TOOL - 获取用于预测的图像预处理流程
|
||
def get_preprocessing_transform(height: int, width: int) -> A.Compose:
|
||
"""获取用于预测的图像预处理流程。"""
|
||
return A.Compose([
|
||
A.Resize(height, width),
|
||
A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
||
ToTensorV2(),
|
||
])
|
||
|
||
# Tool - 为一个全新的训练任务清理并重新创建输出目录。
|
||
def setup_directories(output_dir: Path) -> None:
|
||
"""
|
||
为一个全新的训练任务清理并重新创建输出目录。
|
||
|
||
Args:
|
||
output_dir (Path): 指向主输出目录的路径对象。
|
||
"""
|
||
if output_dir.exists():
|
||
logging.info(f"输出目录 '{output_dir}' 已存在。正在删除...")
|
||
shutil.rmtree(output_dir)
|
||
logging.info(f"正在 '{output_dir}' 创建新的空目录。")
|
||
output_dir.mkdir(parents=True, exist_ok=True) |