first commit
This commit is contained in:
167
Seg_All_In_One_YoloModel/yolo_predict.sh
Normal file
167
Seg_All_In_One_YoloModel/yolo_predict.sh
Normal file
@@ -0,0 +1,167 @@
|
||||
#!/bin/bash
|
||||
|
||||
# =================================================================
|
||||
# YOLO 模型批量并行预测脚本
|
||||
# =================================================================
|
||||
# - 此脚本会自动为每个模型架构查找其训练好的 'best.pt'
|
||||
# - 使用 'echo "1" |' 来自动选择找到的第一个训练版本
|
||||
# - 在不同的指定GPU上并行执行预测任务
|
||||
# =================================================================
|
||||
|
||||
# --- 1. Conda 环境设置 ---
|
||||
CONDA_BASE_PATH="/home/wkmgc/miniconda3" # <--- 在这里修改为您自己的 Conda 路径
|
||||
CONDA_ENV_NAME="${SEG_CONDA_ENV:-seg_smp}" # 可用 SEG_CONDA_ENV=SMP bash yolo_predict.sh 临时覆盖
|
||||
pt_name="best.pt" # <--- 在这里修改为您想使用的权重文件名,例如 "best.pt" 或 "epoch100.pt"
|
||||
conf_threshold=0.2 # <--- [新增] 默认的置信度阈值
|
||||
heatmap_method="None" # <--- [!! 新增 !!] 默认不运行热度图
|
||||
|
||||
# 循环解析参数
|
||||
while [[ $# -gt 0 ]]; do
|
||||
key="$1"
|
||||
case $key in
|
||||
--pt_name)
|
||||
if [ -n "$2" ] && [[ "$2" != -* ]]; then
|
||||
pt_name="$2"
|
||||
shift # 移过 --pt_name
|
||||
shift # 移过它的值
|
||||
else
|
||||
echo "错误: --pt_name 参数需要一个值。" >&2
|
||||
exit 1
|
||||
fi
|
||||
;;
|
||||
--conf)
|
||||
if [ -n "$2" ] && [[ "$2" != -* ]]; then
|
||||
conf_threshold="$2"
|
||||
shift # 移过 --conf
|
||||
shift # 移过它的值
|
||||
else
|
||||
echo "错误: --conf 参数需要一个值。" >&2
|
||||
exit 1
|
||||
fi
|
||||
;;
|
||||
--heatmap_method)
|
||||
if [ -n "$2" ] && [[ "$2" != -* ]]; then
|
||||
heatmap_method="$2"
|
||||
shift # 移过 --heatmap_method
|
||||
shift # 移过它的值
|
||||
else
|
||||
echo "错误: --heatmap_method 参数需要一个值 (例如 'GradCAM' 或 'All')。" >&2
|
||||
exit 1
|
||||
fi
|
||||
;;
|
||||
*)
|
||||
# 移过未知参数,不报错
|
||||
shift
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
# 初始化并激活 Conda 环境
|
||||
if [ -f "${CONDA_BASE_PATH}/etc/profile.d/conda.sh" ]; then
|
||||
source "${CONDA_BASE_PATH}/etc/profile.d/conda.sh"
|
||||
conda activate "${CONDA_ENV_NAME}"
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "错误: 激活 Conda 环境 '${CONDA_ENV_NAME}' 失败!"
|
||||
exit 1
|
||||
fi
|
||||
echo "Conda 环境 '${CONDA_ENV_NAME}' 已成功激活。"
|
||||
else
|
||||
echo "错误: 找不到 conda.sh 脚本。请检查您的 CONDA_BASE_PATH 设置是否正确。"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# --- 2. 模型与 GPU 配置 ---
|
||||
# 此处的分组应与 train.sh 保持一致,以确保能正确找到模型并分配资源
|
||||
GPUS_GROUP_0="0"
|
||||
GPUS_GROUP_1="1"
|
||||
GPUS_GROUP_2="2"
|
||||
GPUS_GROUP_3="3"
|
||||
# TODO #
|
||||
GPUS_GROUP_4="0"
|
||||
GPUS_GROUP_5="1"
|
||||
GPUS_GROUP_6="2"
|
||||
GPUS_GROUP_7="3"
|
||||
|
||||
# 从 yolo_config.py/train.sh 中选择的模型列表
|
||||
GROUP_0_MODELS=("YOLO11l-seg")
|
||||
GROUP_1_MODELS=("YOLOv8n-seg" "YOLOv8m-seg")
|
||||
GROUP_2_MODELS=("YOLO11n-seg" "YOLO11s-seg" "YOLO11m-seg")
|
||||
GROUP_3_MODELS=("YOLOv9e-seg")
|
||||
GROUP_4_MODELS=("YOLO11x-seg")
|
||||
GROUP_5_MODELS=("YOLOv9c-seg" "YOLOv8s-seg")
|
||||
GROUP_6_MODELS=("YOLOv8l-seg" "YOLO12-seg")
|
||||
GROUP_7_MODELS=("YOLOv8x-seg")
|
||||
|
||||
# 1. 从 config.py 中读取 PREDICT_BEST_MODEL_DIR 的值
|
||||
PREDICT_BEST_MODEL_DIR=$(python -c "from yolo_config import PREDICT_BEST_MODEL_DIR; print(PREDICT_BEST_MODEL_DIR)")
|
||||
# 检查是否成功获取了 PREDICT_BEST_MODEL_DIR
|
||||
if [ -z "$PREDICT_BEST_MODEL_DIR" ] || [ ! -e "$PREDICT_BEST_MODEL_DIR" ]; then
|
||||
echo "PREDICT_BEST_MODEL_DIR: $PREDICT_BEST_MODEL_DIR"
|
||||
echo "Error: Could not read PREDICT_BEST_MODEL_DIR from yolo_config.py. Exiting."
|
||||
echo "Error 2: Or the directory specified by PREDICT_BEST_MODEL_DIR does not exist. Please create it first."
|
||||
exit 1
|
||||
fi
|
||||
# 2. 定义带有时间戳的日志目录名
|
||||
LOG_DIR_NAME="yolo_predict_logs_parallel_$(date +%Y-%m-%d_%H-%M-%S)"
|
||||
# 3. 拼接成最终的完整路径
|
||||
LOG_DIR="$PREDICT_BEST_MODEL_DIR/$LOG_DIR_NAME"
|
||||
mkdir -p "${LOG_DIR}"
|
||||
echo "所有模型的预测日志将保存在 ./${LOG_DIR}/ 目录中。"
|
||||
echo "----------------------------------------------------"
|
||||
|
||||
|
||||
# --- 3. 预测执行函数 ---
|
||||
# 定义一个函数来启动一组预测,以避免代码重复
|
||||
start_prediction_group() {
|
||||
# 使用 nameref (引用) 来传递数组
|
||||
local -n models=$1
|
||||
local gpus=$2
|
||||
local group_name=$3
|
||||
|
||||
echo ">>> 准备启动 ${group_name} 的预测任务 (后台运行)..."
|
||||
# 遍历指定组中的所有模型
|
||||
for model_key in "${models[@]}"; do
|
||||
if [ "${heatmap_method}" == "None" ]; then
|
||||
# --- 模式 1: 运行标准预测 (原有逻辑) ---
|
||||
echo " -> 正在后台启动 [标准预测]: ${model_key} on GPUs: ${gpus}"
|
||||
# [注意] 我为您添加了 --conf 参数,您原有的脚本 没有传递它
|
||||
echo "1" | CUDA_VISIBLE_DEVICES=${gpus} python yolo_predict_V2.py --model "${model_key}" --pt_name "${pt_name}" --conf "${conf_threshold}" > "${LOG_DIR}/${model_key}_predict.log" 2>&1 &
|
||||
echo " - 模型 ${model_key} 的预测已在后台启动。日志文件: ${LOG_DIR}/${model_key}_predict.log"
|
||||
else
|
||||
# --- 模式 2: 运行热度图可视化 ---
|
||||
echo " -> 正在后台启动 [热度图可视化]: ${model_key} on GPUs: ${gpus} (Method: ${heatmap_method})"
|
||||
# [注意] 我们使用 yolo_predict_visualize_nn.py 并传递新参数
|
||||
echo "1" | CUDA_VISIBLE_DEVICES=${gpus} python yolo_predict_visualize_nn.py --model "${model_key}" --target_layers "default" --cam_method "${heatmap_method}" --pt_name "${pt_name}" > "${LOG_DIR}/${model_key}_heatmap.log" 2>&1 &
|
||||
echo " - 模型 ${model_key} 的热度图已在后台启动。日志文件: ${LOG_DIR}/${model_key}_heatmap.log"
|
||||
fi
|
||||
echo " - 等待 5 秒,确保 GPU 资源稳定分配..."
|
||||
sleep 5
|
||||
done
|
||||
echo ">>> ${group_name} 的所有模型均已启动。"
|
||||
echo "----------------------------------------------------"
|
||||
}
|
||||
|
||||
# --- 4. 依次启动所有预测任务 ---
|
||||
# 脚本将快速地按顺序启动每一组任务到后台
|
||||
start_prediction_group GROUP_0_MODELS "${GPUS_GROUP_0}" "第零组"
|
||||
start_prediction_group GROUP_1_MODELS "${GPUS_GROUP_1}" "第一组"
|
||||
start_prediction_group GROUP_2_MODELS "${GPUS_GROUP_2}" "第二组"
|
||||
start_prediction_group GROUP_3_MODELS "${GPUS_GROUP_3}" "第三组"
|
||||
start_prediction_group GROUP_4_MODELS "${GPUS_GROUP_4}" "第四组"
|
||||
start_prediction_group GROUP_5_MODELS "${GPUS_GROUP_5}" "第五组"
|
||||
start_prediction_group GROUP_6_MODELS "${GPUS_GROUP_6}" "第六组"
|
||||
start_prediction_group GROUP_7_MODELS "${GPUS_GROUP_7}" "第七组"
|
||||
|
||||
|
||||
# --- 5. 等待所有后台任务完成 ---
|
||||
echo ""
|
||||
echo "--- 所有模型均已在后台启动。现在等待所有预测任务完成... ---"
|
||||
# 'wait' 命令会暂停脚本,直到所有由此脚本启动的后台子进程全部执行完毕
|
||||
wait
|
||||
echo "--- 所有后台预测任务已全部完成! ---"
|
||||
|
||||
|
||||
# --- 6. 退出脚本 ---
|
||||
echo "预测流程结束。"
|
||||
conda deactivate
|
||||
echo "已取消激活 Conda 环境。"
|
||||
Reference in New Issue
Block a user