first commit

This commit is contained in:
admin
2026-05-20 15:05:35 +08:00
commit ac09b26253
2048 changed files with 189478 additions and 0 deletions

View File

@@ -0,0 +1,361 @@
import os, logging, shutil
import torch
import numpy as np
import pandas as pd
from tqdm import tqdm
from torch.amp import autocast
from typing import Dict, Union, List
import albumentations as A
from albumentations.pytorch import ToTensorV2
import segmentation_models_pytorch as smp
import matplotlib.pyplot as plt
from pathlib import Path
import config
# ------------------- 绘图函数 (全新升级) ------------------- #
def log_metrics_to_csv(metrics_dict: Dict[str, Union[int, float]], csv_path: str) -> None:
"""
使用 pandas 将一个指标字典记录到 CSV 文件中。
(此函数保持不变)
"""
df_new_row = pd.DataFrame([metrics_dict])
if not os.path.isfile(csv_path):
df_new_row.to_csv(csv_path, index=False)
else:
df_new_row.to_csv(csv_path, mode='a', header=False, index=False)
def _plot_metric_group(ax, df: pd.DataFrame, group_metrics: List[str], title: str):
"""
内部辅助函数,用于在给定的坐标轴上绘制一组指标。
(此函数保持不变)
"""
epochs_range = df['epoch']
for name in group_metrics:
if f"val_{name}" in df.columns:
metric_col = f"val_{name}"
label = name.replace("_", " ").title()
ax.plot(epochs_range, df[metric_col], 'o-', label=label)
elif f"{name}" in df.columns:
metric_col = f"{name}"
label = name.replace("_", " ").title()
ax.plot(epochs_range, df[metric_col], 'o-', label=label)
else:
print("没有搜索到:", f"val_{name}")
ax.set_title(title)
ax.set_xlabel('Epoch')
ax.set_ylabel('Score')
ax.legend(loc='best', fontsize='small')
ax.grid(True)
def plot_training_progress(csv_path: str, output_dir: str):
"""
从 metrics.csv 文件读取训练历史数据,并生成所有分组和总览图表。
(此函数已修正错误)
"""
os.makedirs(output_dir, exist_ok=True)
try:
df_metrics = pd.read_csv(csv_path)
except FileNotFoundError:
print(f"警告:在路径 {csv_path} 未找到 metrics.csv 文件。跳过绘图。")
return
if len(df_metrics) < 2:
print("警告:数据点不足 (少于2个),无法生成趋势图。")
return
epochs_range = df_metrics['epoch']
train_loss = df_metrics['train_loss']
val_loss = df_metrics['val_loss']
core_metrics = ["iou_score", "fbeta_score", "accuracy"]
pr_metrics = ["precision", "recall", "specificity", "negative_predictive_value"]
error_metrics = ["false_positive_rate", "false_negative_rate", "false_discovery_rate", "false_omission_rate"]
likelihood_metrics = ["positive_likelihood_ratio", "negative_likelihood_ratio"]
# --- 1. 单独绘制并保存每个指标组的图表 ---
# (这部分逻辑没有变化)
fig_loss, ax_loss = plt.subplots(figsize=(10, 6))
ax_loss.plot(epochs_range, train_loss, 'o-', label='Training Loss')
ax_loss.plot(epochs_range, val_loss, 'o-', label='Validation Loss')
ax_loss.set_title('Loss Curves')
ax_loss.set_xlabel('Epoch')
ax_loss.set_ylabel('Loss')
ax_loss.legend()
ax_loss.grid(True)
plt.tight_layout()
plt.savefig(os.path.join(output_dir, "loss_curves.png"))
plt.close(fig_loss)
metric_groups = {
"core_performance": core_metrics,
"precision_recall": pr_metrics,
"error_rates": error_metrics,
"likelihood_ratios": likelihood_metrics
}
for group_name, group_list in metric_groups.items():
fig_group, ax_group = plt.subplots(figsize=(10, 7))
_plot_metric_group(ax_group, df_metrics, group_list, f'{group_name.replace("_", " ").title()} Metrics')
plt.tight_layout()
plt.savefig(os.path.join(output_dir, f"metrics_{group_name}.png"))
plt.close(fig_group)
# --- 2. 绘制并保存 3x2 综合概览图 ---
fig_overview, axes = plt.subplots(3, 2, figsize=(20, 22))
fig_overview.suptitle('Training Progress Overview', fontsize=20)
ax_list = axes.ravel()
# 图 1: 损失曲线 (*** 已修正 ***)
# 使用 ax 而不是 ax
ax_list[0].plot(epochs_range, train_loss, 'o-', label='Training Loss')
ax_list[0].plot(epochs_range, val_loss, 'o-', label='Validation Loss')
ax_list[0].set_title('Loss Curves')
ax_list[0].set_xlabel('Epoch')
ax_list[0].set_ylabel('Loss')
ax_list[0].legend()
ax_list[0].grid(True)
# 图 2: 核心性能指标
# 这里使用 ax 是正确的
_plot_metric_group(ax_list[1], df_metrics, core_metrics, 'Core Performance Metrics')
# 图 3: 精确率/召回率族
# 这里使用 ax 是正确的
_plot_metric_group(ax_list[2], df_metrics, pr_metrics, 'Precision-Recall Family')
# 图 4: 错误率分析
# 这里使用 ax 是正确的
_plot_metric_group(ax_list[3], df_metrics, error_metrics, 'Error Rate Analysis')
# 图 5: 似然比
# 这里使用 ax 是正确的
_plot_metric_group(ax_list[4], df_metrics, likelihood_metrics, 'Likelihood Ratios')
# 隐藏最后一个未使用的子图 (*** 已修正 ***)
# 使用 ax 而不是 ax
ax_list[5].axis('off')
plt.tight_layout(rect=[0, 0, 1, 0.96])
plt.savefig(os.path.join(output_dir, "training_progress_overview.png"))
plt.close(fig_overview)
print(f"所有图表已更新并保存至目录: {output_dir}")
# ------------------- 训练与评估函数 ('vit'模型不适用amp) ------------------- #
def train_fn(loader, model, optimizer, loss_fn, scaler, device, model_architecture: str):
"""一个 epoch 的训练逻辑。"""
loop = tqdm(loader, desc="Training")
total_loss = 0.0
model.train()
# 检查是否为ViT模型决定是否启用AMP
encoder_name = config.ALL_MODEL_CONFIGS[model_architecture].get('encoder_name', '')
use_amp = 'vit' not in encoder_name.lower()
if not use_amp and loop.n == 0: # 只在第一个batch打印一次警告
logging.warning(f"检测到 ViT 编码器 ('{encoder_name}')。本次训练将禁用自动混合精度 (AMP) 以确保稳定性。")
for batch_idx, (data, targets) in enumerate(loop):
data = data.to(device)
targets = targets.to(device)
# 根据 use_amp 标志执行不同代码路径
if use_amp:
# 【AMP 路径】适用于 CNN 等稳定模型
with autocast(device_type='cuda'):
predictions = model(data)
loss = loss_fn(predictions, targets)
optimizer.zero_grad()
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
else:
# 【Float32 路径】适用于 ViT/DPT 等模型
predictions = model(data)
loss = loss_fn(predictions, targets)
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
loop.set_postfix(loss=loss.item())
return total_loss / len(loader)
# --- 【修改】 ---
def evaluate_fn(loader, model, loss_fn, device, num_classes, model_architecture: str, beta=1.0):
"""在验证集上评估模型性能并计算包括原始TP/TN/FP/FN在内的多种指标。"""
model.eval()
loop = tqdm(loader, desc="Validation")
total_loss = 0
all_tp, all_fp, all_fn, all_tn = [], [], [], []
# 检查是否为ViT模型决定是否启用AMP
encoder_name = config.ALL_MODEL_CONFIGS[model_architecture].get('encoder_name', '')
use_amp = 'vit' not in encoder_name.lower()
with torch.no_grad():
for batch_idx, (data, targets) in enumerate(loop):
data = data.to(device)
targets = targets.to(device)
# 根据 use_amp 标志执行不同代码路径
if use_amp:
# 【AMP 路径】
with autocast(device_type='cuda'):
predictions = model(data)
if config.SEG_MODE == 'multilabel':
loss = loss_fn(predictions, targets.float())
else:
loss = loss_fn(predictions, targets.long())
else:
# 【Float32 路径】
predictions = model(data)
if config.SEG_MODE == 'multilabel':
loss = loss_fn(predictions, targets.float())
else:
loss = loss_fn(predictions, targets.long())
total_loss += loss.item()
# --- 【核心修正】根据 SEG_MODE 条件处理 targets ---
# 1. 获取模型预测的类别索引
preds_indices = torch.argmax(predictions, dim=1)
# 2. 根据模式确定真实标签的类别索引
if config.SEG_MODE == 'multilabel':
# 在 multilabel 模式下, targets 是 one-hot, 需要 argmax
targets_indices = torch.argmax(targets, dim=1)
elif config.SEG_MODE == 'multiclass':
# 在 multiclass 模式下, targets 已经是类别索引, 无需处理
targets_indices = targets
# 现在 preds_indices 和 targets_indices 都有正确的形状 [N, H, W]
tp, fp, fn, tn = smp.metrics.get_stats(
preds_indices,
targets_indices.long(), # Ensure it's long type for get_stats
mode=config.SEG_MODE, # Metrics are always calculated in multiclass fashion after argmax
num_classes=num_classes
)
# get_stats in smp 0.3.3 might return per-image stats, let's sum over batch dimension if needed.
# Assuming get_stats returns (B-size, C,) tensor per image in batch, we stack and sum. -> (B-num, B-size, C,)
all_tp.append(tp)
all_fp.append(fp)
all_fn.append(fn)
all_tn.append(tn)
# Sum up stats from all batches (Batch-num, Batch-size, C,) -> (C,)
tp = torch.cat(all_tp, dim=0).sum(dim=0) # (N, B, C) -> (C,)
fp = torch.cat(all_fp, dim=0).sum(dim=0) # (N, B, C) -> (C,)
fn = torch.cat(all_fn, dim=0).sum(dim=0) # (N, B, C) -> (C,)
tn = torch.cat(all_tn, dim=0).sum(dim=0) # (N, B, C) -> (C,)
# ======================== 【新增代码段开始】 ======================== #
# 根据 config.py 中的设置,筛选出用于计算宏观指标的类别
# 1. 获取需要忽略的类别的索引
ignore_indices = [config.CLASSES.index(cls) for cls in config.EVALUATION_CLASSES_TO_IGNORE if cls in config.CLASSES]
# 2. 创建一个布尔掩码,标记哪些类别需要被保留
# 默认保留所有类别
keep_mask = torch.ones(num_classes, dtype=torch.bool, device=tp.device)
if ignore_indices:
# 将需要忽略的类别的标记设置为 False
keep_mask[ignore_indices] = False
logging.info(f"评估时将忽略类别: {config.EVALUATION_CLASSES_TO_IGNORE}")
# 3. 使用掩码过滤统计数据
tp_filtered = tp[keep_mask]
fp_filtered = fp[keep_mask]
fn_filtered = fn[keep_mask]
tn_filtered = tn[keep_mask]
# ======================== 【新增代码段结束】 ======================== #
# ======================== 【修改代码段开始】 ======================== #
# 使用过滤后的 tp_filtered, fp_filtered 等张量来计算所有宏观指标
metrics = {
"val_loss": total_loss / len(loader),
"iou_score": smp.metrics.iou_score(tp_filtered, fp_filtered, fn_filtered, tn_filtered, reduction='micro').item(),
"f1_score": smp.metrics.f1_score(tp_filtered, fp_filtered, fn_filtered, tn_filtered, reduction='micro').item(),
"fbeta_score": smp.metrics.fbeta_score(tp_filtered, fp_filtered, fn_filtered, tn_filtered, beta=beta, reduction='micro').item(),
"accuracy": smp.metrics.accuracy(tp_filtered, fp_filtered, fn_filtered, tn_filtered, reduction='micro').item(),
"recall": smp.metrics.recall(tp_filtered, fp_filtered, fn_filtered, tn_filtered, reduction='micro').item(),
"precision": smp.metrics.precision(tp_filtered, fp_filtered, fn_filtered, tn_filtered, reduction='micro').item(),
# ... 对所有使用 smp.metrics 的宏观指标进行同样修改 ...
"specificity": smp.metrics.specificity(tp_filtered, fp_filtered, fn_filtered, tn_filtered, reduction='micro').item(),
"negative_predictive_value": smp.metrics.negative_predictive_value(tp_filtered, fp_filtered, fn_filtered, tn_filtered, reduction='micro').item(),
"false_negative_rate": smp.metrics.false_negative_rate(tp_filtered, fp_filtered, fn_filtered, tn_filtered, reduction='micro').item(),
"false_positive_rate": smp.metrics.false_positive_rate(tp_filtered, fp_filtered, fn_filtered, tn_filtered, reduction='micro').item(),
"false_discovery_rate": smp.metrics.false_discovery_rate(tp_filtered, fp_filtered, fn_filtered, tn_filtered, reduction='micro').item(),
"false_omission_rate": smp.metrics.false_omission_rate(tp_filtered, fp_filtered, fn_filtered, tn_filtered, reduction='micro').item(),
"positive_likelihood_ratio": smp.metrics.positive_likelihood_ratio(tp_filtered, fp_filtered, fn_filtered, tn_filtered, reduction='micro').item(),
"negative_likelihood_ratio": smp.metrics.negative_likelihood_ratio(tp_filtered, fp_filtered, fn_filtered, tn_filtered, reduction='micro').item(),
}
# ======================== 【修改代码段结束】 ======================== #
cpa_per_class = smp.metrics.recall(tp, fp, fn, tn, reduction='none').cpu().numpy().flatten()
if len(cpa_per_class) > 1:
metrics["mpa"] = np.mean(cpa_per_class[1:]).item()
for i, cpa in enumerate(cpa_per_class):
metrics[f"cpa_class_{i}"] = cpa.item()
else:
metrics["mpa"] = cpa_per_class[0].item()
metrics[f"cpa_class_0"] = cpa_per_class[0].item()
for i in range(num_classes):
metrics[f'tp_class_{i}'] = tp[i].item()
metrics[f'tn_class_{i}'] = tn[i].item()
metrics[f'fp_class_{i}'] = fp[i].item()
metrics[f'fn_class_{i}'] = fn[i].item()
model.train()
return metrics
# ------------------- 数据增强 (保持不变) ------------------- #
def get_training_augmentation(height, width):
return A.Compose([
A.Resize(height, width),
A.HorizontalFlip(p=0.5),
A.VerticalFlip(p=0.5),
A.RandomRotate90(p=0.5),
A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
ToTensorV2(),
])
def get_validation_augmentation(height, width):
return A.Compose([
A.Resize(height, width),
A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
ToTensorV2(),
])
# TOOL - 获取用于预测的图像预处理流程
def get_preprocessing_transform(height: int, width: int) -> A.Compose:
"""获取用于预测的图像预处理流程。"""
return A.Compose([
A.Resize(height, width),
A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
ToTensorV2(),
])
# Tool - 为一个全新的训练任务清理并重新创建输出目录。
def setup_directories(output_dir: Path) -> None:
"""
为一个全新的训练任务清理并重新创建输出目录。
Args:
output_dir (Path): 指向主输出目录的路径对象。
"""
if output_dir.exists():
logging.info(f"输出目录 '{output_dir}' 已存在。正在删除...")
shutil.rmtree(output_dir)
logging.info(f"正在 '{output_dir}' 创建新的空目录。")
output_dir.mkdir(parents=True, exist_ok=True)