312 lines
14 KiB
Python
312 lines
14 KiB
Python
import logging, argparse, shutil
|
|
from datetime import datetime
|
|
from pathlib import Path
|
|
from typing import Dict, Tuple
|
|
import gc
|
|
|
|
import numpy as np
|
|
import segmentation_models_pytorch as smp
|
|
import torch
|
|
import torch.optim as optim
|
|
from torch.optim import lr_scheduler # 学习率调度器
|
|
from torch.amp import GradScaler
|
|
from torch.utils.data import DataLoader
|
|
|
|
# 本地应用/库的导入
|
|
import config, os
|
|
import utils
|
|
from dataset import SegmentationDataset
|
|
from loss import MultiClassLoss, UNetPlusPlusLoss
|
|
|
|
# --- 日志设置 ---
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format="%(asctime)s [%(levelname)s] %(message)s",
|
|
handlers=[
|
|
logging.StreamHandler() # 输出日志到控制台
|
|
]
|
|
)
|
|
|
|
# --- 辅助函数:移动训练结果文件夹从源输出目录移动到配置的硬盘目录 ---
|
|
def move_results_to_hardisk(project_folder: str):
|
|
"""
|
|
将指定的训练结果文件夹从源输出目录移动到配置的硬盘目录。
|
|
移动操作包括复制整个文件夹树,然后在复制成功后删除原始文件夹。
|
|
|
|
参数:
|
|
project_folder (str): 要移动的项目文件夹的名称。
|
|
该名称通常由模型架构和时间戳构成。
|
|
"""
|
|
source_dir = Path(config.OUTPUTS_DIR) / project_folder
|
|
# 确保目标硬盘目录存在
|
|
outputs_folder_name = Path(config.OUTPUTS_DIR).name
|
|
destination_dir = Path(config.HARDISK_DIR) / outputs_folder_name / project_folder
|
|
destination_dir.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
logging.info(f"准备将结果从 {source_dir} 移动到 {destination_dir}...")
|
|
|
|
try:
|
|
# 步骤 1: 复制文件夹
|
|
shutil.copytree(source_dir, destination_dir)
|
|
logging.info(f"成功复制结果到: {destination_dir}")
|
|
|
|
# 步骤 2: 复制成功后,删除原文件夹
|
|
logging.info(f"正在删除原文件夹: {source_dir}")
|
|
shutil.rmtree(source_dir)
|
|
logging.info("成功删除原文件夹。")
|
|
logging.info(f"任务完成,最终结果已保存至: {destination_dir}")
|
|
|
|
except Exception as e:
|
|
logging.error(f"移动文件夹时发生错误: {e}")
|
|
logging.warning(f"原始训练结果仍保留在: {source_dir}")
|
|
|
|
# 1. 根据配置文件初始化模型、损失函数和优化器。
|
|
def initialize_components(model_architecture:str, num_classes: int, seg_mode: str) -> Tuple:
|
|
"""
|
|
根据配置文件动态初始化模型、损失函数和优化器。
|
|
"""
|
|
# --- 初始化模型 ---
|
|
logging.info(f"正在初始化模型: '{model_architecture}'...")
|
|
|
|
# 使用传入的 model_architecture 作为 key 来获取参数
|
|
try:
|
|
model_params = config.ALL_MODEL_CONFIGS[model_architecture]
|
|
model_class = getattr(smp, model_architecture)
|
|
except KeyError:
|
|
logging.error(f"模型 '{model_architecture}' 的配置未在 config.py 的 ALL_MODEL_CONFIGS 中定义!")
|
|
raise
|
|
except AttributeError:
|
|
logging.error(f"模型 '{model_architecture}' 在 segmentation_models_pytorch 库中不存在!")
|
|
raise
|
|
|
|
# 2. 准备参数字典
|
|
# 首先复制 config 中的参数,然后添加固定的 `in_channels` 和 `classes`
|
|
params = model_params.copy()
|
|
params['in_channels'] = 3
|
|
params['classes'] = num_classes
|
|
|
|
# ======================== 【新增代码段开始】 ======================== #
|
|
# 自动检测 encoder_name 是否包含 'vit'
|
|
encoder_name = model_params.get('encoder_name', '')
|
|
if 'vit' in encoder_name.lower():
|
|
params['dynamic_img_size'] = True
|
|
logging.info(f"检测到 ViT 编码器 ('{encoder_name}')。自动设置 dynamic_img_size=True。")
|
|
# ======================== 【新增代码段结束】 ======================== #
|
|
|
|
# 3. 使用字典解包 (**) 将所有参数传递给模型构造函数
|
|
# 这种方式非常灵活,无论你在 config 中定义了多少参数,都能正确传递
|
|
model = model_class(**params).to(config.DEVICE)
|
|
|
|
# --- 设置损失函数 (保持不变) ---
|
|
logging.info(f"为 '{seg_mode}' 模式设置损失函数。")
|
|
if seg_mode == 'multiclass':
|
|
loss_fn = MultiClassLoss(mode=seg_mode)
|
|
elif seg_mode == 'multilabel':
|
|
loss_fn = UNetPlusPlusLoss(mode=seg_mode)
|
|
else:
|
|
raise ValueError(f"无效的 SEG_MODE: '{seg_mode}'。必须是 'multiclass' 或 'multilabel'。")
|
|
|
|
# --- 优化器和梯度缩放器 (保持不变) ---
|
|
optimizer = optim.Adam(model.parameters(), lr=config.LEARNING_RATE, weight_decay=config.WEIGHT_DECAY)
|
|
# --- TODO 添加学习率调度器 TODO ---
|
|
# T_max 是调度器周期的最大迭代次数,通常设置为总的 epoch 数量
|
|
scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=config.EPOCHS, eta_min=1e-6)
|
|
# -----------------------
|
|
scaler = GradScaler('cuda')
|
|
|
|
return model, loss_fn, optimizer, scaler
|
|
|
|
def main(model_architecture: str) -> None:
|
|
"""
|
|
编排图像分割模型的主要训练流程。
|
|
|
|
该函数执行以下步骤:
|
|
1. 设置输出目录。
|
|
2. 准备数据增强,并为训练集和验证集创建 DataLoader 实例。
|
|
3. 初始化 U-Net++ 模型、损失函数、优化器和梯度缩放器。
|
|
4. 运行主训练循环,其中包括:
|
|
- 训练一个 epoch。
|
|
- 在验证集上评估模型。
|
|
- 将各项指标记录到 CSV 文件。
|
|
- 基于验证损失实现早停机制。
|
|
- 保存性能最佳的模型和定期的断点。
|
|
- 绘制训练进度曲线图。
|
|
返回:
|
|
str: 本次训练运行的文件夹名称 (run_name)。
|
|
"""
|
|
logging.info(f"使用设备: {config.DEVICE}")
|
|
|
|
# --- 1.1. 使用 pathlib 进行现代化的路径管理 ---
|
|
timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
|
run_name = f"{model_architecture}_{timestamp}"
|
|
output_dir = Path(config.OUTPUTS_DIR) / run_name # <-- 使用 / 操作符
|
|
metrics_csv_path = output_dir / config.METRICS_CSV_NAME
|
|
best_model_path = output_dir / config.BEST_MODEL_SAVE_NAME
|
|
|
|
utils.setup_directories(output_dir)
|
|
|
|
# --- 1.2. 探测模型以确定最佳图像尺寸 ---
|
|
logging.info("正在探测模型以确定输入尺寸要求...") #
|
|
num_classes = len(config.CLASSES) #
|
|
|
|
# 创建一个临时模型实例,仅用于检查其属性
|
|
# 注意:这里的调用现在更简单,不需要传递 dynamic_img_size
|
|
probe_model, _, _, _ = initialize_components(
|
|
model_architecture, num_classes, config.SEG_MODE #
|
|
)
|
|
|
|
target_height, target_width = config.IMAGE_HEIGHT, config.IMAGE_WIDTH #
|
|
|
|
try:
|
|
# 检查编码器是否需要固定输入尺寸
|
|
if probe_model.encoder.is_fixed_input_size:
|
|
required_size = probe_model.encoder.input_size
|
|
target_height = required_size[1]
|
|
target_width = required_size[2]
|
|
logging.warning(
|
|
f"模型 '{model_architecture}' 的编码器需要固定的输入尺寸 "
|
|
f"({target_height}x{target_width})。"
|
|
f"将忽略 config.py 中的尺寸设置。"
|
|
)
|
|
else:
|
|
logging.info(f"模型 '{model_architecture}' 支持动态输入尺寸。将使用 config.py 中定义的尺寸 ({target_height}x{target_width})。")
|
|
except AttributeError:
|
|
logging.info("无法自动检测模型尺寸要求。将使用 config.py 中定义的尺寸。")
|
|
|
|
del probe_model # 删除临时模型,释放内存
|
|
|
|
# --- 1.3. 数据加载 ---
|
|
logging.info(f"正在设置数据增强和加载器,目标尺寸为 {target_height}x{target_width}...") #
|
|
# 使用探测到的或配置中指定的目标尺寸
|
|
train_transform = utils.get_training_augmentation(target_height, target_width) #
|
|
val_transform = utils.get_validation_augmentation(target_height, target_width) #
|
|
|
|
train_dataset = SegmentationDataset(
|
|
image_dir=Path(config.TRAIN_IMAGE_DIR),
|
|
mask_dir=Path(config.TRAIN_MASK_DIR),
|
|
classes=config.CLASSES,
|
|
class_rgb_values=config.CLASS_RGB_VALUES,
|
|
augmentation=train_transform,
|
|
seg_mode=config.SEG_MODE
|
|
)
|
|
val_dataset = SegmentationDataset(
|
|
# 已修正路径,使用配置文件中的验证数据目录
|
|
image_dir=Path(config.VAL_IMAGE_DIR),
|
|
mask_dir=Path(config.VAL_MASK_DIR),
|
|
classes=config.CLASSES,
|
|
class_rgb_values=config.CLASS_RGB_VALUES,
|
|
augmentation=val_transform,
|
|
seg_mode=config.SEG_MODE
|
|
)
|
|
|
|
train_loader = DataLoader(train_dataset, batch_size=config.BATCH_SIZE, shuffle=True, pin_memory=config.PIN_MEMORY, num_workers=config.NUM_WORKERS)
|
|
val_loader = DataLoader(val_dataset, batch_size=config.BATCH_SIZE, shuffle=False, pin_memory=config.PIN_MEMORY, num_workers=config.NUM_WORKERS)
|
|
|
|
# --- 2. 模型、损失函数、优化器、梯度缩放器 ---
|
|
num_classes = len(config.CLASSES)
|
|
model, loss_fn, optimizer, scaler = initialize_components(model_architecture, num_classes, config.SEG_MODE)
|
|
|
|
# --- 启用多 GPU 训练 ---
|
|
if torch.cuda.device_count() >= 1:
|
|
logging.info(f"正在使用 {torch.cuda.device_count()} 块 GPU 进行并行训练。")
|
|
model = torch.nn.DataParallel(model)
|
|
|
|
# --- 3. 训练循环 ---
|
|
best_val_loss = float('inf')
|
|
epochs_no_improve = 0
|
|
|
|
for epoch in range(config.EPOCHS):
|
|
logging.info(f"\n--- 第 {epoch+1}/{config.EPOCHS} 轮 ---")
|
|
# --- 3.1. 训练集训练数据 ---
|
|
train_loss = utils.train_fn(train_loader, model, optimizer, loss_fn, scaler, config.DEVICE, model_architecture)
|
|
# --- 3.2. 验证集测试数据【获得指标】 ---
|
|
val_metrics_dict = utils.evaluate_fn(val_loader, model, loss_fn, config.DEVICE, num_classes, model_architecture, beta=config.FBETA_BETA)
|
|
|
|
# --- 3.3. 日志记录 及 显示 ---
|
|
log_data = {'epoch': epoch + 1, 'train_loss': train_loss}
|
|
log_data.update(val_metrics_dict)
|
|
utils.log_metrics_to_csv(log_data, metrics_csv_path)
|
|
|
|
val_loss = val_metrics_dict.get('val_loss')
|
|
iou_score = val_metrics_dict.get('iou_score', 0)
|
|
fbeta_score = val_metrics_dict.get('fbeta_score', 0)
|
|
|
|
logging.info(f"训练集Loss: {train_loss:.4f} | 验证集Loss: {val_loss:.4f} | IoU: {iou_score:.4f} | F-beta: {fbeta_score:.4f}")
|
|
|
|
# --- 3.4. 早停与模型检查点 ---
|
|
if val_loss < best_val_loss:
|
|
best_val_loss = val_loss
|
|
epochs_no_improve = 0
|
|
if isinstance(model, torch.nn.DataParallel):
|
|
torch.save(model.module.state_dict(), best_model_path)
|
|
else:
|
|
torch.save(model.state_dict(), best_model_path)
|
|
logging.info(f"验证损失改善至 {best_val_loss:.4f}。正在保存最佳模型至 '{best_model_path}'。")
|
|
else:
|
|
epochs_no_improve += 1
|
|
logging.info(f"验证损失已连续 {epochs_no_improve} 轮未改善。")
|
|
|
|
if isinstance(config.EARLY_STOPPING_PATIENCE, int) and epochs_no_improve >= config.EARLY_STOPPING_PATIENCE:
|
|
logging.info(f"\n训练在 {epoch + 1} 轮后触发早停。")
|
|
logging.warning(f"验证损失已连续 {config.EARLY_STOPPING_PATIENCE} 轮未改善。")
|
|
break
|
|
|
|
# --- 3.5. 检查点及曲线图保存 ---
|
|
checkpoint_path = output_dir / f"epoch_{epoch+1}.pth"
|
|
if isinstance(model, torch.nn.DataParallel):
|
|
torch.save(model.module.state_dict(), checkpoint_path)
|
|
else:
|
|
torch.save(model.state_dict(), checkpoint_path)
|
|
logging.info(f"第 {epoch+1} 轮的检查点已保存至 '{checkpoint_path}'。")
|
|
|
|
# ---------- 清理旧检查点 ----------
|
|
# 只在 epoch > 0 时才开始删,避免第一轮就误删
|
|
if epoch > 0:
|
|
# 要尝试删除的 epoch 编号 = 上一轮
|
|
last_epoch = epoch
|
|
# 如果上一轮不是“10 的整数倍”就删掉
|
|
if last_epoch % config.CKPT_SAVE_INTERVAL != 0:
|
|
old_ckpt = output_dir / f"epoch_{last_epoch}.pth"
|
|
if old_ckpt.exists():
|
|
old_ckpt.unlink() # 文件删除
|
|
|
|
utils.plot_training_progress(metrics_csv_path, output_dir)
|
|
logging.info("所有训练曲线图已更新并保存。")
|
|
|
|
return run_name
|
|
|
|
if __name__ == "__main__":
|
|
# 创建一个参数解析器
|
|
parser = argparse.ArgumentParser(description="训练图像分割模型。")
|
|
|
|
# 添加 --architecture 参数
|
|
# - `choices` 会自动从你的 config 文件中获取所有可用的模型名称
|
|
# - `default` 设置了在不提供参数时的默认模型
|
|
parser.add_argument(
|
|
"-a", "--architecture",
|
|
type=str,
|
|
choices=list(config.ALL_MODEL_CONFIGS.keys()),
|
|
required=True,
|
|
help="选择要训练的模型架构。"
|
|
)
|
|
|
|
# 解析命令行传入的参数
|
|
args = parser.parse_args()
|
|
|
|
# 将解析出的架构名称传入 main 函数并执行
|
|
project_folder = None
|
|
try:
|
|
project_folder = main(model_architecture=args.architecture)
|
|
finally:
|
|
# 强制进行垃圾回收,释放 Python 对象占用的系统内存 (RAM)
|
|
gc.collect()
|
|
logging.info("已执行垃圾回收 (gc.collect())。")
|
|
# 清理 PyTorch 在 CUDA 上缓存的显存
|
|
torch.cuda.empty_cache()
|
|
logging.info("训练结束,已清理 CUDA 缓存。")
|
|
|
|
# 训练和清理完成后,移动结果到硬盘
|
|
if project_folder:
|
|
move_results_to_hardisk(project_folder)
|
|
else:
|
|
logging.error("由于训练失败或未生成项目文件夹,未执行结果移动操作。") |