Files
Seg_Data_Server/Seg_Predict_YoloModel/yolo_config.py
2026-05-20 15:05:35 +08:00

130 lines
6.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import yaml, sys
import torch
from pathlib import Path
# --- 1. 核心目录设置 (Core Directories) ---
HARDISK_DIR = Path.home() / "Desktop" / "Seg" / "Hardisk" # 硬盘根目录
BASE_DIR = Path(__file__).parent.parent # 当前脚本所在目录上级
DATASET_YAML_PATH = Path(__file__).parent / "dataset.yaml" # 数据集的 YAML 配置文件路径
OUTPUTS_DIR = BASE_DIR / "DataSet_Public_outputs" # 输出结果目录 # 最终为OUTPUTS_DIR / dataset_name
TEST_IMAGE_DIR = None # 初始化 TEST_IMAGE_DIR # 测试文件地址 dataset.path / TODO "val" / "test" TODO
# try:
# 3. 读取并解析 YAML 文件
with open(DATASET_YAML_PATH, 'r', encoding='utf-8') as f:
yaml_data = yaml.safe_load(f)
# 4. 从解析后的数据中获取 'path' 的值
relative_path_from_yaml = yaml_data.get('path')
test_path_from_yaml = yaml_data.get('test')
dataset_name = Path(relative_path_from_yaml).name
if relative_path_from_yaml:
# 5. 【核心步骤】构建绝对路径
# relative_path_from_yaml 是相对于 .yaml 文件本身的路径。
# 所以,我们需要获取 .yaml 文件所在的目录,然后与这个相对路径拼接。
yaml_file_directory = DATASET_YAML_PATH.parent
TEST_IMAGE_DIR = (DATASET_YAML_PATH.parent / relative_path_from_yaml / test_path_from_yaml).resolve() # 这里val 或 test在dataset.yaml中定义
# 使用 .exists() 方法来检查路径是否存在
if not TEST_IMAGE_DIR.exists():
# 如果路径不存在,则执行这里的代码
print(f"警告: 测试图片目录不存在: {TEST_IMAGE_DIR}")
sys.exit(1)
# 6. 获取OUTPUTS_DIR
if dataset_name and dataset_name not in {'.', '..'}:
dataset_name_ = dataset_name + "-Yolo"
OUTPUTS_DIR = OUTPUTS_DIR / dataset_name_
print(f"设定输出路径为: '{str(OUTPUTS_DIR)}'")
else:
print(f"警告: 提取的dataset_name: '{dataset_name}' 无效(为空、'.''..')。")
sys.exit(1)
else:
print(f"警告: 在 '{DATASET_YAML_PATH}' 文件中没有找到 'path' 键。")
sys.exit(1)
# 4. 从解析后的数据中获取 'path' 的值
relative_path_from_yaml = yaml_data.get('path')
# except FileNotFoundError:
# print(f"错误: YAML 配置文件未找到: '{DATASET_YAML_PATH}'")
# except Exception as e:
# print(f"读取或解析 YAML 文件时发生错误: {e}")
# ==============================================================================
# --- 新增:模型选择 (只需修改这里) ---
# 你想要使用的模型名称,从下面的 MODEL_CONFIGS 字典中选择一个键。
SELECTED_MODEL = 'YOLOv9e-seg'
# ==============================================================================
# --- 修改:模型、图像大小和批处理大小的集成配置 ---
# 将所有模型的配置(权重、图像大小、批处理大小)集中管理
MODEL_CONFIGS = {
# YOLOv8
'YOLOv8n-seg': {'weights': 'yolov8n-seg.pt', 'image_size': 640, 'batch_size': 16},
'YOLOv8s-seg': {'weights': 'yolov8s-seg.pt', 'image_size': 640, 'batch_size': 16},
'YOLOv8m-seg': {'weights': 'yolov8m-seg.pt', 'image_size': 640, 'batch_size': 16},
'YOLOv8l-seg': {'weights': 'yolov8l-seg.pt', 'image_size': 640, 'batch_size': 16}, # 6401612.5GB
'YOLOv8x-seg': {'weights': 'yolov8x-seg.pt', 'image_size': 640, 'batch_size': 16}, # 6401615.5GB # 示例X模型使用1280分辨率
# YOLOv9
'YOLOv9c-seg': {'weights': 'yolov9c-seg.pt', 'image_size': 640, 'batch_size': 16}, # 6401613GB
'YOLOv9e-seg': {'weights': 'yolov9e-seg.pt', 'image_size': 640, 'batch_size': 8}, # 64016内存超了 # 示例E模型最大使用1280分辨率和更小的batch
# YOLOv11 (假设)
'YOLO11n-seg': {'weights': 'yolo11n-seg.pt', 'image_size': 640, 'batch_size': 16},
'YOLO11s-seg': {'weights': 'yolo11s-seg.pt', 'image_size': 640, 'batch_size': 16},
'YOLO11m-seg': {'weights': 'yolo11m-seg.pt', 'image_size': 640, 'batch_size': 16},
'YOLO11l-seg': {'weights': 'yolo11l-seg.pt', 'image_size': 640, 'batch_size': 16}, # 6401612.5GB
'YOLO11x-seg': {'weights': 'yolo11x-seg.pt', 'image_size': 640, 'batch_size': 16}, # 6401619.5GB
# YOLOv12 (假设)
'YOLO12-seg': {'weights': str(Path(__file__).parent / 'yolo12-seg.yaml'), 'image_size': 640, 'batch_size': 16}, # 640163GB
}
# --- 新增:根据选择自动加载配置 ---
# 检查所选模型是否存在于配置字典中
if SELECTED_MODEL in MODEL_CONFIGS:
config = MODEL_CONFIGS[SELECTED_MODEL]
MODEL_WEIGHTS = config['weights']
IMAGE_SIZE = config['image_size']
BATCH_SIZE = config['batch_size']
print(f"--- 模型配置加载成功 ---")
print(f"模型名称: {SELECTED_MODEL}")
print(f"权重文件: {MODEL_WEIGHTS}")
print(f"图像大小: {IMAGE_SIZE}")
print(f"批处理大小: {BATCH_SIZE}")
print(f"------------------------")
else:
print(f"错误: 所选模型 '{SELECTED_MODEL}' 不在配置列表中。")
print("可用模型包括: ", list(MODEL_CONFIGS.keys()))
sys.exit(1)
# --- 4. 训练超参数 (Training Hyperparameters) ---
EPOCHS = 300 # 训练轮次
PATIENCE = 100 # 提前停止训练的轮数
SAVE_PERIOD = 10 # 每隔多少轮保存一次模型
# BATCH_SIZE 已在上面根据模型自动设置
LEARNING_RATE = 0.01 # 初始学习率
OPTIMIZER = 'Adam' # 优化器 (TODO 'SGD', 'Adam', 'AdamW', 'auto' TODO)
WORKERS = 4 # 数据加载的工作线程数
# --- 动态设备选择 (Dynamic Device Selection) ---
def get_auto_device():
"""自动检测并返回最合适的设备"""
if torch.cuda.is_available():
gpu_count = torch.cuda.device_count()
if gpu_count > 1:
# 如果有多个GPU使用所有GPU
device_str = ",".join(str(i) for i in range(gpu_count))
print(f"检测到 {gpu_count} 个可用的GPU。将使用所有GPU: {device_str}")
return device_str
else:
# 如果只有1个GPU
print("检测到 1 个可用的GPU。将使用 GPU: 0")
return '0'
else:
# 如果没有可用的GPU使用CPU
print("未检测到可用的GPU。将使用CPU。")
return 'cpu'
DEVICE = get_auto_device() # 调用函数来设置设备
# --- 5. 预测设置 (Prediction Settings) ---
SHOW_LABELS = True # 是否在预测结果上显示类别标签
SHOW_CONF = True # 是否在预测结果上显示置信度和边界框
SAVE_PREDICTIONS = True # 是否保存预测结果图像