158 lines
6.6 KiB
Python
158 lines
6.6 KiB
Python
import argparse, shutil, glob
|
||
import logging
|
||
from pathlib import Path
|
||
from ultralytics import YOLO
|
||
from datetime import datetime
|
||
import yolo_config as config
|
||
import gc, torch
|
||
|
||
# --- 日志设置 ---
|
||
logging.basicConfig(
|
||
level=logging.INFO,
|
||
format="%(asctime)s [%(levelname)s] %(message)s",
|
||
handlers=[logging.StreamHandler()]
|
||
)
|
||
|
||
# 寻找最匹配的 last.pt文件 和 其所在文件夹
|
||
def find_latest_last_pt(model_key: str, outputs_dir: Path):
|
||
"""根据模型名查找最新的 last.pt 路径,并返回 (last_pt路径, 文件夹名)"""
|
||
pattern = str(outputs_dir / f"{model_key}_*" / "weights" / "last.pt")
|
||
candidates = glob.glob(pattern)
|
||
if not candidates:
|
||
return None, None
|
||
# 按修改时间排序,取最新的
|
||
latest_pt = max(candidates, key=lambda p: Path(p).stat().st_mtime)
|
||
# 提取文件夹名
|
||
save_folder_name = Path(latest_pt).parent.parent.name # weights 的上级就是训练目录
|
||
return latest_pt, save_folder_name
|
||
|
||
def train_model(model_key: str):
|
||
"""
|
||
根据给定的模型密钥训练YOLO分割模型。
|
||
|
||
Args:
|
||
model_key (str): 在 yolo_config.MODEL_CONFIGS 中定义的模型密钥。
|
||
"""
|
||
# 生成时间戳以区分不同训练运行
|
||
timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
||
project_folder = f"{model_key}_{timestamp}"
|
||
|
||
if model_key not in config.MODEL_CONFIGS:
|
||
logging.error(f"错误:模型 '{model_key}' 不在 yolo_config.py 中定义。")
|
||
logging.info(f"可用模型: {list(config.MODEL_CONFIGS.keys())}")
|
||
return
|
||
|
||
model_config = config.MODEL_CONFIGS[model_key]
|
||
model_file = model_config['weights']
|
||
model_batch_size = model_config['batch_size']
|
||
model_image_size = model_config['image_size']
|
||
logging.info(f"开始训练模型: {model_key} ({model_file})")
|
||
|
||
try:
|
||
# # V1:--- 1. 加载模型 ---
|
||
# model = YOLO(model_file) # 如果没有模型会自动下载
|
||
# logging.info("模型加载成功。")
|
||
|
||
# V2:--- 1. 加载模型(支持断点续训) ---
|
||
last_pt, save_folder_name = find_latest_last_pt(model_key, config.OUTPUTS_DIR)
|
||
if last_pt and Path(last_pt).exists():
|
||
model = YOLO(last_pt)
|
||
logging.info(f"断点续训: {last_pt}")
|
||
logging.info(f"续训目录名为: {save_folder_name}")
|
||
|
||
model.train(
|
||
resume=True, # ✅ 关键参数
|
||
data=str(config.DATASET_YAML_PATH),
|
||
epochs=config.EPOCHS,
|
||
imgsz=model_image_size,
|
||
batch=model_batch_size,
|
||
optimizer=config.OPTIMIZER,
|
||
lr0=config.LEARNING_RATE,
|
||
device=config.DEVICE,
|
||
project=str(config.OUTPUTS_DIR), # 指定输出的根目录
|
||
name=save_folder_name, # 指定本次训练的项目名
|
||
workers=config.WORKERS,
|
||
exist_ok=True, # 如果项目已存在,则覆盖
|
||
patience=config.PATIENCE, # 提前停止训练的轮数
|
||
save_period=config.SAVE_PERIOD # 每隔多少轮保存一次模型
|
||
)
|
||
else:
|
||
model = YOLO(model_file)
|
||
logging.info(f"从头训练: {model_file}")
|
||
save_folder_name = project_folder
|
||
|
||
# --- 2. 模型训练 ---
|
||
logging.info(f"数据集配置文件: {config.DATASET_YAML_PATH}")
|
||
logging.info(f"训练参数: Epochs={config.EPOCHS}, Batch Size={model_batch_size}, Img Size={model_image_size}")
|
||
|
||
model.train(
|
||
data=str(config.DATASET_YAML_PATH),
|
||
epochs=config.EPOCHS,
|
||
imgsz=model_image_size,
|
||
batch=model_batch_size,
|
||
optimizer=config.OPTIMIZER,
|
||
lr0=config.LEARNING_RATE,
|
||
device=config.DEVICE,
|
||
project=str(config.OUTPUTS_DIR), # 指定输出的根目录
|
||
name=save_folder_name, # 指定本次训练的项目名
|
||
workers=config.WORKERS,
|
||
exist_ok=True, # 如果项目已存在,则覆盖
|
||
patience=config.PATIENCE, # 提前停止训练的轮数
|
||
save_period=config.SAVE_PERIOD # 每隔多少轮保存一次模型
|
||
)
|
||
|
||
logging.info("模型训练完成。")
|
||
logging.info(f"训练结果保存在: {config.OUTPUTS_DIR / save_folder_name}")
|
||
|
||
return save_folder_name
|
||
|
||
except Exception as e:
|
||
logging.error(f"训练过程中发生错误: {e}")
|
||
return None
|
||
|
||
# --- 3. 复制结果到硬盘并删除原文件夹 (新增逻辑) ---
|
||
def move_results_to_hardisk(project_folder: str):
|
||
source_dir = config.OUTPUTS_DIR / project_folder
|
||
# 确保目标硬盘目录存在
|
||
outputs_folder_name = config.OUTPUTS_DIR.name
|
||
destination_dir = config.HARDISK_DIR / outputs_folder_name / project_folder
|
||
destination_dir.parent.mkdir(parents=True, exist_ok=True)
|
||
logging.info(f"准备将结果从 {source_dir} 移动到 {destination_dir}...")
|
||
try:
|
||
# 步骤 1: 复制文件夹
|
||
shutil.copytree(source_dir, destination_dir)
|
||
logging.info(f"成功复制结果到: {destination_dir}")
|
||
|
||
# 步骤 2: 复制成功后,删除原文件夹
|
||
logging.info(f"正在删除原文件夹: {source_dir}")
|
||
shutil.rmtree(source_dir)
|
||
logging.info("成功删除原文件夹。")
|
||
logging.info(f"任务完成,最终结果已保存至: {destination_dir}")
|
||
|
||
except Exception as e:
|
||
logging.error(f"移动文件夹时发生错误: {e}")
|
||
logging.warning(f"原始训练结果仍保留在: {source_dir}")
|
||
|
||
if __name__ == "__main__":
|
||
# 创建命令行参数解析器
|
||
parser = argparse.ArgumentParser(description="训练 YOLO 分割模型。")
|
||
parser.add_argument(
|
||
"--model",
|
||
type=str,
|
||
required=True,
|
||
choices=list(config.MODEL_CONFIGS.keys()),
|
||
help=f"选择要训练的模型。可选: {list(config.MODEL_CONFIGS.keys())}"
|
||
)
|
||
args = parser.parse_args()
|
||
|
||
# 1. 开始训练
|
||
save_folder_name = train_model(args.model)
|
||
# 2. 训练完成后,移动结果到硬盘
|
||
if save_folder_name:
|
||
move_results_to_hardisk(save_folder_name)
|
||
else:
|
||
logging.error("由于训练失败,未执行结果移动操作。")
|
||
# 3. 释放资源
|
||
torch.cuda.empty_cache() # 释放 CUDA 显存
|
||
gc.collect() # 释放系统内存
|
||
logging.info("已清理 CUDA 显存与系统内存。") |