import logging, argparse, shutil from datetime import datetime from pathlib import Path from typing import Dict, Tuple import gc import numpy as np import segmentation_models_pytorch as smp import torch import torch.optim as optim from torch.optim import lr_scheduler # 学习率调度器 from torch.amp import GradScaler from torch.utils.data import DataLoader # 本地应用/库的导入 import config, os import utils from dataset import SegmentationDataset from loss import MultiClassLoss, UNetPlusPlusLoss # --- 日志设置 --- logging.basicConfig( level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s", handlers=[ logging.StreamHandler() # 输出日志到控制台 ] ) # --- 辅助函数:移动训练结果文件夹从源输出目录移动到配置的硬盘目录 --- def move_results_to_hardisk(project_folder: str): """ 将指定的训练结果文件夹从源输出目录移动到配置的硬盘目录。 移动操作包括复制整个文件夹树,然后在复制成功后删除原始文件夹。 参数: project_folder (str): 要移动的项目文件夹的名称。 该名称通常由模型架构和时间戳构成。 """ source_dir = Path(config.OUTPUTS_DIR) / project_folder # 确保目标硬盘目录存在 outputs_folder_name = Path(config.OUTPUTS_DIR).name destination_dir = Path(config.HARDISK_DIR) / outputs_folder_name / project_folder destination_dir.parent.mkdir(parents=True, exist_ok=True) logging.info(f"准备将结果从 {source_dir} 移动到 {destination_dir}...") try: # 步骤 1: 复制文件夹 shutil.copytree(source_dir, destination_dir) logging.info(f"成功复制结果到: {destination_dir}") # 步骤 2: 复制成功后,删除原文件夹 logging.info(f"正在删除原文件夹: {source_dir}") shutil.rmtree(source_dir) logging.info("成功删除原文件夹。") logging.info(f"任务完成,最终结果已保存至: {destination_dir}") except Exception as e: logging.error(f"移动文件夹时发生错误: {e}") logging.warning(f"原始训练结果仍保留在: {source_dir}") # 1. 根据配置文件初始化模型、损失函数和优化器。 def initialize_components(model_architecture:str, num_classes: int, seg_mode: str) -> Tuple: """ 根据配置文件动态初始化模型、损失函数和优化器。 """ # --- 初始化模型 --- logging.info(f"正在初始化模型: '{model_architecture}'...") # 使用传入的 model_architecture 作为 key 来获取参数 try: model_params = config.ALL_MODEL_CONFIGS[model_architecture] model_class = getattr(smp, model_architecture) except KeyError: logging.error(f"模型 '{model_architecture}' 的配置未在 config.py 的 ALL_MODEL_CONFIGS 中定义!") raise except AttributeError: logging.error(f"模型 '{model_architecture}' 在 segmentation_models_pytorch 库中不存在!") raise # 2. 准备参数字典 # 首先复制 config 中的参数,然后添加固定的 `in_channels` 和 `classes` params = model_params.copy() params['in_channels'] = 3 params['classes'] = num_classes # ======================== 【新增代码段开始】 ======================== # # 自动检测 encoder_name 是否包含 'vit' encoder_name = model_params.get('encoder_name', '') if 'vit' in encoder_name.lower(): params['dynamic_img_size'] = True logging.info(f"检测到 ViT 编码器 ('{encoder_name}')。自动设置 dynamic_img_size=True。") # ======================== 【新增代码段结束】 ======================== # # 3. 使用字典解包 (**) 将所有参数传递给模型构造函数 # 这种方式非常灵活,无论你在 config 中定义了多少参数,都能正确传递 model = model_class(**params).to(config.DEVICE) # --- 设置损失函数 (保持不变) --- logging.info(f"为 '{seg_mode}' 模式设置损失函数。") if seg_mode == 'multiclass': loss_fn = MultiClassLoss(mode=seg_mode) elif seg_mode == 'multilabel': loss_fn = UNetPlusPlusLoss(mode=seg_mode) else: raise ValueError(f"无效的 SEG_MODE: '{seg_mode}'。必须是 'multiclass' 或 'multilabel'。") # --- 优化器和梯度缩放器 (保持不变) --- optimizer = optim.Adam(model.parameters(), lr=config.LEARNING_RATE, weight_decay=config.WEIGHT_DECAY) # --- TODO 添加学习率调度器 TODO --- # T_max 是调度器周期的最大迭代次数,通常设置为总的 epoch 数量 scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=config.EPOCHS, eta_min=1e-6) # ----------------------- scaler = GradScaler('cuda') return model, loss_fn, optimizer, scaler def main(model_architecture: str) -> None: """ 编排图像分割模型的主要训练流程。 该函数执行以下步骤: 1. 设置输出目录。 2. 准备数据增强,并为训练集和验证集创建 DataLoader 实例。 3. 初始化 U-Net++ 模型、损失函数、优化器和梯度缩放器。 4. 运行主训练循环,其中包括: - 训练一个 epoch。 - 在验证集上评估模型。 - 将各项指标记录到 CSV 文件。 - 基于验证损失实现早停机制。 - 保存性能最佳的模型和定期的断点。 - 绘制训练进度曲线图。 返回: str: 本次训练运行的文件夹名称 (run_name)。 """ logging.info(f"使用设备: {config.DEVICE}") # --- 1.1. 使用 pathlib 进行现代化的路径管理 --- timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") run_name = f"{model_architecture}_{timestamp}" output_dir = Path(config.OUTPUTS_DIR) / run_name # <-- 使用 / 操作符 metrics_csv_path = output_dir / config.METRICS_CSV_NAME best_model_path = output_dir / config.BEST_MODEL_SAVE_NAME utils.setup_directories(output_dir) # --- 1.2. 探测模型以确定最佳图像尺寸 --- logging.info("正在探测模型以确定输入尺寸要求...") # num_classes = len(config.CLASSES) # # 创建一个临时模型实例,仅用于检查其属性 # 注意:这里的调用现在更简单,不需要传递 dynamic_img_size probe_model, _, _, _ = initialize_components( model_architecture, num_classes, config.SEG_MODE # ) target_height, target_width = config.IMAGE_HEIGHT, config.IMAGE_WIDTH # try: # 检查编码器是否需要固定输入尺寸 if probe_model.encoder.is_fixed_input_size: required_size = probe_model.encoder.input_size target_height = required_size[1] target_width = required_size[2] logging.warning( f"模型 '{model_architecture}' 的编码器需要固定的输入尺寸 " f"({target_height}x{target_width})。" f"将忽略 config.py 中的尺寸设置。" ) else: logging.info(f"模型 '{model_architecture}' 支持动态输入尺寸。将使用 config.py 中定义的尺寸 ({target_height}x{target_width})。") except AttributeError: logging.info("无法自动检测模型尺寸要求。将使用 config.py 中定义的尺寸。") del probe_model # 删除临时模型,释放内存 # --- 1.3. 数据加载 --- logging.info(f"正在设置数据增强和加载器,目标尺寸为 {target_height}x{target_width}...") # # 使用探测到的或配置中指定的目标尺寸 train_transform = utils.get_training_augmentation(target_height, target_width) # val_transform = utils.get_validation_augmentation(target_height, target_width) # train_dataset = SegmentationDataset( image_dir=Path(config.TRAIN_IMAGE_DIR), mask_dir=Path(config.TRAIN_MASK_DIR), classes=config.CLASSES, class_rgb_values=config.CLASS_RGB_VALUES, augmentation=train_transform, seg_mode=config.SEG_MODE ) val_dataset = SegmentationDataset( # 已修正路径,使用配置文件中的验证数据目录 image_dir=Path(config.VAL_IMAGE_DIR), mask_dir=Path(config.VAL_MASK_DIR), classes=config.CLASSES, class_rgb_values=config.CLASS_RGB_VALUES, augmentation=val_transform, seg_mode=config.SEG_MODE ) train_loader = DataLoader(train_dataset, batch_size=config.BATCH_SIZE, shuffle=True, pin_memory=config.PIN_MEMORY, num_workers=config.NUM_WORKERS) val_loader = DataLoader(val_dataset, batch_size=config.BATCH_SIZE, shuffle=False, pin_memory=config.PIN_MEMORY, num_workers=config.NUM_WORKERS) # --- 2. 模型、损失函数、优化器、梯度缩放器 --- num_classes = len(config.CLASSES) model, loss_fn, optimizer, scaler = initialize_components(model_architecture, num_classes, config.SEG_MODE) # --- 启用多 GPU 训练 --- if torch.cuda.device_count() >= 1: logging.info(f"正在使用 {torch.cuda.device_count()} 块 GPU 进行并行训练。") model = torch.nn.DataParallel(model) # --- 3. 训练循环 --- best_val_loss = float('inf') epochs_no_improve = 0 for epoch in range(config.EPOCHS): logging.info(f"\n--- 第 {epoch+1}/{config.EPOCHS} 轮 ---") # --- 3.1. 训练集训练数据 --- train_loss = utils.train_fn(train_loader, model, optimizer, loss_fn, scaler, config.DEVICE, model_architecture) # --- 3.2. 验证集测试数据【获得指标】 --- val_metrics_dict = utils.evaluate_fn(val_loader, model, loss_fn, config.DEVICE, num_classes, model_architecture, beta=config.FBETA_BETA) # --- 3.3. 日志记录 及 显示 --- log_data = {'epoch': epoch + 1, 'train_loss': train_loss} log_data.update(val_metrics_dict) utils.log_metrics_to_csv(log_data, metrics_csv_path) val_loss = val_metrics_dict.get('val_loss') iou_score = val_metrics_dict.get('iou_score', 0) fbeta_score = val_metrics_dict.get('fbeta_score', 0) logging.info(f"训练集Loss: {train_loss:.4f} | 验证集Loss: {val_loss:.4f} | IoU: {iou_score:.4f} | F-beta: {fbeta_score:.4f}") # --- 3.4. 早停与模型检查点 --- if val_loss < best_val_loss: best_val_loss = val_loss epochs_no_improve = 0 if isinstance(model, torch.nn.DataParallel): torch.save(model.module.state_dict(), best_model_path) else: torch.save(model.state_dict(), best_model_path) logging.info(f"验证损失改善至 {best_val_loss:.4f}。正在保存最佳模型至 '{best_model_path}'。") else: epochs_no_improve += 1 logging.info(f"验证损失已连续 {epochs_no_improve} 轮未改善。") if isinstance(config.EARLY_STOPPING_PATIENCE, int) and epochs_no_improve >= config.EARLY_STOPPING_PATIENCE: logging.info(f"\n训练在 {epoch + 1} 轮后触发早停。") logging.warning(f"验证损失已连续 {config.EARLY_STOPPING_PATIENCE} 轮未改善。") break # --- 3.5. 检查点及曲线图保存 --- checkpoint_path = output_dir / f"epoch_{epoch+1}.pth" if isinstance(model, torch.nn.DataParallel): torch.save(model.module.state_dict(), checkpoint_path) else: torch.save(model.state_dict(), checkpoint_path) logging.info(f"第 {epoch+1} 轮的检查点已保存至 '{checkpoint_path}'。") # ---------- 清理旧检查点 ---------- # 只在 epoch > 0 时才开始删,避免第一轮就误删 if epoch > 0: # 要尝试删除的 epoch 编号 = 上一轮 last_epoch = epoch # 如果上一轮不是“10 的整数倍”就删掉 if last_epoch % config.CKPT_SAVE_INTERVAL != 0: old_ckpt = output_dir / f"epoch_{last_epoch}.pth" if old_ckpt.exists(): old_ckpt.unlink() # 文件删除 utils.plot_training_progress(metrics_csv_path, output_dir) logging.info("所有训练曲线图已更新并保存。") return run_name if __name__ == "__main__": # 创建一个参数解析器 parser = argparse.ArgumentParser(description="训练图像分割模型。") # 添加 --architecture 参数 # - `choices` 会自动从你的 config 文件中获取所有可用的模型名称 # - `default` 设置了在不提供参数时的默认模型 parser.add_argument( "-a", "--architecture", type=str, choices=list(config.ALL_MODEL_CONFIGS.keys()), required=True, help="选择要训练的模型架构。" ) # 解析命令行传入的参数 args = parser.parse_args() # 将解析出的架构名称传入 main 函数并执行 project_folder = None try: project_folder = main(model_architecture=args.architecture) finally: # 强制进行垃圾回收,释放 Python 对象占用的系统内存 (RAM) gc.collect() logging.info("已执行垃圾回收 (gc.collect())。") # 清理 PyTorch 在 CUDA 上缓存的显存 torch.cuda.empty_cache() logging.info("训练结束,已清理 CUDA 缓存。") # 训练和清理完成后,移动结果到硬盘 if project_folder: move_results_to_hardisk(project_folder) else: logging.error("由于训练失败或未生成项目文件夹,未执行结果移动操作。")