113 lines
4.4 KiB
Python
113 lines
4.4 KiB
Python
import argparse, shutil
|
|
import logging
|
|
from ultralytics import YOLO
|
|
from datetime import datetime
|
|
import yolo_config as config
|
|
import gc, torch
|
|
|
|
# --- 日志设置 ---
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format="%(asctime)s [%(levelname)s] %(message)s",
|
|
handlers=[logging.StreamHandler()]
|
|
)
|
|
|
|
def train_model(model_key: str):
|
|
"""
|
|
根据给定的模型密钥训练YOLO分割模型。
|
|
|
|
Args:
|
|
model_key (str): 在 yolo_config.MODEL_CONFIGS 中定义的模型密钥。
|
|
"""
|
|
# 生成时间戳以区分不同训练运行
|
|
timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
|
project_folder = f"{model_key}_{timestamp}"
|
|
|
|
if model_key not in config.MODEL_CONFIGS:
|
|
logging.error(f"错误:模型 '{model_key}' 不在 yolo_config.py 中定义。")
|
|
logging.info(f"可用模型: {list(config.MODEL_CONFIGS.keys())}")
|
|
return
|
|
|
|
model_config = config.MODEL_CONFIGS[model_key]
|
|
model_file = model_config['weights']
|
|
logging.info(f"开始训练模型: {model_key} ({model_file})")
|
|
|
|
try:
|
|
# --- 1. 加载模型 ---
|
|
model = YOLO(model_file) # 如果没有模型会自动下载
|
|
logging.info("模型加载成功。")
|
|
|
|
# --- 2. 模型训练 ---
|
|
logging.info(f"数据集配置文件: {config.DATASET_YAML_PATH}")
|
|
logging.info(f"训练参数: Epochs={config.EPOCHS}, Batch Size={config.BATCH_SIZE}, Img Size={config.IMAGE_SIZE}")
|
|
|
|
model.train(
|
|
data=str(config.DATASET_YAML_PATH),
|
|
epochs=config.EPOCHS,
|
|
imgsz=config.IMAGE_SIZE,
|
|
batch=config.BATCH_SIZE,
|
|
optimizer=config.OPTIMIZER,
|
|
lr0=config.LEARNING_RATE,
|
|
device=config.DEVICE,
|
|
project=str(config.OUTPUTS_DIR), # 指定输出的根目录
|
|
name=project_folder, # 指定本次训练的项目名
|
|
workers=config.WORKERS,
|
|
exist_ok=True, # 如果项目已存在,则覆盖
|
|
patience=config.PATIENCE, # 提前停止训练的轮数
|
|
save_period=config.SAVE_PERIOD # 每隔多少轮保存一次模型
|
|
)
|
|
|
|
logging.info("模型训练完成。")
|
|
logging.info(f"训练结果保存在: {config.OUTPUTS_DIR / project_folder}")
|
|
return project_folder
|
|
|
|
except Exception as e:
|
|
logging.error(f"训练过程中发生错误: {e}")
|
|
return None
|
|
|
|
# --- 3. 复制结果到硬盘并删除原文件夹 (新增逻辑) ---
|
|
def move_results_to_hardisk(project_folder: str):
|
|
source_dir = config.OUTPUTS_DIR / project_folder
|
|
# 确保目标硬盘目录存在
|
|
outputs_folder_name = config.OUTPUTS_DIR.name
|
|
destination_dir = config.HARDISK_DIR / outputs_folder_name / project_folder
|
|
destination_dir.parent.mkdir(parents=True, exist_ok=True)
|
|
logging.info(f"准备将结果从 {source_dir} 移动到 {destination_dir}...")
|
|
try:
|
|
# 步骤 1: 复制文件夹
|
|
shutil.copytree(source_dir, destination_dir)
|
|
logging.info(f"成功复制结果到: {destination_dir}")
|
|
|
|
# 步骤 2: 复制成功后,删除原文件夹
|
|
logging.info(f"正在删除原文件夹: {source_dir}")
|
|
shutil.rmtree(source_dir)
|
|
logging.info("成功删除原文件夹。")
|
|
logging.info(f"任务完成,最终结果已保存至: {destination_dir}")
|
|
|
|
except Exception as e:
|
|
logging.error(f"移动文件夹时发生错误: {e}")
|
|
logging.warning(f"原始训练结果仍保留在: {source_dir}")
|
|
|
|
if __name__ == "__main__":
|
|
# 创建命令行参数解析器
|
|
parser = argparse.ArgumentParser(description="训练 YOLO 分割模型。")
|
|
parser.add_argument(
|
|
"--model",
|
|
type=str,
|
|
required=True,
|
|
choices=list(config.MODEL_CONFIGS.keys()),
|
|
help=f"选择要训练的模型。可选: {list(config.MODEL_CONFIGS.keys())}"
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
# 1. 开始训练
|
|
project_folder = train_model(args.model)
|
|
# 2. 训练完成后,移动结果到硬盘
|
|
if project_folder:
|
|
move_results_to_hardisk(project_folder)
|
|
else:
|
|
logging.error("由于训练失败,未执行结果移动操作。")
|
|
# 3. 释放资源
|
|
torch.cuda.empty_cache() # 释放 CUDA 显存
|
|
gc.collect() # 释放系统内存
|
|
logging.info("已清理 CUDA 显存与系统内存。") |