first commit
This commit is contained in:
181
Seg_Predict_YoloModel/yolo_Seg_Video-V1-Visible.py
Normal file
181
Seg_Predict_YoloModel/yolo_Seg_Video-V1-Visible.py
Normal file
@@ -0,0 +1,181 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from ultralytics import YOLO
|
||||
import random
|
||||
import os
|
||||
from tqdm import tqdm # 引入进度条库
|
||||
|
||||
# --- 配置参数 ---
|
||||
|
||||
# 模型文件路径
|
||||
# MODEL_PATH = os.path.join('YOLOv9e-seg', 'weights', 'best.pt')
|
||||
MODEL_PATH = os.path.join('YOLO11l-seg', 'best.pt')
|
||||
# 输入视频文件路径
|
||||
VIDEO_IN_PATH = 'LC_Video_1.mp4'
|
||||
# 输出视频文件路径
|
||||
VIDEO_OUT_PATH = 'output_a_segmented_fps_controlled.mp4'
|
||||
|
||||
# --- 新增配置 ---
|
||||
# 要处理的视频帧率。例如,设置为 5 表示每秒大约只对 5 帧进行模型推理。
|
||||
# 设置为 None 则处理所有帧。
|
||||
PROCESS_FPS = None
|
||||
|
||||
# 推理参数
|
||||
CONFIDENCE_THRESHOLD = 0.3 # 置信度阈值,低于此值的检测将被忽略
|
||||
MASK_ALPHA = 0.4 # 分割掩码的透明度 (0.0 完全透明, 1.0 完全不透明)
|
||||
|
||||
# --- 主处理逻辑 ---
|
||||
|
||||
def main():
|
||||
"""
|
||||
主函数,执行视频分割的完整流程。
|
||||
"""
|
||||
# 1. 加载预训练的 YOLOv9 分割模型
|
||||
# YOLO() 类会自动处理模型的加载、权重初始化以及设备选择(优先使用可用的 GPU)
|
||||
print(f"正在加载模型: {MODEL_PATH}...")
|
||||
try:
|
||||
model = YOLO(MODEL_PATH)
|
||||
except Exception as e:
|
||||
print(f"错误: 无法加载模型。请检查路径是否正确以及 Ultralytics 是否已正确安装。")
|
||||
print(f"详细错误: {e}")
|
||||
return
|
||||
|
||||
# 获取模型能够识别的所有类别名称
|
||||
class_names = model.names
|
||||
print(f"模型已加载,可识别 {len(class_names)} 个类别。")
|
||||
|
||||
# 2. 初始化视频读写对象
|
||||
# 使用 OpenCV 打开输入视频文件
|
||||
cap = cv2.VideoCapture(VIDEO_IN_PATH)
|
||||
if not cap.isOpened():
|
||||
print(f"错误: 无法打开视频文件: {VIDEO_IN_PATH}")
|
||||
return
|
||||
|
||||
# 获取输入视频的属性(帧率、宽度、高度、总帧数)
|
||||
fps = cap.get(cv2.CAP_PROP_FPS)
|
||||
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||||
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||||
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) # 获取总帧数用于进度条
|
||||
print(f"输入视频属性: {frame_width}x{frame_height} @ {fps:.2f} FPS, 总共 {total_frames} 帧")
|
||||
|
||||
# 定义视频编码器并创建 VideoWriter 对象以保存输出视频
|
||||
# 'mp4v' 是适用于.mp4 文件的常用编码器
|
||||
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
||||
out = cv2.VideoWriter(VIDEO_OUT_PATH, fourcc, fps, (frame_width, frame_height))
|
||||
|
||||
# 3. 为每个类别生成一个稳定的随机颜色
|
||||
# 这确保了在整个视频中,同一类别的对象掩码颜色保持一致
|
||||
random.seed(42) # 固定随机种子以保证每次运行颜色相同
|
||||
class_colors = {
|
||||
name: [random.randint(0, 255) for _ in range(3)]
|
||||
for name in class_names.values()
|
||||
}
|
||||
|
||||
# --- 新增:计算帧处理间隔 ---
|
||||
frame_counter = 0
|
||||
frame_skip_interval = 1
|
||||
if PROCESS_FPS is not None and PROCESS_FPS > 0:
|
||||
frame_skip_interval = round(fps / PROCESS_FPS)
|
||||
# 确保至少为1,避免除零错误
|
||||
if frame_skip_interval < 1:
|
||||
frame_skip_interval = 1
|
||||
print(f"处理帧率设置为 {PROCESS_FPS} FPS. 将每 {frame_skip_interval} 帧运行一次模型推理。")
|
||||
else:
|
||||
print("将处理所有帧。")
|
||||
|
||||
# 4. 逐帧处理视频,并使用 tqdm 显示进度条
|
||||
print("开始逐帧处理视频... (在预览窗口激活时按 'q' 键可随时退出)")
|
||||
for _ in tqdm(range(total_frames), desc="处理进度"):
|
||||
# 从视频中读取一帧
|
||||
success, frame = cap.read()
|
||||
|
||||
if not success:
|
||||
# 如果视频读取结束,则跳出循环
|
||||
print("\n视频读取完毕或发生错误。")
|
||||
break
|
||||
|
||||
frame_counter += 1
|
||||
|
||||
# 5. 对当前帧执行 YOLOv9 推理 (有条件地)
|
||||
# 仅在达到处理间隔时才运行模型,否则将重用上一次的结果
|
||||
if frame_counter % frame_skip_interval == 0:
|
||||
results = model.predict(source=frame, conf=CONFIDENCE_THRESHOLD, verbose=False)
|
||||
|
||||
# 检查是否有任何有效的 'results' 可供绘制
|
||||
# 'results' 可能来自当前帧的推理,或来自之前的帧
|
||||
# `locals()` 检查 `results` 变量是否已在当前作用域中定义
|
||||
if 'results' in locals() and results is not None:
|
||||
# ------------------- FIX START -------------------
|
||||
# `results` 是一个列表,因为我们处理的是单张图片,所以需要取出第一个元素
|
||||
result = results[0]
|
||||
# -------------------- FIX END --------------------
|
||||
|
||||
# 6. 可视化:绘制半透明掩码和边界框
|
||||
|
||||
# 创建一个与原始帧相同的副本,用于绘制掩码。这是实现透明效果的关键步骤。
|
||||
overlay = frame.copy()
|
||||
|
||||
# 检查是否存在检测到的掩码
|
||||
if result.masks is not None:
|
||||
# 遍历每个检测到的对象
|
||||
# zip() 函数将边界框和掩码一一对应起来
|
||||
for box, mask in zip(result.boxes, result.masks):
|
||||
# 获取类别 ID 和类别名称
|
||||
class_id = int(box.cls)
|
||||
class_name = class_names[class_id]
|
||||
|
||||
# 获取该类别的预定义颜色
|
||||
color = class_colors[class_name]
|
||||
|
||||
# --- 绘制分割掩码 ---
|
||||
# [修正] mask.xy返回的是一个列表的列表,需要用np.array处理
|
||||
points = np.array(mask.xy, dtype=np.int32)
|
||||
# 在 overlay 层上填充多边形(完全不透明)
|
||||
cv2.fillPoly(overlay, [points], color)
|
||||
|
||||
# --- 绘制边界框和标签 ---
|
||||
# 获取边界框坐标
|
||||
x1, y1, x2, y2 = map(int, box.xyxy[0]) # box.xyxy也是列表,取第一个
|
||||
# 在原始帧上绘制矩形边界框
|
||||
cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
|
||||
|
||||
# 准备标签文本(类别名 + 置信度)
|
||||
confidence = float(box.conf)
|
||||
label = f"{class_name}: {confidence:.2f}"
|
||||
|
||||
# 计算文本尺寸以绘制背景
|
||||
(text_width, text_height), baseline = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
|
||||
# 绘制文本背景框
|
||||
cv2.rectangle(frame, (x1, y1 - text_height - baseline), (x1 + text_width, y1), color, -1)
|
||||
# 在背景框上放置文本
|
||||
cv2.putText(frame, label, (x1, y1 - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
|
||||
|
||||
# 7. 融合原始帧和掩码层
|
||||
# cv2.addWeighted() 将 overlay 和 frame 按照指定的权重(透明度)混合
|
||||
# 公式为: output = frame * (1 - alpha) + overlay * alpha + 0
|
||||
cv2.addWeighted(overlay, MASK_ALPHA, frame, 1 - MASK_ALPHA, 0, frame)
|
||||
|
||||
# 8. 将处理后的帧写入输出视频文件
|
||||
out.write(frame)
|
||||
|
||||
# # 9. [新增] 实时显示处理后的帧
|
||||
# cv2.imshow('Real-time Segmentation Preview', frame)
|
||||
|
||||
# # 检测 'q' 键是否被按下,如果是则退出循环
|
||||
# # cv2.waitKey(1) 对于视频处理至关重要,它等待 1ms,允许 OpenCV 刷新窗口
|
||||
# if cv2.waitKey(1) & 0xFF == ord('q'):
|
||||
# print("\n用户中断处理。")
|
||||
# break
|
||||
|
||||
# 10. 释放资源
|
||||
cap.release()
|
||||
out.release()
|
||||
# cv2.destroyAllWindows()
|
||||
print(f"处理完成!")
|
||||
print(f"输出视频已保存至: {VIDEO_OUT_PATH}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
193
Seg_Predict_YoloModel/yolo_Seg_Video-V2-UnVisible.py
Normal file
193
Seg_Predict_YoloModel/yolo_Seg_Video-V2-UnVisible.py
Normal file
@@ -0,0 +1,193 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from ultralytics import YOLO
|
||||
import random
|
||||
import os
|
||||
from tqdm import tqdm
|
||||
import multiprocessing as mp
|
||||
import time
|
||||
|
||||
# --- Configuration ---
|
||||
MODEL_PATH = os.path.join('YOLOv9e-seg', 'weights', 'best.pt')
|
||||
VIDEO_IN_PATH = 'LC_Video_1.mp4'
|
||||
VIDEO_OUT_PATH = 'output_parallel_segmented.mp4'
|
||||
CONFIDENCE_THRESHOLD = 0.3
|
||||
MASK_ALPHA = 0.4
|
||||
|
||||
# --- Parallel Processing Configuration ---
|
||||
# Set the number of parallel processes. Using cpu_count() - 1 is a safe choice.
|
||||
# You can adjust this based on your system's resources.
|
||||
NUM_WORKERS = min(5, max(1, mp.cpu_count() - 1))
|
||||
|
||||
|
||||
def process_frame_worker(task_queue, results_queue, model_path, confidence_threshold):
|
||||
"""
|
||||
Worker function for parallel processing.
|
||||
Each worker loads its own model instance and processes frames from the task queue.
|
||||
"""
|
||||
try:
|
||||
model = YOLO(model_path)
|
||||
except Exception as e:
|
||||
print(f"[Worker PID: {os.getpid()}] Error loading model: {e}")
|
||||
return
|
||||
|
||||
while True:
|
||||
task = task_queue.get()
|
||||
# A None task is a signal to terminate the worker
|
||||
if task is None:
|
||||
break
|
||||
|
||||
frame_index, frame = task
|
||||
try:
|
||||
results = model.predict(source=frame, conf=confidence_threshold, verbose=False)
|
||||
# We only need the first result as we process one frame at a time
|
||||
results_queue.put((frame_index, results[0]))
|
||||
except Exception as e:
|
||||
print(f"[Worker PID: {os.getpid()}] Error processing frame {frame_index}: {e}")
|
||||
# Put a placeholder to avoid deadlocks
|
||||
results_queue.put((frame_index, None))
|
||||
|
||||
|
||||
def main():
|
||||
"""
|
||||
Main function to execute video segmentation using parallel processing.
|
||||
"""
|
||||
# 1. Initialize video capture and get properties
|
||||
cap = cv2.VideoCapture(VIDEO_IN_PATH)
|
||||
if not cap.isOpened():
|
||||
print(f"Error: Could not open video file: {VIDEO_IN_PATH}")
|
||||
return
|
||||
|
||||
fps = cap.get(cv2.CAP_PROP_FPS)
|
||||
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||||
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||||
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
print(f"Input video: {frame_width}x{frame_height} @ {fps:.2f} FPS, {total_frames} total frames")
|
||||
print(f"Using {NUM_WORKERS} parallel workers for processing.")
|
||||
|
||||
# 2. Initialize video writer
|
||||
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
||||
out = cv2.VideoWriter(VIDEO_OUT_PATH, fourcc, fps, (frame_width, frame_height))
|
||||
|
||||
# 3. Generate stable random colors for each class
|
||||
# We load a temporary model instance just to get the class names
|
||||
try:
|
||||
temp_model = YOLO(MODEL_PATH)
|
||||
class_names = temp_model.names
|
||||
del temp_model
|
||||
except Exception as e:
|
||||
print(f"Fatal Error: Could not load model to get class names. Aborting. Details: {e}")
|
||||
cap.release()
|
||||
out.release()
|
||||
return
|
||||
|
||||
random.seed(42)
|
||||
class_colors = {name: [random.randint(0, 255) for _ in range(3)] for name in class_names.values()}
|
||||
print(f"Model recognizes {len(class_names)} classes.")
|
||||
|
||||
# 4. Set up multiprocessing queues and processes
|
||||
manager = mp.Manager()
|
||||
# Queue for frames to be processed. Maxsize helps prevent memory overload.
|
||||
task_queue = manager.Queue(maxsize=NUM_WORKERS * 2)
|
||||
# Queue for processed results.
|
||||
results_queue = manager.Queue(maxsize=NUM_WORKERS * 2)
|
||||
|
||||
processes = []
|
||||
for _ in range(NUM_WORKERS):
|
||||
p = mp.Process(target=process_frame_worker, args=(task_queue, results_queue, MODEL_PATH, CONFIDENCE_THRESHOLD))
|
||||
processes.append(p)
|
||||
p.start()
|
||||
|
||||
# 5. Main process: Read frames and dispatch to workers
|
||||
frame_index = 0
|
||||
pbar_read = tqdm(total=total_frames, desc="Reading frames")
|
||||
while True:
|
||||
success, frame = cap.read()
|
||||
if not success:
|
||||
break
|
||||
task_queue.put((frame_index, frame))
|
||||
frame_index += 1
|
||||
pbar_read.update(1)
|
||||
pbar_read.close()
|
||||
|
||||
# Signal workers to terminate by sending None tasks
|
||||
for _ in range(NUM_WORKERS):
|
||||
task_queue.put(None)
|
||||
|
||||
# 6. Main process: Collect results, draw segmentation, and write to video
|
||||
# This part must be sequential to ensure correct video order.
|
||||
frames_to_write = {}
|
||||
next_frame_to_write = 0
|
||||
pbar_write = tqdm(total=total_frames, desc="Processing & Writing")
|
||||
|
||||
for _ in range(total_frames):
|
||||
# Get a result from the queue (this might be out of order)
|
||||
res_index, result = results_queue.get()
|
||||
|
||||
# If a worker failed, the result might be None
|
||||
if result is None:
|
||||
# We need a placeholder frame. We'll re-read it.
|
||||
# This is inefficient but robust against worker failure.
|
||||
cap.set(cv2.CAP_PROP_POS_FRAMES, res_index)
|
||||
_, frame = cap.read()
|
||||
frames_to_write[res_index] = (frame, None)
|
||||
else:
|
||||
# Re-read the original frame to draw on.
|
||||
# This avoids passing bulky frames through the results queue.
|
||||
cap.set(cv2.CAP_PROP_POS_FRAMES, res_index)
|
||||
_, frame = cap.read()
|
||||
frames_to_write[res_index] = (frame, result)
|
||||
|
||||
# Write all consecutive frames that are now available
|
||||
while next_frame_to_write in frames_to_write:
|
||||
frame, result = frames_to_write.pop(next_frame_to_write)
|
||||
|
||||
if result is not None:
|
||||
overlay = frame.copy()
|
||||
if result.masks is not None:
|
||||
for box, mask in zip(result.boxes, result.masks):
|
||||
class_id = int(box.cls)
|
||||
class_name = class_names[class_id]
|
||||
color = class_colors[class_name]
|
||||
|
||||
# Draw segmentation mask
|
||||
points = np.array(mask.xy, dtype=np.int32)
|
||||
cv2.fillPoly(overlay, [points], color)
|
||||
|
||||
# Draw bounding box and label
|
||||
x1, y1, x2, y2 = map(int, box.xyxy[0])
|
||||
cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
|
||||
|
||||
confidence = float(box.conf)
|
||||
label = f"{class_name}: {confidence:.2f}"
|
||||
(w, h), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
|
||||
cv2.rectangle(frame, (x1, y1 - h - 5), (x1 + w, y1), color, -1)
|
||||
cv2.putText(frame, label, (x1, y1 - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
|
||||
|
||||
# Blend the overlay with the original frame
|
||||
cv2.addWeighted(overlay, MASK_ALPHA, frame, 1 - MASK_ALPHA, 0, frame)
|
||||
|
||||
out.write(frame)
|
||||
pbar_write.update(1)
|
||||
next_frame_to_write += 1
|
||||
|
||||
pbar_write.close()
|
||||
|
||||
# 7. Clean up all resources
|
||||
for p in processes:
|
||||
p.join()
|
||||
|
||||
cap.release()
|
||||
out.release()
|
||||
print("Processing complete!")
|
||||
print(f"Output video saved to: {VIDEO_OUT_PATH}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# On Windows and macOS, 'fork' is not the default start method.
|
||||
# 'spawn' or 'forkserver' are safer and required on these platforms.
|
||||
# It's good practice to set it explicitly.
|
||||
mp.set_start_method("spawn", force=True)
|
||||
main()
|
||||
130
Seg_Predict_YoloModel/yolo_config.py
Normal file
130
Seg_Predict_YoloModel/yolo_config.py
Normal file
@@ -0,0 +1,130 @@
|
||||
import yaml, sys
|
||||
import torch
|
||||
from pathlib import Path
|
||||
|
||||
# --- 1. 核心目录设置 (Core Directories) ---
|
||||
HARDISK_DIR = Path.home() / "Desktop" / "Seg" / "Hardisk" # 硬盘根目录
|
||||
BASE_DIR = Path(__file__).parent.parent # 当前脚本所在目录上级
|
||||
DATASET_YAML_PATH = Path(__file__).parent / "dataset.yaml" # 数据集的 YAML 配置文件路径
|
||||
OUTPUTS_DIR = BASE_DIR / "DataSet_Public_outputs" # 输出结果目录 # 最终为:OUTPUTS_DIR / dataset_name
|
||||
|
||||
TEST_IMAGE_DIR = None # 初始化 TEST_IMAGE_DIR # 测试文件地址 dataset.path / TODO "val" / "test" TODO
|
||||
# try:
|
||||
# 3. 读取并解析 YAML 文件
|
||||
with open(DATASET_YAML_PATH, 'r', encoding='utf-8') as f:
|
||||
yaml_data = yaml.safe_load(f)
|
||||
# 4. 从解析后的数据中获取 'path' 的值
|
||||
relative_path_from_yaml = yaml_data.get('path')
|
||||
test_path_from_yaml = yaml_data.get('test')
|
||||
dataset_name = Path(relative_path_from_yaml).name
|
||||
if relative_path_from_yaml:
|
||||
# 5. 【核心步骤】构建绝对路径
|
||||
# relative_path_from_yaml 是相对于 .yaml 文件本身的路径。
|
||||
# 所以,我们需要获取 .yaml 文件所在的目录,然后与这个相对路径拼接。
|
||||
yaml_file_directory = DATASET_YAML_PATH.parent
|
||||
TEST_IMAGE_DIR = (DATASET_YAML_PATH.parent / relative_path_from_yaml / test_path_from_yaml).resolve() # 这里val 或 test在dataset.yaml中定义
|
||||
# 使用 .exists() 方法来检查路径是否存在
|
||||
if not TEST_IMAGE_DIR.exists():
|
||||
# 如果路径不存在,则执行这里的代码
|
||||
print(f"警告: 测试图片目录不存在: {TEST_IMAGE_DIR}")
|
||||
sys.exit(1)
|
||||
# 6. 获取OUTPUTS_DIR
|
||||
if dataset_name and dataset_name not in {'.', '..'}:
|
||||
dataset_name_ = dataset_name + "-Yolo"
|
||||
OUTPUTS_DIR = OUTPUTS_DIR / dataset_name_
|
||||
print(f"设定输出路径为: '{str(OUTPUTS_DIR)}'")
|
||||
else:
|
||||
print(f"警告: 提取的dataset_name: '{dataset_name}' 无效(为空、'.' 或 '..')。")
|
||||
sys.exit(1)
|
||||
else:
|
||||
print(f"警告: 在 '{DATASET_YAML_PATH}' 文件中没有找到 'path' 键。")
|
||||
sys.exit(1)
|
||||
# 4. 从解析后的数据中获取 'path' 的值
|
||||
relative_path_from_yaml = yaml_data.get('path')
|
||||
# except FileNotFoundError:
|
||||
# print(f"错误: YAML 配置文件未找到: '{DATASET_YAML_PATH}'")
|
||||
# except Exception as e:
|
||||
# print(f"读取或解析 YAML 文件时发生错误: {e}")
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# --- 新增:模型选择 (只需修改这里) ---
|
||||
# 你想要使用的模型名称,从下面的 MODEL_CONFIGS 字典中选择一个键。
|
||||
SELECTED_MODEL = 'YOLOv9e-seg'
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
# --- 修改:模型、图像大小和批处理大小的集成配置 ---
|
||||
# 将所有模型的配置(权重、图像大小、批处理大小)集中管理
|
||||
MODEL_CONFIGS = {
|
||||
# YOLOv8
|
||||
'YOLOv8n-seg': {'weights': 'yolov8n-seg.pt', 'image_size': 640, 'batch_size': 16},
|
||||
'YOLOv8s-seg': {'weights': 'yolov8s-seg.pt', 'image_size': 640, 'batch_size': 16},
|
||||
'YOLOv8m-seg': {'weights': 'yolov8m-seg.pt', 'image_size': 640, 'batch_size': 16},
|
||||
'YOLOv8l-seg': {'weights': 'yolov8l-seg.pt', 'image_size': 640, 'batch_size': 16}, # 640,16,12.5GB
|
||||
'YOLOv8x-seg': {'weights': 'yolov8x-seg.pt', 'image_size': 640, 'batch_size': 16}, # 640,16,15.5GB # 示例:X模型使用1280分辨率
|
||||
# YOLOv9
|
||||
'YOLOv9c-seg': {'weights': 'yolov9c-seg.pt', 'image_size': 640, 'batch_size': 16}, # 640,16,13GB
|
||||
'YOLOv9e-seg': {'weights': 'yolov9e-seg.pt', 'image_size': 640, 'batch_size': 8}, # 640,16,内存超了 # 示例:E模型(最大)使用1280分辨率和更小的batch
|
||||
# YOLOv11 (假设)
|
||||
'YOLO11n-seg': {'weights': 'yolo11n-seg.pt', 'image_size': 640, 'batch_size': 16},
|
||||
'YOLO11s-seg': {'weights': 'yolo11s-seg.pt', 'image_size': 640, 'batch_size': 16},
|
||||
'YOLO11m-seg': {'weights': 'yolo11m-seg.pt', 'image_size': 640, 'batch_size': 16},
|
||||
'YOLO11l-seg': {'weights': 'yolo11l-seg.pt', 'image_size': 640, 'batch_size': 16}, # 640,16,12.5GB
|
||||
'YOLO11x-seg': {'weights': 'yolo11x-seg.pt', 'image_size': 640, 'batch_size': 16}, # 640,16,19.5GB
|
||||
# YOLOv12 (假设)
|
||||
'YOLO12-seg': {'weights': str(Path(__file__).parent / 'yolo12-seg.yaml'), 'image_size': 640, 'batch_size': 16}, # 640,16,3GB
|
||||
}
|
||||
|
||||
# --- 新增:根据选择自动加载配置 ---
|
||||
# 检查所选模型是否存在于配置字典中
|
||||
if SELECTED_MODEL in MODEL_CONFIGS:
|
||||
config = MODEL_CONFIGS[SELECTED_MODEL]
|
||||
MODEL_WEIGHTS = config['weights']
|
||||
IMAGE_SIZE = config['image_size']
|
||||
BATCH_SIZE = config['batch_size']
|
||||
print(f"--- 模型配置加载成功 ---")
|
||||
print(f"模型名称: {SELECTED_MODEL}")
|
||||
print(f"权重文件: {MODEL_WEIGHTS}")
|
||||
print(f"图像大小: {IMAGE_SIZE}")
|
||||
print(f"批处理大小: {BATCH_SIZE}")
|
||||
print(f"------------------------")
|
||||
else:
|
||||
print(f"错误: 所选模型 '{SELECTED_MODEL}' 不在配置列表中。")
|
||||
print("可用模型包括: ", list(MODEL_CONFIGS.keys()))
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
# --- 4. 训练超参数 (Training Hyperparameters) ---
|
||||
EPOCHS = 300 # 训练轮次
|
||||
PATIENCE = 100 # 提前停止训练的轮数
|
||||
SAVE_PERIOD = 10 # 每隔多少轮保存一次模型
|
||||
# BATCH_SIZE 已在上面根据模型自动设置
|
||||
LEARNING_RATE = 0.01 # 初始学习率
|
||||
OPTIMIZER = 'Adam' # 优化器 (TODO 'SGD', 'Adam', 'AdamW', 'auto' TODO)
|
||||
WORKERS = 4 # 数据加载的工作线程数
|
||||
# --- 动态设备选择 (Dynamic Device Selection) ---
|
||||
def get_auto_device():
|
||||
"""自动检测并返回最合适的设备"""
|
||||
if torch.cuda.is_available():
|
||||
gpu_count = torch.cuda.device_count()
|
||||
if gpu_count > 1:
|
||||
# 如果有多个GPU,使用所有GPU
|
||||
device_str = ",".join(str(i) for i in range(gpu_count))
|
||||
print(f"检测到 {gpu_count} 个可用的GPU。将使用所有GPU: {device_str}")
|
||||
return device_str
|
||||
else:
|
||||
# 如果只有1个GPU
|
||||
print("检测到 1 个可用的GPU。将使用 GPU: 0")
|
||||
return '0'
|
||||
else:
|
||||
# 如果没有可用的GPU,使用CPU
|
||||
print("未检测到可用的GPU。将使用CPU。")
|
||||
return 'cpu'
|
||||
DEVICE = get_auto_device() # 调用函数来设置设备
|
||||
|
||||
# --- 5. 预测设置 (Prediction Settings) ---
|
||||
|
||||
SHOW_LABELS = True # 是否在预测结果上显示类别标签
|
||||
SHOW_CONF = True # 是否在预测结果上显示置信度和边界框
|
||||
SAVE_PREDICTIONS = True # 是否保存预测结果图像
|
||||
113
Seg_Predict_YoloModel/yolo_train.py
Normal file
113
Seg_Predict_YoloModel/yolo_train.py
Normal file
@@ -0,0 +1,113 @@
|
||||
import argparse, shutil
|
||||
import logging
|
||||
from ultralytics import YOLO
|
||||
from datetime import datetime
|
||||
import yolo_config as config
|
||||
import gc, torch
|
||||
|
||||
# --- 日志设置 ---
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s [%(levelname)s] %(message)s",
|
||||
handlers=[logging.StreamHandler()]
|
||||
)
|
||||
|
||||
def train_model(model_key: str):
|
||||
"""
|
||||
根据给定的模型密钥训练YOLO分割模型。
|
||||
|
||||
Args:
|
||||
model_key (str): 在 yolo_config.MODEL_CONFIGS 中定义的模型密钥。
|
||||
"""
|
||||
# 生成时间戳以区分不同训练运行
|
||||
timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
||||
project_folder = f"{model_key}_{timestamp}"
|
||||
|
||||
if model_key not in config.MODEL_CONFIGS:
|
||||
logging.error(f"错误:模型 '{model_key}' 不在 yolo_config.py 中定义。")
|
||||
logging.info(f"可用模型: {list(config.MODEL_CONFIGS.keys())}")
|
||||
return
|
||||
|
||||
model_config = config.MODEL_CONFIGS[model_key]
|
||||
model_file = model_config['weights']
|
||||
logging.info(f"开始训练模型: {model_key} ({model_file})")
|
||||
|
||||
try:
|
||||
# --- 1. 加载模型 ---
|
||||
model = YOLO(model_file) # 如果没有模型会自动下载
|
||||
logging.info("模型加载成功。")
|
||||
|
||||
# --- 2. 模型训练 ---
|
||||
logging.info(f"数据集配置文件: {config.DATASET_YAML_PATH}")
|
||||
logging.info(f"训练参数: Epochs={config.EPOCHS}, Batch Size={config.BATCH_SIZE}, Img Size={config.IMAGE_SIZE}")
|
||||
|
||||
model.train(
|
||||
data=str(config.DATASET_YAML_PATH),
|
||||
epochs=config.EPOCHS,
|
||||
imgsz=config.IMAGE_SIZE,
|
||||
batch=config.BATCH_SIZE,
|
||||
optimizer=config.OPTIMIZER,
|
||||
lr0=config.LEARNING_RATE,
|
||||
device=config.DEVICE,
|
||||
project=str(config.OUTPUTS_DIR), # 指定输出的根目录
|
||||
name=project_folder, # 指定本次训练的项目名
|
||||
workers=config.WORKERS,
|
||||
exist_ok=True, # 如果项目已存在,则覆盖
|
||||
patience=config.PATIENCE, # 提前停止训练的轮数
|
||||
save_period=config.SAVE_PERIOD # 每隔多少轮保存一次模型
|
||||
)
|
||||
|
||||
logging.info("模型训练完成。")
|
||||
logging.info(f"训练结果保存在: {config.OUTPUTS_DIR / project_folder}")
|
||||
return project_folder
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"训练过程中发生错误: {e}")
|
||||
return None
|
||||
|
||||
# --- 3. 复制结果到硬盘并删除原文件夹 (新增逻辑) ---
|
||||
def move_results_to_hardisk(project_folder: str):
|
||||
source_dir = config.OUTPUTS_DIR / project_folder
|
||||
# 确保目标硬盘目录存在
|
||||
outputs_folder_name = config.OUTPUTS_DIR.name
|
||||
destination_dir = config.HARDISK_DIR / outputs_folder_name / project_folder
|
||||
destination_dir.parent.mkdir(parents=True, exist_ok=True)
|
||||
logging.info(f"准备将结果从 {source_dir} 移动到 {destination_dir}...")
|
||||
try:
|
||||
# 步骤 1: 复制文件夹
|
||||
shutil.copytree(source_dir, destination_dir)
|
||||
logging.info(f"成功复制结果到: {destination_dir}")
|
||||
|
||||
# 步骤 2: 复制成功后,删除原文件夹
|
||||
logging.info(f"正在删除原文件夹: {source_dir}")
|
||||
shutil.rmtree(source_dir)
|
||||
logging.info("成功删除原文件夹。")
|
||||
logging.info(f"任务完成,最终结果已保存至: {destination_dir}")
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"移动文件夹时发生错误: {e}")
|
||||
logging.warning(f"原始训练结果仍保留在: {source_dir}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 创建命令行参数解析器
|
||||
parser = argparse.ArgumentParser(description="训练 YOLO 分割模型。")
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
required=True,
|
||||
choices=list(config.MODEL_CONFIGS.keys()),
|
||||
help=f"选择要训练的模型。可选: {list(config.MODEL_CONFIGS.keys())}"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# 1. 开始训练
|
||||
project_folder = train_model(args.model)
|
||||
# 2. 训练完成后,移动结果到硬盘
|
||||
if project_folder:
|
||||
move_results_to_hardisk(project_folder)
|
||||
else:
|
||||
logging.error("由于训练失败,未执行结果移动操作。")
|
||||
# 3. 释放资源
|
||||
torch.cuda.empty_cache() # 释放 CUDA 显存
|
||||
gc.collect() # 释放系统内存
|
||||
logging.info("已清理 CUDA 显存与系统内存。")
|
||||
Reference in New Issue
Block a user