#!/bin/bash # ================================================================= # YOLO 模型批量并行预测脚本 # ================================================================= # - 此脚本会自动为每个模型架构查找其训练好的 'best.pt' # - 使用 'echo "1" |' 来自动选择找到的第一个训练版本 # - 在不同的指定GPU上并行执行预测任务 # ================================================================= # --- 1. Conda 环境设置 --- CONDA_BASE_PATH="/home/wkmgc/miniconda3" # <--- 在这里修改为您自己的 Conda 路径 CONDA_ENV_NAME="${SEG_CONDA_ENV:-seg_smp}" # 可用 SEG_CONDA_ENV=SMP bash yolo_predict.sh 临时覆盖 pt_name="best.pt" # <--- 在这里修改为您想使用的权重文件名,例如 "best.pt" 或 "epoch100.pt" conf_threshold=0.2 # <--- [新增] 默认的置信度阈值 heatmap_method="None" # <--- [!! 新增 !!] 默认不运行热度图 # 循环解析参数 while [[ $# -gt 0 ]]; do key="$1" case $key in --pt_name) if [ -n "$2" ] && [[ "$2" != -* ]]; then pt_name="$2" shift # 移过 --pt_name shift # 移过它的值 else echo "错误: --pt_name 参数需要一个值。" >&2 exit 1 fi ;; --conf) if [ -n "$2" ] && [[ "$2" != -* ]]; then conf_threshold="$2" shift # 移过 --conf shift # 移过它的值 else echo "错误: --conf 参数需要一个值。" >&2 exit 1 fi ;; --heatmap_method) if [ -n "$2" ] && [[ "$2" != -* ]]; then heatmap_method="$2" shift # 移过 --heatmap_method shift # 移过它的值 else echo "错误: --heatmap_method 参数需要一个值 (例如 'GradCAM' 或 'All')。" >&2 exit 1 fi ;; *) # 移过未知参数,不报错 shift ;; esac done # 初始化并激活 Conda 环境 if [ -f "${CONDA_BASE_PATH}/etc/profile.d/conda.sh" ]; then source "${CONDA_BASE_PATH}/etc/profile.d/conda.sh" conda activate "${CONDA_ENV_NAME}" if [ $? -ne 0 ]; then echo "错误: 激活 Conda 环境 '${CONDA_ENV_NAME}' 失败!" exit 1 fi echo "Conda 环境 '${CONDA_ENV_NAME}' 已成功激活。" else echo "错误: 找不到 conda.sh 脚本。请检查您的 CONDA_BASE_PATH 设置是否正确。" exit 1 fi # --- 2. 模型与 GPU 配置 --- # 此处的分组应与 train.sh 保持一致,以确保能正确找到模型并分配资源 GPUS_GROUP_0="0" GPUS_GROUP_1="1" GPUS_GROUP_2="2" GPUS_GROUP_3="3" # TODO # GPUS_GROUP_4="0" GPUS_GROUP_5="1" GPUS_GROUP_6="2" GPUS_GROUP_7="3" # 从 yolo_config.py/train.sh 中选择的模型列表 GROUP_0_MODELS=("YOLO11l-seg") GROUP_1_MODELS=("YOLOv8n-seg" "YOLOv8m-seg") GROUP_2_MODELS=("YOLO11n-seg" "YOLO11s-seg" "YOLO11m-seg") GROUP_3_MODELS=("YOLOv9e-seg") GROUP_4_MODELS=("YOLO11x-seg") GROUP_5_MODELS=("YOLOv9c-seg" "YOLOv8s-seg") GROUP_6_MODELS=("YOLOv8l-seg" "YOLO12-seg") GROUP_7_MODELS=("YOLOv8x-seg") # 1. 从 config.py 中读取 PREDICT_BEST_MODEL_DIR 的值 PREDICT_BEST_MODEL_DIR=$(python -c "from yolo_config import PREDICT_BEST_MODEL_DIR; print(PREDICT_BEST_MODEL_DIR)") # 检查是否成功获取了 PREDICT_BEST_MODEL_DIR if [ -z "$PREDICT_BEST_MODEL_DIR" ] || [ ! -e "$PREDICT_BEST_MODEL_DIR" ]; then echo "PREDICT_BEST_MODEL_DIR: $PREDICT_BEST_MODEL_DIR" echo "Error: Could not read PREDICT_BEST_MODEL_DIR from yolo_config.py. Exiting." echo "Error 2: Or the directory specified by PREDICT_BEST_MODEL_DIR does not exist. Please create it first." exit 1 fi # 2. 定义带有时间戳的日志目录名 LOG_DIR_NAME="yolo_predict_logs_parallel_$(date +%Y-%m-%d_%H-%M-%S)" # 3. 拼接成最终的完整路径 LOG_DIR="$PREDICT_BEST_MODEL_DIR/$LOG_DIR_NAME" mkdir -p "${LOG_DIR}" echo "所有模型的预测日志将保存在 ./${LOG_DIR}/ 目录中。" echo "----------------------------------------------------" # --- 3. 预测执行函数 --- # 定义一个函数来启动一组预测,以避免代码重复 start_prediction_group() { # 使用 nameref (引用) 来传递数组 local -n models=$1 local gpus=$2 local group_name=$3 echo ">>> 准备启动 ${group_name} 的预测任务 (后台运行)..." # 遍历指定组中的所有模型 for model_key in "${models[@]}"; do if [ "${heatmap_method}" == "None" ]; then # --- 模式 1: 运行标准预测 (原有逻辑) --- echo " -> 正在后台启动 [标准预测]: ${model_key} on GPUs: ${gpus}" # [注意] 我为您添加了 --conf 参数,您原有的脚本 没有传递它 echo "1" | CUDA_VISIBLE_DEVICES=${gpus} python yolo_predict_V2.py --model "${model_key}" --pt_name "${pt_name}" --conf "${conf_threshold}" > "${LOG_DIR}/${model_key}_predict.log" 2>&1 & echo " - 模型 ${model_key} 的预测已在后台启动。日志文件: ${LOG_DIR}/${model_key}_predict.log" else # --- 模式 2: 运行热度图可视化 --- echo " -> 正在后台启动 [热度图可视化]: ${model_key} on GPUs: ${gpus} (Method: ${heatmap_method})" # [注意] 我们使用 yolo_predict_visualize_nn.py 并传递新参数 echo "1" | CUDA_VISIBLE_DEVICES=${gpus} python yolo_predict_visualize_nn.py --model "${model_key}" --target_layers "default" --cam_method "${heatmap_method}" --pt_name "${pt_name}" > "${LOG_DIR}/${model_key}_heatmap.log" 2>&1 & echo " - 模型 ${model_key} 的热度图已在后台启动。日志文件: ${LOG_DIR}/${model_key}_heatmap.log" fi echo " - 等待 5 秒,确保 GPU 资源稳定分配..." sleep 5 done echo ">>> ${group_name} 的所有模型均已启动。" echo "----------------------------------------------------" } # --- 4. 依次启动所有预测任务 --- # 脚本将快速地按顺序启动每一组任务到后台 start_prediction_group GROUP_0_MODELS "${GPUS_GROUP_0}" "第零组" start_prediction_group GROUP_1_MODELS "${GPUS_GROUP_1}" "第一组" start_prediction_group GROUP_2_MODELS "${GPUS_GROUP_2}" "第二组" start_prediction_group GROUP_3_MODELS "${GPUS_GROUP_3}" "第三组" start_prediction_group GROUP_4_MODELS "${GPUS_GROUP_4}" "第四组" start_prediction_group GROUP_5_MODELS "${GPUS_GROUP_5}" "第五组" start_prediction_group GROUP_6_MODELS "${GPUS_GROUP_6}" "第六组" start_prediction_group GROUP_7_MODELS "${GPUS_GROUP_7}" "第七组" # --- 5. 等待所有后台任务完成 --- echo "" echo "--- 所有模型均已在后台启动。现在等待所有预测任务完成... ---" # 'wait' 命令会暂停脚本,直到所有由此脚本启动的后台子进程全部执行完毕 wait echo "--- 所有后台预测任务已全部完成! ---" # --- 6. 退出脚本 --- echo "预测流程结束。" conda deactivate echo "已取消激活 Conda 环境。"