import argparse, shutil, glob import logging from pathlib import Path from ultralytics import YOLO from datetime import datetime import yolo_config as config import gc, torch # --- 日志设置 --- logging.basicConfig( level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s", handlers=[logging.StreamHandler()] ) # 寻找最匹配的 last.pt文件 和 其所在文件夹 def find_latest_last_pt(model_key: str, outputs_dir: Path): """根据模型名查找最新的 last.pt 路径,并返回 (last_pt路径, 文件夹名)""" pattern = str(outputs_dir / f"{model_key}_*" / "weights" / "last.pt") candidates = glob.glob(pattern) if not candidates: return None, None # 按修改时间排序,取最新的 latest_pt = max(candidates, key=lambda p: Path(p).stat().st_mtime) # 提取文件夹名 save_folder_name = Path(latest_pt).parent.parent.name # weights 的上级就是训练目录 return latest_pt, save_folder_name def train_model(model_key: str): """ 根据给定的模型密钥训练YOLO分割模型。 Args: model_key (str): 在 yolo_config.MODEL_CONFIGS 中定义的模型密钥。 """ # 生成时间戳以区分不同训练运行 timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") project_folder = f"{model_key}_{timestamp}" if model_key not in config.MODEL_CONFIGS: logging.error(f"错误:模型 '{model_key}' 不在 yolo_config.py 中定义。") logging.info(f"可用模型: {list(config.MODEL_CONFIGS.keys())}") return model_config = config.MODEL_CONFIGS[model_key] model_file = model_config['weights'] model_batch_size = model_config['batch_size'] model_image_size = model_config['image_size'] logging.info(f"开始训练模型: {model_key} ({model_file})") try: # # V1:--- 1. 加载模型 --- # model = YOLO(model_file) # 如果没有模型会自动下载 # logging.info("模型加载成功。") # V2:--- 1. 加载模型(支持断点续训) --- last_pt, save_folder_name = find_latest_last_pt(model_key, config.OUTPUTS_DIR) if last_pt and Path(last_pt).exists(): model = YOLO(last_pt) logging.info(f"断点续训: {last_pt}") logging.info(f"续训目录名为: {save_folder_name}") model.train( resume=True, # ✅ 关键参数 data=str(config.DATASET_YAML_PATH), epochs=config.EPOCHS, imgsz=model_image_size, batch=model_batch_size, optimizer=config.OPTIMIZER, lr0=config.LEARNING_RATE, device=config.DEVICE, project=str(config.OUTPUTS_DIR), # 指定输出的根目录 name=save_folder_name, # 指定本次训练的项目名 workers=config.WORKERS, exist_ok=True, # 如果项目已存在,则覆盖 patience=config.PATIENCE, # 提前停止训练的轮数 save_period=config.SAVE_PERIOD # 每隔多少轮保存一次模型 ) else: model = YOLO(model_file) logging.info(f"从头训练: {model_file}") save_folder_name = project_folder # --- 2. 模型训练 --- logging.info(f"数据集配置文件: {config.DATASET_YAML_PATH}") logging.info(f"训练参数: Epochs={config.EPOCHS}, Batch Size={model_batch_size}, Img Size={model_image_size}") model.train( data=str(config.DATASET_YAML_PATH), epochs=config.EPOCHS, imgsz=model_image_size, batch=model_batch_size, optimizer=config.OPTIMIZER, lr0=config.LEARNING_RATE, device=config.DEVICE, project=str(config.OUTPUTS_DIR), # 指定输出的根目录 name=save_folder_name, # 指定本次训练的项目名 workers=config.WORKERS, exist_ok=True, # 如果项目已存在,则覆盖 patience=config.PATIENCE, # 提前停止训练的轮数 save_period=config.SAVE_PERIOD # 每隔多少轮保存一次模型 ) logging.info("模型训练完成。") logging.info(f"训练结果保存在: {config.OUTPUTS_DIR / save_folder_name}") return save_folder_name except Exception as e: logging.error(f"训练过程中发生错误: {e}") return None # --- 3. 复制结果到硬盘并删除原文件夹 (新增逻辑) --- def move_results_to_hardisk(project_folder: str): source_dir = config.OUTPUTS_DIR / project_folder # 确保目标硬盘目录存在 outputs_folder_name = config.OUTPUTS_DIR.name destination_dir = config.HARDISK_DIR / outputs_folder_name / project_folder destination_dir.parent.mkdir(parents=True, exist_ok=True) logging.info(f"准备将结果从 {source_dir} 移动到 {destination_dir}...") try: # 步骤 1: 复制文件夹 shutil.copytree(source_dir, destination_dir) logging.info(f"成功复制结果到: {destination_dir}") # 步骤 2: 复制成功后,删除原文件夹 logging.info(f"正在删除原文件夹: {source_dir}") shutil.rmtree(source_dir) logging.info("成功删除原文件夹。") logging.info(f"任务完成,最终结果已保存至: {destination_dir}") except Exception as e: logging.error(f"移动文件夹时发生错误: {e}") logging.warning(f"原始训练结果仍保留在: {source_dir}") if __name__ == "__main__": # 创建命令行参数解析器 parser = argparse.ArgumentParser(description="训练 YOLO 分割模型。") parser.add_argument( "--model", type=str, required=True, choices=list(config.MODEL_CONFIGS.keys()), help=f"选择要训练的模型。可选: {list(config.MODEL_CONFIGS.keys())}" ) args = parser.parse_args() # 1. 开始训练 save_folder_name = train_model(args.model) # 2. 训练完成后,移动结果到硬盘 if save_folder_name: move_results_to_hardisk(save_folder_name) else: logging.error("由于训练失败,未执行结果移动操作。") # 3. 释放资源 torch.cuda.empty_cache() # 释放 CUDA 显存 gc.collect() # 释放系统内存 logging.info("已清理 CUDA 显存与系统内存。")