Files
Seg_Data_Server/Seg_All_In_One_YoloModel/yolo_train.py
2026-05-20 15:05:35 +08:00

158 lines
6.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import argparse, shutil, glob
import logging
from pathlib import Path
from ultralytics import YOLO
from datetime import datetime
import yolo_config as config
import gc, torch
# --- 日志设置 ---
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(message)s",
handlers=[logging.StreamHandler()]
)
# 寻找最匹配的 last.pt文件 和 其所在文件夹
def find_latest_last_pt(model_key: str, outputs_dir: Path):
"""根据模型名查找最新的 last.pt 路径,并返回 (last_pt路径, 文件夹名)"""
pattern = str(outputs_dir / f"{model_key}_*" / "weights" / "last.pt")
candidates = glob.glob(pattern)
if not candidates:
return None, None
# 按修改时间排序,取最新的
latest_pt = max(candidates, key=lambda p: Path(p).stat().st_mtime)
# 提取文件夹名
save_folder_name = Path(latest_pt).parent.parent.name # weights 的上级就是训练目录
return latest_pt, save_folder_name
def train_model(model_key: str):
"""
根据给定的模型密钥训练YOLO分割模型。
Args:
model_key (str): 在 yolo_config.MODEL_CONFIGS 中定义的模型密钥。
"""
# 生成时间戳以区分不同训练运行
timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
project_folder = f"{model_key}_{timestamp}"
if model_key not in config.MODEL_CONFIGS:
logging.error(f"错误:模型 '{model_key}' 不在 yolo_config.py 中定义。")
logging.info(f"可用模型: {list(config.MODEL_CONFIGS.keys())}")
return
model_config = config.MODEL_CONFIGS[model_key]
model_file = model_config['weights']
model_batch_size = model_config['batch_size']
model_image_size = model_config['image_size']
logging.info(f"开始训练模型: {model_key} ({model_file})")
try:
# # V1--- 1. 加载模型 ---
# model = YOLO(model_file) # 如果没有模型会自动下载
# logging.info("模型加载成功。")
# V2--- 1. 加载模型(支持断点续训) ---
last_pt, save_folder_name = find_latest_last_pt(model_key, config.OUTPUTS_DIR)
if last_pt and Path(last_pt).exists():
model = YOLO(last_pt)
logging.info(f"断点续训: {last_pt}")
logging.info(f"续训目录名为: {save_folder_name}")
model.train(
resume=True, # ✅ 关键参数
data=str(config.DATASET_YAML_PATH),
epochs=config.EPOCHS,
imgsz=model_image_size,
batch=model_batch_size,
optimizer=config.OPTIMIZER,
lr0=config.LEARNING_RATE,
device=config.DEVICE,
project=str(config.OUTPUTS_DIR), # 指定输出的根目录
name=save_folder_name, # 指定本次训练的项目名
workers=config.WORKERS,
exist_ok=True, # 如果项目已存在,则覆盖
patience=config.PATIENCE, # 提前停止训练的轮数
save_period=config.SAVE_PERIOD # 每隔多少轮保存一次模型
)
else:
model = YOLO(model_file)
logging.info(f"从头训练: {model_file}")
save_folder_name = project_folder
# --- 2. 模型训练 ---
logging.info(f"数据集配置文件: {config.DATASET_YAML_PATH}")
logging.info(f"训练参数: Epochs={config.EPOCHS}, Batch Size={model_batch_size}, Img Size={model_image_size}")
model.train(
data=str(config.DATASET_YAML_PATH),
epochs=config.EPOCHS,
imgsz=model_image_size,
batch=model_batch_size,
optimizer=config.OPTIMIZER,
lr0=config.LEARNING_RATE,
device=config.DEVICE,
project=str(config.OUTPUTS_DIR), # 指定输出的根目录
name=save_folder_name, # 指定本次训练的项目名
workers=config.WORKERS,
exist_ok=True, # 如果项目已存在,则覆盖
patience=config.PATIENCE, # 提前停止训练的轮数
save_period=config.SAVE_PERIOD # 每隔多少轮保存一次模型
)
logging.info("模型训练完成。")
logging.info(f"训练结果保存在: {config.OUTPUTS_DIR / save_folder_name}")
return save_folder_name
except Exception as e:
logging.error(f"训练过程中发生错误: {e}")
return None
# --- 3. 复制结果到硬盘并删除原文件夹 (新增逻辑) ---
def move_results_to_hardisk(project_folder: str):
source_dir = config.OUTPUTS_DIR / project_folder
# 确保目标硬盘目录存在
outputs_folder_name = config.OUTPUTS_DIR.name
destination_dir = config.HARDISK_DIR / outputs_folder_name / project_folder
destination_dir.parent.mkdir(parents=True, exist_ok=True)
logging.info(f"准备将结果从 {source_dir} 移动到 {destination_dir}...")
try:
# 步骤 1: 复制文件夹
shutil.copytree(source_dir, destination_dir)
logging.info(f"成功复制结果到: {destination_dir}")
# 步骤 2: 复制成功后,删除原文件夹
logging.info(f"正在删除原文件夹: {source_dir}")
shutil.rmtree(source_dir)
logging.info("成功删除原文件夹。")
logging.info(f"任务完成,最终结果已保存至: {destination_dir}")
except Exception as e:
logging.error(f"移动文件夹时发生错误: {e}")
logging.warning(f"原始训练结果仍保留在: {source_dir}")
if __name__ == "__main__":
# 创建命令行参数解析器
parser = argparse.ArgumentParser(description="训练 YOLO 分割模型。")
parser.add_argument(
"--model",
type=str,
required=True,
choices=list(config.MODEL_CONFIGS.keys()),
help=f"选择要训练的模型。可选: {list(config.MODEL_CONFIGS.keys())}"
)
args = parser.parse_args()
# 1. 开始训练
save_folder_name = train_model(args.model)
# 2. 训练完成后,移动结果到硬盘
if save_folder_name:
move_results_to_hardisk(save_folder_name)
else:
logging.error("由于训练失败,未执行结果移动操作。")
# 3. 释放资源
torch.cuda.empty_cache() # 释放 CUDA 显存
gc.collect() # 释放系统内存
logging.info("已清理 CUDA 显存与系统内存。")