first commit
This commit is contained in:
312
Seg_All_In_One_SegModel/train.py
Normal file
312
Seg_All_In_One_SegModel/train.py
Normal file
@@ -0,0 +1,312 @@
|
||||
import logging, argparse, shutil
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Dict, Tuple
|
||||
import gc
|
||||
|
||||
import numpy as np
|
||||
import segmentation_models_pytorch as smp
|
||||
import torch
|
||||
import torch.optim as optim
|
||||
from torch.optim import lr_scheduler # 学习率调度器
|
||||
from torch.amp import GradScaler
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
# 本地应用/库的导入
|
||||
import config, os
|
||||
import utils
|
||||
from dataset import SegmentationDataset
|
||||
from loss import MultiClassLoss, UNetPlusPlusLoss
|
||||
|
||||
# --- 日志设置 ---
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s [%(levelname)s] %(message)s",
|
||||
handlers=[
|
||||
logging.StreamHandler() # 输出日志到控制台
|
||||
]
|
||||
)
|
||||
|
||||
# --- 辅助函数:移动训练结果文件夹从源输出目录移动到配置的硬盘目录 ---
|
||||
def move_results_to_hardisk(project_folder: str):
|
||||
"""
|
||||
将指定的训练结果文件夹从源输出目录移动到配置的硬盘目录。
|
||||
移动操作包括复制整个文件夹树,然后在复制成功后删除原始文件夹。
|
||||
|
||||
参数:
|
||||
project_folder (str): 要移动的项目文件夹的名称。
|
||||
该名称通常由模型架构和时间戳构成。
|
||||
"""
|
||||
source_dir = Path(config.OUTPUTS_DIR) / project_folder
|
||||
# 确保目标硬盘目录存在
|
||||
outputs_folder_name = Path(config.OUTPUTS_DIR).name
|
||||
destination_dir = Path(config.HARDISK_DIR) / outputs_folder_name / project_folder
|
||||
destination_dir.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
logging.info(f"准备将结果从 {source_dir} 移动到 {destination_dir}...")
|
||||
|
||||
try:
|
||||
# 步骤 1: 复制文件夹
|
||||
shutil.copytree(source_dir, destination_dir)
|
||||
logging.info(f"成功复制结果到: {destination_dir}")
|
||||
|
||||
# 步骤 2: 复制成功后,删除原文件夹
|
||||
logging.info(f"正在删除原文件夹: {source_dir}")
|
||||
shutil.rmtree(source_dir)
|
||||
logging.info("成功删除原文件夹。")
|
||||
logging.info(f"任务完成,最终结果已保存至: {destination_dir}")
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"移动文件夹时发生错误: {e}")
|
||||
logging.warning(f"原始训练结果仍保留在: {source_dir}")
|
||||
|
||||
# 1. 根据配置文件初始化模型、损失函数和优化器。
|
||||
def initialize_components(model_architecture:str, num_classes: int, seg_mode: str) -> Tuple:
|
||||
"""
|
||||
根据配置文件动态初始化模型、损失函数和优化器。
|
||||
"""
|
||||
# --- 初始化模型 ---
|
||||
logging.info(f"正在初始化模型: '{model_architecture}'...")
|
||||
|
||||
# 使用传入的 model_architecture 作为 key 来获取参数
|
||||
try:
|
||||
model_params = config.ALL_MODEL_CONFIGS[model_architecture]
|
||||
model_class = getattr(smp, model_architecture)
|
||||
except KeyError:
|
||||
logging.error(f"模型 '{model_architecture}' 的配置未在 config.py 的 ALL_MODEL_CONFIGS 中定义!")
|
||||
raise
|
||||
except AttributeError:
|
||||
logging.error(f"模型 '{model_architecture}' 在 segmentation_models_pytorch 库中不存在!")
|
||||
raise
|
||||
|
||||
# 2. 准备参数字典
|
||||
# 首先复制 config 中的参数,然后添加固定的 `in_channels` 和 `classes`
|
||||
params = model_params.copy()
|
||||
params['in_channels'] = 3
|
||||
params['classes'] = num_classes
|
||||
|
||||
# ======================== 【新增代码段开始】 ======================== #
|
||||
# 自动检测 encoder_name 是否包含 'vit'
|
||||
encoder_name = model_params.get('encoder_name', '')
|
||||
if 'vit' in encoder_name.lower():
|
||||
params['dynamic_img_size'] = True
|
||||
logging.info(f"检测到 ViT 编码器 ('{encoder_name}')。自动设置 dynamic_img_size=True。")
|
||||
# ======================== 【新增代码段结束】 ======================== #
|
||||
|
||||
# 3. 使用字典解包 (**) 将所有参数传递给模型构造函数
|
||||
# 这种方式非常灵活,无论你在 config 中定义了多少参数,都能正确传递
|
||||
model = model_class(**params).to(config.DEVICE)
|
||||
|
||||
# --- 设置损失函数 (保持不变) ---
|
||||
logging.info(f"为 '{seg_mode}' 模式设置损失函数。")
|
||||
if seg_mode == 'multiclass':
|
||||
loss_fn = MultiClassLoss(mode=seg_mode)
|
||||
elif seg_mode == 'multilabel':
|
||||
loss_fn = UNetPlusPlusLoss(mode=seg_mode)
|
||||
else:
|
||||
raise ValueError(f"无效的 SEG_MODE: '{seg_mode}'。必须是 'multiclass' 或 'multilabel'。")
|
||||
|
||||
# --- 优化器和梯度缩放器 (保持不变) ---
|
||||
optimizer = optim.Adam(model.parameters(), lr=config.LEARNING_RATE, weight_decay=config.WEIGHT_DECAY)
|
||||
# --- TODO 添加学习率调度器 TODO ---
|
||||
# T_max 是调度器周期的最大迭代次数,通常设置为总的 epoch 数量
|
||||
scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=config.EPOCHS, eta_min=1e-6)
|
||||
# -----------------------
|
||||
scaler = GradScaler('cuda')
|
||||
|
||||
return model, loss_fn, optimizer, scaler
|
||||
|
||||
def main(model_architecture: str) -> None:
|
||||
"""
|
||||
编排图像分割模型的主要训练流程。
|
||||
|
||||
该函数执行以下步骤:
|
||||
1. 设置输出目录。
|
||||
2. 准备数据增强,并为训练集和验证集创建 DataLoader 实例。
|
||||
3. 初始化 U-Net++ 模型、损失函数、优化器和梯度缩放器。
|
||||
4. 运行主训练循环,其中包括:
|
||||
- 训练一个 epoch。
|
||||
- 在验证集上评估模型。
|
||||
- 将各项指标记录到 CSV 文件。
|
||||
- 基于验证损失实现早停机制。
|
||||
- 保存性能最佳的模型和定期的断点。
|
||||
- 绘制训练进度曲线图。
|
||||
返回:
|
||||
str: 本次训练运行的文件夹名称 (run_name)。
|
||||
"""
|
||||
logging.info(f"使用设备: {config.DEVICE}")
|
||||
|
||||
# --- 1.1. 使用 pathlib 进行现代化的路径管理 ---
|
||||
timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
||||
run_name = f"{model_architecture}_{timestamp}"
|
||||
output_dir = Path(config.OUTPUTS_DIR) / run_name # <-- 使用 / 操作符
|
||||
metrics_csv_path = output_dir / config.METRICS_CSV_NAME
|
||||
best_model_path = output_dir / config.BEST_MODEL_SAVE_NAME
|
||||
|
||||
utils.setup_directories(output_dir)
|
||||
|
||||
# --- 1.2. 探测模型以确定最佳图像尺寸 ---
|
||||
logging.info("正在探测模型以确定输入尺寸要求...") #
|
||||
num_classes = len(config.CLASSES) #
|
||||
|
||||
# 创建一个临时模型实例,仅用于检查其属性
|
||||
# 注意:这里的调用现在更简单,不需要传递 dynamic_img_size
|
||||
probe_model, _, _, _ = initialize_components(
|
||||
model_architecture, num_classes, config.SEG_MODE #
|
||||
)
|
||||
|
||||
target_height, target_width = config.IMAGE_HEIGHT, config.IMAGE_WIDTH #
|
||||
|
||||
try:
|
||||
# 检查编码器是否需要固定输入尺寸
|
||||
if probe_model.encoder.is_fixed_input_size:
|
||||
required_size = probe_model.encoder.input_size
|
||||
target_height = required_size[1]
|
||||
target_width = required_size[2]
|
||||
logging.warning(
|
||||
f"模型 '{model_architecture}' 的编码器需要固定的输入尺寸 "
|
||||
f"({target_height}x{target_width})。"
|
||||
f"将忽略 config.py 中的尺寸设置。"
|
||||
)
|
||||
else:
|
||||
logging.info(f"模型 '{model_architecture}' 支持动态输入尺寸。将使用 config.py 中定义的尺寸 ({target_height}x{target_width})。")
|
||||
except AttributeError:
|
||||
logging.info("无法自动检测模型尺寸要求。将使用 config.py 中定义的尺寸。")
|
||||
|
||||
del probe_model # 删除临时模型,释放内存
|
||||
|
||||
# --- 1.3. 数据加载 ---
|
||||
logging.info(f"正在设置数据增强和加载器,目标尺寸为 {target_height}x{target_width}...") #
|
||||
# 使用探测到的或配置中指定的目标尺寸
|
||||
train_transform = utils.get_training_augmentation(target_height, target_width) #
|
||||
val_transform = utils.get_validation_augmentation(target_height, target_width) #
|
||||
|
||||
train_dataset = SegmentationDataset(
|
||||
image_dir=Path(config.TRAIN_IMAGE_DIR),
|
||||
mask_dir=Path(config.TRAIN_MASK_DIR),
|
||||
classes=config.CLASSES,
|
||||
class_rgb_values=config.CLASS_RGB_VALUES,
|
||||
augmentation=train_transform,
|
||||
seg_mode=config.SEG_MODE
|
||||
)
|
||||
val_dataset = SegmentationDataset(
|
||||
# 已修正路径,使用配置文件中的验证数据目录
|
||||
image_dir=Path(config.VAL_IMAGE_DIR),
|
||||
mask_dir=Path(config.VAL_MASK_DIR),
|
||||
classes=config.CLASSES,
|
||||
class_rgb_values=config.CLASS_RGB_VALUES,
|
||||
augmentation=val_transform,
|
||||
seg_mode=config.SEG_MODE
|
||||
)
|
||||
|
||||
train_loader = DataLoader(train_dataset, batch_size=config.BATCH_SIZE, shuffle=True, pin_memory=config.PIN_MEMORY, num_workers=config.NUM_WORKERS)
|
||||
val_loader = DataLoader(val_dataset, batch_size=config.BATCH_SIZE, shuffle=False, pin_memory=config.PIN_MEMORY, num_workers=config.NUM_WORKERS)
|
||||
|
||||
# --- 2. 模型、损失函数、优化器、梯度缩放器 ---
|
||||
num_classes = len(config.CLASSES)
|
||||
model, loss_fn, optimizer, scaler = initialize_components(model_architecture, num_classes, config.SEG_MODE)
|
||||
|
||||
# --- 启用多 GPU 训练 ---
|
||||
if torch.cuda.device_count() >= 1:
|
||||
logging.info(f"正在使用 {torch.cuda.device_count()} 块 GPU 进行并行训练。")
|
||||
model = torch.nn.DataParallel(model)
|
||||
|
||||
# --- 3. 训练循环 ---
|
||||
best_val_loss = float('inf')
|
||||
epochs_no_improve = 0
|
||||
|
||||
for epoch in range(config.EPOCHS):
|
||||
logging.info(f"\n--- 第 {epoch+1}/{config.EPOCHS} 轮 ---")
|
||||
# --- 3.1. 训练集训练数据 ---
|
||||
train_loss = utils.train_fn(train_loader, model, optimizer, loss_fn, scaler, config.DEVICE, model_architecture)
|
||||
# --- 3.2. 验证集测试数据【获得指标】 ---
|
||||
val_metrics_dict = utils.evaluate_fn(val_loader, model, loss_fn, config.DEVICE, num_classes, model_architecture, beta=config.FBETA_BETA)
|
||||
|
||||
# --- 3.3. 日志记录 及 显示 ---
|
||||
log_data = {'epoch': epoch + 1, 'train_loss': train_loss}
|
||||
log_data.update(val_metrics_dict)
|
||||
utils.log_metrics_to_csv(log_data, metrics_csv_path)
|
||||
|
||||
val_loss = val_metrics_dict.get('val_loss')
|
||||
iou_score = val_metrics_dict.get('iou_score', 0)
|
||||
fbeta_score = val_metrics_dict.get('fbeta_score', 0)
|
||||
|
||||
logging.info(f"训练集Loss: {train_loss:.4f} | 验证集Loss: {val_loss:.4f} | IoU: {iou_score:.4f} | F-beta: {fbeta_score:.4f}")
|
||||
|
||||
# --- 3.4. 早停与模型检查点 ---
|
||||
if val_loss < best_val_loss:
|
||||
best_val_loss = val_loss
|
||||
epochs_no_improve = 0
|
||||
if isinstance(model, torch.nn.DataParallel):
|
||||
torch.save(model.module.state_dict(), best_model_path)
|
||||
else:
|
||||
torch.save(model.state_dict(), best_model_path)
|
||||
logging.info(f"验证损失改善至 {best_val_loss:.4f}。正在保存最佳模型至 '{best_model_path}'。")
|
||||
else:
|
||||
epochs_no_improve += 1
|
||||
logging.info(f"验证损失已连续 {epochs_no_improve} 轮未改善。")
|
||||
|
||||
if isinstance(config.EARLY_STOPPING_PATIENCE, int) and epochs_no_improve >= config.EARLY_STOPPING_PATIENCE:
|
||||
logging.info(f"\n训练在 {epoch + 1} 轮后触发早停。")
|
||||
logging.warning(f"验证损失已连续 {config.EARLY_STOPPING_PATIENCE} 轮未改善。")
|
||||
break
|
||||
|
||||
# --- 3.5. 检查点及曲线图保存 ---
|
||||
checkpoint_path = output_dir / f"epoch_{epoch+1}.pth"
|
||||
if isinstance(model, torch.nn.DataParallel):
|
||||
torch.save(model.module.state_dict(), checkpoint_path)
|
||||
else:
|
||||
torch.save(model.state_dict(), checkpoint_path)
|
||||
logging.info(f"第 {epoch+1} 轮的检查点已保存至 '{checkpoint_path}'。")
|
||||
|
||||
# ---------- 清理旧检查点 ----------
|
||||
# 只在 epoch > 0 时才开始删,避免第一轮就误删
|
||||
if epoch > 0:
|
||||
# 要尝试删除的 epoch 编号 = 上一轮
|
||||
last_epoch = epoch
|
||||
# 如果上一轮不是“10 的整数倍”就删掉
|
||||
if last_epoch % config.CKPT_SAVE_INTERVAL != 0:
|
||||
old_ckpt = output_dir / f"epoch_{last_epoch}.pth"
|
||||
if old_ckpt.exists():
|
||||
old_ckpt.unlink() # 文件删除
|
||||
|
||||
utils.plot_training_progress(metrics_csv_path, output_dir)
|
||||
logging.info("所有训练曲线图已更新并保存。")
|
||||
|
||||
return run_name
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 创建一个参数解析器
|
||||
parser = argparse.ArgumentParser(description="训练图像分割模型。")
|
||||
|
||||
# 添加 --architecture 参数
|
||||
# - `choices` 会自动从你的 config 文件中获取所有可用的模型名称
|
||||
# - `default` 设置了在不提供参数时的默认模型
|
||||
parser.add_argument(
|
||||
"-a", "--architecture",
|
||||
type=str,
|
||||
choices=list(config.ALL_MODEL_CONFIGS.keys()),
|
||||
required=True,
|
||||
help="选择要训练的模型架构。"
|
||||
)
|
||||
|
||||
# 解析命令行传入的参数
|
||||
args = parser.parse_args()
|
||||
|
||||
# 将解析出的架构名称传入 main 函数并执行
|
||||
project_folder = None
|
||||
try:
|
||||
project_folder = main(model_architecture=args.architecture)
|
||||
finally:
|
||||
# 强制进行垃圾回收,释放 Python 对象占用的系统内存 (RAM)
|
||||
gc.collect()
|
||||
logging.info("已执行垃圾回收 (gc.collect())。")
|
||||
# 清理 PyTorch 在 CUDA 上缓存的显存
|
||||
torch.cuda.empty_cache()
|
||||
logging.info("训练结束,已清理 CUDA 缓存。")
|
||||
|
||||
# 训练和清理完成后,移动结果到硬盘
|
||||
if project_folder:
|
||||
move_results_to_hardisk(project_folder)
|
||||
else:
|
||||
logging.error("由于训练失败或未生成项目文件夹,未执行结果移动操作。")
|
||||
Reference in New Issue
Block a user