Files
Seg_Data_Server/Seg_All_In_One_SegModel/train.py
2026-05-20 15:05:35 +08:00

312 lines
14 KiB
Python

import logging, argparse, shutil
from datetime import datetime
from pathlib import Path
from typing import Dict, Tuple
import gc
import numpy as np
import segmentation_models_pytorch as smp
import torch
import torch.optim as optim
from torch.optim import lr_scheduler # 学习率调度器
from torch.amp import GradScaler
from torch.utils.data import DataLoader
# 本地应用/库的导入
import config, os
import utils
from dataset import SegmentationDataset
from loss import MultiClassLoss, UNetPlusPlusLoss
# --- 日志设置 ---
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(message)s",
handlers=[
logging.StreamHandler() # 输出日志到控制台
]
)
# --- 辅助函数:移动训练结果文件夹从源输出目录移动到配置的硬盘目录 ---
def move_results_to_hardisk(project_folder: str):
"""
将指定的训练结果文件夹从源输出目录移动到配置的硬盘目录。
移动操作包括复制整个文件夹树,然后在复制成功后删除原始文件夹。
参数:
project_folder (str): 要移动的项目文件夹的名称。
该名称通常由模型架构和时间戳构成。
"""
source_dir = Path(config.OUTPUTS_DIR) / project_folder
# 确保目标硬盘目录存在
outputs_folder_name = Path(config.OUTPUTS_DIR).name
destination_dir = Path(config.HARDISK_DIR) / outputs_folder_name / project_folder
destination_dir.parent.mkdir(parents=True, exist_ok=True)
logging.info(f"准备将结果从 {source_dir} 移动到 {destination_dir}...")
try:
# 步骤 1: 复制文件夹
shutil.copytree(source_dir, destination_dir)
logging.info(f"成功复制结果到: {destination_dir}")
# 步骤 2: 复制成功后,删除原文件夹
logging.info(f"正在删除原文件夹: {source_dir}")
shutil.rmtree(source_dir)
logging.info("成功删除原文件夹。")
logging.info(f"任务完成,最终结果已保存至: {destination_dir}")
except Exception as e:
logging.error(f"移动文件夹时发生错误: {e}")
logging.warning(f"原始训练结果仍保留在: {source_dir}")
# 1. 根据配置文件初始化模型、损失函数和优化器。
def initialize_components(model_architecture:str, num_classes: int, seg_mode: str) -> Tuple:
"""
根据配置文件动态初始化模型、损失函数和优化器。
"""
# --- 初始化模型 ---
logging.info(f"正在初始化模型: '{model_architecture}'...")
# 使用传入的 model_architecture 作为 key 来获取参数
try:
model_params = config.ALL_MODEL_CONFIGS[model_architecture]
model_class = getattr(smp, model_architecture)
except KeyError:
logging.error(f"模型 '{model_architecture}' 的配置未在 config.py 的 ALL_MODEL_CONFIGS 中定义!")
raise
except AttributeError:
logging.error(f"模型 '{model_architecture}' 在 segmentation_models_pytorch 库中不存在!")
raise
# 2. 准备参数字典
# 首先复制 config 中的参数,然后添加固定的 `in_channels` 和 `classes`
params = model_params.copy()
params['in_channels'] = 3
params['classes'] = num_classes
# ======================== 【新增代码段开始】 ======================== #
# 自动检测 encoder_name 是否包含 'vit'
encoder_name = model_params.get('encoder_name', '')
if 'vit' in encoder_name.lower():
params['dynamic_img_size'] = True
logging.info(f"检测到 ViT 编码器 ('{encoder_name}')。自动设置 dynamic_img_size=True。")
# ======================== 【新增代码段结束】 ======================== #
# 3. 使用字典解包 (**) 将所有参数传递给模型构造函数
# 这种方式非常灵活,无论你在 config 中定义了多少参数,都能正确传递
model = model_class(**params).to(config.DEVICE)
# --- 设置损失函数 (保持不变) ---
logging.info(f"'{seg_mode}' 模式设置损失函数。")
if seg_mode == 'multiclass':
loss_fn = MultiClassLoss(mode=seg_mode)
elif seg_mode == 'multilabel':
loss_fn = UNetPlusPlusLoss(mode=seg_mode)
else:
raise ValueError(f"无效的 SEG_MODE: '{seg_mode}'。必须是 'multiclass''multilabel'")
# --- 优化器和梯度缩放器 (保持不变) ---
optimizer = optim.Adam(model.parameters(), lr=config.LEARNING_RATE, weight_decay=config.WEIGHT_DECAY)
# --- TODO 添加学习率调度器 TODO ---
# T_max 是调度器周期的最大迭代次数,通常设置为总的 epoch 数量
scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=config.EPOCHS, eta_min=1e-6)
# -----------------------
scaler = GradScaler('cuda')
return model, loss_fn, optimizer, scaler
def main(model_architecture: str) -> None:
"""
编排图像分割模型的主要训练流程。
该函数执行以下步骤:
1. 设置输出目录。
2. 准备数据增强,并为训练集和验证集创建 DataLoader 实例。
3. 初始化 U-Net++ 模型、损失函数、优化器和梯度缩放器。
4. 运行主训练循环,其中包括:
- 训练一个 epoch。
- 在验证集上评估模型。
- 将各项指标记录到 CSV 文件。
- 基于验证损失实现早停机制。
- 保存性能最佳的模型和定期的断点。
- 绘制训练进度曲线图。
返回:
str: 本次训练运行的文件夹名称 (run_name)。
"""
logging.info(f"使用设备: {config.DEVICE}")
# --- 1.1. 使用 pathlib 进行现代化的路径管理 ---
timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
run_name = f"{model_architecture}_{timestamp}"
output_dir = Path(config.OUTPUTS_DIR) / run_name # <-- 使用 / 操作符
metrics_csv_path = output_dir / config.METRICS_CSV_NAME
best_model_path = output_dir / config.BEST_MODEL_SAVE_NAME
utils.setup_directories(output_dir)
# --- 1.2. 探测模型以确定最佳图像尺寸 ---
logging.info("正在探测模型以确定输入尺寸要求...") #
num_classes = len(config.CLASSES) #
# 创建一个临时模型实例,仅用于检查其属性
# 注意:这里的调用现在更简单,不需要传递 dynamic_img_size
probe_model, _, _, _ = initialize_components(
model_architecture, num_classes, config.SEG_MODE #
)
target_height, target_width = config.IMAGE_HEIGHT, config.IMAGE_WIDTH #
try:
# 检查编码器是否需要固定输入尺寸
if probe_model.encoder.is_fixed_input_size:
required_size = probe_model.encoder.input_size
target_height = required_size[1]
target_width = required_size[2]
logging.warning(
f"模型 '{model_architecture}' 的编码器需要固定的输入尺寸 "
f"({target_height}x{target_width})。"
f"将忽略 config.py 中的尺寸设置。"
)
else:
logging.info(f"模型 '{model_architecture}' 支持动态输入尺寸。将使用 config.py 中定义的尺寸 ({target_height}x{target_width})。")
except AttributeError:
logging.info("无法自动检测模型尺寸要求。将使用 config.py 中定义的尺寸。")
del probe_model # 删除临时模型,释放内存
# --- 1.3. 数据加载 ---
logging.info(f"正在设置数据增强和加载器,目标尺寸为 {target_height}x{target_width}...") #
# 使用探测到的或配置中指定的目标尺寸
train_transform = utils.get_training_augmentation(target_height, target_width) #
val_transform = utils.get_validation_augmentation(target_height, target_width) #
train_dataset = SegmentationDataset(
image_dir=Path(config.TRAIN_IMAGE_DIR),
mask_dir=Path(config.TRAIN_MASK_DIR),
classes=config.CLASSES,
class_rgb_values=config.CLASS_RGB_VALUES,
augmentation=train_transform,
seg_mode=config.SEG_MODE
)
val_dataset = SegmentationDataset(
# 已修正路径,使用配置文件中的验证数据目录
image_dir=Path(config.VAL_IMAGE_DIR),
mask_dir=Path(config.VAL_MASK_DIR),
classes=config.CLASSES,
class_rgb_values=config.CLASS_RGB_VALUES,
augmentation=val_transform,
seg_mode=config.SEG_MODE
)
train_loader = DataLoader(train_dataset, batch_size=config.BATCH_SIZE, shuffle=True, pin_memory=config.PIN_MEMORY, num_workers=config.NUM_WORKERS)
val_loader = DataLoader(val_dataset, batch_size=config.BATCH_SIZE, shuffle=False, pin_memory=config.PIN_MEMORY, num_workers=config.NUM_WORKERS)
# --- 2. 模型、损失函数、优化器、梯度缩放器 ---
num_classes = len(config.CLASSES)
model, loss_fn, optimizer, scaler = initialize_components(model_architecture, num_classes, config.SEG_MODE)
# --- 启用多 GPU 训练 ---
if torch.cuda.device_count() >= 1:
logging.info(f"正在使用 {torch.cuda.device_count()} 块 GPU 进行并行训练。")
model = torch.nn.DataParallel(model)
# --- 3. 训练循环 ---
best_val_loss = float('inf')
epochs_no_improve = 0
for epoch in range(config.EPOCHS):
logging.info(f"\n--- 第 {epoch+1}/{config.EPOCHS} 轮 ---")
# --- 3.1. 训练集训练数据 ---
train_loss = utils.train_fn(train_loader, model, optimizer, loss_fn, scaler, config.DEVICE, model_architecture)
# --- 3.2. 验证集测试数据【获得指标】 ---
val_metrics_dict = utils.evaluate_fn(val_loader, model, loss_fn, config.DEVICE, num_classes, model_architecture, beta=config.FBETA_BETA)
# --- 3.3. 日志记录 及 显示 ---
log_data = {'epoch': epoch + 1, 'train_loss': train_loss}
log_data.update(val_metrics_dict)
utils.log_metrics_to_csv(log_data, metrics_csv_path)
val_loss = val_metrics_dict.get('val_loss')
iou_score = val_metrics_dict.get('iou_score', 0)
fbeta_score = val_metrics_dict.get('fbeta_score', 0)
logging.info(f"训练集Loss: {train_loss:.4f} | 验证集Loss: {val_loss:.4f} | IoU: {iou_score:.4f} | F-beta: {fbeta_score:.4f}")
# --- 3.4. 早停与模型检查点 ---
if val_loss < best_val_loss:
best_val_loss = val_loss
epochs_no_improve = 0
if isinstance(model, torch.nn.DataParallel):
torch.save(model.module.state_dict(), best_model_path)
else:
torch.save(model.state_dict(), best_model_path)
logging.info(f"验证损失改善至 {best_val_loss:.4f}。正在保存最佳模型至 '{best_model_path}'")
else:
epochs_no_improve += 1
logging.info(f"验证损失已连续 {epochs_no_improve} 轮未改善。")
if isinstance(config.EARLY_STOPPING_PATIENCE, int) and epochs_no_improve >= config.EARLY_STOPPING_PATIENCE:
logging.info(f"\n训练在 {epoch + 1} 轮后触发早停。")
logging.warning(f"验证损失已连续 {config.EARLY_STOPPING_PATIENCE} 轮未改善。")
break
# --- 3.5. 检查点及曲线图保存 ---
checkpoint_path = output_dir / f"epoch_{epoch+1}.pth"
if isinstance(model, torch.nn.DataParallel):
torch.save(model.module.state_dict(), checkpoint_path)
else:
torch.save(model.state_dict(), checkpoint_path)
logging.info(f"{epoch+1} 轮的检查点已保存至 '{checkpoint_path}'")
# ---------- 清理旧检查点 ----------
# 只在 epoch > 0 时才开始删,避免第一轮就误删
if epoch > 0:
# 要尝试删除的 epoch 编号 = 上一轮
last_epoch = epoch
# 如果上一轮不是“10 的整数倍”就删掉
if last_epoch % config.CKPT_SAVE_INTERVAL != 0:
old_ckpt = output_dir / f"epoch_{last_epoch}.pth"
if old_ckpt.exists():
old_ckpt.unlink() # 文件删除
utils.plot_training_progress(metrics_csv_path, output_dir)
logging.info("所有训练曲线图已更新并保存。")
return run_name
if __name__ == "__main__":
# 创建一个参数解析器
parser = argparse.ArgumentParser(description="训练图像分割模型。")
# 添加 --architecture 参数
# - `choices` 会自动从你的 config 文件中获取所有可用的模型名称
# - `default` 设置了在不提供参数时的默认模型
parser.add_argument(
"-a", "--architecture",
type=str,
choices=list(config.ALL_MODEL_CONFIGS.keys()),
required=True,
help="选择要训练的模型架构。"
)
# 解析命令行传入的参数
args = parser.parse_args()
# 将解析出的架构名称传入 main 函数并执行
project_folder = None
try:
project_folder = main(model_architecture=args.architecture)
finally:
# 强制进行垃圾回收,释放 Python 对象占用的系统内存 (RAM)
gc.collect()
logging.info("已执行垃圾回收 (gc.collect())。")
# 清理 PyTorch 在 CUDA 上缓存的显存
torch.cuda.empty_cache()
logging.info("训练结束,已清理 CUDA 缓存。")
# 训练和清理完成后,移动结果到硬盘
if project_folder:
move_results_to_hardisk(project_folder)
else:
logging.error("由于训练失败或未生成项目文件夹,未执行结果移动操作。")