import argparse, shutil import logging from ultralytics import YOLO from datetime import datetime import yolo_config as config import gc, torch # --- 日志设置 --- logging.basicConfig( level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s", handlers=[logging.StreamHandler()] ) def train_model(model_key: str): """ 根据给定的模型密钥训练YOLO分割模型。 Args: model_key (str): 在 yolo_config.MODEL_CONFIGS 中定义的模型密钥。 """ # 生成时间戳以区分不同训练运行 timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") project_folder = f"{model_key}_{timestamp}" if model_key not in config.MODEL_CONFIGS: logging.error(f"错误:模型 '{model_key}' 不在 yolo_config.py 中定义。") logging.info(f"可用模型: {list(config.MODEL_CONFIGS.keys())}") return model_config = config.MODEL_CONFIGS[model_key] model_file = model_config['weights'] logging.info(f"开始训练模型: {model_key} ({model_file})") try: # --- 1. 加载模型 --- model = YOLO(model_file) # 如果没有模型会自动下载 logging.info("模型加载成功。") # --- 2. 模型训练 --- logging.info(f"数据集配置文件: {config.DATASET_YAML_PATH}") logging.info(f"训练参数: Epochs={config.EPOCHS}, Batch Size={config.BATCH_SIZE}, Img Size={config.IMAGE_SIZE}") model.train( data=str(config.DATASET_YAML_PATH), epochs=config.EPOCHS, imgsz=config.IMAGE_SIZE, batch=config.BATCH_SIZE, optimizer=config.OPTIMIZER, lr0=config.LEARNING_RATE, device=config.DEVICE, project=str(config.OUTPUTS_DIR), # 指定输出的根目录 name=project_folder, # 指定本次训练的项目名 workers=config.WORKERS, exist_ok=True, # 如果项目已存在,则覆盖 patience=config.PATIENCE, # 提前停止训练的轮数 save_period=config.SAVE_PERIOD # 每隔多少轮保存一次模型 ) logging.info("模型训练完成。") logging.info(f"训练结果保存在: {config.OUTPUTS_DIR / project_folder}") return project_folder except Exception as e: logging.error(f"训练过程中发生错误: {e}") return None # --- 3. 复制结果到硬盘并删除原文件夹 (新增逻辑) --- def move_results_to_hardisk(project_folder: str): source_dir = config.OUTPUTS_DIR / project_folder # 确保目标硬盘目录存在 outputs_folder_name = config.OUTPUTS_DIR.name destination_dir = config.HARDISK_DIR / outputs_folder_name / project_folder destination_dir.parent.mkdir(parents=True, exist_ok=True) logging.info(f"准备将结果从 {source_dir} 移动到 {destination_dir}...") try: # 步骤 1: 复制文件夹 shutil.copytree(source_dir, destination_dir) logging.info(f"成功复制结果到: {destination_dir}") # 步骤 2: 复制成功后,删除原文件夹 logging.info(f"正在删除原文件夹: {source_dir}") shutil.rmtree(source_dir) logging.info("成功删除原文件夹。") logging.info(f"任务完成,最终结果已保存至: {destination_dir}") except Exception as e: logging.error(f"移动文件夹时发生错误: {e}") logging.warning(f"原始训练结果仍保留在: {source_dir}") if __name__ == "__main__": # 创建命令行参数解析器 parser = argparse.ArgumentParser(description="训练 YOLO 分割模型。") parser.add_argument( "--model", type=str, required=True, choices=list(config.MODEL_CONFIGS.keys()), help=f"选择要训练的模型。可选: {list(config.MODEL_CONFIGS.keys())}" ) args = parser.parse_args() # 1. 开始训练 project_folder = train_model(args.model) # 2. 训练完成后,移动结果到硬盘 if project_folder: move_results_to_hardisk(project_folder) else: logging.error("由于训练失败,未执行结果移动操作。") # 3. 释放资源 torch.cuda.empty_cache() # 释放 CUDA 显存 gc.collect() # 释放系统内存 logging.info("已清理 CUDA 显存与系统内存。")