first commit
This commit is contained in:
130
Seg_Predict_YoloModel/yolo_config.py
Normal file
130
Seg_Predict_YoloModel/yolo_config.py
Normal file
@@ -0,0 +1,130 @@
|
||||
import yaml, sys
|
||||
import torch
|
||||
from pathlib import Path
|
||||
|
||||
# --- 1. 核心目录设置 (Core Directories) ---
|
||||
HARDISK_DIR = Path.home() / "Desktop" / "Seg" / "Hardisk" # 硬盘根目录
|
||||
BASE_DIR = Path(__file__).parent.parent # 当前脚本所在目录上级
|
||||
DATASET_YAML_PATH = Path(__file__).parent / "dataset.yaml" # 数据集的 YAML 配置文件路径
|
||||
OUTPUTS_DIR = BASE_DIR / "DataSet_Public_outputs" # 输出结果目录 # 最终为:OUTPUTS_DIR / dataset_name
|
||||
|
||||
TEST_IMAGE_DIR = None # 初始化 TEST_IMAGE_DIR # 测试文件地址 dataset.path / TODO "val" / "test" TODO
|
||||
# try:
|
||||
# 3. 读取并解析 YAML 文件
|
||||
with open(DATASET_YAML_PATH, 'r', encoding='utf-8') as f:
|
||||
yaml_data = yaml.safe_load(f)
|
||||
# 4. 从解析后的数据中获取 'path' 的值
|
||||
relative_path_from_yaml = yaml_data.get('path')
|
||||
test_path_from_yaml = yaml_data.get('test')
|
||||
dataset_name = Path(relative_path_from_yaml).name
|
||||
if relative_path_from_yaml:
|
||||
# 5. 【核心步骤】构建绝对路径
|
||||
# relative_path_from_yaml 是相对于 .yaml 文件本身的路径。
|
||||
# 所以,我们需要获取 .yaml 文件所在的目录,然后与这个相对路径拼接。
|
||||
yaml_file_directory = DATASET_YAML_PATH.parent
|
||||
TEST_IMAGE_DIR = (DATASET_YAML_PATH.parent / relative_path_from_yaml / test_path_from_yaml).resolve() # 这里val 或 test在dataset.yaml中定义
|
||||
# 使用 .exists() 方法来检查路径是否存在
|
||||
if not TEST_IMAGE_DIR.exists():
|
||||
# 如果路径不存在,则执行这里的代码
|
||||
print(f"警告: 测试图片目录不存在: {TEST_IMAGE_DIR}")
|
||||
sys.exit(1)
|
||||
# 6. 获取OUTPUTS_DIR
|
||||
if dataset_name and dataset_name not in {'.', '..'}:
|
||||
dataset_name_ = dataset_name + "-Yolo"
|
||||
OUTPUTS_DIR = OUTPUTS_DIR / dataset_name_
|
||||
print(f"设定输出路径为: '{str(OUTPUTS_DIR)}'")
|
||||
else:
|
||||
print(f"警告: 提取的dataset_name: '{dataset_name}' 无效(为空、'.' 或 '..')。")
|
||||
sys.exit(1)
|
||||
else:
|
||||
print(f"警告: 在 '{DATASET_YAML_PATH}' 文件中没有找到 'path' 键。")
|
||||
sys.exit(1)
|
||||
# 4. 从解析后的数据中获取 'path' 的值
|
||||
relative_path_from_yaml = yaml_data.get('path')
|
||||
# except FileNotFoundError:
|
||||
# print(f"错误: YAML 配置文件未找到: '{DATASET_YAML_PATH}'")
|
||||
# except Exception as e:
|
||||
# print(f"读取或解析 YAML 文件时发生错误: {e}")
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# --- 新增:模型选择 (只需修改这里) ---
|
||||
# 你想要使用的模型名称,从下面的 MODEL_CONFIGS 字典中选择一个键。
|
||||
SELECTED_MODEL = 'YOLOv9e-seg'
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
# --- 修改:模型、图像大小和批处理大小的集成配置 ---
|
||||
# 将所有模型的配置(权重、图像大小、批处理大小)集中管理
|
||||
MODEL_CONFIGS = {
|
||||
# YOLOv8
|
||||
'YOLOv8n-seg': {'weights': 'yolov8n-seg.pt', 'image_size': 640, 'batch_size': 16},
|
||||
'YOLOv8s-seg': {'weights': 'yolov8s-seg.pt', 'image_size': 640, 'batch_size': 16},
|
||||
'YOLOv8m-seg': {'weights': 'yolov8m-seg.pt', 'image_size': 640, 'batch_size': 16},
|
||||
'YOLOv8l-seg': {'weights': 'yolov8l-seg.pt', 'image_size': 640, 'batch_size': 16}, # 640,16,12.5GB
|
||||
'YOLOv8x-seg': {'weights': 'yolov8x-seg.pt', 'image_size': 640, 'batch_size': 16}, # 640,16,15.5GB # 示例:X模型使用1280分辨率
|
||||
# YOLOv9
|
||||
'YOLOv9c-seg': {'weights': 'yolov9c-seg.pt', 'image_size': 640, 'batch_size': 16}, # 640,16,13GB
|
||||
'YOLOv9e-seg': {'weights': 'yolov9e-seg.pt', 'image_size': 640, 'batch_size': 8}, # 640,16,内存超了 # 示例:E模型(最大)使用1280分辨率和更小的batch
|
||||
# YOLOv11 (假设)
|
||||
'YOLO11n-seg': {'weights': 'yolo11n-seg.pt', 'image_size': 640, 'batch_size': 16},
|
||||
'YOLO11s-seg': {'weights': 'yolo11s-seg.pt', 'image_size': 640, 'batch_size': 16},
|
||||
'YOLO11m-seg': {'weights': 'yolo11m-seg.pt', 'image_size': 640, 'batch_size': 16},
|
||||
'YOLO11l-seg': {'weights': 'yolo11l-seg.pt', 'image_size': 640, 'batch_size': 16}, # 640,16,12.5GB
|
||||
'YOLO11x-seg': {'weights': 'yolo11x-seg.pt', 'image_size': 640, 'batch_size': 16}, # 640,16,19.5GB
|
||||
# YOLOv12 (假设)
|
||||
'YOLO12-seg': {'weights': str(Path(__file__).parent / 'yolo12-seg.yaml'), 'image_size': 640, 'batch_size': 16}, # 640,16,3GB
|
||||
}
|
||||
|
||||
# --- 新增:根据选择自动加载配置 ---
|
||||
# 检查所选模型是否存在于配置字典中
|
||||
if SELECTED_MODEL in MODEL_CONFIGS:
|
||||
config = MODEL_CONFIGS[SELECTED_MODEL]
|
||||
MODEL_WEIGHTS = config['weights']
|
||||
IMAGE_SIZE = config['image_size']
|
||||
BATCH_SIZE = config['batch_size']
|
||||
print(f"--- 模型配置加载成功 ---")
|
||||
print(f"模型名称: {SELECTED_MODEL}")
|
||||
print(f"权重文件: {MODEL_WEIGHTS}")
|
||||
print(f"图像大小: {IMAGE_SIZE}")
|
||||
print(f"批处理大小: {BATCH_SIZE}")
|
||||
print(f"------------------------")
|
||||
else:
|
||||
print(f"错误: 所选模型 '{SELECTED_MODEL}' 不在配置列表中。")
|
||||
print("可用模型包括: ", list(MODEL_CONFIGS.keys()))
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
# --- 4. 训练超参数 (Training Hyperparameters) ---
|
||||
EPOCHS = 300 # 训练轮次
|
||||
PATIENCE = 100 # 提前停止训练的轮数
|
||||
SAVE_PERIOD = 10 # 每隔多少轮保存一次模型
|
||||
# BATCH_SIZE 已在上面根据模型自动设置
|
||||
LEARNING_RATE = 0.01 # 初始学习率
|
||||
OPTIMIZER = 'Adam' # 优化器 (TODO 'SGD', 'Adam', 'AdamW', 'auto' TODO)
|
||||
WORKERS = 4 # 数据加载的工作线程数
|
||||
# --- 动态设备选择 (Dynamic Device Selection) ---
|
||||
def get_auto_device():
|
||||
"""自动检测并返回最合适的设备"""
|
||||
if torch.cuda.is_available():
|
||||
gpu_count = torch.cuda.device_count()
|
||||
if gpu_count > 1:
|
||||
# 如果有多个GPU,使用所有GPU
|
||||
device_str = ",".join(str(i) for i in range(gpu_count))
|
||||
print(f"检测到 {gpu_count} 个可用的GPU。将使用所有GPU: {device_str}")
|
||||
return device_str
|
||||
else:
|
||||
# 如果只有1个GPU
|
||||
print("检测到 1 个可用的GPU。将使用 GPU: 0")
|
||||
return '0'
|
||||
else:
|
||||
# 如果没有可用的GPU,使用CPU
|
||||
print("未检测到可用的GPU。将使用CPU。")
|
||||
return 'cpu'
|
||||
DEVICE = get_auto_device() # 调用函数来设置设备
|
||||
|
||||
# --- 5. 预测设置 (Prediction Settings) ---
|
||||
|
||||
SHOW_LABELS = True # 是否在预测结果上显示类别标签
|
||||
SHOW_CONF = True # 是否在预测结果上显示置信度和边界框
|
||||
SAVE_PREDICTIONS = True # 是否保存预测结果图像
|
||||
Reference in New Issue
Block a user