first commit

This commit is contained in:
admin
2026-05-20 15:05:35 +08:00
commit ac09b26253
2048 changed files with 189478 additions and 0 deletions

View File

@@ -0,0 +1,181 @@
# -*- coding: utf-8 -*-
import cv2
import numpy as np
from ultralytics import YOLO
import random
import os
from tqdm import tqdm # 引入进度条库
# --- 配置参数 ---
# 模型文件路径
# MODEL_PATH = os.path.join('YOLOv9e-seg', 'weights', 'best.pt')
MODEL_PATH = os.path.join('YOLO11l-seg', 'best.pt')
# 输入视频文件路径
VIDEO_IN_PATH = 'LC_Video_1.mp4'
# 输出视频文件路径
VIDEO_OUT_PATH = 'output_a_segmented_fps_controlled.mp4'
# --- 新增配置 ---
# 要处理的视频帧率。例如,设置为 5 表示每秒大约只对 5 帧进行模型推理。
# 设置为 None 则处理所有帧。
PROCESS_FPS = None
# 推理参数
CONFIDENCE_THRESHOLD = 0.3 # 置信度阈值,低于此值的检测将被忽略
MASK_ALPHA = 0.4 # 分割掩码的透明度 (0.0 完全透明, 1.0 完全不透明)
# --- 主处理逻辑 ---
def main():
"""
主函数,执行视频分割的完整流程。
"""
# 1. 加载预训练的 YOLOv9 分割模型
# YOLO() 类会自动处理模型的加载、权重初始化以及设备选择(优先使用可用的 GPU
print(f"正在加载模型: {MODEL_PATH}...")
try:
model = YOLO(MODEL_PATH)
except Exception as e:
print(f"错误: 无法加载模型。请检查路径是否正确以及 Ultralytics 是否已正确安装。")
print(f"详细错误: {e}")
return
# 获取模型能够识别的所有类别名称
class_names = model.names
print(f"模型已加载,可识别 {len(class_names)} 个类别。")
# 2. 初始化视频读写对象
# 使用 OpenCV 打开输入视频文件
cap = cv2.VideoCapture(VIDEO_IN_PATH)
if not cap.isOpened():
print(f"错误: 无法打开视频文件: {VIDEO_IN_PATH}")
return
# 获取输入视频的属性(帧率、宽度、高度、总帧数)
fps = cap.get(cv2.CAP_PROP_FPS)
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) # 获取总帧数用于进度条
print(f"输入视频属性: {frame_width}x{frame_height} @ {fps:.2f} FPS, 总共 {total_frames}")
# 定义视频编码器并创建 VideoWriter 对象以保存输出视频
# 'mp4v' 是适用于.mp4 文件的常用编码器
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(VIDEO_OUT_PATH, fourcc, fps, (frame_width, frame_height))
# 3. 为每个类别生成一个稳定的随机颜色
# 这确保了在整个视频中,同一类别的对象掩码颜色保持一致
random.seed(42) # 固定随机种子以保证每次运行颜色相同
class_colors = {
name: [random.randint(0, 255) for _ in range(3)]
for name in class_names.values()
}
# --- 新增:计算帧处理间隔 ---
frame_counter = 0
frame_skip_interval = 1
if PROCESS_FPS is not None and PROCESS_FPS > 0:
frame_skip_interval = round(fps / PROCESS_FPS)
# 确保至少为1避免除零错误
if frame_skip_interval < 1:
frame_skip_interval = 1
print(f"处理帧率设置为 {PROCESS_FPS} FPS. 将每 {frame_skip_interval} 帧运行一次模型推理。")
else:
print("将处理所有帧。")
# 4. 逐帧处理视频,并使用 tqdm 显示进度条
print("开始逐帧处理视频... (在预览窗口激活时按 'q' 键可随时退出)")
for _ in tqdm(range(total_frames), desc="处理进度"):
# 从视频中读取一帧
success, frame = cap.read()
if not success:
# 如果视频读取结束,则跳出循环
print("\n视频读取完毕或发生错误。")
break
frame_counter += 1
# 5. 对当前帧执行 YOLOv9 推理 (有条件地)
# 仅在达到处理间隔时才运行模型,否则将重用上一次的结果
if frame_counter % frame_skip_interval == 0:
results = model.predict(source=frame, conf=CONFIDENCE_THRESHOLD, verbose=False)
# 检查是否有任何有效的 'results' 可供绘制
# 'results' 可能来自当前帧的推理,或来自之前的帧
# `locals()` 检查 `results` 变量是否已在当前作用域中定义
if 'results' in locals() and results is not None:
# ------------------- FIX START -------------------
# `results` 是一个列表,因为我们处理的是单张图片,所以需要取出第一个元素
result = results[0]
# -------------------- FIX END --------------------
# 6. 可视化:绘制半透明掩码和边界框
# 创建一个与原始帧相同的副本,用于绘制掩码。这是实现透明效果的关键步骤。
overlay = frame.copy()
# 检查是否存在检测到的掩码
if result.masks is not None:
# 遍历每个检测到的对象
# zip() 函数将边界框和掩码一一对应起来
for box, mask in zip(result.boxes, result.masks):
# 获取类别 ID 和类别名称
class_id = int(box.cls)
class_name = class_names[class_id]
# 获取该类别的预定义颜色
color = class_colors[class_name]
# --- 绘制分割掩码 ---
# [修正] mask.xy返回的是一个列表的列表需要用np.array处理
points = np.array(mask.xy, dtype=np.int32)
# 在 overlay 层上填充多边形(完全不透明)
cv2.fillPoly(overlay, [points], color)
# --- 绘制边界框和标签 ---
# 获取边界框坐标
x1, y1, x2, y2 = map(int, box.xyxy[0]) # box.xyxy也是列表取第一个
# 在原始帧上绘制矩形边界框
cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
# 准备标签文本(类别名 + 置信度)
confidence = float(box.conf)
label = f"{class_name}: {confidence:.2f}"
# 计算文本尺寸以绘制背景
(text_width, text_height), baseline = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
# 绘制文本背景框
cv2.rectangle(frame, (x1, y1 - text_height - baseline), (x1 + text_width, y1), color, -1)
# 在背景框上放置文本
cv2.putText(frame, label, (x1, y1 - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
# 7. 融合原始帧和掩码层
# cv2.addWeighted() 将 overlay 和 frame 按照指定的权重(透明度)混合
# 公式为: output = frame * (1 - alpha) + overlay * alpha + 0
cv2.addWeighted(overlay, MASK_ALPHA, frame, 1 - MASK_ALPHA, 0, frame)
# 8. 将处理后的帧写入输出视频文件
out.write(frame)
# # 9. [新增] 实时显示处理后的帧
# cv2.imshow('Real-time Segmentation Preview', frame)
# # 检测 'q' 键是否被按下,如果是则退出循环
# # cv2.waitKey(1) 对于视频处理至关重要,它等待 1ms允许 OpenCV 刷新窗口
# if cv2.waitKey(1) & 0xFF == ord('q'):
# print("\n用户中断处理。")
# break
# 10. 释放资源
cap.release()
out.release()
# cv2.destroyAllWindows()
print(f"处理完成!")
print(f"输出视频已保存至: {VIDEO_OUT_PATH}")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,193 @@
# -*- coding: utf-8 -*-
import cv2
import numpy as np
from ultralytics import YOLO
import random
import os
from tqdm import tqdm
import multiprocessing as mp
import time
# --- Configuration ---
MODEL_PATH = os.path.join('YOLOv9e-seg', 'weights', 'best.pt')
VIDEO_IN_PATH = 'LC_Video_1.mp4'
VIDEO_OUT_PATH = 'output_parallel_segmented.mp4'
CONFIDENCE_THRESHOLD = 0.3
MASK_ALPHA = 0.4
# --- Parallel Processing Configuration ---
# Set the number of parallel processes. Using cpu_count() - 1 is a safe choice.
# You can adjust this based on your system's resources.
NUM_WORKERS = min(5, max(1, mp.cpu_count() - 1))
def process_frame_worker(task_queue, results_queue, model_path, confidence_threshold):
"""
Worker function for parallel processing.
Each worker loads its own model instance and processes frames from the task queue.
"""
try:
model = YOLO(model_path)
except Exception as e:
print(f"[Worker PID: {os.getpid()}] Error loading model: {e}")
return
while True:
task = task_queue.get()
# A None task is a signal to terminate the worker
if task is None:
break
frame_index, frame = task
try:
results = model.predict(source=frame, conf=confidence_threshold, verbose=False)
# We only need the first result as we process one frame at a time
results_queue.put((frame_index, results[0]))
except Exception as e:
print(f"[Worker PID: {os.getpid()}] Error processing frame {frame_index}: {e}")
# Put a placeholder to avoid deadlocks
results_queue.put((frame_index, None))
def main():
"""
Main function to execute video segmentation using parallel processing.
"""
# 1. Initialize video capture and get properties
cap = cv2.VideoCapture(VIDEO_IN_PATH)
if not cap.isOpened():
print(f"Error: Could not open video file: {VIDEO_IN_PATH}")
return
fps = cap.get(cv2.CAP_PROP_FPS)
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
print(f"Input video: {frame_width}x{frame_height} @ {fps:.2f} FPS, {total_frames} total frames")
print(f"Using {NUM_WORKERS} parallel workers for processing.")
# 2. Initialize video writer
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(VIDEO_OUT_PATH, fourcc, fps, (frame_width, frame_height))
# 3. Generate stable random colors for each class
# We load a temporary model instance just to get the class names
try:
temp_model = YOLO(MODEL_PATH)
class_names = temp_model.names
del temp_model
except Exception as e:
print(f"Fatal Error: Could not load model to get class names. Aborting. Details: {e}")
cap.release()
out.release()
return
random.seed(42)
class_colors = {name: [random.randint(0, 255) for _ in range(3)] for name in class_names.values()}
print(f"Model recognizes {len(class_names)} classes.")
# 4. Set up multiprocessing queues and processes
manager = mp.Manager()
# Queue for frames to be processed. Maxsize helps prevent memory overload.
task_queue = manager.Queue(maxsize=NUM_WORKERS * 2)
# Queue for processed results.
results_queue = manager.Queue(maxsize=NUM_WORKERS * 2)
processes = []
for _ in range(NUM_WORKERS):
p = mp.Process(target=process_frame_worker, args=(task_queue, results_queue, MODEL_PATH, CONFIDENCE_THRESHOLD))
processes.append(p)
p.start()
# 5. Main process: Read frames and dispatch to workers
frame_index = 0
pbar_read = tqdm(total=total_frames, desc="Reading frames")
while True:
success, frame = cap.read()
if not success:
break
task_queue.put((frame_index, frame))
frame_index += 1
pbar_read.update(1)
pbar_read.close()
# Signal workers to terminate by sending None tasks
for _ in range(NUM_WORKERS):
task_queue.put(None)
# 6. Main process: Collect results, draw segmentation, and write to video
# This part must be sequential to ensure correct video order.
frames_to_write = {}
next_frame_to_write = 0
pbar_write = tqdm(total=total_frames, desc="Processing & Writing")
for _ in range(total_frames):
# Get a result from the queue (this might be out of order)
res_index, result = results_queue.get()
# If a worker failed, the result might be None
if result is None:
# We need a placeholder frame. We'll re-read it.
# This is inefficient but robust against worker failure.
cap.set(cv2.CAP_PROP_POS_FRAMES, res_index)
_, frame = cap.read()
frames_to_write[res_index] = (frame, None)
else:
# Re-read the original frame to draw on.
# This avoids passing bulky frames through the results queue.
cap.set(cv2.CAP_PROP_POS_FRAMES, res_index)
_, frame = cap.read()
frames_to_write[res_index] = (frame, result)
# Write all consecutive frames that are now available
while next_frame_to_write in frames_to_write:
frame, result = frames_to_write.pop(next_frame_to_write)
if result is not None:
overlay = frame.copy()
if result.masks is not None:
for box, mask in zip(result.boxes, result.masks):
class_id = int(box.cls)
class_name = class_names[class_id]
color = class_colors[class_name]
# Draw segmentation mask
points = np.array(mask.xy, dtype=np.int32)
cv2.fillPoly(overlay, [points], color)
# Draw bounding box and label
x1, y1, x2, y2 = map(int, box.xyxy[0])
cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
confidence = float(box.conf)
label = f"{class_name}: {confidence:.2f}"
(w, h), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
cv2.rectangle(frame, (x1, y1 - h - 5), (x1 + w, y1), color, -1)
cv2.putText(frame, label, (x1, y1 - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
# Blend the overlay with the original frame
cv2.addWeighted(overlay, MASK_ALPHA, frame, 1 - MASK_ALPHA, 0, frame)
out.write(frame)
pbar_write.update(1)
next_frame_to_write += 1
pbar_write.close()
# 7. Clean up all resources
for p in processes:
p.join()
cap.release()
out.release()
print("Processing complete!")
print(f"Output video saved to: {VIDEO_OUT_PATH}")
if __name__ == "__main__":
# On Windows and macOS, 'fork' is not the default start method.
# 'spawn' or 'forkserver' are safer and required on these platforms.
# It's good practice to set it explicitly.
mp.set_start_method("spawn", force=True)
main()

View File

@@ -0,0 +1,130 @@
import yaml, sys
import torch
from pathlib import Path
# --- 1. 核心目录设置 (Core Directories) ---
HARDISK_DIR = Path.home() / "Desktop" / "Seg" / "Hardisk" # 硬盘根目录
BASE_DIR = Path(__file__).parent.parent # 当前脚本所在目录上级
DATASET_YAML_PATH = Path(__file__).parent / "dataset.yaml" # 数据集的 YAML 配置文件路径
OUTPUTS_DIR = BASE_DIR / "DataSet_Public_outputs" # 输出结果目录 # 最终为OUTPUTS_DIR / dataset_name
TEST_IMAGE_DIR = None # 初始化 TEST_IMAGE_DIR # 测试文件地址 dataset.path / TODO "val" / "test" TODO
# try:
# 3. 读取并解析 YAML 文件
with open(DATASET_YAML_PATH, 'r', encoding='utf-8') as f:
yaml_data = yaml.safe_load(f)
# 4. 从解析后的数据中获取 'path' 的值
relative_path_from_yaml = yaml_data.get('path')
test_path_from_yaml = yaml_data.get('test')
dataset_name = Path(relative_path_from_yaml).name
if relative_path_from_yaml:
# 5. 【核心步骤】构建绝对路径
# relative_path_from_yaml 是相对于 .yaml 文件本身的路径。
# 所以,我们需要获取 .yaml 文件所在的目录,然后与这个相对路径拼接。
yaml_file_directory = DATASET_YAML_PATH.parent
TEST_IMAGE_DIR = (DATASET_YAML_PATH.parent / relative_path_from_yaml / test_path_from_yaml).resolve() # 这里val 或 test在dataset.yaml中定义
# 使用 .exists() 方法来检查路径是否存在
if not TEST_IMAGE_DIR.exists():
# 如果路径不存在,则执行这里的代码
print(f"警告: 测试图片目录不存在: {TEST_IMAGE_DIR}")
sys.exit(1)
# 6. 获取OUTPUTS_DIR
if dataset_name and dataset_name not in {'.', '..'}:
dataset_name_ = dataset_name + "-Yolo"
OUTPUTS_DIR = OUTPUTS_DIR / dataset_name_
print(f"设定输出路径为: '{str(OUTPUTS_DIR)}'")
else:
print(f"警告: 提取的dataset_name: '{dataset_name}' 无效(为空、'.''..')。")
sys.exit(1)
else:
print(f"警告: 在 '{DATASET_YAML_PATH}' 文件中没有找到 'path' 键。")
sys.exit(1)
# 4. 从解析后的数据中获取 'path' 的值
relative_path_from_yaml = yaml_data.get('path')
# except FileNotFoundError:
# print(f"错误: YAML 配置文件未找到: '{DATASET_YAML_PATH}'")
# except Exception as e:
# print(f"读取或解析 YAML 文件时发生错误: {e}")
# ==============================================================================
# --- 新增:模型选择 (只需修改这里) ---
# 你想要使用的模型名称,从下面的 MODEL_CONFIGS 字典中选择一个键。
SELECTED_MODEL = 'YOLOv9e-seg'
# ==============================================================================
# --- 修改:模型、图像大小和批处理大小的集成配置 ---
# 将所有模型的配置(权重、图像大小、批处理大小)集中管理
MODEL_CONFIGS = {
# YOLOv8
'YOLOv8n-seg': {'weights': 'yolov8n-seg.pt', 'image_size': 640, 'batch_size': 16},
'YOLOv8s-seg': {'weights': 'yolov8s-seg.pt', 'image_size': 640, 'batch_size': 16},
'YOLOv8m-seg': {'weights': 'yolov8m-seg.pt', 'image_size': 640, 'batch_size': 16},
'YOLOv8l-seg': {'weights': 'yolov8l-seg.pt', 'image_size': 640, 'batch_size': 16}, # 6401612.5GB
'YOLOv8x-seg': {'weights': 'yolov8x-seg.pt', 'image_size': 640, 'batch_size': 16}, # 6401615.5GB # 示例X模型使用1280分辨率
# YOLOv9
'YOLOv9c-seg': {'weights': 'yolov9c-seg.pt', 'image_size': 640, 'batch_size': 16}, # 6401613GB
'YOLOv9e-seg': {'weights': 'yolov9e-seg.pt', 'image_size': 640, 'batch_size': 8}, # 64016内存超了 # 示例E模型最大使用1280分辨率和更小的batch
# YOLOv11 (假设)
'YOLO11n-seg': {'weights': 'yolo11n-seg.pt', 'image_size': 640, 'batch_size': 16},
'YOLO11s-seg': {'weights': 'yolo11s-seg.pt', 'image_size': 640, 'batch_size': 16},
'YOLO11m-seg': {'weights': 'yolo11m-seg.pt', 'image_size': 640, 'batch_size': 16},
'YOLO11l-seg': {'weights': 'yolo11l-seg.pt', 'image_size': 640, 'batch_size': 16}, # 6401612.5GB
'YOLO11x-seg': {'weights': 'yolo11x-seg.pt', 'image_size': 640, 'batch_size': 16}, # 6401619.5GB
# YOLOv12 (假设)
'YOLO12-seg': {'weights': str(Path(__file__).parent / 'yolo12-seg.yaml'), 'image_size': 640, 'batch_size': 16}, # 640163GB
}
# --- 新增:根据选择自动加载配置 ---
# 检查所选模型是否存在于配置字典中
if SELECTED_MODEL in MODEL_CONFIGS:
config = MODEL_CONFIGS[SELECTED_MODEL]
MODEL_WEIGHTS = config['weights']
IMAGE_SIZE = config['image_size']
BATCH_SIZE = config['batch_size']
print(f"--- 模型配置加载成功 ---")
print(f"模型名称: {SELECTED_MODEL}")
print(f"权重文件: {MODEL_WEIGHTS}")
print(f"图像大小: {IMAGE_SIZE}")
print(f"批处理大小: {BATCH_SIZE}")
print(f"------------------------")
else:
print(f"错误: 所选模型 '{SELECTED_MODEL}' 不在配置列表中。")
print("可用模型包括: ", list(MODEL_CONFIGS.keys()))
sys.exit(1)
# --- 4. 训练超参数 (Training Hyperparameters) ---
EPOCHS = 300 # 训练轮次
PATIENCE = 100 # 提前停止训练的轮数
SAVE_PERIOD = 10 # 每隔多少轮保存一次模型
# BATCH_SIZE 已在上面根据模型自动设置
LEARNING_RATE = 0.01 # 初始学习率
OPTIMIZER = 'Adam' # 优化器 (TODO 'SGD', 'Adam', 'AdamW', 'auto' TODO)
WORKERS = 4 # 数据加载的工作线程数
# --- 动态设备选择 (Dynamic Device Selection) ---
def get_auto_device():
"""自动检测并返回最合适的设备"""
if torch.cuda.is_available():
gpu_count = torch.cuda.device_count()
if gpu_count > 1:
# 如果有多个GPU使用所有GPU
device_str = ",".join(str(i) for i in range(gpu_count))
print(f"检测到 {gpu_count} 个可用的GPU。将使用所有GPU: {device_str}")
return device_str
else:
# 如果只有1个GPU
print("检测到 1 个可用的GPU。将使用 GPU: 0")
return '0'
else:
# 如果没有可用的GPU使用CPU
print("未检测到可用的GPU。将使用CPU。")
return 'cpu'
DEVICE = get_auto_device() # 调用函数来设置设备
# --- 5. 预测设置 (Prediction Settings) ---
SHOW_LABELS = True # 是否在预测结果上显示类别标签
SHOW_CONF = True # 是否在预测结果上显示置信度和边界框
SAVE_PREDICTIONS = True # 是否保存预测结果图像

View File

@@ -0,0 +1,113 @@
import argparse, shutil
import logging
from ultralytics import YOLO
from datetime import datetime
import yolo_config as config
import gc, torch
# --- 日志设置 ---
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(message)s",
handlers=[logging.StreamHandler()]
)
def train_model(model_key: str):
"""
根据给定的模型密钥训练YOLO分割模型。
Args:
model_key (str): 在 yolo_config.MODEL_CONFIGS 中定义的模型密钥。
"""
# 生成时间戳以区分不同训练运行
timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
project_folder = f"{model_key}_{timestamp}"
if model_key not in config.MODEL_CONFIGS:
logging.error(f"错误:模型 '{model_key}' 不在 yolo_config.py 中定义。")
logging.info(f"可用模型: {list(config.MODEL_CONFIGS.keys())}")
return
model_config = config.MODEL_CONFIGS[model_key]
model_file = model_config['weights']
logging.info(f"开始训练模型: {model_key} ({model_file})")
try:
# --- 1. 加载模型 ---
model = YOLO(model_file) # 如果没有模型会自动下载
logging.info("模型加载成功。")
# --- 2. 模型训练 ---
logging.info(f"数据集配置文件: {config.DATASET_YAML_PATH}")
logging.info(f"训练参数: Epochs={config.EPOCHS}, Batch Size={config.BATCH_SIZE}, Img Size={config.IMAGE_SIZE}")
model.train(
data=str(config.DATASET_YAML_PATH),
epochs=config.EPOCHS,
imgsz=config.IMAGE_SIZE,
batch=config.BATCH_SIZE,
optimizer=config.OPTIMIZER,
lr0=config.LEARNING_RATE,
device=config.DEVICE,
project=str(config.OUTPUTS_DIR), # 指定输出的根目录
name=project_folder, # 指定本次训练的项目名
workers=config.WORKERS,
exist_ok=True, # 如果项目已存在,则覆盖
patience=config.PATIENCE, # 提前停止训练的轮数
save_period=config.SAVE_PERIOD # 每隔多少轮保存一次模型
)
logging.info("模型训练完成。")
logging.info(f"训练结果保存在: {config.OUTPUTS_DIR / project_folder}")
return project_folder
except Exception as e:
logging.error(f"训练过程中发生错误: {e}")
return None
# --- 3. 复制结果到硬盘并删除原文件夹 (新增逻辑) ---
def move_results_to_hardisk(project_folder: str):
source_dir = config.OUTPUTS_DIR / project_folder
# 确保目标硬盘目录存在
outputs_folder_name = config.OUTPUTS_DIR.name
destination_dir = config.HARDISK_DIR / outputs_folder_name / project_folder
destination_dir.parent.mkdir(parents=True, exist_ok=True)
logging.info(f"准备将结果从 {source_dir} 移动到 {destination_dir}...")
try:
# 步骤 1: 复制文件夹
shutil.copytree(source_dir, destination_dir)
logging.info(f"成功复制结果到: {destination_dir}")
# 步骤 2: 复制成功后,删除原文件夹
logging.info(f"正在删除原文件夹: {source_dir}")
shutil.rmtree(source_dir)
logging.info("成功删除原文件夹。")
logging.info(f"任务完成,最终结果已保存至: {destination_dir}")
except Exception as e:
logging.error(f"移动文件夹时发生错误: {e}")
logging.warning(f"原始训练结果仍保留在: {source_dir}")
if __name__ == "__main__":
# 创建命令行参数解析器
parser = argparse.ArgumentParser(description="训练 YOLO 分割模型。")
parser.add_argument(
"--model",
type=str,
required=True,
choices=list(config.MODEL_CONFIGS.keys()),
help=f"选择要训练的模型。可选: {list(config.MODEL_CONFIGS.keys())}"
)
args = parser.parse_args()
# 1. 开始训练
project_folder = train_model(args.model)
# 2. 训练完成后,移动结果到硬盘
if project_folder:
move_results_to_hardisk(project_folder)
else:
logging.error("由于训练失败,未执行结果移动操作。")
# 3. 释放资源
torch.cuda.empty_cache() # 释放 CUDA 显存
gc.collect() # 释放系统内存
logging.info("已清理 CUDA 显存与系统内存。")