first commit
This commit is contained in:
92
.gitignore
vendored
Normal file
92
.gitignore
vendored
Normal file
@@ -0,0 +1,92 @@
|
||||
# Python caches
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
.pytest_cache/
|
||||
.mypy_cache/
|
||||
.ruff_cache/
|
||||
.ipynb_checkpoints/
|
||||
|
||||
# Conda / virtual environments
|
||||
.conda/
|
||||
.venv/
|
||||
venv/
|
||||
env/
|
||||
|
||||
# Build artifacts
|
||||
build/
|
||||
dist/
|
||||
*.egg-info/
|
||||
|
||||
# Logs and temporary files
|
||||
*.log
|
||||
.goutputstream-*
|
||||
*~
|
||||
*.tmp
|
||||
*.bak
|
||||
*.swp
|
||||
logs_parallel_*/
|
||||
predict_logs_parallel_*/
|
||||
yolo_train_logs_parallel_*/
|
||||
yolo_predict_logs_parallel_*/
|
||||
wandb/
|
||||
mlruns/
|
||||
.cache/
|
||||
tmp/
|
||||
temp/
|
||||
|
||||
# Large datasets, predictions, training outputs, and backups
|
||||
DataSet_Public/
|
||||
DataSet_Public_outputs/
|
||||
BestMode_Predict_Results_DataSet_Public/
|
||||
Hardisk/
|
||||
Nas_BackUp_Seg/
|
||||
|
||||
# Keep DataSet_Own preprocessing code/manuals, ignore actual image data
|
||||
DataSet_Own/*
|
||||
!DataSet_Own/1. 图片预处理(内含使用手册)/
|
||||
!DataSet_Own/1. 图片预处理(内含使用手册)/**
|
||||
DataSet_Own/1. 图片预处理(内含使用手册)/error*.txt
|
||||
|
||||
# Local model caches and heavy pretrained weights
|
||||
Seg_All_In_One_MMSeg/My_Local_Model/
|
||||
Seg_All_In_One_MMSeg/work_dirs/
|
||||
Seg_All_In_One_MMSeg/flops_results/
|
||||
Seg_All_In_One_MMSeg/tests/data/
|
||||
|
||||
# Generated analysis figures; keep CSV/SVG summaries
|
||||
Seg_All_In_One_Analysis/*/*.png
|
||||
|
||||
# Demo/output media and generated visual data
|
||||
Seg_Predict_Own_Video_V2/*.mp4
|
||||
Seg_Predict_YoloModel/*.mp4
|
||||
Seg_Predict_YoloModel/output_*/
|
||||
Seg_Predict_YoloModel/YOLO*/
|
||||
Seg_All_In_One_YoloModel/Yolo数据集构建/Data/
|
||||
Seg_All_In_One_YoloModel/Yolo数据集构建/Label/
|
||||
Seg_All_In_One_YoloModel/Yolo数据集构建/ORI/
|
||||
Seg_All_In_One_YoloModel/Yolo数据集构建/ORI_GT_label_fold/
|
||||
Seg_All_In_One_YoloModel/Yolo数据集构建/ORI_pro_label_fold/
|
||||
Tool-图片堆叠/ori/
|
||||
Tool-图片堆叠/label/
|
||||
Tool-图片堆叠/result_*/
|
||||
Tool-可视化/Data/
|
||||
Tool-可视化/runs/
|
||||
Tool-可视化/0_图片Labels生成/save_*_label_fold/
|
||||
Tool-可视化/0_图片Labels生成/*.png
|
||||
|
||||
# Large model/video/archive formats
|
||||
*.pt
|
||||
*.pth
|
||||
*.onnx
|
||||
*.engine
|
||||
*.mp4
|
||||
*.avi
|
||||
*.mov
|
||||
*.mkv
|
||||
*.zip
|
||||
*.tar
|
||||
*.tar.gz
|
||||
*.tgz
|
||||
*.7z
|
||||
*.rar
|
||||
104
Back_Up.sh
Normal file
104
Back_Up.sh
Normal file
@@ -0,0 +1,104 @@
|
||||
#!/usr/bin/env bash
|
||||
# sync_interactive.sh —— 交互式同步脚本 (v3)
|
||||
# 脚本会自动以其自身所在的位置为根目录,并提供不同的同步备份模式。
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
# --- 根目录设置 ---
|
||||
# 获取脚本文件所在的绝对路径,并将其作为所有操作的根目录
|
||||
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
|
||||
|
||||
|
||||
# --- 配置区域 ---
|
||||
# 算法源目录 (路径基于脚本位置)
|
||||
SRC_DIRS=(
|
||||
"$SCRIPT_DIR/Seg_All_In_One_YoloModel"
|
||||
"$SCRIPT_DIR/Seg_All_In_One_SegModel"
|
||||
"$SCRIPT_DIR/Seg_All_In_One_MMSeg"
|
||||
"$SCRIPT_DIR/Seg_All_In_One_Analysis"
|
||||
)
|
||||
# 本地镜像/中转目录 (路径基于脚本位置)
|
||||
LOCAL_DST_ROOT="$SCRIPT_DIR/Hardisk"
|
||||
# NAS备份目标目录 (路径基于脚本位置)
|
||||
NAS_DST_ROOT="$SCRIPT_DIR/Nas_BackUp_Seg"
|
||||
|
||||
|
||||
# --- 用户选择操作 ---
|
||||
echo "--- 请选择要执行的同步操作 ---"
|
||||
echo " 1. [更新并备份算法文件] 从源头更新算法文件到Hardisk,并立即备份到NAS"
|
||||
echo " 2. [备份Hardisk] 将整个Hardisk目录的当前内容,完全拷贝到NAS【--delete】"
|
||||
echo " 3. [退出] 不执行任何操作"
|
||||
echo "------------------------------------------------"
|
||||
|
||||
read -p "请输入选项 [1, 2, 或 3]: " choice
|
||||
|
||||
case $choice in
|
||||
1)
|
||||
echo "--- 您选择了 [1]: 更新并备份算法文件 (源->Hardisk->NAS) ---"
|
||||
|
||||
# --- 第 1 步: 从源目录更新文件到 Hardisk ---
|
||||
echo ""
|
||||
echo "--> (1/2) 正在从源目录更新文件到 $LOCAL_DST_ROOT..."
|
||||
for src_path in "${SRC_DIRS[@]}"; do
|
||||
if [ ! -d "$src_path" ]; then
|
||||
echo " 警告: 源目录 '$src_path' 不存在,已跳过。"
|
||||
continue
|
||||
fi
|
||||
|
||||
dst_dir_name=$(basename "$src_path")
|
||||
dst_path=$(mkdir -p "$LOCAL_DST_ROOT/$dst_dir_name" && realpath "$LOCAL_DST_ROOT/$dst_dir_name")
|
||||
|
||||
echo " >>> 正在同步 $src_path -> $dst_path"
|
||||
rsync -avh --delete "$src_path/" "$dst_path/"
|
||||
done
|
||||
echo "--> (1/2) 本地 Hardisk 更新完成。"
|
||||
|
||||
# --- 第 2 步: 从 Hardisk 备份到 NAS ---
|
||||
echo ""
|
||||
echo "--> (2/2) 正在将更新后的算法文件从 Hardisk 备份到 $NAS_DST_ROOT..."
|
||||
for dir_full_path in "${SRC_DIRS[@]}"; do
|
||||
dir_name=$(basename "$dir_full_path")
|
||||
src_from_hardisk="$LOCAL_DST_ROOT/$dir_name"
|
||||
dst_to_nas="$NAS_DST_ROOT/$dir_name"
|
||||
|
||||
if [ ! -d "$src_from_hardisk" ]; then
|
||||
echo " 警告: 源目录 '$src_from_hardisk' 在 Hardisk 中不存在,已跳过备份。"
|
||||
continue
|
||||
fi
|
||||
|
||||
mkdir -p "$dst_to_nas"
|
||||
dst_path_final=$(realpath "$dst_to_nas")
|
||||
|
||||
echo " >>> 正在备份 $src_from_hardisk -> $dst_path_final"
|
||||
rsync -avh --delete "$src_from_hardisk/" "$dst_path_final/"
|
||||
done
|
||||
echo "--> (2/2) 指定的算法文件已成功备份到 NAS!"
|
||||
;;
|
||||
2)
|
||||
echo "--- 您选择了 [2]: 仅备份Hardisk (Hardisk->NAS) ---"
|
||||
src="$LOCAL_DST_ROOT/"
|
||||
dst="$NAS_DST_ROOT/"
|
||||
|
||||
if [ ! -d "$src" ]; then
|
||||
echo "错误: 源目录 '$src' 不存在,无法继续。"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
mkdir -p "$dst"
|
||||
|
||||
echo ">>> 正在将 $src 的全部内容备份到 $dst"
|
||||
rsync -avh --delete "$src" "$dst" # 使用增量复制
|
||||
echo ">>> Hardisk 目录已完全备份到 NAS!"
|
||||
;;
|
||||
3)
|
||||
echo "--- 您选择了 [3]: 退出 ---"
|
||||
echo "操作已取消。"
|
||||
;;
|
||||
*)
|
||||
echo "无效选项 '$choice'。请输入 1, 2, 或 3。"
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
|
||||
echo ""
|
||||
echo ">>> 全部任务完成!"
|
||||
56
Check_Graph_Card.sh
Normal file
56
Check_Graph_Card.sh
Normal file
@@ -0,0 +1,56 @@
|
||||
#!/usr/bin/env bash
|
||||
# 检查系统识别到的 GPU 卡数和驱动状态
|
||||
echo "======== PCIe 在位 =========="
|
||||
lspci -d 10de: | grep -i vga
|
||||
echo
|
||||
echo "======== 驱动认到几卡 ========"
|
||||
nvidia-smi --list-gpus | wc -l
|
||||
echo
|
||||
# nvidia: probe of 0000:04:06.0 failed with error -1
|
||||
# NVRM: The NVIDIA probe routine failed for 1 device(s).
|
||||
echo "======== dmesg 关键报错 ======"
|
||||
dmesg | grep -iE 'nvidia.*fail|nvidia.*error|Xid.*79|GSP.*timeout' | tail -10
|
||||
|
||||
# 当前信息记录:
|
||||
# Ubuntu中识别到的7张卡
|
||||
# 卡0:00000000:04:00.0
|
||||
# 卡1:00000000:04:02.0
|
||||
# 卡2:00000000:04:04.0
|
||||
# 坏卡-第一次:卡3:00000000:04:06.0
|
||||
# 卡4:00000000:05:00.0
|
||||
# 坏卡-第二次:卡5:00000000:05:02.0
|
||||
# 卡6:00000000:05:04.0
|
||||
# 卡7:00000000:05:06.0
|
||||
# Exsi服务器识别到的7张卡
|
||||
# 卡0:0000:16:00.0
|
||||
# 卡1:0000:38:00.0
|
||||
# 卡2:0000:49:00.0
|
||||
# 坏卡卡槽-第一次:卡3:0000:5a:00.0
|
||||
# 卡4:0000:98:00.0
|
||||
# 坏卡卡槽-第二次:卡5:0000:b8:00.0
|
||||
# 卡6:0000:c8:00.0
|
||||
# 卡7:0000:d8:00.0
|
||||
|
||||
# 解决方案尝试
|
||||
# V1.
|
||||
# 1. 关闭图形界面
|
||||
# sudo systemctl set-default multi-user.target
|
||||
# sudo systemctl set-default multi-user.target
|
||||
# 2. 立刻关 GSP 重新加载驱动
|
||||
# sudo modprobe -r nvidia_drm nvidia_modeset nvidia nvidia_uvm
|
||||
# sudo modprobe nvidia NVreg_EnableGpuFirmware=0
|
||||
# sudo nvidia-smi
|
||||
# 3. 若 8 卡出现 → 就是 GSP 问题,长期生效:
|
||||
# echo "options nvidia NVreg_EnableGpuFirmware=0" | sudo tee /etc/modprobe.d/nvidia-disable-gsp.conf
|
||||
# sudo update-initramfs -u
|
||||
|
||||
# V2.
|
||||
# 1. 看是不是 BAR 空间不足
|
||||
# sudo dmesg | grep -i "BAR 0\|resource 0" | grep 04:06.0 # TODO 变为对应的显卡号
|
||||
# [ 4.747643] pci 0000:04:06.0: BAR 0 [mem 0xea000000-0xeaffffff]
|
||||
# 内核已经成功为 04:06.0 分配了 BAR 0,大小 16 MiB,没有报 “can’t allocate” 或 “failed”,因此 BAR 空间不足/Above 4G Decoding 问题可以排除
|
||||
|
||||
# V3.
|
||||
# 彻底冷复位
|
||||
# 宿主机或 云控制台 → 断电 10 秒再上电
|
||||
#
|
||||
308
DataSet_Own/1. 图片预处理(内含使用手册)/1_rename_pics.sh
Executable file
308
DataSet_Own/1. 图片预处理(内含使用手册)/1_rename_pics.sh
Executable file
@@ -0,0 +1,308 @@
|
||||
#!/bin/bash
|
||||
|
||||
usage() {
|
||||
echo "Usage: $0 -i <ori_image_directory> -l <ori_label_directory> [-h]"
|
||||
echo "对image图片和label图片进行处理"
|
||||
echo "-i:原始图片的路径,-l:原始标签的路径,-h:帮助"
|
||||
}
|
||||
|
||||
ori_image_directorys=""
|
||||
ori_label_directorys=""
|
||||
|
||||
while getopts "hl:i:" opt; do
|
||||
case $opt in
|
||||
h)
|
||||
usage
|
||||
exit 0
|
||||
;;
|
||||
i)
|
||||
ori_image_directorys=$OPTARG
|
||||
;;
|
||||
l)
|
||||
ori_label_directorys=$OPTARG
|
||||
;;
|
||||
*)
|
||||
echo -e '\033[31m!!! Error, Illegal input !!!\033[0m'
|
||||
usage
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
# 判断输入地址是否都为空
|
||||
if [ -z "$ori_label_directorys" ] && [ -z "$ori_image_directorys" ]; then
|
||||
echo -e "\033[31m输入地址 -i -l 都为空\033[0m"
|
||||
usage
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# 地址转化
|
||||
ori_image_directory=$(readlink -f "$ori_image_directorys")
|
||||
ori_label_directory=$(readlink -f "$ori_label_directorys")
|
||||
if [ -z "$ori_label_directory" ] && [ -z "$ori_image_directory" ]; then
|
||||
echo "无法解析地址,程序退出"
|
||||
echo -e "\033[31mori_image_directory\033[0m: $ori_image_directory"
|
||||
echo -e "\033[31mori_label_directory\033[0m: $ori_label_directorys"
|
||||
exit 1
|
||||
fi
|
||||
if [ ! -d "$ori_label_directory" ] && [ ! -d "$ori_image_directory" ]; then
|
||||
echo "image、label两目录都不存在,程序退出"
|
||||
echo -e "\033[31mori_image_directory\033[0m: $ori_image_directory"
|
||||
echo -e "\033[31mori_label_directory\033[0m: $ori_label_directorys"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# 获取当前脚本的路径和名称
|
||||
script_path=$(dirname "$0")
|
||||
# 将当前目录更改为脚本所在的路径
|
||||
cd "$script_path"
|
||||
|
||||
echo -e "\033[32m______ 1_rename_data.sh _____\033[0m"
|
||||
while true; do
|
||||
PS3='Please enter your choice: '
|
||||
options=("Move all pics in dir to dir" "Delete all space in filename" "Add jpg to no suffix filename" "Replace content in filename" "Extract filename between prefix and suffix" "Add content before filename" "Add content behand filename" "Quit")
|
||||
echo "一般处理流程1(移动文件)、2(删除所有空格)、3(文件后缀添加.jpg)、4(删除内容)\"-恢复的\"/\"-副本\"、5(取出前缀、后缀中的内容)(方案1:Still/\"\",方案2:\"\"/.Still)、6(添加前缀)、7(添加后缀)、8(关闭)"
|
||||
select opt in "${options[@]}"
|
||||
do
|
||||
case $opt in
|
||||
### 选项1将所有在目录下的图片移动到目录中 ###
|
||||
"Move all pics in dir to dir")
|
||||
# 判断label图片路径是否存在
|
||||
if [ -d "$ori_label_directory" ]; then
|
||||
echo "**** Processing ori_label_directory: $ori_label_directory ****"
|
||||
# 遍历 标签图片 目录中的所有后缀为 .png 或 .PNG 的文件,将其移动到主目录
|
||||
find "$ori_label_directory" -not -path "*/error/*" \( -iname "*.png" -o -iname "*.PNG" \) -type f -print0 |
|
||||
while IFS= read -r -d $'\0' file; do
|
||||
echo cp -n "$file" "$ori_label_directory"
|
||||
cp -n "$file" "$ori_label_directory"
|
||||
done
|
||||
fi
|
||||
echo ""
|
||||
# 判断image图片路径是否存在
|
||||
if [ -d "$ori_image_directory" ]; then
|
||||
echo "**** Processing ori_image_directory: $ori_image_directory ****"
|
||||
# 遍历 原图片 目录中的所有后缀为 .png 或 .PNG 的文件,将其移动到主目录
|
||||
find "$ori_image_directory" -not -path "*/error/*" \( -iname "*.png" -o -iname "*.PNG" \) -type f -print0 |
|
||||
while IFS= read -r -d $'\0' file; do
|
||||
echo cp -n "$file" "$ori_image_directory"
|
||||
cp -n "$file" "$ori_image_directory"
|
||||
done
|
||||
fi
|
||||
echo -e ""
|
||||
break
|
||||
;;
|
||||
|
||||
### 选项2替换文件名中的空格 ###
|
||||
"Delete all space in filename")
|
||||
# 输入待删除的内容
|
||||
Del_str=" "
|
||||
Replace_str=""
|
||||
# 判断image图片路径是否存在
|
||||
if [ -d "$ori_image_directory" ]; then
|
||||
echo "**** Processing ori_image_directory: $ori_image_directory ****"
|
||||
ls "$ori_image_directory" | grep -e "$Del_str" | awk -v ori_image_directory="$ori_image_directory" -v Replace_str="$Replace_str" -F "$Del_str" '{s1=$0; gsub(/ /, ""); print "mv \""ori_image_directory"/"s1"\" \""ori_image_directory"/"$1""Replace_str""$2"\""}'
|
||||
ls "$ori_image_directory" | grep -e "$Del_str" | awk -v ori_image_directory="$ori_image_directory" -v Replace_str="$Replace_str" -F "$Del_str" '{s1=$0; gsub(/ /, ""); print "mv \""ori_image_directory"/"s1"\" \""ori_image_directory"/"$1""Replace_str""$2"\""}' | bash
|
||||
echo -e ""
|
||||
else
|
||||
echo "**** image图片目录不存在: $ori_image_directory ****"
|
||||
echo -e ""
|
||||
fi
|
||||
|
||||
# 判断label图片路径是否存在
|
||||
if [ -d "$ori_label_directory" ]; then
|
||||
echo "**** Processing ori_label_directory: $ori_label_directory ****"
|
||||
ls "$ori_label_directory" | grep -e "$Del_str" | awk -v ori_label_directory="$ori_label_directory" -v Replace_str="$Replace_str" -F "$Del_str" '{s1=$0; gsub(/ /, ""); print "mv \""ori_label_directory"/"s1"\" \""ori_label_directory"/"$1""Replace_str""$2"\""}'
|
||||
ls "$ori_label_directory" | grep -e "$Del_str" | awk -v ori_label_directory="$ori_label_directory" -v Replace_str="$Replace_str" -F "$Del_str" '{s1=$0; gsub(/ /, ""); print "mv \""ori_label_directory"/"s1"\" \""ori_label_directory"/"$1""Replace_str""$2"\""}' | bash
|
||||
echo -e ""
|
||||
else
|
||||
echo "**** label图片目录不存在: $ori_label_directory ****"
|
||||
echo -e ""
|
||||
fi
|
||||
break
|
||||
;;
|
||||
|
||||
### 选项3在无后缀文件后添加.jpg后缀 ###
|
||||
"Add jpg to no suffix filename")
|
||||
# 判断image图片路径是否存在
|
||||
if [ -d "$ori_image_directory" ]; then
|
||||
echo "**** Processing ori_image_directory: $ori_image_directory ****"
|
||||
find "$ori_image_directory" -type f ! -name "*.*" | awk -F/ '{print $NF}' | awk -v ori_image_directory="$ori_image_directory" '{print "mv \""ori_image_directory"/"$0"\" \""ori_image_directory"/"$0".jpg\""}'
|
||||
find "$ori_image_directory" -type f ! -name "*.*" | awk -F/ '{print $NF}' | awk -v ori_image_directory="$ori_image_directory" '{print "mv \""ori_image_directory"/"$0"\" \""ori_image_directory"/"$0".jpg\""}' | bash
|
||||
echo -e ""
|
||||
else
|
||||
echo "**** image图片目录不存在: $ori_image_directory ****"
|
||||
echo -e ""
|
||||
fi
|
||||
|
||||
# 判断label图片路径是否存在
|
||||
if [ -d "$ori_label_directory" ]; then
|
||||
echo "**** Processing ori_label_directory: $ori_label_directory ****"
|
||||
find "$ori_label_directory" -type f ! -name "*.*" | awk -F/ '{print $NF}' | awk -v ori_label_directory="$ori_label_directory" '{print "mv \""ori_label_directory"/"$0"\" \""ori_label_directory"/"$0".jpg\""}'
|
||||
find "$ori_label_directory" -type f ! -name "*.*" | awk -F/ '{print $NF}' | awk -v ori_label_directory="$ori_label_directory" '{print "mv \""ori_label_directory"/"$0"\" \""ori_label_directory"/"$0".jpg\""}' | bash
|
||||
echo -e ""
|
||||
else
|
||||
echo "**** label图片目录不存在: $ori_label_directory ****"
|
||||
echo -e ""
|
||||
fi
|
||||
break
|
||||
;;
|
||||
|
||||
### 选项4替换文件名中内容 ###
|
||||
"Replace content in filename")
|
||||
# 输入待删除的内容
|
||||
echo -n "Please input the content to be deleted = "
|
||||
read -r Del_str
|
||||
echo -n "Please input the content to be replace(default is None) = "
|
||||
read -r Replace_str
|
||||
# 判断image图片路径是否存在
|
||||
if [ -d "$ori_image_directory" ]; then
|
||||
echo "**** Processing ori_image_directory: $ori_image_directory ****"
|
||||
ls "$ori_image_directory" | grep -E "\.(png|jpg|PNG|JPG|BMP|bmp)$" | grep -e "$Del_str" | awk -v ori_image_directory="$ori_image_directory" -v Del_str="$Del_str" -v Replace_str="$Replace_str" -F "$Del_str" '{s1=$0; gsub(Del_str, Replace_str); print "mv \""ori_image_directory"/"s1"\" \""ori_image_directory"/"$0"\""}'
|
||||
ls "$ori_image_directory" | grep -E "\.(png|jpg|PNG|JPG|BMP|bmp)$" | grep -e "$Del_str" | awk -v ori_image_directory="$ori_image_directory" -v Del_str="$Del_str" -v Replace_str="$Replace_str" -F "$Del_str" '{s1=$0; gsub(Del_str, Replace_str); print "mv \""ori_image_directory"/"s1"\" \""ori_image_directory"/"$0"\""}' | bash
|
||||
echo -e ""
|
||||
else
|
||||
echo "**** image图片目录不存在: $ori_image_directory ****"
|
||||
echo -e ""
|
||||
fi
|
||||
|
||||
# 判断label图片路径是否存在
|
||||
if [ -d "$ori_label_directory" ]; then
|
||||
echo "**** Processing ori_label_directory: $ori_label_directory ****"
|
||||
ls "$ori_label_directory" | grep -E "\.(png|jpg|PNG|JPG|BMP|bmp)$" | grep -e "$Del_str" | awk -v ori_label_directory="$ori_label_directory" -v Del_str="$Del_str" -v Replace_str="$Replace_str" -F "$Del_str" '{s1=$0; gsub(Del_str, Replace_str); print "mv \""ori_label_directory"/"s1"\" \""ori_label_directory"/"$0"\""}'
|
||||
ls "$ori_label_directory" | grep -E "\.(png|jpg|PNG|JPG|BMP|bmp)$" | grep -e "$Del_str" | awk -v ori_label_directory="$ori_label_directory" -v Del_str="$Del_str" -v Replace_str="$Replace_str" -F "$Del_str" '{s1=$0; gsub(Del_str, Replace_str); print "mv \""ori_label_directory"/"s1"\" \""ori_label_directory"/"$0"\""}' | bash
|
||||
echo -e ""
|
||||
else
|
||||
echo "**** label图片目录不存在: $ori_label_directory ****"
|
||||
echo -e ""
|
||||
fi
|
||||
break
|
||||
;;
|
||||
|
||||
### 选项5删除文件名前缀、后缀之间的内容 ###
|
||||
"Extract filename between prefix and suffix")
|
||||
# 输入待删除的内容
|
||||
echo -n "Please input the prefix to be deleted = "
|
||||
read -r prefix
|
||||
echo -n "Please input the suffix to be deleted = "
|
||||
read -r suffix
|
||||
|
||||
|
||||
# 判断image图片路径是否存在
|
||||
if [ -d "$ori_image_directory" ]; then
|
||||
echo "**** Processing ori_image_directory: $ori_image_directory ****"
|
||||
# ls -1 以回车显示
|
||||
ls -1 "$ori_image_directory" | grep -E "\.(png|jpg|PNG|JPG|BMP|bmp)$" | grep -e "$prefix" | grep -e "$suffix" | while read file_name; do
|
||||
# 提取出新的文件名
|
||||
if [ -z $prefix ];then
|
||||
file_name_new=$(echo $file_name | sed "s/\(.*\)$suffix.*\(\.jpg\|\.png\|\.bmp\|\.JPG\|\.PNG\|\.BMP\)$/\1\2/")
|
||||
else
|
||||
echo "sed "s/^.*$prefix\(.*\)$suffix.*\(\.jpg\|\.png\|\.bmp\|\.JPG\|\.PNG\|\.BMP\)$/\1\2/""
|
||||
file_name_new=$(echo $file_name | sed "s/^.*$prefix\(.*\)$suffix.*\(\.jpg\|\.png\|\.bmp\|\.JPG\|\.PNG\|\.BMP\)$/\1\2/")
|
||||
fi
|
||||
echo "$file_name -> $file_name_new"
|
||||
echo "mv "$ori_image_directory/$file_name" "$ori_image_directory/$file_name_new""
|
||||
mv "$ori_image_directory/$file_name" "$ori_image_directory/$file_name_new"
|
||||
done
|
||||
# ls "$ori_image_directory" | grep -e "$Del_str" | awk -v ori_image_directory="$ori_image_directory" -v Del_str="$Del_str" '{ s1=$0; match($0, Del_str); sub(substr($0, 1, RSTART + RLENGTH - 1), ""); print "mv \""ori_image_directory"/"s1"\" \""ori_image_directory"/"$0"\""}'
|
||||
# ls "$ori_image_directory" | grep -e "$Del_str" | awk -v ori_image_directory="$ori_image_directory" -v Del_str="$Del_str" '{ s1=$0; match($0, Del_str); sub(substr($0, 1, RSTART + RLENGTH - 1), ""); print "mv \""ori_image_directory"/"s1"\" \""ori_image_directory"/"$0"\""}' | bash
|
||||
echo -e ""
|
||||
else
|
||||
echo "**** image图片目录不存在: $ori_image_directory ****"
|
||||
echo -e ""
|
||||
fi
|
||||
|
||||
# 判断label图片路径是否存在
|
||||
if [ -d "$ori_label_directory" ]; then
|
||||
echo "**** Processing ori_label_directory: $ori_label_directory ****"
|
||||
ls -1 "$ori_label_directory" | grep -E "\.(png|jpg|PNG|JPG|BMP|bmp)$" | grep -e "$prefix" | grep -e "$suffix" | while read file_name; do
|
||||
# 提取出新的文件名
|
||||
if [ -z $prefix ];then
|
||||
file_name_new=$(echo $file_name | sed "s/\(.*\)$suffix.*\(\.jpg\|\.png\|\.bmp\|\.JPG\|\.PNG\|\.BMP\)$/\1\2/")
|
||||
else
|
||||
echo "sed "s/^.*$prefix\(.*\)$suffix.*\(\.jpg\|\.png\|\.bmp\|\.JPG\|\.PNG\|\.BMP\)$/\1\2/""
|
||||
file_name_new=$(echo $file_name | sed "s/^.*$prefix\(.*\)$suffix.*\(\.jpg\|\.png\|\.bmp\|\.JPG\|\.PNG\|\.BMP\)$/\1\2/")
|
||||
fi
|
||||
echo "$file_name -> $file_name_new"
|
||||
echo "mv "$ori_label_directory/$file_name" "$ori_label_directory/$file_name_new""
|
||||
mv "$ori_label_directory/$file_name" "$ori_label_directory/$file_name_new"
|
||||
done
|
||||
# ls "$ori_label_directory" | grep -e "$Del_str" | awk -v ori_label_directory="$ori_label_directory" -v Del_str="$Del_str" '{ s1=$0; match($0, Del_str); sub(substr($0, 1, RSTART + RLENGTH - 1), ""); print "mv \""ori_label_directory"/"s1"\" \""ori_label_directory"/"$0"\""}'
|
||||
# ls "$ori_label_directory" | grep -e "$Del_str" | awk -v ori_label_directory="$ori_label_directory" -v Del_str="$Del_str" '{ s1=$0; match($0, Del_str); sub(substr($0, 1, RSTART + RLENGTH - 1), ""); print "mv \""ori_label_directory"/"s1"\" \""ori_label_directory"/"$0"\""}' | bash
|
||||
echo -e ""
|
||||
else
|
||||
echo "**** label图片目录不存在: $ori_label_directory ****"
|
||||
echo -e ""
|
||||
fi
|
||||
break
|
||||
;;
|
||||
|
||||
### 选项6在文件名前添加内容 ###
|
||||
"Add content before filename")
|
||||
echo -n "Please input the content to be added = "
|
||||
read Group_str
|
||||
# 判断image图片路径是否存在
|
||||
if [ -d "$ori_image_directory" ]; then
|
||||
echo "**** Processing ori_image_directory: $ori_image_directory ****"
|
||||
ls "$ori_image_directory" | grep -E "\.(png|jpg|PNG|JPG|BMP|bmp)$" | awk -v ori_image_directory="$ori_image_directory" -v Group_str="$Group_str" '{print "mv \""ori_image_directory"/"$0"\" \""ori_image_directory"/"Group_str""$0"\""}'
|
||||
ls "$ori_image_directory" | grep -E "\.(png|jpg|PNG|JPG|BMP|bmp)$" | awk -v ori_image_directory="$ori_image_directory" -v Group_str="$Group_str" '{print "mv \""ori_image_directory"/"$0"\" \""ori_image_directory"/"Group_str""$0"\""}' | bash
|
||||
echo -e ""
|
||||
else
|
||||
echo "**** image图片目录不存在: $ori_image_directory ****"
|
||||
echo -e ""
|
||||
fi
|
||||
|
||||
# 判断label图片路径是否存在
|
||||
if [ -d "$ori_label_directory" ]; then
|
||||
echo "**** Processing ori_label_directory: $ori_label_directory ****"
|
||||
ls "$ori_label_directory" | grep -E "\.(png|jpg|PNG|JPG|BMP|bmp)$" | awk -v ori_label_directory="$ori_label_directory" -v Group_str="$Group_str" '{print "mv \""ori_label_directory"/"$0"\" \""ori_label_directory"/"Group_str""$0"\""}'
|
||||
ls "$ori_label_directory" | grep -E "\.(png|jpg|PNG|JPG|BMP|bmp)$" | awk -v ori_label_directory="$ori_label_directory" -v Group_str="$Group_str" '{print "mv \""ori_label_directory"/"$0"\" \""ori_label_directory"/"Group_str""$0"\""}' | bash
|
||||
echo -e ""
|
||||
else
|
||||
echo "**** label图片目录不存在: $ori_label_directory ****"
|
||||
echo -e ""
|
||||
fi
|
||||
break
|
||||
;;
|
||||
|
||||
### 选项7在文件名后添加内容 ###
|
||||
"Add content behand filename")
|
||||
echo -n "Please input the content to be added = "
|
||||
read Group_str
|
||||
# 判断image图片路径是否存在
|
||||
if [ -d "$ori_image_directory" ]; then
|
||||
echo "**** Processing ori_image_directory: $ori_image_directory ****"
|
||||
ls "$ori_image_directory" | grep -E "\.(png|jpg|PNG|JPG|BMP|bmp)$" | awk -v ori_image_directory="$ori_image_directory" -v Group_str="$Group_str" -F . '{print "mv \""ori_image_directory"/"$0"\" \""ori_image_directory"/"substr($0, 1, length($0)-length($NF)-1)""Group_str"."$NF"\""}'
|
||||
ls "$ori_image_directory" | grep -E "\.(png|jpg|PNG|JPG|BMP|bmp)$" | awk -v ori_image_directory="$ori_image_directory" -v Group_str="$Group_str" -F . '{print "mv \""ori_image_directory"/"$0"\" \""ori_image_directory"/"substr($0, 1, length($0)-length($NF)-1)""Group_str"."$NF"\""}' | bash
|
||||
echo -e ""
|
||||
else
|
||||
echo "**** image图片目录不存在: $ori_image_directory ****"
|
||||
echo -e ""
|
||||
fi
|
||||
|
||||
# 判断label图片路径是否存在
|
||||
if [ -d "$ori_label_directory" ]; then
|
||||
echo "**** Processing ori_label_directory: $ori_label_directory ****"
|
||||
ls "$ori_label_directory" | grep -E "\.(png|jpg|PNG|JPG|BMP|bmp)$" | awk -v ori_label_directory="$ori_label_directory" -v Group_str="$Group_str" -F . '{print "mv \""ori_label_directory"/"$0"\" \""ori_label_directory"/"substr($0, 1, length($0)-length($NF)-1)""Group_str"."$NF"\""}'
|
||||
ls "$ori_label_directory" | grep -E "\.(png|jpg|PNG|JPG|BMP|bmp)$" | awk -v ori_label_directory="$ori_label_directory" -v Group_str="$Group_str" -F . '{print "mv \""ori_label_directory"/"$0"\" \""ori_label_directory"/"substr($0, 1, length($0)-length($NF)-1)""Group_str"."$NF"\""}' | bash
|
||||
echo -e ""
|
||||
else
|
||||
echo "**** label图片目录不存在: $ori_label_directory ****"
|
||||
echo -e ""
|
||||
fi
|
||||
break
|
||||
;;
|
||||
|
||||
### 选项8 退出 ###
|
||||
"Quit")
|
||||
echo "Exiting..."
|
||||
exit 0
|
||||
;;
|
||||
|
||||
*)
|
||||
echo "Invalid option: $REPLY"
|
||||
echo -e ""
|
||||
break
|
||||
;;
|
||||
esac
|
||||
done
|
||||
done
|
||||
73
DataSet_Own/1. 图片预处理(内含使用手册)/2_1_Trans_to_png.py
Executable file
73
DataSet_Own/1. 图片预处理(内含使用手册)/2_1_Trans_to_png.py
Executable file
@@ -0,0 +1,73 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*
|
||||
import os, cv2, sys
|
||||
|
||||
def transform(input_path, output_path):
|
||||
for root, dirs, files in os.walk(input_path):
|
||||
for name in files:
|
||||
print("2_1_正在检索", os.path.join(root, name))
|
||||
# 如果图片是jpg图片
|
||||
if name.endswith('.jpg') or name.endswith('.JPG') :
|
||||
# convert一下图片
|
||||
print("convert \""+os.path.join(root, name)+"\" \""+os.path.join(input_path, name)+"\"")
|
||||
os.system("convert \""+os.path.join(root, name)+"\" \""+os.path.join(input_path, name)+"\"")
|
||||
file = os.path.join(root, name)
|
||||
if os.path.basename(root) == 'error':
|
||||
break
|
||||
try:
|
||||
# 读取图片并且改变存储方式
|
||||
im = cv2.imread(file)
|
||||
if output_path:
|
||||
# 压缩度调为0
|
||||
cv2.imwrite(os.path.join(output_path, name.replace('jpg', 'png').replace('JPG', 'png')), im, [cv2.IMWRITE_PNG_COMPRESSION, 100, cv2.IMWRITE_PNG_COMPRESSION, 0])
|
||||
else:
|
||||
print('transform:' + file.replace('jpg', 'png').replace('JPG', 'png'))
|
||||
os.system("rm \""+file+"\"")
|
||||
# 压缩度调为0
|
||||
cv2.imwrite(file.replace('jpg', 'png').replace('JPG', 'png'), im, [cv2.IMWRITE_PNG_COMPRESSION, 100, cv2.IMWRITE_PNG_COMPRESSION, 0])
|
||||
except:
|
||||
os.system("echo "+file+" >> error.txt")
|
||||
# 检查文件夹是否存在
|
||||
if not os.path.exists(os.path.join(root, 'error')):
|
||||
# 如果不存在,创建文件夹
|
||||
os.mkdir(os.path.join(root, 'error'))
|
||||
os.system("mv "+file+" "+os.path.join(root, 'error'))
|
||||
# 如果图片是jpg图片
|
||||
if name.endswith('.bmp') or name.endswith('.BMP') :
|
||||
# convert一下图片
|
||||
print("convert "+os.path.join(root, name)+" "+os.path.join(input_path, name))
|
||||
os.system("convert "+os.path.join(root, name)+" "+os.path.join(input_path, name))
|
||||
file = os.path.join(root, name)
|
||||
if os.path.basename(root) == 'error':
|
||||
break
|
||||
try:
|
||||
# 读取图片并且改变存储方式
|
||||
im = cv2.imread(file)
|
||||
if output_path:
|
||||
# 压缩度调为0
|
||||
cv2.imwrite(os.path.join(output_path, name.replace('bmp', 'png').replace('BMP', 'png')), im, [cv2.IMWRITE_PNG_COMPRESSION, 100, cv2.IMWRITE_PNG_COMPRESSION, 0])
|
||||
else:
|
||||
print('transform:' + os.path.join(root, name.replace('bmp', 'png').replace('BMP', 'png')))
|
||||
os.system("rm \""+file+"\"")
|
||||
# 压缩度调为0
|
||||
cv2.imwrite(os.path.join(root, name.replace('bmp', 'png').replace('BMP', 'png')), im, [cv2.IMWRITE_PNG_COMPRESSION, 100, cv2.IMWRITE_PNG_COMPRESSION, 0])
|
||||
except:
|
||||
os.system("echo \""+file+"\" >> error.txt")
|
||||
# 检查文件夹是否存在
|
||||
if not os.path.exists(os.path.join(root, 'error')):
|
||||
# 如果不存在,创建文件夹
|
||||
os.mkdir(os.path.join(root, 'error'))
|
||||
os.system("mv \""+file+"\" \""+os.path.join(root, 'error')+"\"")
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
input_path = sys.argv[1] # './2.C组/Group_label_C0'
|
||||
output_path = None
|
||||
if not os.path.exists(input_path):
|
||||
print("文件夹不存在!")
|
||||
else:
|
||||
print("Start to transform 2_1_Trans_to_png!")
|
||||
transform(input_path, output_path)
|
||||
print("Transform end 2_1_Trans_to_png!")
|
||||
print()
|
||||
103
DataSet_Own/1. 图片预处理(内含使用手册)/2_2_Resize.py
Executable file
103
DataSet_Own/1. 图片预处理(内含使用手册)/2_2_Resize.py
Executable file
@@ -0,0 +1,103 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*
|
||||
import cv2, sys
|
||||
import os
|
||||
|
||||
break_up = None
|
||||
break_up_point = False
|
||||
|
||||
# 重新改变图像大小
|
||||
def Resize(input_path, output_path, interpolation=False, default_height=1920, default_width=1080):
|
||||
global break_up, break_up_point
|
||||
# 遍历输入路径
|
||||
for root, dirs, files in os.walk(input_path):
|
||||
# 遍历所有文件
|
||||
for name in files:
|
||||
file = os.path.join(root, name)
|
||||
if(name == break_up or break_up == None): # 如果没到达断点或者没有断点
|
||||
break_up_point = True
|
||||
if break_up_point == False:
|
||||
continue
|
||||
|
||||
if not name.endswith(('.png', '.jpg', '.PNG', '.JPG')):
|
||||
continue
|
||||
|
||||
if os.path.basename(root) == 'error':
|
||||
break
|
||||
print("2_2_正在处理", file)
|
||||
try:
|
||||
# print("convert "+os.path.join(root, name)+" "+os.path.join(input_path, name))
|
||||
print("Processing: ","convert \""+file+"\" \""+file+"\"")
|
||||
status = os.system("convert \""+file+"\" \""+file+"\"") # 改变文件格式
|
||||
# 如果返回状态不为0,则移动对应图片到Error文件夹中
|
||||
if status != 0:
|
||||
os.system("echo \""+file+"\" >> error.txt")
|
||||
# 检查文件夹是否存在
|
||||
if not os.path.exists(os.path.join(root, 'error')):
|
||||
# 如果不存在,创建文件夹
|
||||
os.mkdir(os.path.join(root, 'error'))
|
||||
os.system("mv \""+file+"\" \""+os.path.join(root, 'error')+'\"')
|
||||
print("此文件是问题文件 ","mv \""+file+"\" \""+os.path.join(root, 'error')+'\"')
|
||||
continue
|
||||
|
||||
|
||||
# 读取图片
|
||||
im = cv2.imread(file)
|
||||
height, width, channels = im.shape
|
||||
# 判断高度和宽度是否符合要求
|
||||
if height == default_height and width == default_width:
|
||||
print("符合要求")
|
||||
continue
|
||||
else:
|
||||
# 是否满足最近临要求
|
||||
if interpolation == False:
|
||||
im = cv2.resize(im, (default_width, default_height))
|
||||
else:
|
||||
im = cv2.resize(im, (default_width, default_height),interpolation = cv2.INTER_NEAREST)
|
||||
# 输出影像
|
||||
if output_path:
|
||||
cv2.imwrite(os.path.join(output_path, name), im, [cv2.IMWRITE_PNG_COMPRESSION, 100, cv2.IMWRITE_PNG_COMPRESSION, 0])
|
||||
else:
|
||||
print('Resize:' + file)
|
||||
os.system("rm \""+file+"\"") # TODO
|
||||
# 压缩度调为0
|
||||
cv2.imwrite(file, im, [cv2.IMWRITE_PNG_COMPRESSION, 100, cv2.IMWRITE_PNG_COMPRESSION, 0])
|
||||
except:
|
||||
os.system("echo \""+file+"\" >> error.txt")
|
||||
# 检查文件夹是否存在
|
||||
if not os.path.exists(os.path.join(root, 'error')):
|
||||
# 如果不存在,创建文件夹
|
||||
os.mkdir(os.path.join(root, 'error'))
|
||||
os.system("mv \""+file+"\" \""+os.path.join(root, 'error')+"\"")
|
||||
|
||||
if __name__ == '__main__':
|
||||
input_path = sys.argv[1] # './2.C组/Group_label_C0'
|
||||
output_path = None
|
||||
try:
|
||||
default_width = int(sys.argv[3])
|
||||
except:
|
||||
default_width = 1920 # 默认宽度
|
||||
try:
|
||||
default_height = int(sys.argv[4])
|
||||
except:
|
||||
default_height = 1080 # 默认高度
|
||||
if default_width == 0 or default_height == 0:
|
||||
print("发生错误,default_width、default_height不应该为0")
|
||||
sys.exit()
|
||||
interpolation = sys.argv[2] # 最近临插值
|
||||
if interpolation == "False":
|
||||
interpolation = False
|
||||
elif interpolation == "True":
|
||||
interpolation = True
|
||||
else:
|
||||
print("interpolation must be True or False!")
|
||||
quit
|
||||
|
||||
if not os.path.exists(input_path):
|
||||
print(input_path)
|
||||
print("文件夹不存在!")
|
||||
else:
|
||||
print("Start to transform_2_2_Resize.py!")
|
||||
print(input_path, output_path, default_height,default_width)
|
||||
Resize(input_path, output_path, interpolation=interpolation, default_height=default_height, default_width=default_width)
|
||||
print("Transform end_2_2_Resize.py!")
|
||||
118
DataSet_Own/1. 图片预处理(内含使用手册)/2_reformate_pics.sh
Executable file
118
DataSet_Own/1. 图片预处理(内含使用手册)/2_reformate_pics.sh
Executable file
@@ -0,0 +1,118 @@
|
||||
#!/bin/bash
|
||||
|
||||
usage() {
|
||||
echo "Usage: $0 -i <ori_image_directory> -l <ori_label_directory> -w <width_of_pic> -h <height_of_pic> [-help]"
|
||||
echo "对image图片和label图片进行处理,将其转为PNG格式,并调整图片的宽和高和格式"
|
||||
echo "-i:原始图片的路径,-l:原始标签的路径,-w:图片宽度,-h:图片高度,-help:帮助"
|
||||
}
|
||||
|
||||
ori_image_directorys=""
|
||||
ori_label_directorys=""
|
||||
pic_width=1920
|
||||
pic_height=1080
|
||||
|
||||
while getopts "l:i:h:w:" opt; do
|
||||
case $opt in
|
||||
h)
|
||||
if [[ $OPTARG =~ ^-?[0-9]+$ ]];then
|
||||
pic_height=$OPTARG
|
||||
echo pic_height is $pic_height
|
||||
elif [ $OPTARG == 'elp' ];then
|
||||
usage
|
||||
exit 0
|
||||
else
|
||||
echo "-h(pic_height)必须为整数"
|
||||
usage
|
||||
exit 1
|
||||
fi
|
||||
;;
|
||||
i)
|
||||
ori_image_directorys=$OPTARG
|
||||
;;
|
||||
l)
|
||||
ori_label_directorys=$OPTARG
|
||||
;;
|
||||
w)
|
||||
if [[ $OPTARG =~ ^-?[0-9]+$ ]];then
|
||||
pic_width=$OPTARG
|
||||
echo pic_width is $pic_width
|
||||
else
|
||||
echo "-w(pic_height)必须为整数"
|
||||
usage
|
||||
exit 1
|
||||
fi
|
||||
;;
|
||||
*)
|
||||
echo -e '\033[31m!!! Error, Illegal input !!!\033[0m'
|
||||
usage
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
# 判断输入地址是否都为空
|
||||
if [ -z "$ori_label_directorys" ] && [ -z "$ori_image_directorys" ]; then
|
||||
echo -e "\033[31m输入地址 -i -l 都为空\033[0m"
|
||||
usage
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# 地址转化
|
||||
ori_image_directory=$(readlink -f "$ori_image_directorys")
|
||||
ori_label_directory=$(readlink -f "$ori_label_directorys")
|
||||
if [ -z "$ori_label_directory" ] && [ -z "$ori_image_directory" ]; then
|
||||
echo "无法解析地址,程序退出"
|
||||
echo -e "\033[31mori_image_directory\033[0m: $ori_image_directorys"
|
||||
echo -e "\033[31mori_label_directory\033[0m: $ori_label_directorys"
|
||||
exit 1
|
||||
fi
|
||||
if [ ! -d "$ori_label_directory" ] && [ ! -d "$ori_image_directory" ]; then
|
||||
echo "image、label两目录都不存在,程序退出"
|
||||
echo -e "\033[31mori_image_directory\033[0m: $ori_image_directorys"
|
||||
echo -e "\033[31mori_label_directory\033[0m: $ori_label_directorys"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo -e "\033[32m_____ 2_reformate_data.sh _____\033[0m"
|
||||
|
||||
# 获取当前脚本的路径和名称
|
||||
script_path=$(dirname "$0")
|
||||
# 将当前目录更改为脚本所在的路径
|
||||
cd "$script_path"
|
||||
|
||||
# 激活conda环境
|
||||
source /home/"$USER"/miniconda/bin/activate Deal_pics
|
||||
|
||||
# 判断image图片路径是否存在
|
||||
if [ -d "$ori_image_directory" ]; then
|
||||
echo "**** Processing ori_image_directory: $ori_image_directory ****"
|
||||
echo "1.Trans pics to png"
|
||||
python 2_1_Trans_to_png.py "$ori_image_directory"
|
||||
echo -e ""
|
||||
echo "2.Resize image pics with nearest"
|
||||
echo -e "\033[35m运行:\033[0mpython 2_2_Resize.py "$ori_image_directory" False $pic_width $pic_height "
|
||||
python 2_2_Resize.py "$ori_image_directory" False $pic_width $pic_height # False 是不使用最近邻插值
|
||||
echo -e ""
|
||||
else
|
||||
echo "**** image图片目录不存在: $ori_image_directory ****"
|
||||
echo -e ""
|
||||
fi
|
||||
|
||||
# 判断label图片路径是否存在
|
||||
if [ -d "$ori_label_directory" ]; then
|
||||
echo "**** Processing ori_label_directory: $ori_label_directory ****"
|
||||
echo -e "\033[33m__ 1.Trans pics to png __\033[0m"
|
||||
echo -e "\033[35m运行:\033[0mpython 2_1_Trans_to_png.py "$ori_label_directory""
|
||||
python 2_1_Trans_to_png.py "$ori_label_directory"
|
||||
echo -e ""
|
||||
echo -e "\033[33m__ 2.Resize label pics without nearest __\033[0m"
|
||||
echo -e "\033[35m运行:\033[0mpython 2_2_Resize.py "$ori_label_directory" True $pic_width $pic_height"
|
||||
python 2_2_Resize.py "$ori_label_directory" True $pic_width $pic_height # True 是使用最近邻插值
|
||||
echo -e ""
|
||||
else
|
||||
echo -e "\033[33m**** label图片目录不存在: $ori_image_directory ****\033[0m"
|
||||
echo -e ""
|
||||
fi
|
||||
|
||||
source /home/"$USER"/miniconda/bin/deactivate
|
||||
|
||||
186
DataSet_Own/1. 图片预处理(内含使用手册)/3_pair_ori_label.sh
Executable file
186
DataSet_Own/1. 图片预处理(内含使用手册)/3_pair_ori_label.sh
Executable file
@@ -0,0 +1,186 @@
|
||||
#!/bin/bash
|
||||
|
||||
usage() {
|
||||
echo "Usage: $0 -i <ori_image_directory> -l <ori_label_directory> [ -p <prefix> -s <suffix> -h]"
|
||||
echo "对image图片和label图片进行匹配(-i、-l均不能为空)(-p -s默认为空"") "
|
||||
echo "-i:原始image的路径,-l:原始label的路径,-p:前缀内容,-s:后缀内容(不用管文件后缀名),-h:帮助"
|
||||
echo "e.g. 3_pair_ori_label.sh -i ./C组未标注 -l ./C组标注图片 -p Group_C_ -s _label"
|
||||
}
|
||||
|
||||
ori_image_directorys=""
|
||||
ori_label_directorys=""
|
||||
prefix=""
|
||||
suffix=""
|
||||
|
||||
while getopts "hl:i:p:s:" opt; do
|
||||
case $opt in
|
||||
h)
|
||||
usage
|
||||
exit 0
|
||||
;;
|
||||
i)
|
||||
ori_image_directorys=$OPTARG
|
||||
;;
|
||||
l)
|
||||
ori_label_directorys=$OPTARG
|
||||
;;
|
||||
p)
|
||||
prefix=$OPTARG
|
||||
;;
|
||||
s)
|
||||
suffix=$OPTARG
|
||||
;;
|
||||
*)
|
||||
echo "$opt"
|
||||
echo -e '\033[31m!!! Error, Illegal input !!!\033[0m'
|
||||
usage
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
# 判断输入地址是否为空
|
||||
if [ -z "$ori_label_directorys" ] || [ -z "$ori_image_directorys" ]; then
|
||||
echo -e "\033[31m输入地址 -i -l 存在空地址\033[0m"
|
||||
usage
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# 地址转化
|
||||
ori_image_directory=$(readlink -f "$ori_image_directorys")
|
||||
ori_label_directory=$(readlink -f "$ori_label_directorys")
|
||||
if [ -z "$ori_label_directory" ] || [ -z "$ori_image_directory" ]; then
|
||||
echo "无法解析地址,程序退出"
|
||||
echo -e "\033[31mori_image_directory\033[0m: $ori_image_directorys"
|
||||
echo -e "\033[31mori_label_directory\033[0m: $ori_label_directorys"
|
||||
exit 1
|
||||
fi
|
||||
if [ ! -d "$ori_label_directory" ] || [ ! -d "$ori_image_directory" ]; then
|
||||
echo "image、label两目录有一个不存在,程序退出"
|
||||
echo -e "\033[31mori_image_directory\033[0m: $ori_image_directorys"
|
||||
echo -e "\033[31mori_label_directory\033[0m: $ori_label_directorys"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# 获取当前脚本的路径和名称
|
||||
script_path=$(dirname "$0")
|
||||
# 将当前目录更改为脚本所在的路径
|
||||
cd "$script_path"
|
||||
|
||||
echo -e "\033[32m_____ 3_pair_ori_label.sh _____\033[0m"
|
||||
# find -name 中"*"表示通配,与'.*'不同
|
||||
# 遍历img目录
|
||||
for file_path in "$ori_image_directory"/*; do
|
||||
# 判断是否是文件
|
||||
if [[ -f "$file_path" ]]; then
|
||||
file_name=$(basename "$file_path")
|
||||
# 判断文件名是否符合规范
|
||||
if [[ "$file_name" =~ .*\.(jpg|png|bmp|JPG|PNG|BMP) ]]; then # 判断是否有为图片
|
||||
# if [[ "$file_name" =~ "$prefix"".*$suffix".*\.(jpg|png|bmp|JPG|PNG|BMP)$ ]]; then # 判断是否有满足要求的文件名
|
||||
# 抽取文件名(有前缀、后缀的抽取前缀、后缀里面的,没有的返回整个)
|
||||
if [ -z $prefix ];then
|
||||
file_name_extract=$(echo $file_name | sed "s/"$prefix"\(.*\)"$suffix".*/\1/" | sed "s/\(.*\)\.\(jpg\|png\|bmp\|JPG\|PNG\|BMP\)$/\1/")
|
||||
else
|
||||
file_name_extract=$(echo $file_name | sed "s/".*$prefix"\(.*\)"$suffix".*/\1/" | sed "s/\(.*\)\.\(jpg\|png\|bmp\|JPG\|PNG\|BMP\)$/\1/")
|
||||
fi
|
||||
echo "$file_name -> $file_name_extract"
|
||||
# 从label目录中看是否有此文件,sed匹配时$suffix后要有.*
|
||||
if [ -z $prefix ];then
|
||||
file_name_other=$(ls $ori_label_directory | grep $file_name_extract | sed "s/"$prefix"\(.*\)"$suffix".*/\1/" | sed "s/\(.*\)\.\(jpg\|png\|bmp\|JPG\|PNG\|BMP\)$/\1/")
|
||||
else
|
||||
file_name_other=$(ls $ori_label_directory | grep $file_name_extract | sed "s/".*$prefix"\(.*\)"$suffix".*/\1/" | sed "s/\(.*\)\.\(jpg\|png\|bmp\|JPG\|PNG\|BMP\)$/\1/")
|
||||
fi
|
||||
# 如果另一个目录没有此文件的话
|
||||
if [ -z "$file_name_other" ]; then
|
||||
echo "$file_name image中内容未在$ori_label_directory搜索到"
|
||||
# 建立相关存储文件夹
|
||||
if [ ! -d "$ori_image_directory/Not_pair_pics" ]; then
|
||||
mkdir -p "$ori_image_directory/Not_pair_pics" # 建立存储文件夹
|
||||
fi
|
||||
# 移动相关文件
|
||||
echo "$file_name" >> "$ori_image_directory/Not_pair_pics/not_pair.txt"
|
||||
mv "$ori_image_directory/$file_name" "$ori_image_directory/Not_pair_pics"
|
||||
fi
|
||||
|
||||
fi
|
||||
fi
|
||||
done
|
||||
|
||||
# 遍历label目录
|
||||
for file_path in "$ori_label_directory"/*; do
|
||||
# 判断是否是文件
|
||||
if [[ -f "$file_path" ]]; then
|
||||
file_name=$(basename "$file_path")
|
||||
# 判断文件名是否符合规范
|
||||
if [[ "$file_name" =~ .*\.(jpg|png|bmp|JPG|PNG|BMP) ]]; then # 判断是否有为图片
|
||||
# if [[ "$file_name" =~ "$prefix"".*$suffix".*\.(jpg|png|bmp|JPG|PNG|BMP)$ ]]; then # 判断是否有满足要求的文件名
|
||||
# 抽取文件名(有前缀、后缀的抽取前缀、后缀里面的,没有的返回整个)
|
||||
if [ -z $prefix ];then
|
||||
file_name_extract=$(echo $file_name | sed "s/"$prefix"\(.*\)"$suffix".*/\1/" | sed "s/\(.*\)\.\(jpg\|png\|bmp\|JPG\|PNG\|BMP\)$/\1/")
|
||||
else
|
||||
file_name_extract=$(echo $file_name | sed "s/".*$prefix"\(.*\)"$suffix".*/\1/" | sed "s/\(.*\)\.\(jpg\|png\|bmp\|JPG\|PNG\|BMP\)$/\1/")
|
||||
fi
|
||||
echo "$file_name -> $file_name_extract"
|
||||
# 从image目录中看是否有此文件
|
||||
if [ -z $prefix ];then
|
||||
file_name_other=$(ls $ori_image_directory | grep $file_name_extract | sed "s/"$prefix"\(.*\)"$suffix".*/\1/" | sed "s/\(.*\)\.\(jpg\|png\|bmp\|JPG\|PNG\|BMP\)$/\1/")
|
||||
else
|
||||
file_name_other=$(ls $ori_image_directory | grep $file_name_extract | sed "s/".*$prefix"\(.*\)"$suffix".*/\1/" | sed "s/\(.*\)\.\(jpg\|png\|bmp\|JPG\|PNG\|BMP\)$/\1/")
|
||||
fi
|
||||
|
||||
# 如果另一个目录没有此文件的话
|
||||
if [ -z "$file_name_other" ]; then
|
||||
echo "$file_name label中对应内容未在$ori_image_directory搜索到"
|
||||
# 建立相关存储文件夹
|
||||
if [ ! -d "$ori_label_directory/Not_pair_pics" ]; then
|
||||
mkdir -p "$ori_label_directory/Not_pair_pics" # 建立存储文件夹
|
||||
fi
|
||||
# 移动相关文件
|
||||
mv "$ori_label_directory/$file_name" "$ori_label_directory/Not_pair_pics"
|
||||
echo "$file_name" >> "$ori_label_directory/Not_pair_pics/not_pair.txt"
|
||||
fi
|
||||
|
||||
fi
|
||||
fi
|
||||
done
|
||||
|
||||
# 第二次遍历img目录
|
||||
for file_path in "$ori_image_directory"/*; do
|
||||
# 判断是否是文件
|
||||
if [[ -f "$file_path" ]]; then
|
||||
file_name=$(basename "$file_path")
|
||||
# 判断文件名是否符合规范
|
||||
if [[ "$file_name" =~ .*\.(jpg|png|bmp|JPG|PNG|BMP) ]]; then # 判断是否有为图片
|
||||
# if [[ "$file_name" =~ "$prefix"".*$suffix".*\.(jpg|png|bmp|JPG|PNG|BMP)$ ]]; then # 判断是否有满足要求的文件名
|
||||
# 抽取文件名(有前缀、后缀的抽取前缀、后缀里面的,没有的返回整个)
|
||||
if [ -z $prefix ];then
|
||||
file_name_extract=$(echo $file_name | sed "s/"$prefix"\(.*\)"$suffix".*/\1/" | sed "s/\(.*\)\.\(jpg\|png\|bmp\|JPG\|PNG\|BMP\)$/\1/")
|
||||
else
|
||||
file_name_extract=$(echo $file_name | sed "s/".*$prefix"\(.*\)"$suffix".*/\1/" | sed "s/\(.*\)\.\(jpg\|png\|bmp\|JPG\|PNG\|BMP\)$/\1/")
|
||||
fi
|
||||
echo "$file_name -> $file_name_extract"
|
||||
# 从label目录中看是否有此文件,sed匹配时$suffix后要有.*
|
||||
if [ -z $prefix ];then
|
||||
file_name_other=$(ls $ori_label_directory | grep $file_name_extract | sed "s/"$prefix"\(.*\)"$suffix".*/\1/" | sed "s/\(.*\)\.\(jpg\|png\|bmp\|JPG\|PNG\|BMP\)$/\1/")
|
||||
else
|
||||
file_name_other=$(ls $ori_label_directory | grep $file_name_extract | sed "s/".*$prefix"\(.*\)"$suffix".*/\1/" | sed "s/\(.*\)\.\(jpg\|png\|bmp\|JPG\|PNG\|BMP\)$/\1/")
|
||||
fi
|
||||
# 如果另一个目录没有此文件的话
|
||||
if [ -z "$file_name_other" ]; then
|
||||
echo "$file_name image中内容未在$ori_label_directory搜索到"
|
||||
# 建立相关存储文件夹
|
||||
if [ ! -d "$ori_image_directory/Not_pair_pics" ]; then
|
||||
mkdir -p "$ori_image_directory/Not_pair_pics" # 建立存储文件夹
|
||||
fi
|
||||
# 移动相关文件
|
||||
echo "$file_name" >> "$ori_image_directory/Not_pair_pics/not_pair.txt"
|
||||
mv "$ori_image_directory/$file_name" "$ori_image_directory/Not_pair_pics"
|
||||
fi
|
||||
|
||||
fi
|
||||
fi
|
||||
done
|
||||
|
||||
|
||||
|
||||
|
||||
445
DataSet_Own/1. 图片预处理(内含使用手册)/4_deal_labels.py
Executable file
445
DataSet_Own/1. 图片预处理(内含使用手册)/4_deal_labels.py
Executable file
@@ -0,0 +1,445 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*
|
||||
import os,time,sys,threading, colorsys, argparse
|
||||
import asyncio, cv2, multiprocessing, random
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
from Tool_deal_labels import edge_detection, detect_connected_regions, Tool_color_connected_array, fill_white_regions, color_connected_regions
|
||||
|
||||
def getFileList(dir,Filelist=[], ext=None, Max_layer=1, layer=0, Donot_Search=['1_边缘检测并膨胀', '2_连通区域检测', '3_分水岭算法填充']):
|
||||
"""
|
||||
获取文件夹及其子文件夹中文件列表
|
||||
输入 dir:文件夹根目录
|
||||
输入 ext: 扩展名
|
||||
返回: 文件路径列表
|
||||
"""
|
||||
newDir = dir
|
||||
if os.path.isfile(dir):
|
||||
if ext is None:
|
||||
Filelist.append(dir)
|
||||
else:
|
||||
if ext in dir[-3:]:
|
||||
Filelist.append(dir)
|
||||
|
||||
elif os.path.isdir(dir):
|
||||
file_name = os.path.basename(dir)
|
||||
# 判断是否在禁搜名单中
|
||||
if file_name in Donot_Search:
|
||||
return Filelist
|
||||
for s in os.listdir(dir):
|
||||
newDir=os.path.join(dir,s)
|
||||
if layer <= Max_layer:
|
||||
getFileList(newDir, Filelist, ext, Max_layer, layer+1)
|
||||
|
||||
return Filelist
|
||||
|
||||
class Deal_image():
|
||||
def __init__(self, Annotate_CLASSES = ('肝脏','胆囊'), Annotate_PALETTE = [[255,91,0],[255,234,0]], src_label_fold = "./Label", save_pro_label_fold = "./LABEL_PNG_new", save_GT_label_fold = "./Label_Generate", GT_channel = 1, pro_append_name="_label", GT_append_name="_gtFine_labelTrainIds", ori_img_folder="./ORI_PNG", res_label_folder="./Result_label", save_merge_pic_folder="./Result_merge", back_gnd_color=0, first_class_color=1, pic_type="png", Max_width = 10000, Label_Max_Search_layer=1000, save_process_pics=False, bg_PALETTE = [0,0,0]):
|
||||
# 背景最好放在最后
|
||||
# self.src_CLASSES = ('肝脏','胆囊','分离钳','止血海绵','肝总管','胆总管','吸引器','剪刀','止血纱布','生物夹','无损伤钳','喷洒','胆囊管','胆囊动脉','电凝','标本袋','引流管','纱布','金属钛夹','术中超声','吻合器','乳胶管','推结器','肝带','钳夹','超声刀','脂肪','双极电凝','棉球','血管阻断夹','肿瘤','针','线','韧带','胆囊静脉','背景')
|
||||
# self.src_PALETTE = np.array([[255,91,0],[255,234,0],[85, 111, 181],[181, 227, 14],[72, 0, 255],[0, 155, 33],[255,0,255],[29, 32, 136],[160, 15, 95],[0,160,233],[52,184,178],[90,120,41],[255,0,0],[177,0,0],[167,24,233],[112,113,150],[0,255,0],[255,255,255],[0,255,255],[138,251,213],[136,162,196],[197,83,181],[202,202,200],[113,102,140],[66,115,82],[240,16,116],[155,132,0],[155,62,0],[146,175,236],[255,172,159],[245,161,0],[134,124,118], [0,157,142], [181,85,105], [42,8,66],[0,0,0]])
|
||||
# self.src_CLASSES_NUM = np.shape(self.src_CLASSES)[0]
|
||||
self.bg_PALETTE = bg_PALETTE # 背景颜色 TODO
|
||||
|
||||
self.Annotate_CLASSES = Annotate_CLASSES # 待分类的类
|
||||
self.Annotate_PALETTE = np.array(Annotate_PALETTE) # 每一类的像素直
|
||||
self.Annotate_CLASSES_NUM = np.shape(Annotate_CLASSES)[0] # 类数量
|
||||
|
||||
self.save_process_pics = save_process_pics # 保存中间过程图片
|
||||
|
||||
self.src_label_fold = src_label_fold # 原始标签图片 保存位置
|
||||
self.save_pro_label_fold = save_pro_label_fold # 优化后标签图片 保存位置
|
||||
self.save_GT_label_fold = save_GT_label_fold # GT标签图片 保存位置
|
||||
|
||||
self.ori_img_folder = ori_img_folder # 最原始手术图片 保存位置
|
||||
self.res_label_folder = res_label_folder # 训练出来的label 保存位置
|
||||
self.save_merge_pic_folder = save_merge_pic_folder # 融合图像保存位置
|
||||
|
||||
self.pro_append_name = pro_append_name # 优化后标签图片后缀
|
||||
self.GT_append_name = GT_append_name # GT标签图片后缀
|
||||
self.GT_channel = GT_channel # GT标签图片通道数
|
||||
|
||||
self.Max_width = Max_width # 最大图片宽度(匹配时候用)
|
||||
self.pic_type = pic_type # 图片类型
|
||||
self.back_gnd_color = back_gnd_color # 背景颜色
|
||||
self.first_class_color = first_class_color # 第一类上的颜色
|
||||
self.Label_Max_Search_layer=Label_Max_Search_layer # 文件夹最大搜索深度
|
||||
try:
|
||||
self.labellist_src = getFileList(src_label_fold, [], pic_type, self.Label_Max_Search_layer)
|
||||
print('本次执行检索到ori_label图片 '+str(len(self.labellist_src))+' 张图像')
|
||||
except:
|
||||
self.labellist_src = None
|
||||
print("没有ori_label相关文件")
|
||||
|
||||
try:
|
||||
# print(save_pro_label_fold)
|
||||
self.labellist_pro = getFileList(save_pro_label_fold, [], pic_type, self.Label_Max_Search_layer)
|
||||
print('本次执行检索到pro_label图片 '+str(len(self.labellist_pro))+' 张图像')
|
||||
except:
|
||||
self.labellist_pro = None
|
||||
print("没有pro_label相关文件")
|
||||
|
||||
try:
|
||||
self.imglist_src = getFileList(ori_img_folder, [], pic_type, self.Label_Max_Search_layer)
|
||||
self.reslist_src = getFileList(res_label_folder, [], pic_type, self.Label_Max_Search_layer)
|
||||
print('本次执行检索到ori原始图片 '+str(len(self.imglist_src))+' 张图像')
|
||||
print('本次执行检索到训练train_result图片 '+str(len(self.reslist_src))+' 张图像')
|
||||
except:
|
||||
self.imglist_src = None
|
||||
self.reslist_src = None
|
||||
print("没有train_result和原始图片相关文件")
|
||||
|
||||
# 获取单张图片各个通路信息
|
||||
def get_single_pic_rgb(self, imgpath):
|
||||
print(imgpath)
|
||||
image = Image.open(imgpath).convert('RGB') # 转为RGB图片
|
||||
# 将 RGB 色值分离
|
||||
image.load()
|
||||
r, g, b = image.split()
|
||||
r = np.array(r)
|
||||
g = np.array(g)
|
||||
b = np.array(b)
|
||||
return image, r, g, b
|
||||
|
||||
# 将单个pro图片变成GT图片
|
||||
def Conver_pro_label_pic_2_GT_pic(self, imgpath, imgname):
|
||||
time_start=time.time() # 记录开始时间
|
||||
# 获取单张图片各个通路信息
|
||||
image, r,g,b = self.get_single_pic_rgb(imgpath)
|
||||
|
||||
result_gt = np.ones(np.shape(image))*self.back_gnd_color # 初始化填充内容为back_gnd_color
|
||||
gt_number = self.first_class_color # 第一类上色颜色确定
|
||||
|
||||
# PALETTE中排除掉 '背景' [0,0,0]
|
||||
PALETTE_No_Bg = self.Annotate_PALETTE[~np.all(self.Annotate_PALETTE == self.bg_PALETTE, axis=1)]
|
||||
|
||||
# 遍历所有待识别颜色
|
||||
for [Annotate_PALETTE_r, Annotate_PALETTE_g, Annotate_PALETTE_b] in PALETTE_No_Bg:
|
||||
# 查找三原色匹配位置
|
||||
locate_r = np.where( r == Annotate_PALETTE_r )
|
||||
locate_g = np.where( g == Annotate_PALETTE_g )
|
||||
locate_b = np.where( b == Annotate_PALETTE_b )
|
||||
|
||||
# 查找都匹配位置(交集)
|
||||
# 将矩阵换一种表示形式
|
||||
locate_r = np.array(locate_r[0]) * Max_width + np.array(locate_r[1])
|
||||
locate_g = np.array(locate_g[0]) * Max_width + np.array(locate_g[1])
|
||||
locate_b = np.array(locate_b[0]) * Max_width + np.array(locate_b[1])
|
||||
|
||||
# 用自带函数寻找匹配项
|
||||
matched = np.intersect1d(np.intersect1d(locate_r, locate_g), locate_b)
|
||||
matched = np.concatenate(([matched // self.Max_width], [np.mod(matched, self.Max_width)]), 0)
|
||||
result_gt[matched[0],matched[1], :] = gt_number
|
||||
gt_number = gt_number + 1
|
||||
|
||||
# 输出GT图片
|
||||
if(int(self.GT_channel) == 1):
|
||||
result_gt = result_gt[:,:,0]
|
||||
elif(int(self.GT_channel) == 3):
|
||||
result_gt = cv2.cvtColor(np.float32(result_gt), cv2.COLOR_RGB2BGR) # rgb颜色互换
|
||||
else:
|
||||
print("GT_channel 必须为1或3")
|
||||
quit
|
||||
try: # 新建文件夹
|
||||
os.mkdir(self.save_GT_label_fold)
|
||||
except:
|
||||
print("已有"+self.save_GT_label_fold)
|
||||
if imgname.lower().endswith(('.jpg', '.png')):
|
||||
save_dir = os.path.join(self.save_GT_label_fold, os.path.basename(imgname).rpartition('.')[0]+self.GT_append_name+'.'+self.pic_type)
|
||||
else:
|
||||
save_dir = os.path.join(self.save_GT_label_fold, os.path.basename(imgname)+self.GT_append_name+'.'+self.pic_type)
|
||||
cv2.imwrite(save_dir, result_gt)
|
||||
print("GT图片已保存", save_dir)
|
||||
time_end=time.time() # 输出结束时间
|
||||
print('time cost',time_end-time_start,'s')
|
||||
|
||||
# 将处理好的图片转化为GT图片
|
||||
def Conver_pro_label_pic_2_GT_pic_all(self):
|
||||
print("\033[33m**** 进行转换将Pro_label_pic转换为GT_label_pic ****\033[0m")
|
||||
print("\033[33mPro_label_pic存储位置为:\033[0m", self.save_pro_label_fold)
|
||||
print("\033[33mGT_label_pic生成位置为:\033[0m", self.save_GT_label_fold)
|
||||
try:
|
||||
# print(save_pro_label_fold)
|
||||
self.labellist_pro = getFileList(save_pro_label_fold, [], pic_type, self.Label_Max_Search_layer)
|
||||
print('本次执行检索到pro_label图片 '+str(len(self.labellist_pro))+' 张图像')
|
||||
except:
|
||||
self.labellist_pro = None
|
||||
print("没有pro_label相关文件")
|
||||
try:
|
||||
os.mkdir(self.save_GT_label_fold) # 新建存储文件夹
|
||||
except:
|
||||
print("已有"+self.save_GT_label_fold)
|
||||
|
||||
# 指定最大进程数为 3
|
||||
max_processes = 20
|
||||
# 创建Pool对象
|
||||
pool = multiprocessing.Pool(processes=max_processes)
|
||||
# 创建并启动进程
|
||||
args_list1 = []
|
||||
args_list2 = []
|
||||
|
||||
# 遍历整个文件夹
|
||||
for imgpath in self.labellist_pro:
|
||||
imgname = os.path.basename(imgpath).rpartition('.')[0].replace(self.pro_append_name,"")
|
||||
args_list1.append(imgpath)
|
||||
args_list2.append(imgname)
|
||||
args_list = zip(args_list1, args_list2)
|
||||
# 使用进程池并行执行任务
|
||||
pool.starmap(self.Conver_pro_label_pic_2_GT_pic, args_list)
|
||||
# 关闭进程池
|
||||
pool.close()
|
||||
pool.join()
|
||||
|
||||
def Conver_ori_label_pic_2_pro_pic(self, imgpath, imgname):
|
||||
time_start=time.time() # 记录开始时间
|
||||
# 获取单张图片各个通路信息
|
||||
image = cv2.imread(imgpath)
|
||||
|
||||
# 1. 边缘检测并膨胀
|
||||
dilated_image = edge_detection(image)
|
||||
# 如果需要存储中间态图片
|
||||
if(self.save_process_pics == True):
|
||||
if imgname.lower().endswith(('.jpg', '.png')):
|
||||
save_dir = os.path.join(self.save_pro_label_fold, '1_边缘检测并膨胀', os.path.basename(imgname).rpartition('.')[0]+self.pro_append_name+'_Edge'+'.'+self.pic_type)
|
||||
else:
|
||||
save_dir = os.path.join(self.save_pro_label_fold, '1_边缘检测并膨胀', os.path.basename(imgname)+self.pro_append_name+'_Edge'+'.'+self.pic_type)
|
||||
cv2.imwrite(save_dir, dilated_image)
|
||||
print("中间态-边缘检测并膨胀 图片已保存", save_dir)
|
||||
time_end=time.time() # 输出结束时间
|
||||
print('time cost',time_end-time_start,'s')
|
||||
|
||||
# 2. 检测连通区域
|
||||
filtered_labeled_array, _ = detect_connected_regions(dilated_image)
|
||||
colored_image_filtered = Tool_color_connected_array(filtered_labeled_array)
|
||||
# 如果需要存储中间态图片
|
||||
if(self.save_process_pics == True):
|
||||
if imgname.lower().endswith(('.jpg', '.png')):
|
||||
save_dir = os.path.join(self.save_pro_label_fold, '2_连通区域检测', os.path.basename(imgname).rpartition('.')[0]+self.pro_append_name+'_Region'+'.'+self.pic_type)
|
||||
else:
|
||||
save_dir = os.path.join(self.save_pro_label_fold, '2_连通区域检测', os.path.basename(imgname)+self.pro_append_name+'_Region'+'.'+self.pic_type)
|
||||
cv2.imwrite(save_dir, colored_image_filtered)
|
||||
print("中间态-连通区域检测 图片已保存", save_dir)
|
||||
time_end=time.time() # 输出结束时间
|
||||
print('time cost',time_end-time_start,'s')
|
||||
|
||||
# 3. 分水岭填充白色区域
|
||||
filled_labeled_array = fill_white_regions(filtered_labeled_array)
|
||||
colored_image_filled = Tool_color_connected_array(filled_labeled_array)
|
||||
# 如果需要存储中间态图片
|
||||
if(self.save_process_pics == True):
|
||||
if imgname.lower().endswith(('.jpg', '.png')):
|
||||
save_dir = os.path.join(self.save_pro_label_fold, '3_分水岭算法填充', os.path.basename(imgname).rpartition('.')[0]+self.pro_append_name+'_FillEdge'+'.'+self.pic_type)
|
||||
else:
|
||||
save_dir = os.path.join(self.save_pro_label_fold, '3_分水岭算法填充', os.path.basename(imgname)+self.pro_append_name+'_FillEdge'+'.'+self.pic_type)
|
||||
cv2.imwrite(save_dir, colored_image_filled)
|
||||
print("中间态-分水岭算法填充 图片已保存", save_dir)
|
||||
time_end=time.time() # 输出结束时间
|
||||
print('time cost',time_end-time_start,'s')
|
||||
|
||||
# 4. 对连通区域最终上色
|
||||
ori_labeled_image = image
|
||||
result_pro = color_connected_regions(filled_labeled_array, filtered_labeled_array, ori_labeled_image, self.Annotate_PALETTE)
|
||||
if imgname.lower().endswith(('.jpg', '.png')):
|
||||
save_dir = os.path.join(self.save_pro_label_fold, os.path.basename(imgname).rpartition('.')[0]+self.pro_append_name+'.'+self.pic_type)
|
||||
else:
|
||||
save_dir = os.path.join(self.save_pro_label_fold, os.path.basename(imgname)+self.pro_append_name+'.'+self.pic_type)
|
||||
print("Pro图片已保存", save_dir)
|
||||
cv2.imwrite(save_dir, result_pro)
|
||||
time_end=time.time() # 输出结束时间
|
||||
print('time cost',time_end-time_start,'s')
|
||||
|
||||
# 将原始src图片转化为处理好的pro图片
|
||||
def Conver_ori_label_pic_2_pro_pic_all(self):
|
||||
print("\033[33m**** 进行转换将Ori_label_pic转换为Pro_label_pic ****\033[0m")
|
||||
print("\033[33mOri_label_pic存储位置为:\033[0m", self.src_label_fold)
|
||||
print("\033[33mPro_label_pic生成位置为:\033[0m", self.save_pro_label_fold)
|
||||
# 输出颜色预处理图片
|
||||
try:
|
||||
os.mkdir(self.save_pro_label_fold) # 新建存储文件夹
|
||||
except:
|
||||
print("已有"+self.save_pro_label_fold)
|
||||
if(self.save_process_pics == True):
|
||||
try:
|
||||
os.mkdir(os.path.join(self.save_pro_label_fold, '1_边缘检测并膨胀')) # 新建存储1_边缘检测并膨胀文件夹
|
||||
except:
|
||||
print("已有"+os.path.join(self.save_pro_label_fold, '1_边缘检测并膨胀'))
|
||||
try:
|
||||
os.mkdir(os.path.join(self.save_pro_label_fold, '2_连通区域检测')) # 新建存储2_连通区域检测文件夹
|
||||
except:
|
||||
print("已有"+os.path.join(self.save_pro_label_fold, '2_连通区域检测'))
|
||||
try:
|
||||
os.mkdir(os.path.join(self.save_pro_label_fold, '3_分水岭算法填充')) # 新建存储1_边缘检测并膨胀文件夹
|
||||
except:
|
||||
print("已有"+os.path.join(self.save_pro_label_fold, '3_分水岭算法填充'))
|
||||
|
||||
# 指定最大进程数为 20,多参数函数并行
|
||||
max_processes = 20
|
||||
# 创建Pool对象
|
||||
pool = multiprocessing.Pool(processes=max_processes)
|
||||
# 创建并启动进程
|
||||
args_list1 = []
|
||||
args_list2 = []
|
||||
|
||||
# 遍历整个文件夹
|
||||
for imgpath in self.labellist_src:
|
||||
if imgpath.lower().endswith(('.jpg', '.png')):
|
||||
imgname= os.path.basename(imgpath).rpartition('.')[0].replace(self.pro_append_name,"")
|
||||
else:
|
||||
imgname= os.path.basename(imgpath).replace(self.pro_append_name,"")
|
||||
try:
|
||||
print("Processing: ", imgname, "...")
|
||||
# self.Conver_ori_label_pic_2_pro_pic(imgpath, imgname)s
|
||||
# args_list.append({'imgpath': imgpath, 'imgname': imgname})
|
||||
args_list1.append(imgpath)
|
||||
args_list2.append(imgname)
|
||||
except:
|
||||
os.system("echo "+imgname+" >> error_1.txt")
|
||||
args_list = zip(args_list1, args_list2)
|
||||
# 使用进程池并行执行任务
|
||||
pool.starmap(self.Conver_ori_label_pic_2_pro_pic, args_list) # 使用starmap进行多参数并行
|
||||
# 关闭进程池
|
||||
pool.close()
|
||||
pool.join()
|
||||
|
||||
# 图片堆叠
|
||||
def Merge_ori_pic_and_label_pic(self, res_img_path, res_imgname):
|
||||
time_start=time.time() # 记录开始时间
|
||||
# 获取单张图片各个通路信息
|
||||
ori_img_path = os.path.join(self.ori_img_folder, res_imgname+'.'+self.pic_type)
|
||||
if not os.path.exists(ori_img_path):
|
||||
print("****照片不存在:****", ori_img_path)
|
||||
return -1
|
||||
ori_image, ori_r, ori_g, ori_b = self.get_single_pic_rgb(ori_img_path)
|
||||
res_image, res_r, res_g, res_b = self.get_single_pic_rgb(res_img_path)
|
||||
|
||||
merge_img = np.array(ori_image) # merge图片初始化,默认图片背景为0.0.0
|
||||
|
||||
# 遍历所有待识别颜色
|
||||
for [Annotate_PALETTE_r, Annotate_PALETTE_g, Annotate_PALETTE_b] in self.Annotate_PALETTE:
|
||||
# 查找三原色匹配位置
|
||||
locate_r = np.where( res_r == Annotate_PALETTE_r )
|
||||
locate_g = np.where( res_g == Annotate_PALETTE_g )
|
||||
locate_b = np.where( res_b == Annotate_PALETTE_b )
|
||||
|
||||
# 查找都匹配位置(交集)
|
||||
# 将矩阵换一种表示形式
|
||||
locate_r = np.array(locate_r[0]) * self.Max_width + np.array(locate_r[1])
|
||||
locate_g = np.array(locate_g[0]) * self.Max_width + np.array(locate_g[1])
|
||||
locate_b = np.array(locate_b[0]) * self.Max_width + np.array(locate_b[1])
|
||||
|
||||
# 用自带函数寻找匹配项
|
||||
matched = np.intersect1d(np.intersect1d(locate_r, locate_g), locate_b)
|
||||
matched = np.concatenate(([matched // self.Max_width], [np.mod(matched, self.Max_width)]), 0)
|
||||
merge_img[matched[0],matched[1], 0] = Annotate_PALETTE_r
|
||||
merge_img[matched[0],matched[1], 1] = Annotate_PALETTE_g
|
||||
merge_img[matched[0],matched[1], 2] = Annotate_PALETTE_b
|
||||
|
||||
# 转成cv2形式
|
||||
merge_img = cv2.cvtColor(np.float32(merge_img), cv2.COLOR_RGB2BGR)
|
||||
|
||||
try: # 新建文件夹
|
||||
os.mkdir(self.save_merge_pic_folder)
|
||||
except:
|
||||
print("已有"+self.save_merge_pic_folder)
|
||||
if res_imgname.lower().endswith(('.jpg', '.png')):
|
||||
save_dir = os.path.join(self.save_merge_pic_folder, os.path.basename(res_imgname).rpartition('.')[0]+'.'+self.pic_type)
|
||||
else:
|
||||
save_dir = os.path.join(self.save_merge_pic_folder, os.path.basename(res_imgname)+'.'+self.pic_type)
|
||||
|
||||
|
||||
cv2.imwrite(save_dir, merge_img)
|
||||
print("Merge图片已保存", save_dir)
|
||||
time_end=time.time() # 输出结束时间
|
||||
print('time cost',time_end-time_start,'s')
|
||||
|
||||
# 将label图片与原图片重合
|
||||
def Merge_ori_pic_and_label_pic_all(self):
|
||||
# 遍历整个文件夹
|
||||
for res_img_path in self.reslist_src:
|
||||
if res_img_path.lower().endswith(('.jpg', '.png')):
|
||||
res_imgname = os.path.basename(res_img_path).rpartition('.')[0].replace(self.pro_append_name,"")
|
||||
else:
|
||||
res_imgname = os.path.basename(res_img_path).replace(self.pro_append_name,"")
|
||||
print("Processing: ", res_imgname, "...")
|
||||
self.Merge_ori_pic_and_label_pic(res_img_path, res_imgname)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Annotate_CLASSES = ('肝脏','胆囊','分离钳','止血海绵','肝总管','胆总管','吸引器','剪刀','止血纱布','生物夹','无损伤钳','喷洒','胆囊管','胆囊动脉','电凝','标本袋','引流管','纱布','金属钛夹','术中超声','吻合器','乳胶管','推结器','肝带','钳夹','超声刀','脂肪','双极电凝','棉球','血管阻断夹','肿瘤','针','线','韧带','胆囊静脉','背景') # 待分类的类
|
||||
Annotate_PALETTE = [[255,91,0],[255,234,0],[85, 111, 181],[181, 227, 14],[72, 0, 255],[0, 155, 33],[255,0,255],[29, 32, 136],[160, 15, 95],[0,160,233],[52,184,178],[90,120,41],[255,0,0],[177,0,0],[167,24,233],[112,113,150],[0,255,0],[255,255,255],[0,255,255],[138,251,213],[136,162,196],[197,83,181],[202,202,200],[113,102,140],[66,115,82],[240,16,116],[155,132,0],[155,62,0],[146,175,236],[255,172,159],[245,161,0],[134,124,118], [0,157,142], [181,85,105], [42,8,66],[0,0,0]] # 每一类的像素直
|
||||
bg_PALETTE = [0,0,0] # 背景的RGB
|
||||
|
||||
# 创建参数解析器
|
||||
parser = argparse.ArgumentParser(description='Process some files.')
|
||||
# 添加参数选项
|
||||
parser.add_argument('-src_fold', dest='src_label_fold', default='', help='source label folder')
|
||||
parser.add_argument('-save_pro_fold', dest='save_pro_label_fold', default='./save_pro_label_fold', help='processed label folder')
|
||||
parser.add_argument('-save_GT_fold', dest='save_GT_label_fold', default='./save_GT_label_fold', help='ground truth folder')
|
||||
parser.add_argument('-fold_search_depth', dest='Label_Max_Search_layer', default='1000', type=int, help='Folder Search Depth')
|
||||
parser.add_argument('-pro_suffix_name', dest='pro_append_name', default='_label', help='Pro file suffix')
|
||||
parser.add_argument('-GT_suffix_name', dest='GT_append_name', default='_gtFine_labelTrainIds', help='GT file suffix')
|
||||
parser.add_argument('-GT_channel', dest='GT_channel', default='1', type=int, help='GT file channel(1 or 3)')
|
||||
parser.add_argument('-back_gnd_color', dest='back_gnd_color', default='0', type=int, help='Color of "Back ground"(0 or 255)')
|
||||
parser.add_argument('-first_class_color', dest='first_class_color', default='1', type=int, help='Color of "First Class"')
|
||||
parser.add_argument('-pic_type', dest='pic_type', default='png', help='type of picture(Do not add ".")')
|
||||
parser.add_argument('-Max_width', dest='Max_width', default='10000', type=int, help='Max width of picture')
|
||||
parser.add_argument('-Rebuild_from', dest='Rebuild_from', default='label', help='Source to Rebuild Labels(label/pro)')
|
||||
parser.add_argument('-Rebuild_to', dest='Rebuild_to', default='GT', help='Destination of Rebuild Labels(pro/GT)')
|
||||
parser.add_argument('-save_process_pics', dest='save_process_pics', default='False', help='Save the processed pics(e.g.Gray_pics,Color_pics) in generating pro_pics')
|
||||
|
||||
# 解析命令行参数
|
||||
args = parser.parse_args()
|
||||
|
||||
src_label_fold = args.src_label_fold
|
||||
save_pro_label_fold = args.save_pro_label_fold
|
||||
save_GT_label_fold = args.save_GT_label_fold
|
||||
Label_Max_Search_layer = args.Label_Max_Search_layer
|
||||
pro_append_name = args.pro_append_name
|
||||
GT_append_name = args.GT_append_name
|
||||
GT_channel = args.GT_channel
|
||||
back_gnd_color = args.back_gnd_color
|
||||
first_class_color = args.first_class_color
|
||||
pic_type = args.pic_type
|
||||
Max_width = args.Max_width
|
||||
Rebuild_from = args.Rebuild_from
|
||||
Rebuild_to = args.Rebuild_to
|
||||
save_process_pics = args.save_process_pics
|
||||
|
||||
|
||||
try: # 遍历文件深度,最小为1
|
||||
Label_Max_Search_layer=int(Label_Max_Search_layer)
|
||||
except:
|
||||
Label_Max_Search_layer=1000
|
||||
try: # GT标签图片通道数
|
||||
GT_channel=int(GT_channel)
|
||||
except:
|
||||
GT_channel=1
|
||||
try: # 背景颜色(背景选择0或255)
|
||||
back_gnd_color=int(back_gnd_color)
|
||||
except:
|
||||
back_gnd_color=0
|
||||
try: # 第一类上的颜色(如果背景为0,选择1;)
|
||||
first_class_color=int(first_class_color)
|
||||
except:
|
||||
first_class_color=1
|
||||
try: # 最大图片宽度(匹配时候用)
|
||||
Max_width=int(Max_width)
|
||||
except:
|
||||
Max_width=10000
|
||||
if(save_process_pics.lower() == 'false'):
|
||||
save_process_pics = False
|
||||
elif(save_process_pics.lower() == 'true'):
|
||||
save_process_pics = True
|
||||
else:
|
||||
save_process_pics = False
|
||||
|
||||
D = Deal_image(Annotate_CLASSES=Annotate_CLASSES, Annotate_PALETTE=Annotate_PALETTE, src_label_fold=src_label_fold, save_pro_label_fold=save_pro_label_fold, save_GT_label_fold=save_GT_label_fold, GT_channel=GT_channel, pro_append_name=pro_append_name, GT_append_name=GT_append_name, back_gnd_color=back_gnd_color, first_class_color=first_class_color, pic_type=pic_type, Max_width=Max_width, Label_Max_Search_layer=Label_Max_Search_layer, save_process_pics=save_process_pics, bg_PALETTE = bg_PALETTE)
|
||||
# print(D.src_CLASSES_NUM)
|
||||
if Rebuild_from == 'label':
|
||||
# 1.先将所有原始图片转为pro图片
|
||||
D.Conver_ori_label_pic_2_pro_pic_all()
|
||||
pass
|
||||
if Rebuild_to == 'GT':
|
||||
# 2.再将pro图片转为GT图片
|
||||
D.Conver_pro_label_pic_2_GT_pic_all()
|
||||
pass
|
||||
463
DataSet_Own/1. 图片预处理(内含使用手册)/4_deal_labels_old(老版程序).py
Executable file
463
DataSet_Own/1. 图片预处理(内含使用手册)/4_deal_labels_old(老版程序).py
Executable file
@@ -0,0 +1,463 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*
|
||||
import os,time,sys,threading, colorsys, argparse
|
||||
import asyncio, cv2, multiprocessing
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
def getFileList(dir,Filelist=[], ext=None, Max_layer=1, layer=0, Donot_Search=['1_颜色预处理', '2_灰度化', '3_边缘化']):
|
||||
"""
|
||||
获取文件夹及其子文件夹中文件列表
|
||||
输入 dir:文件夹根目录
|
||||
输入 ext: 扩展名
|
||||
返回: 文件路径列表
|
||||
"""
|
||||
newDir = dir
|
||||
if os.path.isfile(dir):
|
||||
if ext is None:
|
||||
Filelist.append(dir)
|
||||
else:
|
||||
if ext in dir[-3:]:
|
||||
Filelist.append(dir)
|
||||
|
||||
elif os.path.isdir(dir):
|
||||
file_name = os.path.basename(dir)
|
||||
# 判断是否在禁搜名单中
|
||||
if file_name in Donot_Search:
|
||||
return Filelist
|
||||
for s in os.listdir(dir):
|
||||
newDir=os.path.join(dir,s)
|
||||
if layer <= Max_layer:
|
||||
getFileList(newDir, Filelist, ext, Max_layer, layer+1)
|
||||
|
||||
return Filelist
|
||||
|
||||
class Deal_image():
|
||||
def __init__(self, Annotate_CLASSES = ('肝脏','胆囊'), Annotate_PALETTE = [[255,91,0],[255,234,0]], src_label_fold = "./Label", save_pro_label_fold = "./LABEL_PNG_new", save_GT_label_fold = "./Label_Generate", GT_channel = 1, pro_append_name="_label", GT_append_name="_gtFine_labelTrainIds", ori_img_folder="./ORI_PNG", res_label_folder="./Result_label", save_merge_pic_folder="./Result_merge", back_gnd_color=0, first_class_color=1, pic_type="png", Max_width = 10000, Label_Max_Search_layer=1000, save_process_pics=False):
|
||||
self.src_CLASSES = ('肝脏','胆囊','分离钳','止血海绵','肝总管','胆总管','吸引器','剪刀','止血纱布','生物夹','无损伤钳','喷洒','胆囊管','胆囊动脉','电凝','标本袋','引流管','纱布','金属钛夹','背景')
|
||||
self.src_PALETTE = np.array([[255,91,0],[255,234,0],[85, 107, 179],[181, 227, 14],[72, 0, 255],[0, 155, 33],[255,0,255],[29, 32, 136],[160, 15, 95],[0,160,233],[52,184,178],[90,120,41],[255,0,0],[177,0,0],[167,24,233],[112,113,150],[0,255,0],[255,255,255],[0,255,255],[0,0,0]])
|
||||
self.src_CLASSES_NUM = np.shape(self.src_CLASSES)[0]
|
||||
|
||||
self.Annotate_CLASSES = Annotate_CLASSES # 待分类的类
|
||||
self.Annotate_PALETTE = np.array(Annotate_PALETTE) # 每一类的像素直
|
||||
self.Annotate_CLASSES_NUM = np.shape(Annotate_CLASSES)[0] # 类数量
|
||||
|
||||
self.save_process_pics = save_process_pics # 保存中间过程图片
|
||||
|
||||
self.src_label_fold = src_label_fold # 原始标签图片 保存位置
|
||||
self.save_pro_label_fold = save_pro_label_fold # 优化后标签图片 保存位置
|
||||
self.save_GT_label_fold = save_GT_label_fold # GT标签图片 保存位置
|
||||
|
||||
self.ori_img_folder = ori_img_folder # 最原始手术图片 保存位置
|
||||
self.res_label_folder = res_label_folder # 训练出来的label 保存位置
|
||||
self.save_merge_pic_folder = save_merge_pic_folder # 融合图像保存位置
|
||||
|
||||
self.pro_append_name = pro_append_name # 优化后标签图片后缀
|
||||
self.GT_append_name = GT_append_name # GT标签图片后缀
|
||||
self.GT_channel = GT_channel # GT标签图片通道数
|
||||
|
||||
self.Max_width = Max_width # 最大图片宽度(匹配时候用)
|
||||
self.pic_type = pic_type # 图片类型
|
||||
self.back_gnd_color = back_gnd_color # 背景颜色
|
||||
self.first_class_color = first_class_color # 第一类上的颜色
|
||||
self.Label_Max_Search_layer=Label_Max_Search_layer # 文件夹最大搜索深度
|
||||
try:
|
||||
self.labellist_src = getFileList(src_label_fold, [], pic_type, self.Label_Max_Search_layer)
|
||||
print('本次执行检索到ori_label图片 '+str(len(self.labellist_src))+' 张图像')
|
||||
except:
|
||||
self.labellist_src = None
|
||||
print("没有ori_label相关文件")
|
||||
|
||||
try:
|
||||
# print(save_pro_label_fold)
|
||||
self.labellist_pro = getFileList(save_pro_label_fold, [], pic_type, self.Label_Max_Search_layer)
|
||||
print('本次执行检索到pro_label图片 '+str(len(self.labellist_pro))+' 张图像')
|
||||
except:
|
||||
self.labellist_pro = None
|
||||
print("没有pro_label相关文件")
|
||||
|
||||
try:
|
||||
self.imglist_src = getFileList(ori_img_folder, [], pic_type, self.Label_Max_Search_layer)
|
||||
self.reslist_src = getFileList(res_label_folder, [], pic_type, self.Label_Max_Search_layer)
|
||||
print('本次执行检索到ori原始图片 '+str(len(self.imglist_src))+' 张图像')
|
||||
print('本次执行检索到训练train_result图片 '+str(len(self.reslist_src))+' 张图像')
|
||||
except:
|
||||
self.imglist_src = None
|
||||
self.reslist_src = None
|
||||
print("没有train_result和原始图片相关文件")
|
||||
|
||||
# 获取单张图片各个通路信息
|
||||
def get_single_pic_rgb(self, imgpath):
|
||||
print(imgpath)
|
||||
image = Image.open(imgpath).convert('RGB') # 转为RGB图片
|
||||
# 将 RGB 色值分离
|
||||
image.load()
|
||||
r, g, b = image.split()
|
||||
r = np.array(r)
|
||||
g = np.array(g)
|
||||
b = np.array(b)
|
||||
return image, r, g, b
|
||||
|
||||
# 将单个pro图片变成GT图片
|
||||
def Conver_pro_label_pic_2_GT_pic(self, imgpath, imgname):
|
||||
time_start=time.time() # 记录开始时间
|
||||
# 获取单张图片各个通路信息
|
||||
image, r,g,b = self.get_single_pic_rgb(imgpath)
|
||||
|
||||
result_gt = np.ones(np.shape(image))*self.back_gnd_color # 初始化填充内容为back_gnd_color
|
||||
gt_number = self.first_class_color # 第一类上色颜色确定
|
||||
|
||||
# 遍历所有待识别颜色
|
||||
for [Annotate_PALETTE_r, Annotate_PALETTE_g, Annotate_PALETTE_b] in self.Annotate_PALETTE:
|
||||
# 查找三原色匹配位置
|
||||
locate_r = np.where( r == Annotate_PALETTE_r )
|
||||
locate_g = np.where( g == Annotate_PALETTE_g )
|
||||
locate_b = np.where( b == Annotate_PALETTE_b )
|
||||
|
||||
# 查找都匹配位置(交集)
|
||||
# 将矩阵换一种表示形式
|
||||
locate_r = np.array(locate_r[0]) * Max_width + np.array(locate_r[1])
|
||||
locate_g = np.array(locate_g[0]) * Max_width + np.array(locate_g[1])
|
||||
locate_b = np.array(locate_b[0]) * Max_width + np.array(locate_b[1])
|
||||
|
||||
# 用自带函数寻找匹配项
|
||||
matched = np.intersect1d(np.intersect1d(locate_r, locate_g), locate_b)
|
||||
matched = np.concatenate(([matched // self.Max_width], [np.mod(matched, self.Max_width)]), 0)
|
||||
result_gt[matched[0],matched[1], :] = gt_number
|
||||
gt_number = gt_number + 1
|
||||
|
||||
# 输出GT图片
|
||||
if(int(self.GT_channel) == 1):
|
||||
result_gt = result_gt[:,:,0]
|
||||
elif(int(self.GT_channel) == 3):
|
||||
result_gt = cv2.cvtColor(np.float32(result_gt), cv2.COLOR_RGB2BGR) # rgb颜色互换
|
||||
else:
|
||||
print("GT_channel 必须为1或3")
|
||||
quit
|
||||
try: # 新建文件夹
|
||||
os.mkdir(self.save_GT_label_fold)
|
||||
except:
|
||||
print("已有"+self.save_GT_label_fold)
|
||||
save_dir = os.path.join(self.save_GT_label_fold, os.path.splitext(imgname)[0]+self.GT_append_name+'.'+self.pic_type)
|
||||
cv2.imwrite(save_dir, result_gt)
|
||||
print("GT图片已保存", save_dir)
|
||||
time_end=time.time() # 输出结束时间
|
||||
print('time cost',time_end-time_start,'s')
|
||||
|
||||
# 将处理好的图片转化为GT图片
|
||||
def Conver_pro_label_pic_2_GT_pic_all(self):
|
||||
print("\033[33m**** 进行转换将Pro_label_pic转换为GT_label_pic ****\033[0m")
|
||||
print("\033[33mPro_label_pic存储位置为:\033[0m", self.save_pro_label_fold)
|
||||
print("\033[33mGT_label_pic生成位置为:\033[0m", self.save_GT_label_fold)
|
||||
try:
|
||||
# print(save_pro_label_fold)
|
||||
self.labellist_pro = getFileList(save_pro_label_fold, [], pic_type, self.Label_Max_Search_layer)
|
||||
print('本次执行检索到pro_label图片 '+str(len(self.labellist_pro))+' 张图像')
|
||||
except:
|
||||
self.labellist_pro = None
|
||||
print("没有pro_label相关文件")
|
||||
try:
|
||||
os.mkdir(self.save_GT_label_fold) # 新建存储文件夹
|
||||
except:
|
||||
print("已有"+self.save_GT_label_fold)
|
||||
|
||||
# 指定最大进程数为 3
|
||||
max_processes = 20
|
||||
# 创建Pool对象
|
||||
pool = multiprocessing.Pool(processes=max_processes)
|
||||
# 创建并启动进程
|
||||
args_list1 = []
|
||||
args_list2 = []
|
||||
|
||||
# 遍历整个文件夹
|
||||
for imgpath in self.labellist_pro:
|
||||
imgname= os.path.splitext(os.path.basename(imgpath))[0].replace(self.pro_append_name,"")
|
||||
args_list1.append(imgpath)
|
||||
args_list2.append(imgname)
|
||||
args_list = zip(args_list1, args_list2)
|
||||
# 使用进程池并行执行任务
|
||||
pool.starmap(self.Conver_pro_label_pic_2_GT_pic, args_list)
|
||||
# 关闭进程池
|
||||
pool.close()
|
||||
pool.join()
|
||||
|
||||
def Conver_ori_label_pic_2_pro_pic(self, imgpath, imgname):
|
||||
time_start=time.time() # 记录开始时间
|
||||
# 获取单张图片各个通路信息
|
||||
image, r, g, b = self.get_single_pic_rgb(imgpath)
|
||||
result_pro = np.zeros(np.shape(image)) # pro图片初始化,默认图片背景为0.0.0
|
||||
|
||||
# 生成距离初始矩阵 width*height*num_of_classes
|
||||
Dis_mat = np.zeros((np.shape(image)[0], np.shape(image)[1], self.src_CLASSES_NUM))
|
||||
|
||||
# 遍历所有类
|
||||
i = 0 # 顺序指示物
|
||||
# 遍历寻找最短距离是第几个内容
|
||||
for palette in self.src_PALETTE:
|
||||
Dis_mat[:,:,i] = np.abs((r - palette[0])) + np.abs((g - palette[1])) + np.abs((b - palette[2]))
|
||||
i = i + 1
|
||||
Min_Dis_mat = np.argmin(Dis_mat, axis = 2)
|
||||
|
||||
# 给图片上色
|
||||
for number_class in range(self.src_CLASSES_NUM):
|
||||
# 查找图片种类
|
||||
loc_of_class_pic = np.where(Min_Dis_mat == number_class)
|
||||
# 否则给图片上对应颜色
|
||||
result_pro[loc_of_class_pic[0], loc_of_class_pic[1], :] = self.src_PALETTE[number_class]
|
||||
result_pro = cv2.cvtColor(np.float32(result_pro), cv2.COLOR_RGB2BGR) # rgb颜色互换
|
||||
|
||||
# 如果需要存储中间态图片
|
||||
if(self.save_process_pics == True):
|
||||
save_dir = os.path.join(self.save_pro_label_fold, '1_颜色预处理', os.path.splitext(imgname)[0]+self.pro_append_name+'_preproc'+'.'+self.pic_type)
|
||||
cv2.imwrite(save_dir, result_pro)
|
||||
print("中间态-颜色图片已保存", save_dir)
|
||||
time_end=time.time() # 输出结束时间
|
||||
print('time cost',time_end-time_start,'s')
|
||||
|
||||
# 边界区域大小
|
||||
half_region_len=(5-1)//2
|
||||
pic_height = result_pro.shape[0]
|
||||
pic_width = result_pro.shape[1]
|
||||
# 2_灰度化
|
||||
gray = cv2.cvtColor(result_pro, cv2.COLOR_BGR2GRAY)
|
||||
gray = (gray*255).astype(np.uint8) # cv2对图片格式要求较为严格
|
||||
# 如果需要存储中间态图片
|
||||
if(self.save_process_pics == True):
|
||||
save_dir = os.path.join(self.save_pro_label_fold, '2_灰度化', os.path.splitext(imgname)[0]+self.pro_append_name+'_gray'+'.'+self.pic_type)
|
||||
cv2.imwrite(save_dir, gray)
|
||||
print("中间态-灰度图片已保存", save_dir)
|
||||
time_end=time.time() # 输出结束时间
|
||||
print('time cost',time_end-time_start,'s')
|
||||
# 边缘检测
|
||||
edges = cv2.Canny(gray, 50, 150)
|
||||
# 定义膨胀核
|
||||
kernel = np.ones((3, 3), np.uint8)
|
||||
# 对Canny输出图像进行膨胀
|
||||
edges = cv2.dilate(edges, kernel, iterations=1)
|
||||
|
||||
# 如果需要存储中间态图片
|
||||
if(self.save_process_pics == True):
|
||||
save_dir = os.path.join(self.save_pro_label_fold, '3_边缘化', os.path.splitext(imgname)[0]+self.pro_append_name+'_edge'+'.'+self.pic_type)
|
||||
cv2.imwrite(save_dir, edges)
|
||||
print("中间态-边缘图片已保存", save_dir)
|
||||
time_end=time.time() # 输出结束时间
|
||||
print('time cost',time_end-time_start,'s')
|
||||
|
||||
# 查找边缘
|
||||
nonzero_idx = np.array(np.nonzero(edges)).T
|
||||
# nonzero_idx, counts = np.unique(nonzero_idx, axis=0, return_counts=True)
|
||||
print("边缘个数:",np.shape(nonzero_idx)[0])
|
||||
|
||||
for k in range(len(nonzero_idx)):
|
||||
i = nonzero_idx[k][0]
|
||||
j = nonzero_idx[k][1]
|
||||
# 遍历周围3*3的区域
|
||||
# 找到与当前像素值附近5*5*3(channel)的像素,并将矩阵变为25(5*5)*3(channel)
|
||||
if i <= half_region_len:
|
||||
i = half_region_len
|
||||
elif i >= pic_height-half_region_len:
|
||||
i = pic_height-half_region_len
|
||||
if j <= half_region_len:
|
||||
j = half_region_len
|
||||
elif j >= pic_width-half_region_len:
|
||||
j = pic_width-half_region_len
|
||||
region_x = i
|
||||
region_y = j
|
||||
region = np.array(result_pro[region_x-half_region_len:region_x+half_region_len, region_y-half_region_len:region_y+half_region_len,:]).reshape(-1, 3)
|
||||
# 计算唯一列及其出现次数
|
||||
# print(region_x-half_region_len, region_x+half_region_len, region_y-half_region_len, region_y+half_region_len)
|
||||
unique_columns, counts = np.unique(region, axis=0, return_counts=True)
|
||||
# 找出出现次数最多的列的索引
|
||||
max_count_index = np.argmax(counts)
|
||||
# 将其赋值为出现次数最多的列
|
||||
result_pro[i, j, :] = unique_columns[max_count_index]
|
||||
# result_pro[i, j, :] = [0,0,0]
|
||||
|
||||
save_dir = os.path.join(self.save_pro_label_fold, os.path.splitext(imgname)[0]+self.pro_append_name+'.'+self.pic_type)
|
||||
print("Pro图片已保存", save_dir)
|
||||
cv2.imwrite(save_dir, result_pro)
|
||||
time_end=time.time() # 输出结束时间
|
||||
print('time cost',time_end-time_start,'s')
|
||||
|
||||
# 将原始src图片转化为处理好的pro图片
|
||||
def Conver_ori_label_pic_2_pro_pic_all(self):
|
||||
print("\033[33m**** 进行转换将Ori_label_pic转换为Pro_label_pic ****\033[0m")
|
||||
print("\033[33mOri_label_pic存储位置为:\033[0m", self.src_label_fold)
|
||||
print("\033[33mPro_label_pic生成位置为:\033[0m", self.save_pro_label_fold)
|
||||
# 输出颜色预处理图片
|
||||
try:
|
||||
os.mkdir(self.save_pro_label_fold) # 新建存储文件夹
|
||||
except:
|
||||
print("已有"+self.save_pro_label_fold)
|
||||
if(self.save_process_pics == True):
|
||||
try:
|
||||
os.mkdir(os.path.join(self.save_pro_label_fold, '1_颜色预处理')) # 新建存储1_颜色预处理文件夹
|
||||
except:
|
||||
print("已有"+os.path.join(self.save_pro_label_fold, '1_颜色预处理'))
|
||||
try:
|
||||
os.mkdir(os.path.join(self.save_pro_label_fold, '2_灰度化')) # 新建存储2_灰度化文件夹
|
||||
except:
|
||||
print("已有"+os.path.join(self.save_pro_label_fold, '2_灰度化'))
|
||||
try:
|
||||
os.mkdir(os.path.join(self.save_pro_label_fold, '3_边缘化')) # 新建存储1_颜色预处理文件夹
|
||||
except:
|
||||
print("已有"+os.path.join(self.save_pro_label_fold, '3_边缘化'))
|
||||
|
||||
# 指定最大进程数为 20,多参数函数并行
|
||||
max_processes = 20
|
||||
# 创建Pool对象
|
||||
pool = multiprocessing.Pool(processes=max_processes)
|
||||
# 创建并启动进程
|
||||
args_list1 = []
|
||||
args_list2 = []
|
||||
|
||||
# 遍历整个文件夹
|
||||
for imgpath in self.labellist_src:
|
||||
try:
|
||||
imgname= os.path.splitext(os.path.basename(imgpath))[0].replace(self.pro_append_name,"")
|
||||
print("Processing: ", imgname, "...")
|
||||
# self.Conver_ori_label_pic_2_pro_pic(imgpath, imgname)s
|
||||
# args_list.append({'imgpath': imgpath, 'imgname': imgname})
|
||||
args_list1.append(imgpath)
|
||||
args_list2.append(imgname)
|
||||
except:
|
||||
os.system("echo "+imgname+" >> error_1.txt")
|
||||
args_list = zip(args_list1, args_list2)
|
||||
# 使用进程池并行执行任务
|
||||
pool.starmap(self.Conver_ori_label_pic_2_pro_pic, args_list) # 使用starmap进行多参数并行
|
||||
# 关闭进程池
|
||||
pool.close()
|
||||
pool.join()
|
||||
|
||||
# 图片堆叠
|
||||
def Merge_ori_pic_and_label_pic(self, res_img_path, res_imgname):
|
||||
time_start=time.time() # 记录开始时间
|
||||
# 获取单张图片各个通路信息
|
||||
ori_img_path = os.path.join(self.ori_img_folder, res_imgname+'.'+self.pic_type)
|
||||
if not os.path.exists(ori_img_path):
|
||||
print("****照片不存在:****", ori_img_path)
|
||||
return -1
|
||||
ori_image, ori_r, ori_g, ori_b = self.get_single_pic_rgb(ori_img_path)
|
||||
res_image, res_r, res_g, res_b = self.get_single_pic_rgb(res_img_path)
|
||||
|
||||
merge_img = np.array(ori_image) # merge图片初始化,默认图片背景为0.0.0
|
||||
|
||||
# 遍历所有待识别颜色
|
||||
for [Annotate_PALETTE_r, Annotate_PALETTE_g, Annotate_PALETTE_b] in self.Annotate_PALETTE:
|
||||
# 查找三原色匹配位置
|
||||
locate_r = np.where( res_r == Annotate_PALETTE_r )
|
||||
locate_g = np.where( res_g == Annotate_PALETTE_g )
|
||||
locate_b = np.where( res_b == Annotate_PALETTE_b )
|
||||
|
||||
# 查找都匹配位置(交集)
|
||||
# 将矩阵换一种表示形式
|
||||
locate_r = np.array(locate_r[0]) * self.Max_width + np.array(locate_r[1])
|
||||
locate_g = np.array(locate_g[0]) * self.Max_width + np.array(locate_g[1])
|
||||
locate_b = np.array(locate_b[0]) * self.Max_width + np.array(locate_b[1])
|
||||
|
||||
# 用自带函数寻找匹配项
|
||||
matched = np.intersect1d(np.intersect1d(locate_r, locate_g), locate_b)
|
||||
matched = np.concatenate(([matched // self.Max_width], [np.mod(matched, self.Max_width)]), 0)
|
||||
merge_img[matched[0],matched[1], 0] = Annotate_PALETTE_r
|
||||
merge_img[matched[0],matched[1], 1] = Annotate_PALETTE_g
|
||||
merge_img[matched[0],matched[1], 2] = Annotate_PALETTE_b
|
||||
|
||||
# 转成cv2形式
|
||||
merge_img = cv2.cvtColor(np.float32(merge_img), cv2.COLOR_RGB2BGR)
|
||||
|
||||
try: # 新建文件夹
|
||||
os.mkdir(self.save_merge_pic_folder)
|
||||
except:
|
||||
print("已有"+self.save_merge_pic_folder)
|
||||
save_dir = os.path.join(self.save_merge_pic_folder, os.path.splitext(res_imgname)[0]+'.'+self.pic_type)
|
||||
cv2.imwrite(save_dir, merge_img)
|
||||
print("Merge图片已保存", save_dir)
|
||||
time_end=time.time() # 输出结束时间
|
||||
print('time cost',time_end-time_start,'s')
|
||||
|
||||
# 将label图片与原图片重合
|
||||
def Merge_ori_pic_and_label_pic_all(self):
|
||||
# 遍历整个文件夹
|
||||
for res_img_path in self.reslist_src:
|
||||
res_imgname = os.path.splitext(os.path.basename(res_img_path))[0].replace(self.pro_append_name,"")
|
||||
print("Processing: ", res_imgname, "...")
|
||||
self.Merge_ori_pic_and_label_pic(res_img_path, res_imgname)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Annotate_CLASSES = ('肝脏','胆囊','分离钳','止血海绵','肝总管','胆总管','吸引器','剪刀','止血纱布','生物夹','无损伤钳','喷洒','胆囊管','胆囊动脉','电凝','标本袋','引流管','纱布','金属钛夹') # 待分类的类
|
||||
Annotate_PALETTE = [[255,91,0],[255,234,0],[85, 107, 179],[181, 227, 14],[72, 0, 255],[0, 155, 33],[255,0,255],[29, 32, 136],[160, 15, 95],[0,160,233],[52,184,178],[90,120,41],[255,0,0],[177,0,0],[167,24,233],[112,113,150],[0,255,0],[255,255,255],[0,255,255]]# [[255,91,0],[255,234,0]] # 每一类的像素直
|
||||
|
||||
# 创建参数解析器
|
||||
parser = argparse.ArgumentParser(description='Process some files.')
|
||||
# 添加参数选项
|
||||
parser.add_argument('-src_fold', dest='src_label_fold', default='', help='source label folder')
|
||||
parser.add_argument('-save_pro_fold', dest='save_pro_label_fold', default='./save_pro_label_fold', help='processed label folder')
|
||||
parser.add_argument('-save_GT_fold', dest='save_GT_label_fold', default='./save_GT_label_fold', help='ground truth folder')
|
||||
parser.add_argument('-fold_search_depth', dest='Label_Max_Search_layer', default='1000', type=int, help='Folder Search Depth')
|
||||
parser.add_argument('-pro_suffix_name', dest='pro_append_name', default='_label', help='Pro file suffix')
|
||||
parser.add_argument('-GT_suffix_name', dest='GT_append_name', default='_gtFine_labelTrainIds', help='GT file suffix')
|
||||
parser.add_argument('-GT_channel', dest='GT_channel', default='1', type=int, help='GT file channel(1 or 3)')
|
||||
parser.add_argument('-back_gnd_color', dest='back_gnd_color', default='0', type=int, help='Color of "Back ground"(0 or 255)')
|
||||
parser.add_argument('-first_class_color', dest='first_class_color', default='1', type=int, help='Color of "First Class"')
|
||||
parser.add_argument('-pic_type', dest='pic_type', default='png', help='type of picture(Do not add ".")')
|
||||
parser.add_argument('-Max_width', dest='Max_width', default='10000', type=int, help='Max width of picture')
|
||||
parser.add_argument('-Rebuild_from', dest='Rebuild_from', default='label', help='Source to Rebuild Labels(label/pro)')
|
||||
parser.add_argument('-Rebuild_to', dest='Rebuild_to', default='GT', help='Destination of Rebuild Labels(pro/GT)')
|
||||
parser.add_argument('-save_process_pics', dest='save_process_pics', default='False', help='Save the processed pics(e.g.Gray_pics,Color_pics) in generating pro_pics')
|
||||
|
||||
# 解析命令行参数
|
||||
args = parser.parse_args()
|
||||
|
||||
src_label_fold = args.src_label_fold
|
||||
save_pro_label_fold = args.save_pro_label_fold
|
||||
save_GT_label_fold = args.save_GT_label_fold
|
||||
Label_Max_Search_layer = args.Label_Max_Search_layer
|
||||
pro_append_name = args.pro_append_name
|
||||
GT_append_name = args.GT_append_name
|
||||
GT_channel = args.GT_channel
|
||||
back_gnd_color = args.back_gnd_color
|
||||
first_class_color = args.first_class_color
|
||||
pic_type = args.pic_type
|
||||
Max_width = args.Max_width
|
||||
Rebuild_from = args.Rebuild_from
|
||||
Rebuild_to = args.Rebuild_to
|
||||
save_process_pics = args.save_process_pics
|
||||
|
||||
|
||||
try: # 遍历文件深度,最小为1
|
||||
Label_Max_Search_layer=int(Label_Max_Search_layer)
|
||||
except:
|
||||
Label_Max_Search_layer=1000
|
||||
try: # GT标签图片通道数
|
||||
GT_channel=int(GT_channel)
|
||||
except:
|
||||
GT_channel=1
|
||||
try: # 背景颜色(背景选择0或255)
|
||||
back_gnd_color=int(back_gnd_color)
|
||||
except:
|
||||
back_gnd_color=0
|
||||
try: # 第一类上的颜色(如果背景为0,选择1;)
|
||||
first_class_color=int(first_class_color)
|
||||
except:
|
||||
first_class_color=1
|
||||
try: # 最大图片宽度(匹配时候用)
|
||||
Max_width=int(Max_width)
|
||||
except:
|
||||
Max_width=10000
|
||||
if(save_process_pics.lower() == 'false'):
|
||||
save_process_pics = False
|
||||
elif(save_process_pics.lower() == 'true'):
|
||||
save_process_pics = True
|
||||
else:
|
||||
save_process_pics = False
|
||||
|
||||
D = Deal_image(Annotate_CLASSES=Annotate_CLASSES, Annotate_PALETTE=Annotate_PALETTE, src_label_fold=src_label_fold, save_pro_label_fold=save_pro_label_fold, save_GT_label_fold=save_GT_label_fold, GT_channel=GT_channel, pro_append_name=pro_append_name, GT_append_name=GT_append_name, back_gnd_color=back_gnd_color, first_class_color=first_class_color, pic_type=pic_type, Max_width=Max_width, Label_Max_Search_layer=Label_Max_Search_layer, save_process_pics=save_process_pics)
|
||||
# print(D.src_CLASSES_NUM)
|
||||
if Rebuild_from == 'label':
|
||||
# 1.先将所有原始图片转为pro图片
|
||||
D.Conver_ori_label_pic_2_pro_pic_all()
|
||||
pass
|
||||
if Rebuild_to == 'GT':
|
||||
# 2.再将pro图片转为GT图片
|
||||
D.Conver_pro_label_pic_2_GT_pic_all()
|
||||
pass
|
||||
163
DataSet_Own/1. 图片预处理(内含使用手册)/4_rebuild_labels.sh
Executable file
163
DataSet_Own/1. 图片预处理(内含使用手册)/4_rebuild_labels.sh
Executable file
@@ -0,0 +1,163 @@
|
||||
#!/bin/bash
|
||||
|
||||
usage() {
|
||||
echo "Usage: $0 l <ori_label_directory> [ -h ]"
|
||||
echo "对label图片进行处理及转化(-l不能为空) "
|
||||
echo "-l:原始label的路径,-h:帮助"
|
||||
echo "e.g. 4_rebuild_labels.sh -l ./C组标注图片"
|
||||
echo "接下来一路回车就ok"
|
||||
}
|
||||
|
||||
ori_label_directorys=""
|
||||
|
||||
while getopts "hl:" opt; do
|
||||
case $opt in
|
||||
h)
|
||||
usage
|
||||
exit 0
|
||||
;;
|
||||
l)
|
||||
ori_label_directorys=$OPTARG
|
||||
;;
|
||||
*)
|
||||
echo -e '\033[31m!!! Error, Illegal input !!!\033[0m'
|
||||
usage
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
# 判断输入地址是否为空
|
||||
if [ -z "$ori_label_directorys" ]; then
|
||||
echo -e "\033[31m输入地址 -i -l 存在空地址\033[0m"
|
||||
usage
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# 地址转化
|
||||
ori_label_directory=$(readlink -f "$ori_label_directorys")
|
||||
if [ -z "$ori_label_directory" ]; then
|
||||
echo -e "\033[31m无法解析地址,程序退出\033[0m"
|
||||
echo -e "\033[31mori_label_directory\033[0m: $ori_label_directorys"
|
||||
exit 1
|
||||
fi
|
||||
if [ ! -d "$ori_label_directory" ]; then
|
||||
echo -e "\033[31mlabel目录不存在,程序退出\033[0m"
|
||||
echo -e "\033[31mori_label_directory\033[0m: $ori_label_directorys"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo -e "\033[32m_____ 4_rebuild_labels.sh _____\033[0m"
|
||||
|
||||
|
||||
echo -n "请选择label图片搜索深度(默认为1):"
|
||||
read -r fold_search_depth
|
||||
if [ -z $fold_search_depth ]; then
|
||||
fold_search_depth='1'
|
||||
fi
|
||||
|
||||
save_pro_folds=""${ori_label_directory%/}"_pro_label_fold" # 去掉末尾的足/
|
||||
save_pro_fold=$(readlink -f "$save_pro_folds")
|
||||
echo -n "请选择初步处理label后label_pro图片存储位置(默认为$save_pro_fold):"
|
||||
read -r save_pro_folds
|
||||
if [ -z $save_pro_folds ]; then
|
||||
save_pro_folds=""${ori_label_directory%/}"_pro_label_fold"
|
||||
fi
|
||||
save_pro_fold=$(readlink -f "$save_pro_folds")
|
||||
|
||||
echo -n "请选择label_pro图片后缀(默认为\"_label\"):"
|
||||
read -r pro_suffix_name
|
||||
if [ -z $pro_suffix_name ]; then
|
||||
pro_suffix_name='_label'
|
||||
fi
|
||||
|
||||
save_GT_folds=""${ori_label_directory%/}"_GT_label_fold"
|
||||
save_GT_fold=$(readlink -f "$save_GT_folds")
|
||||
echo -n "请选择处理label_pro后label_GT图片存储位置(默认为$save_GT_fold):"
|
||||
read -r save_pro_folds
|
||||
if [ -z $save_pro_folds ]; then
|
||||
save_GT_folds=""${ori_label_directory%/}"_GT_label_fold"
|
||||
fi
|
||||
save_GT_fold=$(readlink -f "$save_GT_folds")
|
||||
|
||||
echo -n "请选择label_GT图片后缀(默认为\"_gtFine_labelTrainIds\"):"
|
||||
read -r GT_suffix_name
|
||||
if [ -z $GT_suffix_name ]; then
|
||||
GT_suffix_name='_gtFine_labelTrainIds'
|
||||
fi
|
||||
|
||||
echo -n "请选择label_GT图片通道数(1或3,默认为1):"
|
||||
read -r GT_channel
|
||||
if [ -z $GT_channel ]; then
|
||||
GT_channel='1'
|
||||
fi
|
||||
if [[ $GT_channel != '1' && $GT_channel != '3' ]]; then
|
||||
echo -e "\033[35mGT_channel只能为1或3,输入有误,将其默认变为1\033[0m"
|
||||
GT_channel='1'
|
||||
fi
|
||||
|
||||
echo -n "请选择GT图片背景颜色(0或255,默认为0(黑色))):"
|
||||
read -r back_gnd_color
|
||||
if [ -z $back_gnd_color]; then
|
||||
back_gnd_color='0'
|
||||
fi
|
||||
|
||||
echo -n "请选择GT图片中第一类的颜色(黑色背景(0)下默认为1,白色背景(255)下默认为0):"
|
||||
read -r first_class_color
|
||||
if [ -z $first_class_color]; then
|
||||
if [ $back_gnd_color == '255' ]; then
|
||||
first_class_color='0'
|
||||
else
|
||||
first_class_color='1'
|
||||
fi
|
||||
fi
|
||||
|
||||
echo -n "请选择图片类型(png或jpg,默认为png(没有\".\"))):"
|
||||
read -r pic_type
|
||||
if [ -z $pic_type ]; then
|
||||
pic_type='png'
|
||||
fi
|
||||
|
||||
echo -n "请选择重建起始目录(label或pro,默认为label):"
|
||||
read -r Rebuild_from
|
||||
if [ -z $Rebuild_from ]; then
|
||||
Rebuild_from='label'
|
||||
fi
|
||||
if [[ $Rebuild_from != 'label' && $Rebuild_from != 'pro' ]]; then
|
||||
echo -e "\033[35mRebuild_from只能为label或pro,输入有误,将其默认变为label\033[0m"
|
||||
Rebuild_from='label'
|
||||
fi
|
||||
|
||||
echo -n "请选择重建最终目标(pro或GT,默认为GT):"
|
||||
read -r Rebuild_to
|
||||
if [ -z $Rebuild_to ]; then
|
||||
Rebuild_to='GT'
|
||||
fi
|
||||
if [[ $Rebuild_to != 'GT' && $Rebuild_to != 'pro' ]]; then
|
||||
echo -e "\033[35mRebuild_to只能为GT或pro,输入有误,将其默认变为GT\033[0m"
|
||||
Rebuild_to='GT'
|
||||
fi
|
||||
|
||||
echo -n "请选择是否保存pro图片生成中间状态(e.g.灰度图等)(false或true,默认为false):"
|
||||
read -r save_process_pics
|
||||
if [ -z $save_process_pics ]; then
|
||||
save_process_pics='false'
|
||||
fi
|
||||
if [[ $save_process_pics != 'true' && $save_process_pics != 'false' ]]; then
|
||||
echo -e "\033[35msave_process_pics只能为true或false,输入有误,将其默认变为false\033[0m"
|
||||
save_process_pics='false'
|
||||
fi
|
||||
|
||||
# 获取当前脚本的路径和名称
|
||||
script_path=$(dirname "$0")
|
||||
# 将当前目录更改为脚本所在的路径
|
||||
cd "$script_path"
|
||||
|
||||
# 激活conda环境
|
||||
source /home/"$USER"/miniconda/bin/activate Deal_pics
|
||||
echo -e "\033[35m运行:\033[0mpython 4_deal_labels.py -src_fold $ori_label_directory -save_pro_fold $save_pro_fold -save_GT_fold $save_GT_fold -fold_search_depth $fold_search_depth -pro_suffix_name $pro_suffix_name -GT_suffix_name $GT_suffix_name -GT_channel $GT_channel -back_gnd_color $back_gnd_color -first_class_color $first_class_color -pic_type $pic_type -Rebuild_from $Rebuild_from -Rebuild_to $Rebuild_to -save_process_pics $save_process_pics "
|
||||
echo ""
|
||||
python 4_deal_labels.py -src_fold $ori_label_directory -save_pro_fold $save_pro_fold -save_GT_fold $save_GT_fold -fold_search_depth $fold_search_depth -pro_suffix_name $pro_suffix_name -GT_suffix_name $GT_suffix_name -GT_channel $GT_channel -back_gnd_color $back_gnd_color -first_class_color $first_class_color -pic_type $pic_type -Rebuild_from $Rebuild_from -Rebuild_to $Rebuild_to -save_process_pics $save_process_pics
|
||||
|
||||
echo "4_rebuild_label_pics.sh重构完毕"
|
||||
|
||||
122
DataSet_Own/1. 图片预处理(内含使用手册)/5_TOOL_stack_pics.sh
Executable file
122
DataSet_Own/1. 图片预处理(内含使用手册)/5_TOOL_stack_pics.sh
Executable file
@@ -0,0 +1,122 @@
|
||||
#!/bin/bash
|
||||
|
||||
usage() {
|
||||
echo "Usage: $0 -i <ori_image_directory> -l <ori_label_directory> -r <stack_result_directory> [ -a <alpha> -p <prefix> -s <suffix> -h]"
|
||||
echo "对image图片和label图片进行匹配(-i、-l -r均不能为空)(-p -s默认为空"" -a默认为\"0.3\") "
|
||||
echo "-i:原始image的路径,-l:原始label的路径,-p:前缀内容,-s:后缀内容(不用管文件后缀名),-h:帮助"
|
||||
echo "e.g. 5_TOOL_stack_pics.sh -i ./C组未标注 -l ./C组标注图片 -r ./C组result_0.3透明度 -a 0.3 -p Group_C_ -s _label"
|
||||
}
|
||||
|
||||
ori_image_directorys=""
|
||||
ori_label_directorys=""
|
||||
stack_result_directorys=""
|
||||
prefix=""
|
||||
suffix=""
|
||||
alpha="0.3"
|
||||
|
||||
while getopts "hl:i:r:p:s:a:" opt; do
|
||||
case $opt in
|
||||
h)
|
||||
usage
|
||||
exit 0
|
||||
;;
|
||||
i)
|
||||
ori_image_directorys=$OPTARG
|
||||
;;
|
||||
l)
|
||||
ori_label_directorys=$OPTARG
|
||||
;;
|
||||
p)
|
||||
prefix=$OPTARG
|
||||
;;
|
||||
s)
|
||||
suffix=$OPTARG
|
||||
;;
|
||||
r)
|
||||
stack_result_directorys=$OPTARG
|
||||
;;
|
||||
a)
|
||||
alpha=$OPTARG
|
||||
;;
|
||||
*)
|
||||
echo -e '\033[31m!!! Error, Illegal input !!!\033[0m'
|
||||
usage
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
# 判断输入地址是否为空
|
||||
if [ -z "$ori_label_directorys" ] || [ -z "$ori_image_directorys" ] || [ -z "$stack_result_directorys" ]; then
|
||||
echo -e "\033[31m输入地址 -i -l -z 存在空地址\033[0m"
|
||||
usage
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# 地址转化
|
||||
ori_image_directory=$(readlink -f "$ori_image_directorys")
|
||||
ori_label_directory=$(readlink -f "$ori_label_directorys")
|
||||
stack_result_directory=$(readlink -f "$stack_result_directorys")
|
||||
if [ -z "$ori_label_directory" ] || [ -z "$ori_image_directory" ]|| [ -z "$stack_result_directory" ]; then
|
||||
echo "image、label、result存在无法解析地址,程序退出"
|
||||
echo -e "\033[31mori_image_directory\033[0m: $ori_image_directorys"
|
||||
echo -e "\033[31mori_label_directory\033[0m: $ori_label_directorys"
|
||||
echo -e "\033[31mori_label_directory\033[0m: $stack_result_directorys"
|
||||
exit 1
|
||||
fi
|
||||
if [ ! -d "$ori_label_directory" ] || [ ! -d "$ori_image_directory" ]; then
|
||||
echo "image、label两目录有一个不存在,程序退出"
|
||||
echo -e "\033[31mori_image_directory\033[0m: $ori_image_directory"
|
||||
echo -e "\033[31mori_label_directory\033[0m: $ori_label_directorys"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# 获取当前脚本的路径和名称
|
||||
script_path=$(dirname "$0")
|
||||
# 将当前目录更改为脚本所在的路径
|
||||
cd "$script_path"
|
||||
|
||||
# 激活conda环境
|
||||
source /home/"$USER"/miniconda/bin/activate Deal_pics
|
||||
|
||||
echo -e "\033[32m_____ 5_TOOL_stack_pics.sh _____\033[0m"
|
||||
echo -e "\033[33mimage所在文件夹为$ori_image_directory\nlable所在文件夹为$ori_label_directory\033[0m"
|
||||
# 遍历label目录
|
||||
for file_path in "$ori_label_directory"/*; do
|
||||
# 判断是否是文件
|
||||
if [[ -f "$file_path" ]]; then
|
||||
file_name=$(basename "$file_path")
|
||||
# 判断文件名是否符合规范
|
||||
if [[ "$file_name" =~ .*\.(jpg|png|bmp|JPG|PNG|BMP) ]]; then # 判断是否有为图片
|
||||
# if [[ "$file_name" =~ "$prefix".*"$suffix".*\.(jpg|png|bmp|JPG|PNG|BMP)$ ]]; then # 判断是否有满足要求的文件名
|
||||
# 抽取文件名(有前缀、后缀的抽取前缀、后缀里面的,没有的返回整个)
|
||||
if [ -z $prefix ];then
|
||||
file_name_extract=$(echo $file_name | sed "s/"$prefix"\(.*\)"$suffix".*/\1/" | sed "s/\(.*\)\.\(jpg\|png\|bmp\|JPG\|PNG\|BMP\)$/\1/")
|
||||
else
|
||||
file_name_extract=$(echo $file_name | sed "s/".*$prefix"\(.*\)"$suffix".*/\1/" | sed "s/\(.*\)\.\(jpg\|png\|bmp\|JPG\|PNG\|BMP\)$/\1/")
|
||||
fi
|
||||
# 从label目录中看是否有此文件
|
||||
file_name_other=$(ls $ori_image_directory | grep $file_name_extract)
|
||||
file_name_other=$(echo "$(echo "$file_name_other" | sed '/^$/d')" | head -n1) # 提取出文件名
|
||||
# 如果另一个目录没有此文件的话
|
||||
if [ -z "$file_name_other" ]; then
|
||||
echo "$file_name label中对应内容未在$ori_image_directory搜索到"
|
||||
# 建立相关存储文件夹
|
||||
if [ ! -d "$ori_label_directory/Not_pair_pics" ]; then
|
||||
mkdir -p "$ori_label_directory/Not_pair_pics" # 建立存储文件夹
|
||||
fi
|
||||
# 移动相关文件
|
||||
cp "$ori_label_directory/$file_name" "$ori_label_directory/Not_pair_pics"
|
||||
echo "$file_name" >> "$ori_label_directory/Not_pair_pics/not_pair.txt"
|
||||
else # 如果另一个目录有此配对文件的话,则运行相关程序
|
||||
echo "image中的$file_name_other,与lable中的$file_name"
|
||||
mkdir -p "$stack_result_directory"
|
||||
python 5_stack_picture.py "$ori_image_directory/$file_name_other" "$ori_label_directory/$file_name" "$stack_result_directory" "$alpha"
|
||||
echo ""
|
||||
fi
|
||||
|
||||
fi
|
||||
else
|
||||
echo "$file_path不是文件"
|
||||
fi
|
||||
done
|
||||
41
DataSet_Own/1. 图片预处理(内含使用手册)/5_stack_picture.py
Executable file
41
DataSet_Own/1. 图片预处理(内含使用手册)/5_stack_picture.py
Executable file
@@ -0,0 +1,41 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: UTF-8 -*-
|
||||
import cv2, os, sys
|
||||
|
||||
def Stack_pic(Background_path, Overlay_path, Result_dir, alpha=0.3):
|
||||
# 读取两张没有alpha通道的图片
|
||||
img1 = cv2.imread(Background_path) # 底层图片
|
||||
img2 = cv2.imread(Overlay_path) # 顶层图片
|
||||
|
||||
Result_name = os.path.splitext(os.path.basename(Background_path))[0]
|
||||
|
||||
# 将img2调整为与img1大小相同
|
||||
img2 = cv2.resize(img2, (img1.shape[1], img1.shape[0]))
|
||||
|
||||
# 将img2的透明度调整为20%
|
||||
overlay_alpha = alpha
|
||||
|
||||
# 将img2叠加到img1上
|
||||
overlay = cv2.addWeighted(img1, 1 - overlay_alpha, img2, overlay_alpha, 0)
|
||||
|
||||
# 保存结果
|
||||
if not os.path.exists(Result_dir):
|
||||
os.makedirs(Result_dir)
|
||||
cv2.imwrite(os.path.join(Result_dir, Result_name+'.png'), overlay)
|
||||
print("堆叠图片写入地址:", os.path.join(Result_dir, Result_name+'.png'))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
Background_path = sys.argv[1] # 背景所在路径
|
||||
Overlay_path = sys.argv[2] # 上层图片所在路径
|
||||
Result_dir = sys.argv[3] # 结果所在目录
|
||||
# 透明度,默认为0.3
|
||||
try:
|
||||
alpha = float(sys.argv[4])
|
||||
if(alpha > 1 or alpha < 0):
|
||||
print("alpha 透明度输入不正确,其值应该在0~1之间")
|
||||
alpha = 0.3
|
||||
except:
|
||||
alpha = 0.3
|
||||
# 进行对叠程序
|
||||
Stack_pic(Background_path, Overlay_path, Result_dir, alpha)
|
||||
167
DataSet_Own/1. 图片预处理(内含使用手册)/6_TOOL_stitch_pics.sh
Executable file
167
DataSet_Own/1. 图片预处理(内含使用手册)/6_TOOL_stitch_pics.sh
Executable file
@@ -0,0 +1,167 @@
|
||||
#!/bin/bash
|
||||
|
||||
usage() {
|
||||
echo "Usage: $0 -i <ori_image_directory> -l <ori_label_directory> -r <stitch_result_directory> [ -p <prefix> -s <suffix> -h]"
|
||||
echo "对image图片和label图片进行拼接(-i -r不能为空)(-l为填写项-p -s默认为空"") "
|
||||
echo "-i:原始image的路径,-l:拼接label的路径,-p:前缀内容,-s:后缀内容(不用管文件后缀名),-h:帮助"
|
||||
echo "e.g. 6_TOOL_stitch_pics.sh -i ./C组未标注 -l ./C组标注图片 -r ./C组result_stitch -p Group_C_ -s _label"
|
||||
}
|
||||
|
||||
ori_image_directorys=""
|
||||
ori_label_directorys=""
|
||||
stitch_result_directorys=""
|
||||
prefix=""
|
||||
suffix=""
|
||||
|
||||
while getopts "hl:i:r:p:s:a:" opt; do
|
||||
case $opt in
|
||||
h)
|
||||
usage
|
||||
exit 0
|
||||
;;
|
||||
i)
|
||||
ori_image_directorys=$OPTARG
|
||||
;;
|
||||
l)
|
||||
ori_label_directorys=$OPTARG
|
||||
;;
|
||||
p)
|
||||
prefix=$OPTARG
|
||||
;;
|
||||
s)
|
||||
suffix=$OPTARG
|
||||
;;
|
||||
r)
|
||||
stitch_result_directorys=$OPTARG
|
||||
;;
|
||||
*)
|
||||
echo -e '\033[31m!!! Error, Illegal input !!!\033[0m'
|
||||
usage
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
# 如果堆叠地址为空,生成默认堆叠地址
|
||||
if [ -z "$stitch_result_directorys" ]; then
|
||||
stitch_result_directorys=""$ori_label_directorys"_拼接"
|
||||
echo -n "请输入堆叠结果存储目录(默认为$stitch_result_directorys):"
|
||||
read -r temp
|
||||
if [ ! -z $temp ]; then
|
||||
stitch_result_directorys=$temp
|
||||
fi
|
||||
fi
|
||||
|
||||
# 判断输入地址是否为空
|
||||
if [ -z "$ori_label_directorys" ] || [ -z "$ori_image_directorys" ] || [ -z "$stitch_result_directorys" ]; then
|
||||
echo -e "\033[31m输入地址 -i -l -z 存在空地址\033[0m"
|
||||
usage
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# 地址转化
|
||||
ori_image_directory=$(readlink -f "$ori_image_directorys")
|
||||
ori_label_directory=$(readlink -f "$ori_label_directorys")
|
||||
stitch_result_directory=$(readlink -f "$stitch_result_directorys")
|
||||
if [ -z "$ori_label_directory" ] || [ -z "$ori_image_directory" ]|| [ -z "$stitch_result_directory" ]; then
|
||||
echo "image、label、result存在无法解析地址,程序退出"
|
||||
echo -e "\033[31mori_image_directory\033[0m: $ori_image_directorys"
|
||||
echo -e "\033[31mori_label_directory\033[0m: $ori_label_directorys"
|
||||
echo -e "\033[31mori_label_directory\033[0m: $stitch_result_directorys"
|
||||
exit 1
|
||||
fi
|
||||
if [ ! -d "$ori_label_directory" ] || [ ! -d "$ori_image_directory" ]; then
|
||||
echo "image、label两目录有一个不存在,程序退出"
|
||||
echo -e "\033[31mori_image_directory\033[0m: $ori_image_directory"
|
||||
echo -e "\033[31mori_label_directory\033[0m: $ori_label_directorys"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo -e "\033[32m_____ 6_TOOL_stitch_pics.sh _____\033[0m"
|
||||
|
||||
# 图片相对位置
|
||||
PS3='Please enter your choice:'
|
||||
options=("Imge ↑ Label ↓" "Label ↑ Imge ↓" "Imge ← Label →" "Label ← Imge →" )
|
||||
echo "请选择图片与Label的相对位置,默认为1.\"Imge↑Label↓\""
|
||||
select opt in "${options[@]}"
|
||||
do
|
||||
case $opt in
|
||||
# 这里面返回的结果都是img在前,label在后的
|
||||
"Imge ↑ Label ↓")
|
||||
relative_pos="Img_up_label_down"
|
||||
break
|
||||
;;
|
||||
"Label ↑ Imge ↓")
|
||||
relative_pos="Img_down_label_up"
|
||||
break
|
||||
;;
|
||||
"Imge ← Label →")
|
||||
relative_pos="Img_left_label_right"
|
||||
break
|
||||
;;
|
||||
"Label ← Imge →")
|
||||
relative_pos="Img_right_label_left"
|
||||
break
|
||||
;;
|
||||
*)
|
||||
echo "Set to default 1.\"Imge↑Label↓\""
|
||||
relative_pos="Img_up_label_down"
|
||||
break
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
# 获取当前脚本的路径和名称
|
||||
script_path=$(dirname "$0")
|
||||
# 将当前目录更改为脚本所在的路径
|
||||
cd "$script_path"
|
||||
|
||||
# 激活conda环境
|
||||
source /home/"$USER"/miniconda/bin/activate Deal_pics
|
||||
|
||||
echo -e "\033[33mimage所在文件夹为$ori_image_directory\nlable所在文件夹为$ori_label_directory\033[0m"
|
||||
|
||||
# 遍历label目录
|
||||
for file_path in "$ori_label_directory"/*; do
|
||||
# 判断是否是文件
|
||||
if [[ -f "$file_path" ]]; then
|
||||
file_name=$(basename "$file_path")
|
||||
# 判断文件名是否符合规范
|
||||
if [[ "$file_name" =~ .*\.(jpg|png|bmp|JPG|PNG|BMP) ]]; then # 判断是否有为图片
|
||||
# if [[ "$file_name" =~ "$prefix".*"$suffix".*\.(jpg|png|bmp|JPG|PNG|BMP)$ ]]; then # 判断是否有满足要求的文件名
|
||||
# 抽取文件名(有前缀、后缀的抽取前缀、后缀里面的,没有的返回整个)
|
||||
if [ -z $prefix ];then
|
||||
file_name_extract=$(echo $file_name | sed "s/"$prefix"\(.*\)"$suffix".*/\1/" | sed "s/\(.*\)\.\(jpg\|png\|bmp\|JPG\|PNG\|BMP\)$/\1/")
|
||||
else
|
||||
file_name_extract=$(echo $file_name | sed "s/".*$prefix"\(.*\)"$suffix".*/\1/" | sed "s/\(.*\)\.\(jpg\|png\|bmp\|JPG\|PNG\|BMP\)$/\1/")
|
||||
fi
|
||||
# 从label目录中看是否有此文件
|
||||
file_name_other=$(ls $ori_image_directory | grep $file_name_extract)
|
||||
file_name_other=$(echo "$(echo "$file_name_other" | sed '/^$/d')" | head -n1) # 提取出文件名
|
||||
# 如果另一个目录没有此文件的话
|
||||
if [ -z "$file_name_other" ]; then
|
||||
echo "$file_name label中对应内容未在$ori_image_directory搜索到"
|
||||
# 建立相关存储文件夹
|
||||
if [ ! -d "$ori_label_directory/Not_pair_pics" ]; then
|
||||
mkdir -p "$ori_label_directory/Not_pair_pics" # 建立存储文件夹
|
||||
fi
|
||||
# 移动相关文件
|
||||
cp "$ori_label_directory/$file_name" "$ori_label_directory/Not_pair_pics"
|
||||
echo "$file_name" >> "$ori_label_directory/Not_pair_pics/not_pair.txt"
|
||||
else # 如果另一个目录有此配对文件的话,则运行相关程序
|
||||
echo "image中的$file_name_other,与lable中的$file_name"
|
||||
if [ ! -d "$stitch_result_directory" ]; then
|
||||
echo "创建stitch_result_directory存储文件夹"
|
||||
echo -e "\033[35运行:\033[0m mkdir -p $stitch_result_directory"
|
||||
mkdir -p $stitch_result_directory
|
||||
fi
|
||||
echo -e "\033[35运行:\033[0mpython 6_stitch_picture.py "$ori_image_directory/$file_name_other" "$ori_label_directory/$file_name" "$stitch_result_directory" "$relative_pos""
|
||||
python 6_stitch_picture.py "$ori_image_directory/$file_name_other" "$ori_label_directory/$file_name" "$stitch_result_directory" "$relative_pos"
|
||||
echo ""
|
||||
fi
|
||||
|
||||
fi
|
||||
else
|
||||
echo "$file_path不是文件"
|
||||
fi
|
||||
done
|
||||
75
DataSet_Own/1. 图片预处理(内含使用手册)/6_stitch_picture.py
Executable file
75
DataSet_Own/1. 图片预处理(内含使用手册)/6_stitch_picture.py
Executable file
@@ -0,0 +1,75 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: UTF-8 -*-
|
||||
import cv2, os, sys, re
|
||||
|
||||
def Stitch_pic(Ori_image_path, Ori_label_path, Result_dir, img_pos, label_pos):
|
||||
|
||||
# 读取两张没有stitch_pos通道的图片
|
||||
img1 = cv2.imread(Ori_image_path) # 底层图片
|
||||
img2 = cv2.imread(Ori_label_path) # 顶层图片
|
||||
|
||||
Result_name = os.path.splitext(os.path.basename(Ori_image_path))[0]
|
||||
|
||||
# 将img2调整为与img1大小相同
|
||||
if img1.shape != img2.shape:
|
||||
img2 = cv2.resize(img2, (img1.shape[1], img1.shape[0]))
|
||||
|
||||
# 拼接图片
|
||||
if img_pos == 'up' and label_pos == 'down':
|
||||
result = cv2.vconcat([img1, img2])
|
||||
elif img_pos == 'down' and label_pos == 'up':
|
||||
result = cv2.vconcat([img2, img1])
|
||||
elif img_pos == 'left' and label_pos == 'right':
|
||||
result = cv2.hconcat([img1, img2])
|
||||
elif img_pos == 'right' and label_pos == 'left':
|
||||
result = cv2.hconcat([img2, img1])
|
||||
else:
|
||||
RED = '\033[91m'
|
||||
END = '\033[0m'
|
||||
print(RED + "The input of relative_pos is wrong, img_pos is " + img_pos + " label_pos is " + label_pos + END)
|
||||
os.exit()
|
||||
|
||||
# 保存结果
|
||||
if not os.path.exists(Result_dir):
|
||||
os.makedirs(Result_dir)
|
||||
cv2.imwrite(os.path.join(Result_dir, Result_name+'.png'), result)
|
||||
print("堆叠图片写入地址:", os.path.join(Result_dir, Result_name+'.png'))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
Ori_image_path = sys.argv[1] # 背景所在路径
|
||||
Ori_label_path = sys.argv[2] # 上层图片所在路径
|
||||
Result_dir = sys.argv[3] # 结果所在目录
|
||||
relative_pos = str.lower(sys.argv[4])
|
||||
|
||||
match_up = re.search(r'up', relative_pos)
|
||||
match_down = re.search(r'down', relative_pos)
|
||||
match_left = re.search(r'up', relative_pos)
|
||||
match_right = re.search(r'down', relative_pos)
|
||||
|
||||
if match_up and match_down:
|
||||
pos_up = match_up.start()
|
||||
pos_down = match_down.start()
|
||||
if pos_down < pos_up :
|
||||
img_pos = "down"
|
||||
label_pos = "up"
|
||||
else:
|
||||
img_pos = "up"
|
||||
label_pos = "down"
|
||||
elif match_left and match_right:
|
||||
pos_left = match_up.start()
|
||||
pos_right = match_down.start()
|
||||
if pos_left < pos_right :
|
||||
img_pos = "left"
|
||||
label_pos = "right"
|
||||
else:
|
||||
img_pos = "left"
|
||||
label_pos = "right"
|
||||
else:
|
||||
print("Either 'up'/'down' 'left'/'right' is missing or in the text.")
|
||||
print("Set to default img_pos = up label_pos = down")
|
||||
img_pos = "up"
|
||||
label_pos = "down"
|
||||
|
||||
# 进行对叠程序
|
||||
Stitch_pic(Ori_image_path, Ori_label_path, Result_dir, img_pos, label_pos)
|
||||
246
DataSet_Own/1. 图片预处理(内含使用手册)/Seg_data_run.sh
Executable file
246
DataSet_Own/1. 图片预处理(内含使用手册)/Seg_data_run.sh
Executable file
@@ -0,0 +1,246 @@
|
||||
#!/bin/bash
|
||||
|
||||
usage() {
|
||||
echo "Usage: $0 -i <ori_image_directory> -l <ori_label_directory> [-h]"
|
||||
echo "对image图片和label图片进行统一处理"
|
||||
echo "-i:原始图片的路径,-l:原始标签的路径,-h:帮助"
|
||||
}
|
||||
|
||||
ori_image_directorys=""
|
||||
ori_label_directorys=""
|
||||
stack_result_directorys=""
|
||||
stitch_result_directorys=""
|
||||
|
||||
while getopts "hl:i:" opt; do
|
||||
case $opt in
|
||||
h)
|
||||
usage
|
||||
exit 0
|
||||
;;
|
||||
i)
|
||||
ori_image_directorys=$OPTARG
|
||||
;;
|
||||
l)
|
||||
ori_label_directorys=$OPTARG
|
||||
;;
|
||||
*)
|
||||
echo '!!! Error, Illegal input !!!'
|
||||
usage
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
# 判断label、image是否都为空
|
||||
echo $ori_label_directorys $ori_image_directorys
|
||||
if [ -z "$ori_label_directorys" ] && [ -z "$ori_image_directorys" ]; then
|
||||
echo -e "\033[31mori_label_directory、ori_image_directory不能都为空\033[0m"
|
||||
usage
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# 进行绝对路径转化
|
||||
ori_image_directory=$(readlink -f "$ori_image_directorys")
|
||||
ori_label_directory=$(readlink -f "$ori_label_directorys")
|
||||
if [ ! -d "$ori_label_directory" ] && [ ! -d "$ori_image_directory" ]; then
|
||||
echo "image、label都不存在,程序退出"
|
||||
echo -e "\033[31mori_image_directory\033[0m: $ori_image_directorys"
|
||||
echo "$ori_image_directory"
|
||||
echo -e "\033[31mori_label_directory\033[0m: $ori_label_directorys"
|
||||
echo "$ori_label_directory"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# 激活conda环境
|
||||
# source /home/"$USER"/miniconda/bin/activate Pat_infos
|
||||
|
||||
# 记录脚本所在路径
|
||||
script_path=$(dirname "$0")
|
||||
echo -e "\033[32m******* 开始Run运行Seg_data_run.sh批量图片处理程序 *******\033[0m"
|
||||
# 进行操作选择
|
||||
while true; do
|
||||
PS3='Please enter your choice: '
|
||||
options=("移动图片与重命名" "图片统一大小与类型" "image、label配对检测" "对label进行重建" "TOOL_image、label堆叠" "TOOL_image、label拼接" "Quit")
|
||||
echo -e "\033[35m____ Seg_data_run选择 ____\033[0m"
|
||||
echo -e "\033[35mImage所在地址:\033[0m$ori_image_directory"
|
||||
echo -e "\033[35mLabel所在地址:\033[0m$ori_label_directory"
|
||||
select opt in "${options[@]}"
|
||||
do
|
||||
case $opt in
|
||||
### 选项1 ###
|
||||
"移动图片与重命名")
|
||||
echo -e "\033[31mRun运行:\033[0mbash $script_path/1_rename_pics.sh -i $ori_image_directory -l $ori_label_directory"
|
||||
bash $script_path/1_rename_pics.sh -i "$ori_image_directory" -l "$ori_label_directory"
|
||||
echo ""
|
||||
break
|
||||
;;
|
||||
|
||||
### 选项2 ###
|
||||
"图片统一大小与类型")
|
||||
echo -n "请输入宽度(默认1920):"
|
||||
read -r width
|
||||
echo -n "请输入高度(默认1080):"
|
||||
read -r height
|
||||
if [ -z "$width" ]; then
|
||||
width=1920
|
||||
fi
|
||||
if [ -z "$height" ]; then
|
||||
height=1080
|
||||
fi
|
||||
# 如果输入图片路径为空
|
||||
if [ -z $ori_image_directory ];then
|
||||
echo -e "\033[31mRun运行:\033[0mbash $script_path/2_reformate_pics.sh -l $ori_label_directory -w $width -h $height "
|
||||
bash $script_path/2_reformate_pics.sh -l $ori_label_directory -w $width -h $height
|
||||
elif [ -z $ori_label_directory ];then
|
||||
echo -e "\033[31mRun运行:\033[0mbash $script_path/2_reformate_pics.sh -l $ori_label_directory -w $width -h $height "
|
||||
bash $script_path/2_reformate_pics.sh -i $ori_image_directory -w $width -h $height
|
||||
else
|
||||
echo -e "\033[31mRun运行:\033[0mbash $script_path/2_reformate_pics.sh -i $ori_image_directory -l $ori_label_directory -w $width -h $height"
|
||||
bash $script_path/2_reformate_pics.sh -i $ori_image_directory -l $ori_label_directory -w $width -h $height
|
||||
fi
|
||||
echo ""
|
||||
break
|
||||
;;
|
||||
|
||||
### 选项3 ###
|
||||
"image、label配对检测")
|
||||
while [ ! -d "$ori_image_directory" ]; do
|
||||
echo -e "\033[31mImage地址为:$ori_image_directorys,其不存在\033[0m"
|
||||
echo -n "请输入image所在地址:"
|
||||
read -r ori_image_directorys
|
||||
ori_image_directory=$(readlink -f "$ori_image_directorys")
|
||||
done
|
||||
while [ ! -d "$ori_label_directory" ]; do
|
||||
echo -e "\033[31mLabel地址为:$ori_label_directorys,其不存在\033[0m"
|
||||
echo -n "请输入label所在地址:"
|
||||
read -r ori_label_directorys
|
||||
ori_label_directory=$(readlink -f "$ori_label_directorys")
|
||||
done
|
||||
echo -n "请输入图片前缀文本(默认为\"\"):"
|
||||
read -r prefix # 禁止转译
|
||||
echo -n "请输入图片后缀文本(非.png类后缀名,默认为\"\"):"
|
||||
read -r suffix
|
||||
if [ -z $preffix ];then
|
||||
preffix=""
|
||||
fi
|
||||
if [ -z $suffix ];then
|
||||
suffix=""
|
||||
fi
|
||||
echo -e "\033[31mRun运行:\033[0mbash $script_path/3_pair_ori_label.sh -i $ori_image_directory -l $ori_label_directory -p $prefix -s $suffix"
|
||||
bash $script_path/3_pair_ori_label.sh -i $ori_image_directory -l $ori_label_directory -p "$prefix" -s "$suffix"
|
||||
echo ""
|
||||
break
|
||||
;;
|
||||
|
||||
### 选项4 ###
|
||||
"对label进行重建")
|
||||
# 判断Label目录是否存在
|
||||
while [ -z "$ori_label_directory" ]; do
|
||||
echo -e "\033[31mLabel地址为:$ori_label_directorys,其存在异常\033[0m"
|
||||
echo -n "请输入堆叠结果存储目录地址:"
|
||||
read -r ori_label_directorys
|
||||
ori_label_directory=$(readlink -f "$ori_label_directorys")
|
||||
done
|
||||
echo -e "\033[31mRun运行:\033[0mbash $script_path/4_rebuild_labels.sh -l $ori_label_directory"
|
||||
bash $script_path/4_rebuild_labels.sh -l $ori_label_directory
|
||||
echo ""
|
||||
break
|
||||
;;
|
||||
|
||||
### 选项5 ###
|
||||
"TOOL_image、label堆叠")
|
||||
while [ ! -d "$ori_image_directory" ]; do
|
||||
echo -e "\033[31mImage地址为:$ori_image_directorys,其不存在\033[0m"
|
||||
echo -n "请输入image所在地址:"
|
||||
read -r ori_image_directorys
|
||||
ori_image_directory=$(readlink -f "$ori_image_directorys")
|
||||
done
|
||||
while [ ! -d "$ori_label_directory" ]; do
|
||||
echo -e "\033[31mLabel地址为:$ori_label_directorys,其不存在\033[0m"
|
||||
echo -n "请输入label所在地址:"
|
||||
read -r ori_label_directorys
|
||||
ori_label_directory=$(readlink -f "$ori_label_directorys")
|
||||
done
|
||||
|
||||
echo -n "请输入堆叠图片透明程度(0~1,0为最透明,默认为0.3)"
|
||||
read -r alpha
|
||||
echo -n "请输入图片前缀文本(默认为\"\"):"
|
||||
read -r prefix
|
||||
echo -n "请输入图片后缀文本(非.png类后缀名,默认为\"\"):"
|
||||
read -r suffix
|
||||
if [ -z "$alpha" ]; then
|
||||
alpha="0.3"
|
||||
fi
|
||||
stack_result_directory=""$ori_label_directory"_堆叠_"$alpha"_透明度"
|
||||
echo -n "请输入堆叠结果存储目录(默认为$stack_result_directory):"
|
||||
read -r stack_result_directorys
|
||||
# 判断堆叠目录非为空则进行转换
|
||||
if [ ! -z "$stack_result_directorys" ]; then
|
||||
stack_result_directory=$(readlink -f "$stack_result_directorys")
|
||||
fi
|
||||
# 判断Label目录是否存在
|
||||
while [ -z "$ori_label_directory" ]; do
|
||||
echo -e "\033[31mLabel地址为:$ori_label_directorys,其存在异常\033[0m"
|
||||
echo -n "请输入堆叠结果存储目录地址:"
|
||||
read -r ori_label_directorys
|
||||
ori_label_directory=$(readlink -f "$ori_label_directorys")
|
||||
done
|
||||
echo -e "\033[31mRun运行:\033[0mbash $script_path/5_TOOL_stack_pics.sh -i $ori_image_directory -l $ori_label_directory -p $prefix -s $suffix -a $alpha -r $stack_result_directory"
|
||||
bash $script_path/5_TOOL_stack_pics.sh -i "$ori_image_directory" -l "$ori_label_directory" -p "$prefix" -s "$suffix" -a "$alpha" -r "$stack_result_directory"
|
||||
echo ""
|
||||
break
|
||||
;;
|
||||
|
||||
### 选项6 图片拼接 ###
|
||||
"TOOL_image、label拼接")
|
||||
while [ ! -d "$ori_image_directory" ]; do
|
||||
echo -e "\033[31mImage地址为:$ori_image_directorys,其不存在\033[0m"
|
||||
echo -n "请输入image所在地址:"
|
||||
read -r ori_image_directorys
|
||||
ori_image_directory=$(readlink -f "$ori_image_directorys")
|
||||
done
|
||||
echo "1.label目前为\"$ori_label_directory\""
|
||||
echo -n "是否调整label所在文件夹(默认为不调整,有输入视为调整):"
|
||||
read -r temp
|
||||
while [ ! -z "$temp" ]; do
|
||||
read -r ori_label_directorys
|
||||
ori_label_directorys=$(readlink -f "$ori_label_directorys")
|
||||
done
|
||||
while [ ! -d "$ori_label_directory" ]; do
|
||||
echo -e "\033[31mLabel地址为:$ori_label_directorys,其不存在\033[0m"
|
||||
echo -n "请输入label所在地址:"
|
||||
read -r ori_label_directorys
|
||||
ori_label_directory=$(readlink -f "$ori_label_directorys")
|
||||
done
|
||||
stitch_result_directory=""$ori_label_directory"_拼接"
|
||||
echo -n "2.请输入堆叠结果存储目录(默认为$stitch_result_directory):"
|
||||
read -r stitch_result_directorys
|
||||
# 判断堆叠目录非为空则进行转换
|
||||
if [ ! -z "$stitch_result_directorys" ]; then
|
||||
stitch_result_directory=$(readlink -f "$stitch_result_directorys")
|
||||
fi
|
||||
|
||||
echo -n "3.请输入图片前缀文本(默认为\"\"):"
|
||||
read -r prefix
|
||||
echo -n "4.请输入图片后缀文本(非.png类后缀名,默认为\"\"):"
|
||||
read -r suffix
|
||||
|
||||
echo -e "\033[31mRun运行:\033[0mbash $script_path/6_TOOL_stitch_pics.sh -i $ori_image_directory -l $ori_label_directory -p "$prefix" -s "$suffix" -r "$stitch_result_directory""
|
||||
bash $script_path/6_TOOL_stitch_pics.sh -i "$ori_image_directory" -l "$ori_label_directory" -p "$prefix" -s "$suffix" -r "$stitch_result_directory"
|
||||
exit 0
|
||||
;;
|
||||
|
||||
### 选项7 退出 ###
|
||||
"Quit")
|
||||
echo -e "\033[35mSeg_data_run.sh Exiting...\033[0m"
|
||||
exit 0
|
||||
;;
|
||||
|
||||
*)
|
||||
echo "Invalid option: $REPLY"
|
||||
echo -e ""
|
||||
break
|
||||
;;
|
||||
esac
|
||||
done
|
||||
done
|
||||
14
DataSet_Own/1. 图片预处理(内含使用手册)/※1_环境安装.txt
Executable file
14
DataSet_Own/1. 图片预处理(内含使用手册)/※1_环境安装.txt
Executable file
@@ -0,0 +1,14 @@
|
||||
# 创建环境
|
||||
conda create -n Deal_pics python=3.8
|
||||
sudo apt install imagemagick
|
||||
# 安装包
|
||||
conda activate Deal_pics
|
||||
pip install --upgrade pip
|
||||
pip install opencv-python
|
||||
|
||||
# 解决BUG
|
||||
# 命令:python cv2.Canny(image, 50, 150)
|
||||
# 错误:cv2.error: OpenCV(4.5.5) /io/opencv/modules/imgproc/src/canny.cpp:829: error: (-215:Assertion failed) _src.depth() == CV_8U in function 'Canny'
|
||||
# 错误原因:Canny函数只对0~255起作用,对0~1不起作用
|
||||
# 解决方案:image = (image*255).astype(np.uint8)
|
||||
|
||||
28
DataSet_Own/1. 图片预处理(内含使用手册)/※2_使用手册.txt
Executable file
28
DataSet_Own/1. 图片预处理(内含使用手册)/※2_使用手册.txt
Executable file
@@ -0,0 +1,28 @@
|
||||
# 相关程序所在位置:./1_Preprocess_pics(已加入路径)
|
||||
|
||||
1.对于图片批量重命名(允许只输入-i或-l一个参数)
|
||||
1_rename_pics.sh -i <ori_image_directory> -l <ori_label_directory> [-h]
|
||||
|
||||
2.对于图片大小批量修改(允许只输入-i或-l一个参数)
|
||||
2_reformate_pics.sh -i <ori_image_directory> -l <ori_label_directory> [ -w <width_of_pic> -h <height_of_pic> -help]
|
||||
默认:-w=1920 -h=1080
|
||||
|
||||
3.对于原始图片、标签图片进行配对(必须输入-i和-l两个参数,文件可带前缀或后缀)
|
||||
3_pair_ori_label.sh -i <ori_image_directory> -l <ori_label_directory> [ -p <prefix> -s <suffix> -h]
|
||||
默认:-p="" -s=""
|
||||
|
||||
4.对于标签图片进行进一步处理
|
||||
※ 如果Label有变,请修改 4_deal_labels.py main 中的 Annotate_CLASSES、Annotate_PALETTE
|
||||
4_rebuild_labels.sh -l <ori_label_directory> [ -h ]
|
||||
|
||||
5.TOOL - 对于image图像、label图片进行堆叠(透明度可调,文件可带前缀或后缀)
|
||||
5_TOOL_stack_pics.sh -i <ori_image_directory> -l <ori_label_directory> -r <stack_result_directory> [ -a <alpha> -p <prefix> -s <suffix> -h]
|
||||
默认:-a=0.3 -p="" -s=""
|
||||
|
||||
6.TOOL - 对于image图像、label图像进行左右放置(文件可带前缀或后缀)
|
||||
6_TOOL_stitch_pics.sh -i <ori_image_directory> -l <ori_label_directory> -r <stitch_result_directory> [ -p <prefix> -s <suffix> -h]
|
||||
|
||||
============ 用法 ============
|
||||
1. 将原始Label与现有label进行匹配,查看差异
|
||||
bash 6_TOOL_stitch_pics.sh -i ./A_Label -l ./A_Label_pro_label_fold -r ./A_Label_Compare -s _label
|
||||
|
||||
264
README.md
Normal file
264
README.md
Normal file
@@ -0,0 +1,264 @@
|
||||
# Seg 图像分割项目使用说明
|
||||
|
||||
本项目是一个多路线图像分割实验与推理工程,包含自研 `segmentation_models_pytorch` 训练流程、YOLO 分割流程、MMSegmentation 流程,以及数据预处理、视频抽帧、图片叠加和结果分析工具。
|
||||
|
||||
## 1. 项目结构
|
||||
|
||||
| 路径 | 用途 |
|
||||
| --- | --- |
|
||||
| `Seg_All_In_One_SegModel/` | 基于 `segmentation_models_pytorch` 的语义分割训练、推理、参数量/FLOPs/FPS 统计和输出完整性检查。 |
|
||||
| `Seg_All_In_One_YoloModel/` | 基于 Ultralytics YOLO 的实例/语义分割训练、推理、热力图可视化和横向对比。 |
|
||||
| `Seg_All_In_One_MMSeg/` | 基于 OpenMMLab MMSegmentation 的模型配置、训练和结果汇总流程。 |
|
||||
| `Seg_All_In_One_Analysis/` | 汇总不同模型的指标、FLOPs、FPS,生成表格和 mIoU/FPS 图。 |
|
||||
| `DataSet_Own/1. 图片预处理(内含使用手册)/` | 自有数据的重命名、尺寸统一、图像/标签配对、标签重建、叠加和拼接检查。 |
|
||||
| `Seg_Predict_Own_Video_V2/` | 将视频按固定间隔抽帧,转换为 `DataSet_Public/<dataset>/images/val` 格式。 |
|
||||
| `Tool-图片堆叠/` | 快速检查原图和标签是否匹配,并生成透明叠加图。 |
|
||||
| `Tool-可视化/` | YOLO 标签生成、热图、FPS 等可视化辅助脚本。 |
|
||||
| `Back_Up.sh` | 将算法目录同步到 `Hardisk/` 和 `Nas_BackUp_Seg/` 的本地/NAS 备份脚本。 |
|
||||
|
||||
大体量目录如 `DataSet_Public/`、`DataSet_Own/` 中的数据、`BestMode_Predict_Results_DataSet_Public/`、`Hardisk/`、`Nas_BackUp_Seg/`、模型权重和训练产物不进入 Git。
|
||||
|
||||
## 2. Conda 环境
|
||||
|
||||
推荐使用独立环境 `seg_smp`。本机已有可运行的 `SMP` 环境时,最快方式是克隆它:
|
||||
|
||||
```bash
|
||||
conda create --name seg_smp --clone SMP -y
|
||||
conda activate seg_smp
|
||||
python -V
|
||||
python -c "import torch; print(torch.cuda.is_available(), torch.__version__)"
|
||||
```
|
||||
|
||||
从零安装时可参考:
|
||||
|
||||
```bash
|
||||
conda create -n seg_smp python=3.9 -y
|
||||
conda activate seg_smp
|
||||
|
||||
pip install torch==2.8.0+cu129 torchvision==0.23.0+cu129 --index-url https://download.pytorch.org/whl/cu129
|
||||
pip install -r requirements-seg_smp.txt
|
||||
|
||||
cd Seg_All_In_One_MMSeg
|
||||
pip install -v -e .
|
||||
cd ..
|
||||
```
|
||||
|
||||
如果使用批量脚本,默认会激活 `seg_smp`。如需临时使用旧环境:
|
||||
|
||||
```bash
|
||||
SEG_CONDA_ENV=SMP bash yolo_train.sh
|
||||
```
|
||||
|
||||
环境验证:
|
||||
|
||||
```bash
|
||||
conda run -n seg_smp python -c "import torch, segmentation_models_pytorch, ultralytics, mmcv, mmengine, mmseg, cv2, albumentations; print('ok')"
|
||||
```
|
||||
|
||||
## 3. 数据约定
|
||||
|
||||
SegModel 默认读取:
|
||||
|
||||
```text
|
||||
DataSet_Public/<dataset>/
|
||||
images/train
|
||||
images/val
|
||||
labels_GT/train
|
||||
labels_GT/val
|
||||
```
|
||||
|
||||
YOLO 默认读取:
|
||||
|
||||
```text
|
||||
DataSet_Public/<dataset>/
|
||||
images/train
|
||||
images/val
|
||||
labels/train
|
||||
labels/val
|
||||
```
|
||||
|
||||
切换数据集时:
|
||||
|
||||
- SegModel:修改 `Seg_All_In_One_SegModel/config.py` 中的 `DATA_DIR`、`OUTPUTS_DIR`、`PREDICT_BEST_MODEL_DIR`、类别和图像尺寸。
|
||||
- YOLO:修改 `Seg_All_In_One_YoloModel/dataset.yaml` 中的 `path`、`train`、`val`、`test`、`names`;必要时修改 `yolo_config.py` 的训练参数。
|
||||
- MMSeg:按 `Seg_All_In_One_MMSeg/※使用手册/※2025_9_23_MMSeg使用手册` 生成数据集和算法配置。
|
||||
|
||||
## 4. SegModel 使用方式
|
||||
|
||||
```bash
|
||||
cd Seg_All_In_One_SegModel
|
||||
conda activate seg_smp
|
||||
|
||||
# 单模型训练
|
||||
CUDA_VISIBLE_DEVICES=0 python train.py -a Unet
|
||||
|
||||
# 批量训练
|
||||
bash train.sh
|
||||
|
||||
# 单模型推理
|
||||
CUDA_VISIBLE_DEVICES=0 python 1_predict.py -a Unet
|
||||
|
||||
# 批量推理
|
||||
bash predict.sh
|
||||
|
||||
# 参数量、FLOPs、FPS
|
||||
CUDA_VISIBLE_DEVICES=0 python 2_predict_params_and_FLOPs_V2.py
|
||||
|
||||
# 检查预测 raw mask 是否齐全
|
||||
python 1_predict_raw_masks_check.py
|
||||
```
|
||||
|
||||
可选模型包括:
|
||||
|
||||
```text
|
||||
Unet, UnetPlusPlus, FPN, PSPNet, DeepLabV3, DeepLabV3Plus,
|
||||
Linknet, MAnet, PAN, UPerNet, Segformer, DPT
|
||||
```
|
||||
|
||||
训练结果先写入 `DataSet_Public_outputs/<dataset>_outputs-SegModel/`,脚本结束后会移动到 `Hardisk/`;推理结果写入 `BestMode_Predict_Results_DataSet_Public/<dataset>_outputs-SegModel/`。
|
||||
|
||||
## 5. YOLO 使用方式
|
||||
|
||||
```bash
|
||||
cd Seg_All_In_One_YoloModel
|
||||
conda activate seg_smp
|
||||
|
||||
# 检查当前 dataset.yaml 解析出的路径
|
||||
python yolo_config.py
|
||||
|
||||
# 单模型训练
|
||||
CUDA_VISIBLE_DEVICES=0 python yolo_train.py --model "YOLOv8n-seg"
|
||||
|
||||
# 批量训练
|
||||
bash yolo_train.sh
|
||||
|
||||
# 复制最佳权重到预测目录
|
||||
bash ./Tool_Yolo_Copy_Best_Model.sh --pt_name "best.pt"
|
||||
|
||||
# 单模型推理
|
||||
CUDA_VISIBLE_DEVICES=0 python yolo_predict_V2.py --model "YOLOv8n-seg" --conf 0.2 --pt_name "best.pt"
|
||||
|
||||
# 批量推理
|
||||
bash yolo_predict.sh --conf 0.2 --pt_name "best.pt"
|
||||
|
||||
# 批量热图
|
||||
bash yolo_predict.sh --heatmap_method "All" --pt_name "best.pt"
|
||||
|
||||
# 横向对比
|
||||
python yolo_predict_V2_compare_all.py --pt_name "all"
|
||||
|
||||
# 检查预测 raw mask 是否齐全
|
||||
python yolo_predict_raw_masks_check.py --pt_name "best.pt"
|
||||
```
|
||||
|
||||
常用模型包括:
|
||||
|
||||
```text
|
||||
YOLOv8n-seg, YOLOv8s-seg, YOLOv8m-seg, YOLOv8l-seg, YOLOv8x-seg,
|
||||
YOLOv9c-seg, YOLOv9e-seg,
|
||||
YOLO11n-seg, YOLO11s-seg, YOLO11m-seg, YOLO11l-seg, YOLO11x-seg,
|
||||
YOLO12-seg
|
||||
```
|
||||
|
||||
## 6. MMSeg 使用方式
|
||||
|
||||
```bash
|
||||
cd Seg_All_In_One_MMSeg
|
||||
conda activate seg_smp
|
||||
|
||||
# 首次或新增模块后注册工程
|
||||
pip install -v -e .
|
||||
|
||||
# 下载/保存必要预训练权重
|
||||
python My_All_In_One/0_Initial_Save_All_Model_locally.py
|
||||
|
||||
# 生成数据配置
|
||||
python My_All_In_One/1_Initial_Data_All_data_from_1_Data_Parameter-V2.py
|
||||
|
||||
# 生成算法配置
|
||||
python My_All_In_One/2_Initial_Alg_All_data_from_2_Alg_Program-V2.py
|
||||
|
||||
# 参数量、FLOPs、FPS
|
||||
CUDA_VISIBLE_DEVICES=0 python My_All_In_One/4_1_predict_params_FLOPs_FPS_V2.py
|
||||
|
||||
# 指标汇总
|
||||
CUDA_VISIBLE_DEVICES=0 python My_All_In_One/4_2_predict_matrics_from_log_V2.py
|
||||
|
||||
# 生成预测图和表格
|
||||
CUDA_VISIBLE_DEVICES=0 python My_All_In_One/4_3_predict_draw_pictures_and_tabels.py
|
||||
|
||||
# 提取 loss 和 best mIoU
|
||||
CUDA_VISIBLE_DEVICES=0 python My_All_In_One/4_4_extract_loss_and_best_miou.py
|
||||
```
|
||||
|
||||
MMSeg 对 `mmcv/mmengine/mmsegmentation` 版本较敏感;若遇到 `mmcv` CUDA 算子或版本错误,优先参考 MMSeg 使用手册中的安装记录。
|
||||
|
||||
## 7. 数据预处理与视频抽帧
|
||||
|
||||
自有图片预处理:
|
||||
|
||||
```bash
|
||||
cd "DataSet_Own/1. 图片预处理(内含使用手册)"
|
||||
|
||||
bash 1_rename_pics.sh -i <ori_image_directory> -l <ori_label_directory>
|
||||
bash 2_reformate_pics.sh -i <ori_image_directory> -l <ori_label_directory> -w 1920 -h 1080
|
||||
bash 3_pair_ori_label.sh -i <ori_image_directory> -l <ori_label_directory>
|
||||
bash 4_rebuild_labels.sh -l <ori_label_directory>
|
||||
bash 5_TOOL_stack_pics.sh -i <ori_image_directory> -l <ori_label_directory> -r <stack_result_directory> -a 0.3
|
||||
bash 6_TOOL_stitch_pics.sh -i <ori_image_directory> -l <ori_label_directory> -r <stitch_result_directory>
|
||||
```
|
||||
|
||||
视频抽帧:
|
||||
|
||||
```bash
|
||||
cd Seg_Predict_Own_Video_V2
|
||||
python 1_Save_Frame_V2.py \
|
||||
--video ./LC_Video_1.mp4 \
|
||||
--resize "1920x1080" \
|
||||
--output_dir "../DataSet_Public/5_Predict_Video" \
|
||||
--interval 0.5
|
||||
```
|
||||
|
||||
输出路径为:
|
||||
|
||||
```text
|
||||
DataSet_Public/5_Predict_Video/<video_name>/images/val
|
||||
```
|
||||
|
||||
图片叠加检查:
|
||||
|
||||
```bash
|
||||
cd Tool-图片堆叠
|
||||
python 1_check_picture_pair.py -i ./ori -l ./label
|
||||
bash 2_TOOL_stack_pics.sh -i ./ori -l ./label -r ./result_0.3透明度 -a 0.3 -s _label
|
||||
```
|
||||
|
||||
## 8. 结果分析
|
||||
|
||||
```bash
|
||||
cd Seg_All_In_One_Analysis
|
||||
conda activate seg_smp
|
||||
|
||||
python 1_Analysis_All.py \
|
||||
--input_dir ../BestMode_Predict_Results_DataSet_Public \
|
||||
--output_dir ./
|
||||
```
|
||||
|
||||
脚本会交互选择数据集,合并 SegModel/MMSeg 的指标和速度数据,并生成 CSV、PNG、SVG 等分析结果。
|
||||
|
||||
## 9. 备份与 Git
|
||||
|
||||
本仓库只提交程序和轻量配置。以下内容不进入 Git:
|
||||
|
||||
- 数据集:`DataSet_Public/`、除预处理脚本外的 `DataSet_Own/`
|
||||
- 训练和预测产物:`DataSet_Public_outputs/`、`BestMode_Predict_Results_DataSet_Public/`、`Hardisk/`、`Nas_BackUp_Seg/`
|
||||
- 大权重和视频:`*.pt`、`*.pth`、`*.onnx`、`*.mp4`
|
||||
- Python/构建缓存:`__pycache__/`、`.pytest_cache/`、`build/`、`dist/`
|
||||
|
||||
常用 Git 检查:
|
||||
|
||||
```bash
|
||||
git status --short
|
||||
git ls-files | grep -E '\\.(pt|pth|mp4|zip)$|DataSet_Public|Hardisk|BestMode' || true
|
||||
```
|
||||
|
||||
298
Seg_All_In_One_Analysis/1_Analysis_All.py
Normal file
298
Seg_All_In_One_Analysis/1_Analysis_All.py
Normal file
@@ -0,0 +1,298 @@
|
||||
import os
|
||||
import glob
|
||||
import pandas as pd
|
||||
import matplotlib.pyplot as plt
|
||||
import seaborn as sns
|
||||
import re
|
||||
import argparse
|
||||
from collections import defaultdict
|
||||
|
||||
def get_model_family(model_name):
|
||||
"""
|
||||
根据模型名称提取模型族。
|
||||
例如: 'my_bisenetv1_r50' -> 'my_bisenetv1'
|
||||
'my_fast_scnn' -> 'my_fast_scnn'
|
||||
"""
|
||||
# 使用正则表达式匹配,将 _rXX 或 _dXX 等后缀去掉
|
||||
match = re.match(r'^(.*?)_r\d+$', model_name)
|
||||
if match:
|
||||
return match.group(1)
|
||||
return model_name
|
||||
|
||||
def select_dataset(results_dir):
|
||||
"""
|
||||
扫描目录,对数据集进行分组,让用户交互式选择一个数据集进行合并分析。
|
||||
"""
|
||||
print("正在扫描可用的数据集...")
|
||||
try:
|
||||
# 查找所有匹配后缀的目录
|
||||
all_dirs = glob.glob(os.path.join(results_dir, '*_outputs-MMSeg')) + \
|
||||
glob.glob(os.path.join(results_dir, '*_outputs-SegModel'))
|
||||
|
||||
if not all_dirs:
|
||||
print(f"在 '{results_dir}' 中未找到任何数据集目录 (以 '_outputs-MMSeg' 或 '_outputs-SegModel' 结尾)。")
|
||||
return None, None
|
||||
|
||||
# --- 新增逻辑:按基本数据集名称对目录进行分组 ---
|
||||
datasets_map = defaultdict(list)
|
||||
for dir_path in all_dirs:
|
||||
if os.path.isdir(dir_path):
|
||||
# 提取基本名称,例如 '1_CholecSeg8k-13Type-1920x1080'
|
||||
base_name = re.sub(r'_outputs-(MMSeg|SegModel)$', '', os.path.basename(dir_path))
|
||||
datasets_map[base_name].append(dir_path)
|
||||
|
||||
sorted_dataset_names = sorted(datasets_map.keys())
|
||||
|
||||
except Exception as e:
|
||||
print(f"扫描目录 '{results_dir}' 时出错: {e}")
|
||||
return None, None
|
||||
|
||||
print("\n请选择要合并分析的数据集:")
|
||||
for i, name in enumerate(sorted_dataset_names):
|
||||
# 显示每个数据集包含的源文件夹数量
|
||||
source_count = len(datasets_map[name])
|
||||
print(f" [{i+1}] {name} ({source_count}个源)")
|
||||
|
||||
while True:
|
||||
try:
|
||||
choice = input(f"\n请输入选项编号 (1-{len(sorted_dataset_names)}): ")
|
||||
choice_idx = int(choice) - 1
|
||||
if 0 <= choice_idx < len(sorted_dataset_names):
|
||||
selected_name = sorted_dataset_names[choice_idx]
|
||||
selected_dirs = datasets_map[selected_name] # 获取与所选数据集关联的所有目录
|
||||
return selected_dirs, selected_name
|
||||
else:
|
||||
print("无效的选项,请输入列表中的编号。")
|
||||
except (ValueError, IndexError):
|
||||
print("无效的输入,请输入一个数字编号。")
|
||||
except (KeyboardInterrupt, EOFError):
|
||||
print("\n操作已取消。")
|
||||
return None, None
|
||||
|
||||
def F1_plot_performance_speed(selected_dirs, dataset_name, output_base_dir):
|
||||
"""
|
||||
根据选定的数据集目录列表,加载并合并数据、生成图表和表格,并保存到指定的输出目录。
|
||||
|
||||
Args:
|
||||
selected_dirs (list): 用户选择的原始数据所在的所有目录的列表。
|
||||
dataset_name (str): 从目录名中提取的数据集名称。
|
||||
output_base_dir (str): 保存所有输出文件的根目录。
|
||||
"""
|
||||
print(f"\n正在为数据集 '{dataset_name}' 合并数据并生成图表...")
|
||||
|
||||
# 在指定的输出根目录下,为当前数据集创建一个专属的输出文件夹
|
||||
dataset_output_dir = os.path.join(output_base_dir, dataset_name)
|
||||
os.makedirs(dataset_output_dir, exist_ok=True)
|
||||
print(f"所有输出文件将被保存到: {dataset_output_dir}")
|
||||
|
||||
# --- 修改逻辑:从多个目录加载并合并数据 ---
|
||||
all_metrics = []
|
||||
all_fps = []
|
||||
|
||||
print("正在读取以下来源的数据:")
|
||||
for selected_dir in selected_dirs:
|
||||
print(f" - {os.path.basename(selected_dir)}")
|
||||
metrics_file = os.path.join(selected_dir, f"{dataset_name}_metrics_summary_wide.csv")
|
||||
fps_file = os.path.join(selected_dir, f"{dataset_name}_flops_params_fps_summary.csv")
|
||||
|
||||
# 检查文件是否存在
|
||||
if not os.path.exists(metrics_file) or not os.path.exists(fps_file):
|
||||
print(f" -> 警告: 在目录 '{os.path.basename(selected_dir)}' 中缺少数据文件,已跳过。")
|
||||
continue
|
||||
|
||||
try:
|
||||
metrics_df_part = pd.read_csv(metrics_file)
|
||||
all_metrics.append(metrics_df_part)
|
||||
|
||||
fps_df_part = pd.read_csv(fps_file)
|
||||
all_fps.append(fps_df_part)
|
||||
except Exception as e:
|
||||
print(f" -> 错误: 读取CSV文件时出错: {e}")
|
||||
continue
|
||||
|
||||
if not all_metrics or not all_fps:
|
||||
print("\n错误: 未能从任何有效的源目录中加载数据,无法继续生成报告。")
|
||||
return
|
||||
|
||||
# 合并来自所有源的数据
|
||||
metrics_df = pd.concat(all_metrics, ignore_index=True)
|
||||
fps_df = pd.concat(all_fps, ignore_index=True)
|
||||
print("\n数据合并完成。")
|
||||
|
||||
# 对合并后的数据进行去重处理
|
||||
if 'Epoch' in metrics_df.columns:
|
||||
metrics_df = metrics_df.sort_values('Epoch', ascending=False).drop_duplicates('Algorithm')
|
||||
else:
|
||||
metrics_df = metrics_df.drop_duplicates('Algorithm')
|
||||
|
||||
fps_df = fps_df.drop_duplicates('Model')
|
||||
|
||||
# 合并两个DataFrame
|
||||
merged_df = pd.merge(metrics_df, fps_df, left_on='Algorithm', right_on='Model', how='inner')
|
||||
|
||||
if merged_df.empty:
|
||||
print("错误: 数据合并失败。请检查 'Algorithm' 和 'Model' 列中的模型名称是否完全匹配。")
|
||||
print(f" - 指标文件中的模型: {metrics_df['Algorithm'].unique()}")
|
||||
print(f" - 性能文件中的模型: {fps_df['Model'].unique()}")
|
||||
return
|
||||
|
||||
# 调用函数创建并保存摘要表格到新的输出目录
|
||||
T1_create_and_save_summary_table(merged_df, dataset_output_dir, dataset_name)
|
||||
|
||||
# 调用函数来提取和保存所有IoU数据到新的输出目录
|
||||
T2_extract_and_save_iou_data(metrics_df, dataset_output_dir, dataset_name)
|
||||
|
||||
# 提取模型族
|
||||
merged_df['Family'] = merged_df['Model'].apply(get_model_family)
|
||||
|
||||
# --- 绘图 ---
|
||||
plt.style.use('seaborn-v0_8-whitegrid')
|
||||
fig, ax = plt.subplots(figsize=(16, 10))
|
||||
|
||||
# 定义颜色和标记
|
||||
families = sorted(merged_df['Family'].unique())
|
||||
palette = sns.color_palette("husl", len(families))
|
||||
markers = ['o', 's', 'X', 'D', '^', 'P', '*', 'v', '<', '>']
|
||||
|
||||
# 循环绘制每个模型族
|
||||
for i, family in enumerate(families):
|
||||
family_df = merged_df[merged_df['Family'] == family].sort_values('Average_FPS')
|
||||
color = palette[i]
|
||||
marker = markers[i % len(markers)]
|
||||
|
||||
# 绘制散点
|
||||
ax.scatter(family_df['Average_FPS'], family_df['mIoU'],
|
||||
color=color, marker=marker, s=150, label=family, zorder=3)
|
||||
|
||||
# 如果族内有多个模型,则用线连接
|
||||
if len(family_df) > 1:
|
||||
ax.plot(family_df['Average_FPS'], family_df['mIoU'],
|
||||
color=color, linestyle='--', linewidth=1.5, zorder=2)
|
||||
|
||||
# 在每个点旁边添加模型全名注释
|
||||
for j, row in family_df.iterrows():
|
||||
ax.text(row['Average_FPS'] * 1.01, row['mIoU'], row['Model'],
|
||||
fontsize=9, verticalalignment='center')
|
||||
|
||||
# 设置图表属性
|
||||
ax.set_title(f'Model Performance vs. Inference Speed ({dataset_name})', fontsize=18, pad=20)
|
||||
ax.set_xlabel('Inference Speed (FPS)', fontsize=14)
|
||||
ax.set_ylabel('Mean IoU (%)', fontsize=14)
|
||||
ax.legend(title='Model Family', bbox_to_anchor=(1.02, 1), loc='upper left', borderaxespad=0.)
|
||||
|
||||
plt.tight_layout(rect=[0, 0, 0.88, 1]) # 调整布局为图例留出空间
|
||||
plt.grid(True, which='both', linestyle='--', linewidth=0.5)
|
||||
|
||||
# 保存图表到新的输出目录
|
||||
output_filename_png = f"F1_{dataset_name}_mIoU_vs_FPS.png"
|
||||
save_file_path_png = os.path.join(dataset_output_dir, output_filename_png)
|
||||
plt.savefig(save_file_path_png, dpi=600)
|
||||
|
||||
output_filename_svg = f"F1_{dataset_name}_mIoU_vs_FPS.svg"
|
||||
save_file_path_svg = os.path.join(dataset_output_dir, output_filename_svg)
|
||||
plt.savefig(save_file_path_svg)
|
||||
|
||||
print(f"\n图表已成功生成并保存为: {save_file_path_svg} 和 {save_file_path_png}")
|
||||
plt.close(fig) # 关闭图形,避免在循环中使用时重复显示
|
||||
|
||||
def T1_create_and_save_summary_table(merged_df, output_dir, dataset_name):
|
||||
"""
|
||||
根据合并后的数据创建、格式化并保存性能摘要表格。
|
||||
"""
|
||||
print("正在创建摘要表格...")
|
||||
|
||||
# 检查所需列是否存在
|
||||
required_columns = ['Model', 'mIoU', 'mAcc', 'aAcc', 'Average_FPS', 'FLOPs', 'Params']
|
||||
if not all(col in merged_df.columns for col in required_columns):
|
||||
print("错误: DataFrame中缺少必要的列。请检查CSV文件内容。")
|
||||
print(f" - 需要的列: {required_columns}")
|
||||
print(f" - 实际的列: {merged_df.columns.tolist()}")
|
||||
return
|
||||
|
||||
# 提取并复制数据,避免修改原始DataFrame
|
||||
summary_df = merged_df[required_columns].copy()
|
||||
|
||||
# 清理和转换数据
|
||||
summary_df['FLOPs'] = summary_df['FLOPs'].astype(str).str.replace(r'\s*G', '', regex=True).astype(float)
|
||||
summary_df['Params'] = summary_df['Params'].astype(str).str.replace(r'\s*M', '', regex=True).astype(float)
|
||||
|
||||
# 按照用户的要求重命名列
|
||||
summary_df.rename(columns={
|
||||
'Average_FPS': 'FPS',
|
||||
'FLOPs': 'FLOPs(G)',
|
||||
'Params': 'Params(M)'
|
||||
}, inplace=True)
|
||||
|
||||
# 按 mIoU 降序排序
|
||||
summary_df = summary_df.sort_values(by='mIoU', ascending=False)
|
||||
|
||||
# 保存表格到CSV文件
|
||||
summary_filename = f"T1_{dataset_name}_performance_summary.csv"
|
||||
summary_save_path = os.path.join(output_dir, summary_filename)
|
||||
|
||||
try:
|
||||
summary_df.to_csv(summary_save_path, index=False, float_format='%.3f')
|
||||
print(f"摘要表格已成功保存到: {summary_save_path}")
|
||||
except Exception as e:
|
||||
print(f"保存摘要表格时出错: {e}")
|
||||
|
||||
def T2_extract_and_save_iou_data(metrics_df, output_dir, dataset_name):
|
||||
"""
|
||||
从 metrics DataFrame 中提取所有 mIoU 和 Class_IoU,并保存到新的CSV文件。
|
||||
"""
|
||||
print("正在提取所有 mIoU 和 Class_IoU 数据...")
|
||||
|
||||
# 检查'Algorithm'列是否存在
|
||||
if 'Algorithm' not in metrics_df.columns:
|
||||
print("错误: 'Algorithm' 列未找到,无法继续。")
|
||||
return
|
||||
|
||||
# 找出所有与IoU相关的列
|
||||
iou_columns = ['Algorithm', 'mIoU'] + [col for col in metrics_df.columns if col.endswith('_IoU') and col != 'mIoU']
|
||||
|
||||
# 移除重复的列名(以防万一)
|
||||
iou_columns = list(dict.fromkeys(iou_columns))
|
||||
|
||||
# 提取数据
|
||||
iou_df = metrics_df[iou_columns].copy()
|
||||
|
||||
# 按 mIoU 降序排序,便于查看
|
||||
if 'mIoU' in iou_df.columns:
|
||||
iou_df = iou_df.sort_values(by='mIoU', ascending=False)
|
||||
|
||||
# 定义并保存文件
|
||||
iou_filename = f"T2_{dataset_name}_all_iou_summary.csv"
|
||||
iou_save_path = os.path.join(output_dir, iou_filename)
|
||||
|
||||
try:
|
||||
iou_df.to_csv(iou_save_path, index=False, float_format='%.2f')
|
||||
print(f"所有IoU数据已成功保存到: {iou_save_path}")
|
||||
except Exception as e:
|
||||
print(f"保存IoU数据时出错: {e}")
|
||||
|
||||
if __name__ == '__main__':
|
||||
# --- 设置命令行参数解析 ---
|
||||
parser = argparse.ArgumentParser(description="从模型评估结果生成性能与速度对比图和摘要表。")
|
||||
parser.add_argument(
|
||||
'--input_dir',
|
||||
type=str,
|
||||
default='../BestMode_Predict_Results_DataSet_Public',
|
||||
help="包含所有数据集结果的根目录 (例如 '..._outputs-MMSeg' 或 '..._outputs-SegModel' 的父目录)。"
|
||||
)
|
||||
parser.add_argument(
|
||||
'--output_dir',
|
||||
type=str,
|
||||
default='./',
|
||||
help="用于存储所有生成的图表和表格的根目录。"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# 确保输出目录存在
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
# 启动交互式选择
|
||||
selected_directories, selected_dataset_name = select_dataset(args.input_dir)
|
||||
|
||||
# 如果用户成功选择,则生成图表和表格
|
||||
if selected_directories and selected_dataset_name:
|
||||
F1_plot_performance_speed(selected_directories, selected_dataset_name, args.output_dir)
|
||||
File diff suppressed because it is too large
Load Diff
|
After Width: | Height: | Size: 86 KiB |
@@ -0,0 +1,21 @@
|
||||
Model,mIoU,mAcc,aAcc,FPS,FLOPs(G),Params(M)
|
||||
UnetPlusPlus,96.860,95.430,99.750,11.940,590.910,26.080
|
||||
UPerNet,96.670,95.380,99.740,17.250,574.480,29.600
|
||||
MAnet,96.630,94.960,99.740,23.480,271.820,31.790
|
||||
Unet,96.590,95.520,99.730,26.290,253.380,24.440
|
||||
DeepLabV3Plus,96.500,94.990,99.730,33.210,252.410,22.440
|
||||
Linknet,96.460,94.550,99.720,32.820,161.800,21.770
|
||||
Segformer,96.450,94.880,99.720,21.020,209.450,21.880
|
||||
DeepLabV3,96.420,94.730,99.720,13.860,871.240,26.010
|
||||
FPN,96.410,94.740,99.720,34.920,219.570,23.160
|
||||
PAN,96.370,94.480,99.720,37.630,238.120,21.480
|
||||
DPT,96.310,94.900,99.710,1.900,1696.580,137.810
|
||||
PSPNet,96.010,94.610,99.690,79.510,76.810,21.490
|
||||
my_fastfcn_r50,89.740,94.210,97.830,10.620,1032.000,66.346
|
||||
my_icnet_r50,88.840,93.150,97.780,58.690,122.000,47.527
|
||||
my_icnet_r18,85.760,92.400,96.600,101.260,73.869,24.873
|
||||
my_bisenetv1_r50,82.640,89.980,95.690,13.630,784.000,56.867
|
||||
my_bisenetv1_r18,82.610,89.220,94.890,66.760,118.000,13.274
|
||||
my_bisenetv2,74.610,82.580,92.090,68.050,97.578,3.353
|
||||
my_fast_scnn,69.290,76.970,93.650,179.900,7.426,1.400
|
||||
my_en_bisenetv2,30.950,44.500,67.960,66.090,62.729,2.776
|
||||
|
@@ -0,0 +1,21 @@
|
||||
Algorithm,mIoU,10_IoU,11_IoU,12_IoU,1_IoU,2_IoU,3_IoU,4_IoU,5_IoU,6_IoU,7_IoU,8_IoU,9_IoU,背景_IoU
|
||||
UnetPlusPlus,96.86,84.08,71.33,99.41,96.18,96.27,94.45,91.76,90.90,92.24,93.28,79.37,98.01,97.99
|
||||
UPerNet,96.67,83.18,71.32,99.32,95.95,96.10,94.12,90.87,90.75,92.12,92.90,78.68,97.92,97.85
|
||||
MAnet,96.63,83.84,69.46,99.34,95.94,96.05,94.00,91.09,90.14,91.92,92.85,78.50,98.05,97.82
|
||||
Unet,96.59,83.11,71.34,99.34,95.88,96.01,93.90,90.62,90.34,91.67,92.45,79.29,97.73,97.81
|
||||
DeepLabV3Plus,96.50,83.07,71.81,99.28,95.86,95.85,93.77,89.92,90.31,91.77,92.17,78.45,97.69,97.73
|
||||
Linknet,96.46,81.22,71.54,99.33,95.69,95.83,93.64,89.90,90.38,91.54,92.29,78.48,97.57,97.71
|
||||
Segformer,96.45,81.30,67.75,99.27,95.74,95.85,93.56,89.94,89.98,91.87,92.42,78.16,97.80,97.73
|
||||
DeepLabV3,96.42,82.35,67.37,99.26,95.76,95.78,93.51,88.96,90.20,91.91,92.48,76.75,97.86,97.75
|
||||
FPN,96.41,82.43,69.98,99.24,95.71,95.74,93.61,89.88,90.22,91.74,92.23,77.50,97.76,97.69
|
||||
PAN,96.37,81.33,71.47,99.24,95.65,95.77,93.51,89.62,90.08,91.60,92.43,77.88,97.81,97.61
|
||||
DPT,96.31,83.17,70.59,99.27,95.49,95.64,93.48,89.71,90.34,91.36,92.53,78.20,97.63,97.53
|
||||
PSPNet,96.01,81.24,68.49,99.13,95.33,95.30,92.75,87.61,89.36,91.01,91.76,76.03,97.51,97.45
|
||||
my_fastfcn_r50,89.74,83.23,64.19,97.57,95.43,95.86,94.40,91.26,89.54,90.49,93.22,76.89,97.98,96.56
|
||||
my_icnet_r50,88.84,80.89,61.34,97.85,94.99,95.16,93.92,89.80,89.45,89.54,91.77,75.20,97.76,97.26
|
||||
my_icnet_r18,85.76,79.53,60.25,94.73,93.08,94.07,92.81,84.45,87.61,83.91,91.24,73.36,84.31,95.45
|
||||
my_bisenetv1_r50,82.64,80.56,67.31,97.75,91.70,91.28,85.29,83.91,84.91,73.37,80.36,71.34,74.45,92.07
|
||||
my_bisenetv1_r18,82.61,74.28,59.56,94.47,88.80,91.20,88.70,87.56,84.82,78.50,86.37,68.96,80.99,89.71
|
||||
my_bisenetv2,74.61,71.97,0.00,88.12,89.12,85.45,84.42,77.65,77.16,75.13,87.12,65.45,86.14,82.19
|
||||
my_fast_scnn,69.29,0.00,0.00,92.37,86.24,88.04,82.33,78.77,76.54,80.84,84.43,63.59,76.35,91.32
|
||||
my_en_bisenetv2,30.95,0.00,0.00,79.41,56.32,44.01,28.35,27.32,47.21,20.52,26.05,27.07,0.00,46.09
|
||||
|
File diff suppressed because it is too large
Load Diff
|
After Width: | Height: | Size: 84 KiB |
@@ -0,0 +1,21 @@
|
||||
Model,mIoU,mAcc,aAcc,FPS,FLOPs(G),Params(M)
|
||||
DeepLabV3,80.740,83.990,99.520,14.030,871.240,26.010
|
||||
PSPNet,79.980,83.730,99.500,79.650,76.810,21.490
|
||||
UPerNet,79.960,85.130,99.500,17.440,574.480,29.600
|
||||
PAN,79.730,83.990,99.500,38.020,238.120,21.480
|
||||
DeepLabV3Plus,79.610,85.070,99.480,33.420,252.410,22.440
|
||||
Segformer,79.250,83.200,99.480,21.050,209.450,21.880
|
||||
FPN,78.990,83.980,99.470,35.060,219.570,23.160
|
||||
MAnet,77.380,82.040,99.420,23.610,271.820,31.790
|
||||
UnetPlusPlus,77.250,81.010,99.440,12.080,590.910,26.080
|
||||
Unet,76.160,83.160,99.380,26.410,253.380,24.440
|
||||
Linknet,75.510,81.050,99.380,33.040,161.800,21.770
|
||||
my_fastfcn_r50,71.040,79.630,92.280,10.610,1032.000,66.346
|
||||
my_icnet_r50,70.900,78.660,94.020,59.150,122.000,47.526
|
||||
my_icnet_r18,64.370,76.040,91.130,102.830,73.857,24.873
|
||||
DPT,58.120,62.610,98.840,1.910,1696.580,137.810
|
||||
my_bisenetv1_r50,49.540,70.640,85.890,13.670,784.000,56.864
|
||||
my_bisenetv1_r18,43.630,51.500,86.400,67.190,118.000,13.273
|
||||
my_fast_scnn,35.470,53.070,78.230,178.010,7.426,1.400
|
||||
my_bisenetv2,30.770,46.870,67.040,68.880,97.479,3.350
|
||||
my_en_bisenetv2,21.060,28.780,81.280,66.830,62.629,2.773
|
||||
|
@@ -0,0 +1,21 @@
|
||||
Algorithm,mIoU,1_IoU,2_IoU,3_IoU,4_IoU,5_IoU,6_IoU,7_IoU,8_IoU,9_IoU,背景_IoU
|
||||
DeepLabV3,80.74,68.94,77.29,80.52,86.86,67.33,78.49,40.91,85.00,80.77,95.97
|
||||
PSPNet,79.98,66.02,78.51,79.55,86.79,65.31,81.79,43.71,88.19,78.72,95.59
|
||||
UPerNet,79.96,68.17,79.22,79.58,88.29,66.22,76.92,48.90,87.01,78.32,95.69
|
||||
PAN,79.73,68.17,79.87,80.10,87.75,67.79,80.61,45.17,85.84,77.10,95.55
|
||||
DeepLabV3Plus,79.61,67.65,80.67,79.04,86.41,67.82,78.38,45.17,84.93,78.48,95.51
|
||||
Segformer,79.25,70.26,80.48,79.32,86.77,64.76,77.48,40.30,86.94,76.90,95.67
|
||||
FPN,78.99,64.32,76.67,77.73,85.13,66.86,80.62,41.37,86.36,78.77,95.61
|
||||
MAnet,77.38,68.36,75.96,76.54,85.39,64.29,75.99,42.33,80.07,76.40,95.13
|
||||
UnetPlusPlus,77.25,68.41,80.79,78.11,88.39,61.29,75.66,43.16,78.42,73.51,95.01
|
||||
Unet,76.16,65.81,75.72,77.40,86.54,64.59,78.09,41.00,86.14,71.37,94.50
|
||||
Linknet,75.51,67.53,72.66,77.15,85.43,62.20,66.99,42.97,80.32,72.87,94.71
|
||||
my_fastfcn_r50,71.04,72.94,75.20,83.46,86.34,75.83,79.99,15.77,77.42,52.27,91.19
|
||||
my_icnet_r50,70.90,70.59,79.88,82.49,88.10,75.12,84.08,0.00,76.83,58.64,93.31
|
||||
my_icnet_r18,64.37,68.27,66.15,76.69,80.02,69.07,77.24,0.00,67.25,48.86,90.15
|
||||
DPT,58.12,35.86,57.90,52.58,69.75,25.48,51.48,11.85,64.40,61.94,90.98
|
||||
my_bisenetv1_r50,49.54,41.45,52.43,51.70,70.38,31.42,58.09,4.18,58.32,42.20,85.24
|
||||
my_bisenetv1_r18,43.63,40.48,33.63,57.69,68.59,18.73,41.95,5.47,46.50,37.73,85.56
|
||||
my_fast_scnn,35.47,17.13,31.73,32.59,67.72,13.07,44.63,0.00,39.41,30.97,77.47
|
||||
my_bisenetv2,30.77,13.45,36.67,16.43,60.36,13.90,36.03,0.00,37.48,28.07,65.35
|
||||
my_en_bisenetv2,21.06,1.90,15.66,34.99,36.61,5.25,18.82,0.00,14.46,0.37,82.57
|
||||
|
File diff suppressed because it is too large
Load Diff
|
After Width: | Height: | Size: 81 KiB |
@@ -0,0 +1,21 @@
|
||||
Model,mIoU,mAcc,aAcc,FPS,FLOPs(G),Params(M)
|
||||
FPN,77.150,91.110,99.500,204.850,27.550,23.160
|
||||
DeepLabV3,77.110,92.370,99.490,94.790,109.330,26.010
|
||||
PAN,76.600,90.580,99.480,234.070,29.880,21.480
|
||||
UPerNet,75.930,90.910,99.450,107.590,72.100,29.600
|
||||
UnetPlusPlus,75.800,90.720,99.460,89.430,74.150,26.080
|
||||
Segformer,75.080,88.260,99.450,151.560,26.280,21.880
|
||||
PSPNet,74.850,86.440,99.450,573.660,9.640,21.490
|
||||
Unet,73.860,89.190,99.410,173.520,31.800,24.440
|
||||
DeepLabV3Plus,73.830,86.410,99.420,208.780,31.680,22.440
|
||||
Linknet,73.790,87.770,99.410,197.430,20.300,21.770
|
||||
MAnet,73.630,89.900,99.400,152.330,33.850,31.790
|
||||
my_fastfcn_r50,61.430,90.120,97.480,71.100,130.000,66.346
|
||||
DPT,61.420,82.190,99.070,30.180,212.980,137.810
|
||||
my_bisenetv1_r50,59.590,84.700,95.770,88.970,98.945,56.862
|
||||
my_icnet_r50,57.840,80.930,94.950,179.660,15.428,47.526
|
||||
my_icnet_r18,57.350,88.250,96.590,268.050,9.360,24.873
|
||||
my_bisenetv1_r18,56.730,84.610,96.770,310.400,14.827,13.273
|
||||
my_bisenetv2,45.950,73.530,94.010,223.740,12.311,3.348
|
||||
my_fast_scnn,41.590,67.120,92.870,314.130,0.936,1.400
|
||||
my_en_bisenetv2,26.470,47.770,88.930,167.350,7.907,2.771
|
||||
|
@@ -0,0 +1,21 @@
|
||||
Algorithm,mIoU,1_IoU,2_IoU,3_IoU,4_IoU,6_IoU,背景_IoU,5_IoU,7_IoU
|
||||
FPN,77.15,68.36,56.28,89.32,75.57,90.63,97.53,0.00,0.00
|
||||
DeepLabV3,77.11,70.40,56.63,88.94,75.66,91.14,97.59,0.00,0.00
|
||||
PAN,76.60,67.15,58.83,88.96,66.32,93.62,97.63,0.00,0.00
|
||||
UPerNet,75.93,70.51,55.94,87.54,66.97,93.12,97.42,0.00,0.00
|
||||
UnetPlusPlus,75.80,71.10,53.32,89.23,62.62,92.12,97.61,0.00,0.00
|
||||
Segformer,75.08,69.28,51.50,89.63,60.44,89.39,97.46,0.00,0.00
|
||||
PSPNet,74.85,69.13,59.87,89.43,41.93,88.80,97.14,0.00,0.00
|
||||
Unet,73.86,70.85,47.93,88.16,65.93,81.26,97.40,0.00,0.00
|
||||
DeepLabV3Plus,73.83,67.50,55.09,88.72,43.60,89.56,97.30,0.00,0.00
|
||||
Linknet,73.79,70.46,51.28,88.14,61.32,74.38,97.41,0.00,0.00
|
||||
MAnet,73.63,69.37,50.95,86.99,64.69,88.41,97.27,0.00,0.00
|
||||
my_fastfcn_r50,61.43,75.72,61.00,89.72,74.27,92.82,97.88,,
|
||||
DPT,61.42,53.96,44.36,74.11,42.46,72.76,95.73,0.00,0.00
|
||||
my_bisenetv1_r50,59.59,55.03,44.88,81.05,56.63,82.73,96.78,,
|
||||
my_icnet_r50,57.84,55.85,40.23,81.41,59.23,72.62,95.51,,
|
||||
my_icnet_r18,57.35,67.17,51.83,88.81,66.89,87.06,97.04,,
|
||||
my_bisenetv1_r18,56.73,65.55,56.03,89.65,54.60,90.81,97.18,,
|
||||
my_bisenetv2,45.95,35.33,43.64,75.24,32.28,86.09,95.00,,
|
||||
my_fast_scnn,41.59,33.67,19.74,60.73,38.78,84.75,95.02,,
|
||||
my_en_bisenetv2,26.47,21.72,3.36,51.58,1.01,42.74,91.37,,
|
||||
|
File diff suppressed because it is too large
Load Diff
|
After Width: | Height: | Size: 84 KiB |
@@ -0,0 +1,21 @@
|
||||
Model,mIoU,mAcc,aAcc,FPS,FLOPs(G),Params(M)
|
||||
MAnet,93.900,84.990,99.210,151.600,33.850,31.790
|
||||
Segformer,93.280,82.550,99.130,151.930,26.280,21.880
|
||||
UnetPlusPlus,92.970,82.760,99.090,90.260,74.150,26.080
|
||||
FPN,92.670,82.440,99.050,207.350,27.550,23.160
|
||||
DeepLabV3,92.550,82.140,99.030,92.870,109.330,26.010
|
||||
Unet,92.530,79.230,99.030,177.100,31.800,24.440
|
||||
PAN,92.480,81.380,99.020,232.430,29.880,21.480
|
||||
UPerNet,92.180,80.020,98.980,105.860,72.100,29.600
|
||||
Linknet,92.060,79.170,98.970,199.680,20.300,21.770
|
||||
PSPNet,91.940,76.940,98.950,578.550,9.640,21.490
|
||||
DeepLabV3Plus,91.500,78.630,98.890,213.500,31.680,22.440
|
||||
DPT,87.840,72.480,98.380,30.860,212.980,137.810
|
||||
my_fastfcn_r50,55.340,84.060,96.630,71.420,130.000,66.346
|
||||
my_icnet_r50,50.400,78.440,95.210,202.090,15.428,47.526
|
||||
my_bisenetv1_r50,49.620,78.030,96.150,88.850,98.945,56.862
|
||||
my_icnet_r18,47.540,76.700,94.100,275.600,9.360,24.873
|
||||
my_bisenetv1_r18,45.020,67.190,95.580,346.950,14.827,13.273
|
||||
my_bisenetv2,38.850,65.830,93.150,243.230,12.311,3.348
|
||||
my_fast_scnn,36.200,61.550,92.870,381.410,0.936,1.400
|
||||
my_en_bisenetv2,21.760,41.090,86.700,203.200,7.907,2.771
|
||||
|
@@ -0,0 +1,21 @@
|
||||
Algorithm,mIoU,1_IoU,2_IoU,3_IoU,4_IoU,6_IoU,背景_IoU,5_IoU,7_IoU
|
||||
MAnet,93.90,70.53,46.86,88.84,50.14,73.06,97.91,0.00,0.00
|
||||
Segformer,93.28,68.10,50.06,85.56,36.53,75.76,97.56,0.00,0.00
|
||||
UnetPlusPlus,92.97,64.84,47.57,81.19,59.57,62.17,97.61,0.00,0.00
|
||||
FPN,92.67,63.45,39.01,89.55,41.62,55.78,97.70,0.00,0.00
|
||||
DeepLabV3,92.55,68.69,38.16,83.63,42.43,56.29,97.59,0.00,0.00
|
||||
Unet,92.53,64.81,44.44,81.98,34.55,60.02,97.52,0.00,0.00
|
||||
PAN,92.48,64.05,37.24,81.88,47.61,63.51,97.47,0.00,0.00
|
||||
UPerNet,92.18,68.07,37.63,83.54,39.83,44.67,97.68,0.00,0.00
|
||||
Linknet,92.06,57.25,42.14,86.36,31.33,61.26,97.65,0.00,0.00
|
||||
PSPNet,91.94,62.48,37.34,82.50,19.62,61.92,97.44,0.00,0.00
|
||||
DeepLabV3Plus,91.50,62.77,36.12,76.23,40.89,55.95,97.32,0.00,0.00
|
||||
DPT,87.84,52.03,29.80,67.47,24.96,44.50,94.84,0.00,0.00
|
||||
my_fastfcn_r50,55.34,71.20,55.69,85.87,60.47,72.04,97.47,,
|
||||
my_icnet_r50,50.40,61.35,42.53,78.21,62.38,62.30,96.39,,
|
||||
my_bisenetv1_r50,49.62,63.38,40.38,82.82,47.85,64.80,97.78,,
|
||||
my_icnet_r18,47.54,48.47,27.63,83.13,55.34,70.40,95.37,,
|
||||
my_bisenetv1_r18,45.02,60.13,35.97,81.80,30.29,54.89,97.10,,
|
||||
my_bisenetv2,38.85,47.72,28.29,73.87,24.03,41.84,95.06,,
|
||||
my_fast_scnn,36.20,36.31,19.69,59.68,31.48,46.85,95.56,,
|
||||
my_en_bisenetv2,21.76,19.78,7.06,38.84,0.06,18.16,90.16,,
|
||||
|
File diff suppressed because it is too large
Load Diff
|
After Width: | Height: | Size: 84 KiB |
@@ -0,0 +1,21 @@
|
||||
Model,mIoU,mAcc,aAcc,FPS,FLOPs(G),Params(M)
|
||||
MAnet,92.710,58.340,99.310,152.290,33.850,31.790
|
||||
my_fastfcn_r50,37.810,51.890,96.480,71.040,130.000,66.346
|
||||
my_icnet_r50,35.930,54.750,95.850,193.830,15.430,47.526
|
||||
my_icnet_r18,34.840,47.480,95.570,286.580,9.362,24.873
|
||||
my_bisenetv1_r50,33.600,48.180,95.950,88.460,98.957,56.865
|
||||
PAN,32.230,52.260,99.570,238.750,29.880,21.480
|
||||
FPN,31.480,50.620,99.500,208.170,27.550,23.160
|
||||
UPerNet,30.810,56.260,99.540,108.960,72.100,29.600
|
||||
PSPNet,30.460,48.110,99.580,586.440,9.640,21.490
|
||||
DeepLabV3,30.440,45.870,99.550,96.470,109.330,26.010
|
||||
DeepLabV3Plus,30.390,53.560,99.520,218.940,31.680,22.440
|
||||
Segformer,30.280,50.430,99.550,153.490,26.280,21.880
|
||||
my_bisenetv1_r18,29.660,36.130,96.230,316.660,14.830,13.273
|
||||
Unet,29.560,48.490,99.500,177.730,31.800,24.440
|
||||
UnetPlusPlus,29.020,46.550,99.530,91.560,74.150,26.080
|
||||
Linknet,27.440,45.720,99.520,202.960,20.300,21.770
|
||||
my_bisenetv2,26.790,48.100,94.290,220.480,12.323,3.351
|
||||
my_fast_scnn,24.240,38.810,94.450,318.670,0.936,1.400
|
||||
DPT,12.740,27.010,99.490,30.930,212.980,137.810
|
||||
my_en_bisenetv2,12.710,29.950,85.020,202.950,7.919,2.774
|
||||
|
@@ -0,0 +1,21 @@
|
||||
Algorithm,mIoU,10_IoU,1_IoU,2_IoU,4_IoU,5_IoU,6_IoU,7_IoU,8_IoU,9_IoU,背景_IoU,3_IoU
|
||||
MAnet,92.71,35.32,26.82,27.63,69.93,17.24,1.35,68.73,39.18,18.98,96.27,0.00
|
||||
my_fastfcn_r50,37.81,38.43,20.81,34.67,71.51,20.69,0.00,63.65,45.78,23.86,96.53,
|
||||
my_icnet_r50,35.93,37.37,25.03,26.86,63.67,21.21,1.01,64.49,41.90,17.85,95.88,
|
||||
my_icnet_r18,34.84,37.50,29.12,26.10,67.33,17.58,0.17,64.91,38.94,5.98,95.65,
|
||||
my_bisenetv1_r50,33.60,33.75,23.22,33.49,57.73,7.82,0.54,61.04,39.39,16.52,96.05,
|
||||
PAN,32.23,24.44,27.98,17.51,55.65,14.65,1.28,52.96,33.81,14.59,95.93,0.00
|
||||
FPN,31.48,26.46,27.38,16.61,54.84,13.75,0.10,52.75,30.84,14.13,95.44,0.00
|
||||
UPerNet,30.81,25.82,23.85,18.94,48.88,19.74,3.77,55.93,28.24,13.85,95.60,0.00
|
||||
PSPNet,30.46,22.59,24.88,13.77,52.96,10.33,2.15,57.02,32.44,11.20,95.99,0.00
|
||||
DeepLabV3,30.44,20.32,23.59,9.83,53.98,12.70,0.00,54.00,33.47,12.89,95.82,0.00
|
||||
DeepLabV3Plus,30.39,21.50,25.78,16.46,55.69,14.54,1.83,49.53,31.38,9.97,95.64,0.00
|
||||
Segformer,30.28,29.15,22.06,13.37,52.94,11.50,0.52,58.03,40.69,16.09,95.76,0.00
|
||||
my_bisenetv1_r18,29.66,31.09,13.25,24.55,57.82,7.06,0.00,41.10,38.13,16.98,96.25,
|
||||
Unet,29.56,24.81,26.57,8.55,47.35,16.95,0.00,41.73,32.39,14.65,95.35,0.00
|
||||
UnetPlusPlus,29.02,15.46,25.18,9.08,49.10,17.60,0.33,56.31,31.64,12.25,95.60,0.00
|
||||
Linknet,27.44,20.81,22.23,14.15,48.85,15.84,0.01,54.10,28.56,10.97,95.59,0.00
|
||||
my_bisenetv2,26.79,15.92,20.58,16.00,53.72,10.36,0.42,37.73,38.00,7.52,94.43,
|
||||
my_fast_scnn,24.24,0.45,20.35,10.26,57.55,0.00,0.47,33.80,24.96,0.00,94.61,
|
||||
DPT,12.74,0.00,2.63,0.00,30.18,0.00,0.22,13.80,12.37,0.00,95.15,0.00
|
||||
my_en_bisenetv2,12.71,0.00,12.19,0.00,35.31,0.00,0.14,1.78,4.69,0.00,85.72,
|
||||
|
8
Seg_All_In_One_MMSeg/CITATION.cff
Normal file
8
Seg_All_In_One_MMSeg/CITATION.cff
Normal file
@@ -0,0 +1,8 @@
|
||||
cff-version: 1.2.0
|
||||
message: "If you use this software, please cite it as below."
|
||||
authors:
|
||||
- name: "MMSegmentation Contributors"
|
||||
title: "OpenMMLab Semantic Segmentation Toolbox and Benchmark"
|
||||
date-released: 2020-07-10
|
||||
url: "https://github.com/open-mmlab/mmsegmentation"
|
||||
license: Apache-2.0
|
||||
203
Seg_All_In_One_MMSeg/LICENSE
Normal file
203
Seg_All_In_One_MMSeg/LICENSE
Normal file
@@ -0,0 +1,203 @@
|
||||
Copyright 2020 The MMSegmentation Authors. All rights reserved.
|
||||
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright 2020 The MMSegmentation Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
5
Seg_All_In_One_MMSeg/MANIFEST.in
Normal file
5
Seg_All_In_One_MMSeg/MANIFEST.in
Normal file
@@ -0,0 +1,5 @@
|
||||
include requirements/*.txt
|
||||
include mmseg/.mim/model-index.yml
|
||||
include mmseg/utils/bpe_simple_vocab_16e6.txt.gz
|
||||
recursive-include mmseg/.mim/configs *.py *.yaml
|
||||
recursive-include mmseg/.mim/tools *.py *.sh
|
||||
@@ -0,0 +1,297 @@
|
||||
import os, requests, hashlib
|
||||
from tqdm import tqdm
|
||||
|
||||
### 链接获取网址:https://github.com/open-mmlab/mmcv/blob/master/mmcv/model_zoo/[deprecated.json | mmcls.json | open_mmlab.json | torchvision_0.12.json] ###
|
||||
|
||||
# open_mmlab JSON 数据
|
||||
open_mmlab_model_urls = {
|
||||
"vgg16_caffe": "https://download.openmmlab.com/pretrain/third_party/vgg16_caffe-292e1171.pth",
|
||||
"detectron/resnet50_caffe": "https://download.openmmlab.com/pretrain/third_party/resnet50_caffe-788b5fa3.pth",
|
||||
"detectron2/resnet50_caffe": "https://download.openmmlab.com/pretrain/third_party/resnet50_msra-5891d200.pth",
|
||||
"detectron/resnet101_caffe": "https://download.openmmlab.com/pretrain/third_party/resnet101_caffe-3ad79236.pth",
|
||||
"detectron2/resnet101_caffe": "https://download.openmmlab.com/pretrain/third_party/resnet101_msra-6cc46731.pth",
|
||||
"detectron2/resnext101_32x8d": "https://download.openmmlab.com/pretrain/third_party/resnext101_32x8d-1516f1aa.pth",
|
||||
"resnext50_32x4d": "https://download.openmmlab.com/pretrain/third_party/resnext50-32x4d-0ab1a123.pth",
|
||||
"resnext101_32x4d": "https://download.openmmlab.com/pretrain/third_party/resnext101_32x4d-a5af3160.pth",
|
||||
"resnext101_64x4d": "https://download.openmmlab.com/pretrain/third_party/resnext101_64x4d-ee2c6f71.pth",
|
||||
"contrib/resnet50_gn": "https://download.openmmlab.com/pretrain/third_party/resnet50_gn_thangvubk-ad1730dd.pth",
|
||||
"detectron/resnet50_gn": "https://download.openmmlab.com/pretrain/third_party/resnet50_gn-9186a21c.pth",
|
||||
"detectron/resnet101_gn": "https://download.openmmlab.com/pretrain/third_party/resnet101_gn-cac0ab98.pth",
|
||||
"jhu/resnet50_gn_ws": "https://download.openmmlab.com/pretrain/third_party/resnet50_gn_ws-15beedd8.pth",
|
||||
"jhu/resnet101_gn_ws": "https://download.openmmlab.com/pretrain/third_party/resnet101_gn_ws-3e3c308c.pth",
|
||||
"jhu/resnext50_32x4d_gn_ws": "https://download.openmmlab.com/pretrain/third_party/resnext50_32x4d_gn_ws-0d87ac85.pth",
|
||||
"jhu/resnext101_32x4d_gn_ws": "https://download.openmmlab.com/pretrain/third_party/resnext101_32x4d_gn_ws-34ac1a9e.pth",
|
||||
"jhu/resnext50_32x4d_gn": "https://download.openmmlab.com/pretrain/third_party/resnext50_32x4d_gn-c7e8b754.pth",
|
||||
"jhu/resnext101_32x4d_gn": "https://download.openmmlab.com/pretrain/third_party/resnext101_32x4d_gn-ac3bb84e.pth",
|
||||
"msra/hrnetv2_w18_small": "https://download.openmmlab.com/pretrain/third_party/hrnetv2_w18_small-b5a04e21.pth",
|
||||
"msra/hrnetv2_w18": "https://download.openmmlab.com/pretrain/third_party/hrnetv2_w18-00eb2006.pth",
|
||||
"msra/hrnetv2_w32": "https://download.openmmlab.com/pretrain/third_party/hrnetv2_w32-dc9eeb4f.pth",
|
||||
"msra/hrnetv2_w40": "https://download.openmmlab.com/pretrain/third_party/hrnetv2_w40-ed0b031c.pth",
|
||||
"msra/hrnetv2_w48": "https://download.openmmlab.com/pretrain/third_party/hrnetv2_w48-d2186c55.pth",
|
||||
"bninception_caffe": "https://download.openmmlab.com/pretrain/third_party/bn_inception_caffe-ed2e8665.pth",
|
||||
"kin400/i3d_r50_f32s2_k400": "https://download.openmmlab.com/pretrain/third_party/i3d_r50_f32s2_k400-2c57e077.pth",
|
||||
"kin400/nl3d_r50_f32s2_k400": "https://download.openmmlab.com/pretrain/third_party/nl3d_r50_f32s2_k400-fa7e7caa.pth",
|
||||
"res2net101_v1d_26w_4s": "https://download.openmmlab.com/pretrain/third_party/res2net101_v1d_26w_4s_mmdetv2-f0a600f9.pth",
|
||||
"regnetx_400mf": "https://download.openmmlab.com/pretrain/third_party/regnetx_400mf-a5b10d96.pth",
|
||||
"regnetx_800mf": "https://download.openmmlab.com/pretrain/third_party/regnetx_800mf-1f4be4c7.pth",
|
||||
"regnetx_1.6gf": "https://download.openmmlab.com/pretrain/third_party/regnetx_1.6gf-5791c176.pth",
|
||||
"regnetx_3.2gf": "https://download.openmmlab.com/pretrain/third_party/regnetx_3.2gf-c2599b0f.pth",
|
||||
"regnetx_4.0gf": "https://download.openmmlab.com/pretrain/third_party/regnetx_4.0gf-a88f671e.pth",
|
||||
"regnetx_6.4gf": "https://download.openmmlab.com/pretrain/third_party/regnetx_6.4gf-006af45d.pth",
|
||||
"regnetx_8.0gf": "https://download.openmmlab.com/pretrain/third_party/regnetx_8.0gf-3c68abe7.pth",
|
||||
"regnetx_12gf": "https://download.openmmlab.com/pretrain/third_party/regnetx_12gf-4c2a3350.pth",
|
||||
"resnet18_v1c": "https://download.openmmlab.com/pretrain/third_party/resnet18_v1c-b5776b93.pth",
|
||||
"resnet50_v1c": "https://download.openmmlab.com/pretrain/third_party/resnet50_v1c-2cccc1ad.pth",
|
||||
"resnet101_v1c": "https://download.openmmlab.com/pretrain/third_party/resnet101_v1c-e67eebb6.pth",
|
||||
"mmedit/vgg16": "https://download.openmmlab.com/mmediting/third_party/vgg_state_dict.pth",
|
||||
"mmedit/res34_en_nomixup": "https://download.openmmlab.com/mmediting/third_party/model_best_resnet34_En_nomixup.pth",
|
||||
"mmedit/mobilenet_v2": "https://download.openmmlab.com/mmediting/third_party/mobilenet_v2.pth",
|
||||
"contrib/mobilenet_v3_large": "https://download.openmmlab.com/pretrain/third_party/mobilenet_v3_large-bc2c3fd3.pth",
|
||||
"contrib/mobilenet_v3_small": "https://download.openmmlab.com/pretrain/third_party/mobilenet_v3_small-47085aa1.pth",
|
||||
"resnest50": "https://download.openmmlab.com/pretrain/third_party/resnest50_d2-7497a55b.pth",
|
||||
"resnest101": "https://download.openmmlab.com/pretrain/third_party/resnest101_d2-f3b931b2.pth",
|
||||
"resnest200": "https://download.openmmlab.com/pretrain/third_party/resnest200_d2-ca88e41f.pth",
|
||||
"darknet53": "https://download.openmmlab.com/pretrain/third_party/darknet53-a628ea1b.pth",
|
||||
"mmdet/mobilenet_v2": "https://download.openmmlab.com/mmdetection/v2.0/third_party/mobilenet_v2_batch256_imagenet-ff34753d.pth",
|
||||
"pidnet-s": "https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/pidnet/pidnet-s_imagenet1k_20230306-715e6273.pth",
|
||||
"pidnet-m": "https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/pidnet/pidnet-m_imagenet1k_20230306-39893c52.pth",
|
||||
"pidnet-l": "https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/pidnet/pidnet-l_imagenet1k_20230306-67889109.pth",
|
||||
"ddrnet23-s": "https://download.openmmlab.com/mmsegmentation/v0.5/ddrnet/pretrain/ddrnet23s-in1kpre_3rdparty-1ccac5b1.pth",
|
||||
"ddrnet23": "https://download.openmmlab.com/mmsegmentation/v0.5/ddrnet/pretrain/ddrnet23-in1kpre_3rdparty-9ca29f62.pth",
|
||||
"stdc1": "https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/stdc/stdc1_20220308-5368626c.pth",
|
||||
"stdc2": "https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/stdc/stdc2_20220308-7dbd9127.pth"
|
||||
}
|
||||
|
||||
# deprecated_model_urls = {{
|
||||
# "resnet50_caffe": "detectron/resnet50_caffe",
|
||||
# "resnet50_caffe_bgr": "detectron2/resnet50_caffe_bgr",
|
||||
# "resnet101_caffe": "detectron/resnet101_caffe",
|
||||
# "resnet101_caffe_bgr": "detectron2/resnet101_caffe_bgr"
|
||||
# }}
|
||||
|
||||
mmcls_model_urls = {
|
||||
"vgg11": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg11_batch256_imagenet_20210208-4271cd6c.pth",
|
||||
"vgg13": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg13_batch256_imagenet_20210208-4d1d6080.pth",
|
||||
"vgg16": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg16_batch256_imagenet_20210208-db26f1a5.pth",
|
||||
"vgg19": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg19_batch256_imagenet_20210208-e6920e4a.pth",
|
||||
"vgg11_bn": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg11_bn_batch256_imagenet_20210207-f244902c.pth",
|
||||
"vgg13_bn": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg13_bn_batch256_imagenet_20210207-1a8b7864.pth",
|
||||
"vgg16_bn": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg16_bn_batch256_imagenet_20210208-7e55cd29.pth",
|
||||
"vgg19_bn": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg19_bn_batch256_imagenet_20210208-da620c4f.pth",
|
||||
"resnet18": "https://download.openmmlab.com/mmclassification/v0/resnet/resnet18_8xb32_in1k_20210831-fbbb1da6.pth",
|
||||
"resnet34": "https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth",
|
||||
"resnet50": "https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth",
|
||||
"resnet101": "https://download.openmmlab.com/mmclassification/v0/resnet/resnet101_8xb32_in1k_20210831-539c63f8.pth",
|
||||
"resnet152": "https://download.openmmlab.com/mmclassification/v0/resnet/resnet152_8xb32_in1k_20210901-4d7582fa.pth",
|
||||
"resnet50_v1d": "https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1d50_b32x8_imagenet_20210531-db14775a.pth",
|
||||
"resnet101_v1d": "https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1d101_b32x8_imagenet_20210531-6e13bcd3.pth",
|
||||
"resnet152_v1d": "https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1d152_b32x8_imagenet_20210531-278cf22a.pth",
|
||||
"resnext50_32x4d": "https://download.openmmlab.com/mmclassification/v0/resnext/resnext50_32x4d_b32x8_imagenet_20210429-56066e27.pth",
|
||||
"resnext101_32x4d": "https://download.openmmlab.com/mmclassification/v0/resnext/resnext101_32x4d_b32x8_imagenet_20210506-e0fa3dd5.pth",
|
||||
"resnext101_32x8d": "https://download.openmmlab.com/mmclassification/v0/resnext/resnext101_32x8d_b32x8_imagenet_20210506-23a247d5.pth",
|
||||
"resnext152_32x4d": "https://download.openmmlab.com/mmclassification/v0/resnext/resnext152_32x4d_b32x8_imagenet_20210524-927787be.pth",
|
||||
"se-resnet50": "https://download.openmmlab.com/mmclassification/v0/se-resnet/se-resnet50_batch256_imagenet_20200804-ae206104.pth",
|
||||
"se-resnet101": "https://download.openmmlab.com/mmclassification/v0/se-resnet/se-resnet101_batch256_imagenet_20200804-ba5b51d4.pth",
|
||||
"resnest50": "https://download.openmmlab.com/mmclassification/v0/resnest/resnest50_imagenet_converted-1ebf0afe.pth",
|
||||
"resnest101": "https://download.openmmlab.com/mmclassification/v0/resnest/resnest101_imagenet_converted-032caa52.pth",
|
||||
"resnest200": "https://download.openmmlab.com/mmclassification/v0/resnest/resnest200_imagenet_converted-581a60f2.pth",
|
||||
"resnest269": "https://download.openmmlab.com/mmclassification/v0/resnest/resnest269_imagenet_converted-59930960.pth",
|
||||
"shufflenet_v1": "https://download.openmmlab.com/mmclassification/v0/shufflenet_v1/shufflenet_v1_batch1024_imagenet_20200804-5d6cec73.pth",
|
||||
"shufflenet_v2": "https://download.openmmlab.com/mmclassification/v0/shufflenet_v2/shufflenet_v2_batch1024_imagenet_20200812-5bf4721e.pth",
|
||||
"mobilenet_v2": "https://download.openmmlab.com/mmclassification/v0/mobilenet_v2/mobilenet_v2_batch256_imagenet_20200708-3b2dc3af.pth",
|
||||
"mobilenet_v3_small": "https://download.openmmlab.com/mmclassification/v0/mobilenet_v3/convert/mobilenet_v3_small-8427ecf0.pth",
|
||||
"mobilenet_v3_large": "https://download.openmmlab.com/mmclassification/v0/mobilenet_v3/convert/mobilenet_v3_large-3ea3c186.pth",
|
||||
"repvgg_A0": "https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-A0_3rdparty_4xb64-coslr-120e_in1k_20210909-883ab98c.pth",
|
||||
"repvgg_A1": "https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-A1_3rdparty_4xb64-coslr-120e_in1k_20210909-24003a24.pth",
|
||||
"repvgg_A2": "https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-A2_3rdparty_4xb64-coslr-120e_in1k_20210909-97d7695a.pth",
|
||||
"repvgg_B0": "https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B0_3rdparty_4xb64-coslr-120e_in1k_20210909-446375f4.pth",
|
||||
"repvgg_B1": "https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B1_3rdparty_4xb64-coslr-120e_in1k_20210909-750cdf67.pth",
|
||||
"repvgg_B1g2": "https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B1g2_3rdparty_4xb64-coslr-120e_in1k_20210909-344f6422.pth",
|
||||
"repvgg_B1g4": "https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B1g4_3rdparty_4xb64-coslr-120e_in1k_20210909-d4c1a642.pth",
|
||||
"repvgg_B2": "https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B2_3rdparty_4xb64-coslr-120e_in1k_20210909-bd6b937c.pth",
|
||||
"repvgg_B2g4": "https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B2g4_3rdparty_4xb64-autoaug-lbs-mixup-coslr-200e_in1k_20210909-7b7955f0.pth",
|
||||
"repvgg_B3": "https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B3_3rdparty_4xb64-autoaug-lbs-mixup-coslr-200e_in1k_20210909-dda968bf.pth",
|
||||
"repvgg_B3g4": "https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B3g4_3rdparty_4xb64-autoaug-lbs-mixup-coslr-200e_in1k_20210909-4e54846a.pth",
|
||||
"repvgg_D2se": "https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-D2se_3rdparty_4xb64-autoaug-lbs-mixup-coslr-200e_in1k_20210909-cf3139b7.pth",
|
||||
"res2net101_w26": "https://download.openmmlab.com/mmclassification/v0/res2net/res2net101-w26-s4_3rdparty_8xb32_in1k_20210927-870b6c36.pth",
|
||||
"res2net50_w14": "https://download.openmmlab.com/mmclassification/v0/res2net/res2net50-w14-s8_3rdparty_8xb32_in1k_20210927-bc967bf1.pth",
|
||||
"res2net50_w26": "https://download.openmmlab.com/mmclassification/v0/res2net/res2net50-w26-s8_3rdparty_8xb32_in1k_20210927-f547a94b.pth",
|
||||
"swin_tiny": "https://download.openmmlab.com/mmclassification/v0/swin-transformer/swin_tiny_224_b16x64_300e_imagenet_20210616_090925-66df6be6.pth",
|
||||
"swin_small": "https://download.openmmlab.com/mmclassification/v0/swin-transformer/swin_small_224_b16x64_300e_imagenet_20210615_110219-7f9d988b.pth",
|
||||
"swin_base": "https://download.openmmlab.com/mmclassification/v0/swin-transformer/convert/swin_base_patch4_window7_224_22kto1k-f967f799.pth",
|
||||
"swin_large": "https://download.openmmlab.com/mmclassification/v0/swin-transformer/convert/swin_large_patch4_window7_224_22kto1k-5f0996db.pth",
|
||||
"t2t_vit_t_14": "https://download.openmmlab.com/mmclassification/v0/t2t-vit/t2t-vit-t-14_3rdparty_8xb64_in1k_20210928-b7c09b62.pth",
|
||||
"t2t_vit_t_19": "https://download.openmmlab.com/mmclassification/v0/t2t-vit/t2t-vit-t-19_3rdparty_8xb64_in1k_20210928-7f1478d5.pth",
|
||||
"t2t_vit_t_24": "https://download.openmmlab.com/mmclassification/v0/t2t-vit/t2t-vit-t-24_3rdparty_8xb64_in1k_20210928-fe95a61b.pth",
|
||||
"tnt_small": "https://download.openmmlab.com/mmclassification/v0/tnt/tnt-small-p16_3rdparty_in1k_20210903-c56ee7df.pth",
|
||||
"vit_base_p16": "https://download.openmmlab.com/mmclassification/v0/vit/finetune/vit-base-p16_in21k-pre-3rdparty_ft-64xb64_in1k-384_20210928-98e8652b.pth",
|
||||
"vit_base_p32": "https://download.openmmlab.com/mmclassification/v0/vit/finetune/vit-base-p32_in21k-pre-3rdparty_ft-64xb64_in1k-384_20210928-9cea8599.pth",
|
||||
"vit_large_p16": "https://download.openmmlab.com/mmclassification/v0/vit/finetune/vit-large-p16_in21k-pre-3rdparty_ft-64xb64_in1k-384_20210928-b20ba619.pth"
|
||||
}
|
||||
|
||||
torchvision_012_model_urls = {
|
||||
"alexnet": "https://download.pytorch.org/models/alexnet-owt-7be5be79.pth",
|
||||
"densenet121": "https://download.pytorch.org/models/densenet121-a639ec97.pth",
|
||||
"densenet169": "https://download.pytorch.org/models/densenet169-b2777c0a.pth",
|
||||
"densenet201": "https://download.pytorch.org/models/densenet201-c1103571.pth",
|
||||
"densenet161": "https://download.pytorch.org/models/densenet161-8d451a50.pth",
|
||||
"efficientnet_b0": "https://download.pytorch.org/models/efficientnet_b0_rwightman-3dd342df.pth",
|
||||
"efficientnet_b1": "https://download.pytorch.org/models/efficientnet_b1_rwightman-533bc792.pth",
|
||||
"efficientnet_b2": "https://download.pytorch.org/models/efficientnet_b2_rwightman-bcdf34b7.pth",
|
||||
"efficientnet_b3": "https://download.pytorch.org/models/efficientnet_b3_rwightman-cf984f9c.pth",
|
||||
"efficientnet_b4": "https://download.pytorch.org/models/efficientnet_b4_rwightman-7eb33cd5.pth",
|
||||
"efficientnet_b5": "https://download.pytorch.org/models/efficientnet_b5_lukemelas-b6417697.pth",
|
||||
"efficientnet_b6": "https://download.pytorch.org/models/efficientnet_b6_lukemelas-c76e70fd.pth",
|
||||
"efficientnet_b7": "https://download.pytorch.org/models/efficientnet_b7_lukemelas-dcc49843.pth",
|
||||
"googlenet": "https://download.pytorch.org/models/googlenet-1378be20.pth",
|
||||
"inception_v3_google": "https://download.pytorch.org/models/inception_v3_google-0cc3c7bd.pth",
|
||||
"mobilenet_v2": "https://download.pytorch.org/models/mobilenet_v2-b0353104.pth",
|
||||
"mobilenet_v3_large": "https://download.pytorch.org/models/mobilenet_v3_large-8738ca79.pth",
|
||||
"mobilenet_v3_small": "https://download.pytorch.org/models/mobilenet_v3_small-047dcff4.pth",
|
||||
"regnet_y_400mf": "https://download.pytorch.org/models/regnet_y_400mf-c65dace8.pth",
|
||||
"regnet_y_800mf": "https://download.pytorch.org/models/regnet_y_800mf-1b27b58c.pth",
|
||||
"regnet_y_1_6gf": "https://download.pytorch.org/models/regnet_y_1_6gf-b11a554e.pth",
|
||||
"regnet_y_3_2gf": "https://download.pytorch.org/models/regnet_y_3_2gf-b5a9779c.pth",
|
||||
"regnet_y_8gf": "https://download.pytorch.org/models/regnet_y_8gf-d0d0e4a8.pth",
|
||||
"regnet_y_16gf": "https://download.pytorch.org/models/regnet_y_16gf-9e6ed7dd.pth",
|
||||
"regnet_y_32gf": "https://download.pytorch.org/models/regnet_y_32gf-4dee3f7a.pth",
|
||||
"regnet_x_400mf": "https://download.pytorch.org/models/regnet_x_400mf-adf1edd5.pth",
|
||||
"regnet_x_800mf": "https://download.pytorch.org/models/regnet_x_800mf-ad17e45c.pth",
|
||||
"regnet_x_1_6gf": "https://download.pytorch.org/models/regnet_x_1_6gf-e3633e7f.pth",
|
||||
"regnet_x_3_2gf": "https://download.pytorch.org/models/regnet_x_3_2gf-f342aeae.pth",
|
||||
"regnet_x_8gf": "https://download.pytorch.org/models/regnet_x_8gf-03ceed89.pth",
|
||||
"regnet_x_16gf": "https://download.pytorch.org/models/regnet_x_16gf-2007eb11.pth",
|
||||
"regnet_x_32gf": "https://download.pytorch.org/models/regnet_x_32gf-9d47f8d0.pth",
|
||||
"resnet18": "https://download.pytorch.org/models/resnet18-f37072fd.pth",
|
||||
"resnet34": "https://download.pytorch.org/models/resnet34-b627a593.pth",
|
||||
"resnet50": "https://download.pytorch.org/models/resnet50-0676ba61.pth",
|
||||
"resnet101": "https://download.pytorch.org/models/resnet101-63fe2227.pth",
|
||||
"resnet152": "https://download.pytorch.org/models/resnet152-394f9c45.pth",
|
||||
"resnext50_32x4d": "https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth",
|
||||
"resnext101_32x8d": "https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth",
|
||||
"wide_resnet50_2": "https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth",
|
||||
"wide_resnet101_2": "https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth",
|
||||
"shufflenetv2_x0.5": "https://download.pytorch.org/models/shufflenetv2_x0.5-f707e7126e.pth",
|
||||
"shufflenetv2_x1.0": "https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth",
|
||||
"shufflenetv2_x1.5": None,
|
||||
"shufflenetv2_x2.0": None,
|
||||
"squeezenet1_0": "https://download.pytorch.org/models/squeezenet1_0-b66bff10.pth",
|
||||
"squeezenet1_1": "https://download.pytorch.org/models/squeezenet1_1-b8a52dc0.pth",
|
||||
"vgg11": "https://download.pytorch.org/models/vgg11-8a719046.pth",
|
||||
"vgg13": "https://download.pytorch.org/models/vgg13-19584684.pth",
|
||||
"vgg16": "https://download.pytorch.org/models/vgg16-397923af.pth",
|
||||
"vgg19": "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth",
|
||||
"vgg11_bn": "https://download.pytorch.org/models/vgg11_bn-6002323d.pth",
|
||||
"vgg13_bn": "https://download.pytorch.org/models/vgg13_bn-abd245e5.pth",
|
||||
"vgg16_bn": "https://download.pytorch.org/models/vgg16_bn-6c64b313.pth",
|
||||
"vgg19_bn": "https://download.pytorch.org/models/vgg19_bn-c79401a0.pth"
|
||||
}
|
||||
|
||||
def calculate_file_hash(file_path, hash_algorithm='md5'):
|
||||
"""计算文件的哈希值,默认使用 MD5"""
|
||||
hash_func = hashlib.new(hash_algorithm)
|
||||
with open(file_path, 'rb') as f:
|
||||
while chunk := f.read(8192):
|
||||
hash_func.update(chunk)
|
||||
return hash_func.hexdigest()
|
||||
|
||||
def download_file(url, output_path):
|
||||
"""下载并保存文件,显示下载进度条"""
|
||||
response = requests.get(url, stream=True)
|
||||
|
||||
if response.status_code == 200:
|
||||
# 获取文件的总大小,以便确定进度条的总长度
|
||||
total_size = int(response.headers.get('Content-Length', 0))
|
||||
|
||||
# 初始化 tqdm 进度条
|
||||
with tqdm(total=total_size, unit='B', unit_scale=True, desc=output_path, ncols=100) as pbar:
|
||||
with open(output_path, 'wb') as f:
|
||||
for chunk in response.iter_content(chunk_size=1024):
|
||||
if chunk:
|
||||
f.write(chunk)
|
||||
pbar.update(len(chunk)) # 更新进度条
|
||||
|
||||
print(f"Downloaded {output_path}")
|
||||
else:
|
||||
print(f"Failed to download {url}")
|
||||
|
||||
def file_exists_and_same(url, output_path):
|
||||
"""检查文件是否已存在并且相同"""
|
||||
if not os.path.exists(output_path):
|
||||
return False
|
||||
|
||||
# 计算远程文件的大小
|
||||
response = requests.head(url)
|
||||
remote_file_size = int(response.headers.get('Content-Length', 0))
|
||||
|
||||
# 获取本地文件的大小
|
||||
local_file_size = os.path.getsize(output_path)
|
||||
|
||||
# 如果文件大小不同,返回 False
|
||||
if local_file_size != remote_file_size:
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
# # 如果大小相同,比较文件哈希值
|
||||
# remote_hash = requests.get(url + ".md5").text.strip() if url.endswith(".pth") else None
|
||||
# local_hash = calculate_file_hash(output_path) if remote_hash else None
|
||||
# print(remote_hash, local_hash)
|
||||
# # 返回 True 如果哈希值相同
|
||||
# return local_hash == remote_hash if remote_hash else False
|
||||
|
||||
def download_all_models(model_urls, output_dir):
|
||||
"""遍历JSON并下载文件"""
|
||||
for model_name, url in model_urls.items():
|
||||
# 如果url为空则继续
|
||||
if url == None:
|
||||
print(" ", end='')
|
||||
print(f"\033[91m{model_name}后URL为空,跳过下载!\033[0m")
|
||||
continue
|
||||
# 创建保存路径
|
||||
file_name = f"{model_name}.pth" # 将文件名按 key 进行命名
|
||||
output_path = os.path.join(output_dir, file_name)
|
||||
|
||||
# 获取 output_path 中的文件夹路径,并确保该路径存在
|
||||
output_folder = os.path.dirname(output_path)
|
||||
os.makedirs(output_folder, exist_ok=True) # 如果文件夹不存在则创建
|
||||
|
||||
# 检查文件是否已存在并且相同
|
||||
if file_exists_and_same(url, output_path):
|
||||
print(" ", end='')
|
||||
print(f"\033[93mFile {output_path} already exists and the size is same, skipping download.\033[0m")
|
||||
else:
|
||||
print(" ", end='')
|
||||
print(f"正在下载 {model_name} : {output_path} 中...")
|
||||
# 下载文件
|
||||
download_file(url, output_path)
|
||||
|
||||
if __name__ == '__main__':
|
||||
### 1.下载openmmlab数据集 ###
|
||||
# 创建存储下载内容的文件夹
|
||||
output_open_mmlab_dir = './My_Local_Model/open_mmlab'
|
||||
os.makedirs(output_open_mmlab_dir, exist_ok=True)
|
||||
# 执行下载
|
||||
print(f"\033[32m下载open_mmlab数据中...\033[0m")
|
||||
download_all_models(open_mmlab_model_urls, output_open_mmlab_dir)
|
||||
|
||||
# ### 2.下载deprecated数据集 ###
|
||||
# # 创建存储下载内容的文件夹
|
||||
# output_deprecated_dir = './My_Local_Model/deprecated'
|
||||
# os.makedirs(output_deprecated_dir, exist_ok=True)
|
||||
# # 执行下载
|
||||
# download_all_models(deprecated_model_urls, output_deprecated_dir)
|
||||
|
||||
### 3.下载mmcls数据集 ###
|
||||
# 创建存储下载内容的文件夹
|
||||
output_mmcls_dir = './My_Local_Model/mmcls'
|
||||
os.makedirs(output_mmcls_dir, exist_ok=True)
|
||||
# 执行下载
|
||||
download_all_models(mmcls_model_urls, output_mmcls_dir)
|
||||
|
||||
### 4.下载torchvision_012数据集 ###
|
||||
# 创建存储下载内容的文件夹
|
||||
output_torchvision_012_dir = './My_Local_Model/torchvision_012'
|
||||
os.makedirs(output_torchvision_012_dir, exist_ok=True)
|
||||
# 执行下载
|
||||
download_all_models(torchvision_012_model_urls, output_torchvision_012_dir)
|
||||
@@ -0,0 +1,628 @@
|
||||
{
|
||||
"publicdataset_cholecseg8k": {
|
||||
"train_imgs_num": 6464,
|
||||
"classes": [
|
||||
"背景",
|
||||
"1",
|
||||
"2",
|
||||
"3",
|
||||
"4",
|
||||
"5",
|
||||
"6",
|
||||
"7",
|
||||
"8",
|
||||
"9",
|
||||
"10",
|
||||
"11",
|
||||
"12"
|
||||
],
|
||||
"palette": [
|
||||
[
|
||||
0,
|
||||
0,
|
||||
0
|
||||
],
|
||||
[
|
||||
255,
|
||||
91,
|
||||
0
|
||||
],
|
||||
[
|
||||
255,
|
||||
234,
|
||||
0
|
||||
],
|
||||
[
|
||||
85,
|
||||
111,
|
||||
181
|
||||
],
|
||||
[
|
||||
181,
|
||||
227,
|
||||
14
|
||||
],
|
||||
[
|
||||
72,
|
||||
0,
|
||||
255
|
||||
],
|
||||
[
|
||||
0,
|
||||
155,
|
||||
33
|
||||
],
|
||||
[
|
||||
255,
|
||||
0,
|
||||
255
|
||||
],
|
||||
[
|
||||
29,
|
||||
32,
|
||||
136
|
||||
],
|
||||
[
|
||||
160,
|
||||
15,
|
||||
95
|
||||
],
|
||||
[
|
||||
0,
|
||||
160,
|
||||
233
|
||||
],
|
||||
[
|
||||
52,
|
||||
184,
|
||||
178
|
||||
],
|
||||
[
|
||||
90,
|
||||
120,
|
||||
41
|
||||
]
|
||||
],
|
||||
"palette_num": 13,
|
||||
"mean": [
|
||||
85.65740418979115,
|
||||
53.99282220050495,
|
||||
46.074045888534535
|
||||
],
|
||||
"std": [
|
||||
72.24589167201978,
|
||||
56.76979155397199,
|
||||
49.056637115061775
|
||||
],
|
||||
"imgs_num": 6464
|
||||
},
|
||||
"my_dataset_model": {
|
||||
"train_imgs_num": 631,
|
||||
"classes": [
|
||||
"背景",
|
||||
"肝脏",
|
||||
"胆囊",
|
||||
"分离钳",
|
||||
"止血海绵",
|
||||
"肝总管",
|
||||
"胆总管",
|
||||
"吸引器",
|
||||
"剪刀",
|
||||
"止血纱布",
|
||||
"生物夹",
|
||||
"无损伤钳",
|
||||
"喷洒",
|
||||
"胆囊管",
|
||||
"胆囊动脉",
|
||||
"电凝",
|
||||
"标本袋",
|
||||
"引流管",
|
||||
"纱布",
|
||||
"金属钛夹",
|
||||
"术中超声",
|
||||
"吻合器",
|
||||
"乳胶管",
|
||||
"推结器",
|
||||
"肝带",
|
||||
"钳夹",
|
||||
"超声刀",
|
||||
"脂肪",
|
||||
"双极电凝",
|
||||
"棉球",
|
||||
"血管阻断夹",
|
||||
"肿瘤",
|
||||
"针",
|
||||
"线",
|
||||
"韧带",
|
||||
"胆囊静脉"
|
||||
],
|
||||
"palette": [
|
||||
[
|
||||
0,
|
||||
0,
|
||||
0
|
||||
],
|
||||
[
|
||||
255,
|
||||
91,
|
||||
0
|
||||
],
|
||||
[
|
||||
255,
|
||||
234,
|
||||
0
|
||||
],
|
||||
[
|
||||
85,
|
||||
111,
|
||||
181
|
||||
],
|
||||
[
|
||||
181,
|
||||
227,
|
||||
14
|
||||
],
|
||||
[
|
||||
72,
|
||||
0,
|
||||
255
|
||||
],
|
||||
[
|
||||
0,
|
||||
155,
|
||||
33
|
||||
],
|
||||
[
|
||||
255,
|
||||
0,
|
||||
255
|
||||
],
|
||||
[
|
||||
29,
|
||||
32,
|
||||
136
|
||||
],
|
||||
[
|
||||
160,
|
||||
15,
|
||||
95
|
||||
],
|
||||
[
|
||||
0,
|
||||
160,
|
||||
233
|
||||
],
|
||||
[
|
||||
52,
|
||||
184,
|
||||
178
|
||||
],
|
||||
[
|
||||
90,
|
||||
120,
|
||||
41
|
||||
],
|
||||
[
|
||||
255,
|
||||
0,
|
||||
0
|
||||
],
|
||||
[
|
||||
177,
|
||||
0,
|
||||
0
|
||||
],
|
||||
[
|
||||
167,
|
||||
24,
|
||||
233
|
||||
],
|
||||
[
|
||||
112,
|
||||
113,
|
||||
150
|
||||
],
|
||||
[
|
||||
0,
|
||||
255,
|
||||
0
|
||||
],
|
||||
[
|
||||
255,
|
||||
255,
|
||||
255
|
||||
],
|
||||
[
|
||||
0,
|
||||
255,
|
||||
255
|
||||
],
|
||||
[
|
||||
138,
|
||||
251,
|
||||
213
|
||||
],
|
||||
[
|
||||
136,
|
||||
162,
|
||||
196
|
||||
],
|
||||
[
|
||||
197,
|
||||
83,
|
||||
181
|
||||
],
|
||||
[
|
||||
202,
|
||||
202,
|
||||
200
|
||||
],
|
||||
[
|
||||
113,
|
||||
102,
|
||||
140
|
||||
],
|
||||
[
|
||||
66,
|
||||
115,
|
||||
82
|
||||
],
|
||||
[
|
||||
240,
|
||||
16,
|
||||
116
|
||||
],
|
||||
[
|
||||
155,
|
||||
132,
|
||||
0
|
||||
],
|
||||
[
|
||||
155,
|
||||
62,
|
||||
0
|
||||
],
|
||||
[
|
||||
146,
|
||||
175,
|
||||
236
|
||||
],
|
||||
[
|
||||
255,
|
||||
172,
|
||||
159
|
||||
],
|
||||
[
|
||||
245,
|
||||
161,
|
||||
0
|
||||
],
|
||||
[
|
||||
134,
|
||||
124,
|
||||
118
|
||||
],
|
||||
[
|
||||
0,
|
||||
157,
|
||||
142
|
||||
],
|
||||
[
|
||||
181,
|
||||
85,
|
||||
105
|
||||
],
|
||||
[
|
||||
42,
|
||||
8,
|
||||
66
|
||||
]
|
||||
],
|
||||
"palette_num": 36,
|
||||
"mean": [
|
||||
94.94709810464319,
|
||||
61.729422339499315,
|
||||
75.93763705236911
|
||||
],
|
||||
"std": [
|
||||
44.00550608113231,
|
||||
42.695956669847746,
|
||||
44.99354156225513
|
||||
],
|
||||
"imgs_num": 2000
|
||||
},
|
||||
"publicdataset_autolaparo": {
|
||||
"train_imgs_num": 1440,
|
||||
"classes": [
|
||||
"背景",
|
||||
"1",
|
||||
"2",
|
||||
"3",
|
||||
"4",
|
||||
"5",
|
||||
"6",
|
||||
"7",
|
||||
"8",
|
||||
"9"
|
||||
],
|
||||
"palette": [
|
||||
[
|
||||
0,
|
||||
0,
|
||||
0
|
||||
],
|
||||
[
|
||||
255,
|
||||
91,
|
||||
0
|
||||
],
|
||||
[
|
||||
255,
|
||||
234,
|
||||
0
|
||||
],
|
||||
[
|
||||
85,
|
||||
111,
|
||||
181
|
||||
],
|
||||
[
|
||||
181,
|
||||
227,
|
||||
14
|
||||
],
|
||||
[
|
||||
72,
|
||||
0,
|
||||
255
|
||||
],
|
||||
[
|
||||
0,
|
||||
155,
|
||||
33
|
||||
],
|
||||
[
|
||||
255,
|
||||
0,
|
||||
255
|
||||
],
|
||||
[
|
||||
29,
|
||||
32,
|
||||
136
|
||||
],
|
||||
[
|
||||
160,
|
||||
15,
|
||||
95
|
||||
]
|
||||
],
|
||||
"palette_num": 10,
|
||||
"mean": [
|
||||
123.62464353460942,
|
||||
85.34836259209033,
|
||||
82.31539425671558
|
||||
],
|
||||
"std": [
|
||||
47.172211618459315,
|
||||
47.08256715323592,
|
||||
48.135121265163605
|
||||
]
|
||||
},
|
||||
"publicdataset_endovis_2017": {
|
||||
"train_imgs_num": 1800,
|
||||
"classes": [
|
||||
"背景",
|
||||
"1",
|
||||
"2",
|
||||
"3",
|
||||
"4",
|
||||
"5",
|
||||
"6",
|
||||
"7"
|
||||
],
|
||||
"palette": [
|
||||
[
|
||||
0,
|
||||
0,
|
||||
0
|
||||
],
|
||||
[
|
||||
255,
|
||||
91,
|
||||
0
|
||||
],
|
||||
[
|
||||
255,
|
||||
234,
|
||||
0
|
||||
],
|
||||
[
|
||||
85,
|
||||
111,
|
||||
181
|
||||
],
|
||||
[
|
||||
181,
|
||||
227,
|
||||
14
|
||||
],
|
||||
[
|
||||
72,
|
||||
0,
|
||||
255
|
||||
],
|
||||
[
|
||||
0,
|
||||
155,
|
||||
33
|
||||
],
|
||||
[
|
||||
255,
|
||||
0,
|
||||
255
|
||||
]
|
||||
],
|
||||
"palette_num": 8,
|
||||
"mean": [
|
||||
122.21429912990676,
|
||||
77.0821859677977,
|
||||
87.03836664626716
|
||||
],
|
||||
"std": [
|
||||
50.53335800365262,
|
||||
42.895340354037465,
|
||||
47.739426483390446
|
||||
]
|
||||
},
|
||||
"publicdataset_dresden": {
|
||||
"train_imgs_num": 17363,
|
||||
"classes": [
|
||||
"背景",
|
||||
"1",
|
||||
"2",
|
||||
"3",
|
||||
"4",
|
||||
"5",
|
||||
"6",
|
||||
"7",
|
||||
"8",
|
||||
"9",
|
||||
"10"
|
||||
],
|
||||
"palette": [
|
||||
[
|
||||
0,
|
||||
0,
|
||||
0
|
||||
],
|
||||
[
|
||||
255,
|
||||
91,
|
||||
0
|
||||
],
|
||||
[
|
||||
255,
|
||||
234,
|
||||
0
|
||||
],
|
||||
[
|
||||
85,
|
||||
111,
|
||||
181
|
||||
],
|
||||
[
|
||||
181,
|
||||
227,
|
||||
14
|
||||
],
|
||||
[
|
||||
72,
|
||||
0,
|
||||
255
|
||||
],
|
||||
[
|
||||
0,
|
||||
155,
|
||||
33
|
||||
],
|
||||
[
|
||||
255,
|
||||
0,
|
||||
255
|
||||
],
|
||||
[
|
||||
29,
|
||||
32,
|
||||
136
|
||||
],
|
||||
[
|
||||
160,
|
||||
15,
|
||||
95
|
||||
],
|
||||
[
|
||||
0,
|
||||
160,
|
||||
233
|
||||
]
|
||||
],
|
||||
"palette_num": 11,
|
||||
"mean": [
|
||||
103.172638338208,
|
||||
61.44762740851152,
|
||||
51.407770213021976
|
||||
],
|
||||
"std": [
|
||||
75.77031253622098,
|
||||
54.63616729031377,
|
||||
49.45572239497569
|
||||
]
|
||||
},
|
||||
"publicdataset_endovis_2018": {
|
||||
"train_imgs_num": 1800,
|
||||
"classes": [
|
||||
"背景",
|
||||
"1",
|
||||
"2",
|
||||
"3",
|
||||
"4",
|
||||
"5",
|
||||
"6",
|
||||
"7"
|
||||
],
|
||||
"palette": [
|
||||
[
|
||||
0,
|
||||
0,
|
||||
0
|
||||
],
|
||||
[
|
||||
255,
|
||||
91,
|
||||
0
|
||||
],
|
||||
[
|
||||
255,
|
||||
234,
|
||||
0
|
||||
],
|
||||
[
|
||||
85,
|
||||
111,
|
||||
181
|
||||
],
|
||||
[
|
||||
181,
|
||||
227,
|
||||
14
|
||||
],
|
||||
[
|
||||
72,
|
||||
0,
|
||||
255
|
||||
],
|
||||
[
|
||||
0,
|
||||
155,
|
||||
33
|
||||
],
|
||||
[
|
||||
255,
|
||||
0,
|
||||
255
|
||||
]
|
||||
],
|
||||
"palette_num": 8,
|
||||
"mean": [
|
||||
122.21429912990676,
|
||||
77.0821859677977,
|
||||
87.03836664626716
|
||||
],
|
||||
"std": [
|
||||
50.53335800365262,
|
||||
42.895340354037465,
|
||||
47.739426483390446
|
||||
]
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,50 @@
|
||||
{
|
||||
"____一、comment_dataset_info": "定义多个数据集文件名和对应类名",
|
||||
"dataset_info": {
|
||||
"dataset_file_name": "my_dataset_model",
|
||||
"dataset_class_name": "MyDataset_model",
|
||||
"data_root": "/home/wkmgc/Desktop/Seg/Seg_All_In_One_MMSeg/My_Data",
|
||||
"img_scale_width": 1920,
|
||||
"img_scale_height": 1080,
|
||||
"____#####comment_paths#####": "训练、验证、测试集所在文件夹",
|
||||
"paths": {
|
||||
"train_img_path": "A_Ori",
|
||||
"train_seg_map_path": "A_Label_GT_label_fold",
|
||||
"val_img_path": "A_Ori",
|
||||
"val_seg_map_path": "A_Label_GT_label_fold",
|
||||
"test_img_path": "A_Ori",
|
||||
"test_seg_map_path": "A_Label_GT_label_fold"
|
||||
}
|
||||
},
|
||||
"____二、comment_label_info": "定义Label图片相关参数",
|
||||
"label_info": {
|
||||
"classes": [
|
||||
"背景", "肝脏", "胆囊", "分离钳", "止血海绵", "肝总管", "胆总管", "吸引器", "剪刀", "止血纱布", "生物夹", "无损伤钳", "喷洒",
|
||||
"胆囊管", "胆囊动脉", "电凝", "标本袋", "引流管", "纱布", "金属钛夹", "术中超声", "吻合器", "乳胶管", "推结器",
|
||||
"肝带", "钳夹", "超声刀", "脂肪", "双极电凝", "棉球", "血管阻断夹", "肿瘤", "针", "线", "韧带", "胆囊静脉"
|
||||
],
|
||||
"palette": [
|
||||
[0, 0, 0], [255, 91, 0], [255, 234, 0], [85, 111, 181], [181, 227, 14], [72, 0, 255], [0, 155, 33], [255, 0, 255],
|
||||
[29, 32, 136], [160, 15, 95], [0, 160, 233], [52, 184, 178], [90, 120, 41], [255, 0, 0], [177, 0, 0],
|
||||
[167, 24, 233], [112, 113, 150], [0, 255, 0], [255, 255, 255], [0, 255, 255], [138, 251, 213], [136, 162, 196],
|
||||
[197, 83, 181], [202, 202, 200], [113, 102, 140], [66, 115, 82], [240, 16, 116], [155, 132, 0], [155, 62, 0],
|
||||
[146, 175, 236], [255, 172, 159], [245, 161, 0], [134, 124, 118], [0, 157, 142], [181, 85, 105], [42, 8, 66]
|
||||
],
|
||||
"____#####comment#####": "一般不太会变的参数",
|
||||
"img_suffix": ".png",
|
||||
"seg_map_suffix": "_gtFine_labelTrainIds.png",
|
||||
"____#####comment_reduce_zero_label_1#####": "在第 0 类为无意义黑边时,使用reduce_zero_label = True将其和待分类内容分开;",
|
||||
"____#####comment_reduce_zero_label_2#####": "在第 0 类为 background 类别的数据集上,如果您最终是需要将背景和您的其余类别分开时,是不需要使用reduce_zero_label的 【reduce_zero_label = False】",
|
||||
"reduce_zero_label": false
|
||||
},
|
||||
"____三、comment_training_info": "定义训练相关参数",
|
||||
"training_info": {
|
||||
"crop_size_width": 256,
|
||||
"crop_size_height": 256,
|
||||
"train_batch_size": 16,
|
||||
"train_num_workers": 4,
|
||||
"val_and_test_batch_size": 1,
|
||||
"val_and_test_num_workers": 4
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,42 @@
|
||||
{
|
||||
"____一、comment_dataset_info": "定义多个数据集文件名和对应类名",
|
||||
"dataset_info": {
|
||||
"dataset_file_name": "publicdataset_autolaparo",
|
||||
"dataset_class_name": "PublicDataSet_AutoLaparo",
|
||||
"data_root": "/home/wkmgc/Desktop/Seg/DataSet_Public/2_AutoLaparo-10Type-1920x1080",
|
||||
"img_scale_width": 1920,
|
||||
"img_scale_height": 1080,
|
||||
"____#####comment_paths#####": "训练、验证、测试集所在文件夹",
|
||||
"paths": {
|
||||
"train_img_path": "images/train",
|
||||
"train_seg_map_path": "labels_GT/train",
|
||||
"val_img_path": "images/val",
|
||||
"val_seg_map_path": "labels_GT/val",
|
||||
"test_img_path": "images/val",
|
||||
"test_seg_map_path": "labels_GT/val"
|
||||
}
|
||||
},
|
||||
"____二、comment_label_info": "定义Label图片相关参数",
|
||||
"label_info": {
|
||||
"classes": ["背景", "1", "2", "3", "4", "5", "6", "7", "8", "9"],
|
||||
"palette": [
|
||||
[0, 0, 0], [255, 91, 0], [255, 234, 0], [85, 111, 181], [181, 227, 14], [72, 0, 255], [0, 155, 33], [255, 0, 255], [29, 32, 136], [160, 15, 95]
|
||||
],
|
||||
"____#####comment#####": "一般不太会变的参数",
|
||||
"img_suffix": ".png",
|
||||
"seg_map_suffix": ".png",
|
||||
"____#####comment_reduce_zero_label_1#####": "在第 0 类为无意义黑边时,使用reduce_zero_label = True将其和待分类内容分开;",
|
||||
"____#####comment_reduce_zero_label_2#####": "在第 0 类为 background 类别的数据集上,如果您最终是需要将背景和您的其余类别分开时,是不需要使用reduce_zero_label的 【reduce_zero_label = False】",
|
||||
"reduce_zero_label": false
|
||||
},
|
||||
"____三、comment_training_info": "定义训练相关参数",
|
||||
"training_info": {
|
||||
"crop_size_width": 256,
|
||||
"crop_size_height": 256,
|
||||
"train_batch_size": 16,
|
||||
"train_num_workers": 4,
|
||||
"val_and_test_batch_size": 1,
|
||||
"val_and_test_num_workers": 4
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,42 @@
|
||||
{
|
||||
"____一、comment_dataset_info": "定义多个数据集文件名和对应类名",
|
||||
"dataset_info": {
|
||||
"dataset_file_name": "publicdataset_cholecseg8k",
|
||||
"dataset_class_name": "PublicDataSet_CholecSeg8k",
|
||||
"data_root": "/home/wkmgc/Desktop/Seg/DataSet_Public/1_CholecSeg8k-13Type-1920x1080",
|
||||
"img_scale_width": 1920,
|
||||
"img_scale_height": 1080,
|
||||
"____#####comment_paths#####": "训练、验证、测试集所在文件夹",
|
||||
"paths": {
|
||||
"train_img_path": "images/train",
|
||||
"train_seg_map_path": "labels_GT/train",
|
||||
"val_img_path": "images/val",
|
||||
"val_seg_map_path": "labels_GT/val",
|
||||
"test_img_path": "images/val",
|
||||
"test_seg_map_path": "labels_GT/val"
|
||||
}
|
||||
},
|
||||
"____二、comment_label_info": "定义Label图片相关参数",
|
||||
"label_info": {
|
||||
"classes": ["背景", "1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11", "12"],
|
||||
"palette": [
|
||||
[0, 0, 0], [255, 91, 0], [255, 234, 0], [85, 111, 181], [181, 227, 14], [72, 0, 255], [0, 155, 33], [255, 0, 255], [29, 32, 136], [160, 15, 95], [0, 160, 233], [52, 184, 178], [90, 120, 41]
|
||||
],
|
||||
"____#####comment#####": "一般不太会变的参数",
|
||||
"img_suffix": ".png",
|
||||
"seg_map_suffix": ".png",
|
||||
"____#####comment_reduce_zero_label_1#####": "在第 0 类为无意义黑边时,使用reduce_zero_label = True将其和待分类内容分开;",
|
||||
"____#####comment_reduce_zero_label_2#####": "在第 0 类为 background 类别的数据集上,如果您最终是需要将背景和您的其余类别分开时,是不需要使用reduce_zero_label的 【reduce_zero_label = False】",
|
||||
"reduce_zero_label": false
|
||||
},
|
||||
"____三、comment_training_info": "定义训练相关参数",
|
||||
"training_info": {
|
||||
"crop_size_width": 256,
|
||||
"crop_size_height": 256,
|
||||
"train_batch_size": 16,
|
||||
"train_num_workers": 4,
|
||||
"val_and_test_batch_size": 1,
|
||||
"val_and_test_num_workers": 4
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,42 @@
|
||||
{
|
||||
"____一、comment_dataset_info": "定义多个数据集文件名和对应类名",
|
||||
"dataset_info": {
|
||||
"dataset_file_name": "publicdataset_dresden",
|
||||
"dataset_class_name": "PublicDataSet_Dresden",
|
||||
"data_root": "/home/wkmgc/Desktop/Seg/DataSet_Public/4_Dresden-11Type-512x512",
|
||||
"img_scale_width": 512,
|
||||
"img_scale_height": 512,
|
||||
"____#####comment_paths#####": "训练、验证、测试集所在文件夹",
|
||||
"paths": {
|
||||
"train_img_path": "images/train",
|
||||
"train_seg_map_path": "labels_GT/train",
|
||||
"val_img_path": "images/val",
|
||||
"val_seg_map_path": "labels_GT/val",
|
||||
"test_img_path": "images/test",
|
||||
"test_seg_map_path": "labels_GT/test"
|
||||
}
|
||||
},
|
||||
"____二、comment_label_info": "定义Label图片相关参数",
|
||||
"label_info": {
|
||||
"classes": ["背景", "1", "2", "3", "4", "5", "6", "7", "8", "9", "10"],
|
||||
"palette": [
|
||||
[0, 0, 0], [255, 91, 0], [255, 234, 0], [85, 111, 181], [181, 227, 14], [72, 0, 255], [0, 155, 33], [255, 0, 255], [29, 32, 136], [160, 15, 95], [0, 160, 233]
|
||||
],
|
||||
"____#####comment#####": "一般不太会变的参数",
|
||||
"img_suffix": ".png",
|
||||
"seg_map_suffix": ".png",
|
||||
"____#####comment_reduce_zero_label_1#####": "在第 0 类为无意义黑边时,使用reduce_zero_label = True将其和待分类内容分开;",
|
||||
"____#####comment_reduce_zero_label_2#####": "在第 0 类为 background 类别的数据集上,如果您最终是需要将背景和您的其余类别分开时,是不需要使用reduce_zero_label的 【reduce_zero_label = False】",
|
||||
"reduce_zero_label": false
|
||||
},
|
||||
"____三、comment_training_info": "定义训练相关参数",
|
||||
"training_info": {
|
||||
"crop_size_width": 256,
|
||||
"crop_size_height": 256,
|
||||
"train_batch_size": 16,
|
||||
"train_num_workers": 4,
|
||||
"val_and_test_batch_size": 1,
|
||||
"val_and_test_num_workers": 4
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,42 @@
|
||||
{
|
||||
"____一、comment_dataset_info": "定义多个数据集文件名和对应类名",
|
||||
"dataset_info": {
|
||||
"dataset_file_name": "publicdataset_endovis_2017",
|
||||
"dataset_class_name": "PublicDataSet_Endovis_2017",
|
||||
"data_root": "/home/wkmgc/Desktop/Seg/DataSet_Public/3_1_Endovis_2017-8Type-512x512",
|
||||
"img_scale_width": 512,
|
||||
"img_scale_height": 512,
|
||||
"____#####comment_paths#####": "训练、验证、测试集所在文件夹",
|
||||
"paths": {
|
||||
"train_img_path": "images/train",
|
||||
"train_seg_map_path": "labels_GT/train",
|
||||
"val_img_path": "images/val",
|
||||
"val_seg_map_path": "labels_GT/val",
|
||||
"test_img_path": "images/val",
|
||||
"test_seg_map_path": "labels_GT/val"
|
||||
}
|
||||
},
|
||||
"____二、comment_label_info": "定义Label图片相关参数",
|
||||
"label_info": {
|
||||
"classes": ["背景", "1", "2", "3", "4", "5", "6", "7"],
|
||||
"palette": [
|
||||
[0, 0, 0], [255, 91, 0], [255, 234, 0], [85, 111, 181], [181, 227, 14], [72, 0, 255], [0, 155, 33], [255, 0, 255]
|
||||
],
|
||||
"____#####comment#####": "一般不太会变的参数",
|
||||
"img_suffix": ".bmp",
|
||||
"seg_map_suffix": ".bmp",
|
||||
"____#####comment_reduce_zero_label_1#####": "在第 0 类为无意义黑边时,使用reduce_zero_label = True将其和待分类内容分开;",
|
||||
"____#####comment_reduce_zero_label_2#####": "在第 0 类为 background 类别的数据集上,如果您最终是需要将背景和您的其余类别分开时,是不需要使用reduce_zero_label的 【reduce_zero_label = False】",
|
||||
"reduce_zero_label": false
|
||||
},
|
||||
"____三、comment_training_info": "定义训练相关参数",
|
||||
"training_info": {
|
||||
"crop_size_width": 256,
|
||||
"crop_size_height": 256,
|
||||
"train_batch_size": 16,
|
||||
"train_num_workers": 4,
|
||||
"val_and_test_batch_size": 1,
|
||||
"val_and_test_num_workers": 4
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,42 @@
|
||||
{
|
||||
"____一、comment_dataset_info": "定义多个数据集文件名和对应类名",
|
||||
"dataset_info": {
|
||||
"dataset_file_name": "publicdataset_endovis_2018",
|
||||
"dataset_class_name": "PublicDataSet_Endovis_2018",
|
||||
"data_root": "/home/wkmgc/Desktop/Seg/DataSet_Public/3_2_Endovis_2018-8Type-512x512",
|
||||
"img_scale_width": 512,
|
||||
"img_scale_height": 512,
|
||||
"____#####comment_paths#####": "训练、验证、测试集所在文件夹",
|
||||
"paths": {
|
||||
"train_img_path": "images/train",
|
||||
"train_seg_map_path": "labels_GT/train",
|
||||
"val_img_path": "images/val",
|
||||
"val_seg_map_path": "labels_GT/val",
|
||||
"test_img_path": "images/val",
|
||||
"test_seg_map_path": "labels_GT/val"
|
||||
}
|
||||
},
|
||||
"____二、comment_label_info": "定义Label图片相关参数",
|
||||
"label_info": {
|
||||
"classes": ["背景", "1", "2", "3", "4", "5", "6", "7"],
|
||||
"palette": [
|
||||
[0, 0, 0], [255, 91, 0], [255, 234, 0], [85, 111, 181], [181, 227, 14], [72, 0, 255], [0, 155, 33], [255, 0, 255]
|
||||
],
|
||||
"____#####comment#####": "一般不太会变的参数",
|
||||
"img_suffix": ".bmp",
|
||||
"seg_map_suffix": ".bmp",
|
||||
"____#####comment_reduce_zero_label_1#####": "在第 0 类为无意义黑边时,使用reduce_zero_label = True将其和待分类内容分开;",
|
||||
"____#####comment_reduce_zero_label_2#####": "在第 0 类为 background 类别的数据集上,如果您最终是需要将背景和您的其余类别分开时,是不需要使用reduce_zero_label的 【reduce_zero_label = False】",
|
||||
"reduce_zero_label": false
|
||||
},
|
||||
"____三、comment_training_info": "定义训练相关参数",
|
||||
"training_info": {
|
||||
"crop_size_width": 256,
|
||||
"crop_size_height": 256,
|
||||
"train_batch_size": 16,
|
||||
"train_num_workers": 4,
|
||||
"val_and_test_batch_size": 1,
|
||||
"val_and_test_num_workers": 4
|
||||
}
|
||||
}
|
||||
|
||||
61
Seg_All_In_One_MMSeg/My_All_In_One/1_Initial_Data_All-ori.py
Normal file
61
Seg_All_In_One_MMSeg/My_All_In_One/1_Initial_Data_All-ori.py
Normal file
@@ -0,0 +1,61 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: UTF-8 -*-
|
||||
import os
|
||||
from Initial_Data_Program.Initial_Data_Gen_configs_base_datasets_my_dataset import generate_configs_base_datasets_my_dataset_file
|
||||
from Initial_Data_Program.Initial_Data_Gen_mmseg_utils_class_names import generate_mmseg_utils_class_names_file
|
||||
from Initial_Data_Program.Initial_Data_Gen_mmseg_datasets_my_dataset import generate_mmseg_datasets_my_dataset_file
|
||||
from Initial_Data_Program.Initial_Data_Gen_mmseg_datasets_init_ import generate_mmseg_datasets_init_file
|
||||
|
||||
if __name__ == '__main__':
|
||||
########### 1.1.定义各数据集相关参数 ###########
|
||||
# 可以定义多个数据集文件名 和 对应类名
|
||||
dataset_file_names = ["my_dataset_model"] # =['my_dataset', 'my_dataset_2'] # =["my_dataset"]
|
||||
dataset_class_names = ["MyDataset_model"] # =['MyDataset', 'MyDataset2'] # =["MyDataset"]
|
||||
dataset_file_name='my_dataset_model' # 数据集 文件名.py TODO
|
||||
dataset_class_name='MyDataset_model' # 数据集 类名称 TODO
|
||||
data_root='/home/audience/Desktop/Seg_data/Data' # 数据根目录
|
||||
img_scale=(1920, 1080) # 图片大小
|
||||
# 训练、验证、测试集所在文件夹
|
||||
train_img_path='A_Ori'
|
||||
train_seg_map_path='A_Label_GT_label_fold'
|
||||
val_img_path='A_Ori'
|
||||
val_seg_map_path='A_Label_GT_label_fold'
|
||||
test_img_path='A_Ori'
|
||||
test_seg_map_path='A_Label_GT_label_fold'
|
||||
|
||||
########### 1.2.定义Label图片相关参数 ###########
|
||||
classes = ['肝脏','胆囊','分离钳','止血海绵','肝总管','胆总管','吸引器','剪刀','止血纱布','生物夹','无损伤钳','喷洒','胆囊管','胆囊动脉','电凝','标本袋','引流管','纱布','金属钛夹','术中超声','吻合器','乳胶管','推结器','肝带','钳夹','超声刀','脂肪','双极电凝','棉球','血管阻断夹','肿瘤','针','线','韧带','胆囊静脉']
|
||||
palette = [[255,91,0],[255,234,0],[85, 111, 181],[181, 227, 14],[72, 0, 255],[0, 155, 33],[255,0,255],[29, 32, 136],[160, 15, 95],[0,160,233],[52,184,178],[90,120,41],[255,0,0],[177,0,0],[167,24,233],[112,113,150],[0,255,0],[255,255,255],[0,255,255],[138,251,213],[136,162,196],[197,83,181],[202,202,200],[113,102,140],[66,115,82],[240,16,116],[155,132,0],[155,62,0],[146,175,236],[255,172,159],[245,161,0],[134,124,118],[0,157,142],[181,85,105],[42,8,66]]
|
||||
# 这里的classes一定是经过“Initial_Gen_mmseg_datasets_my_dataset.py”处理的
|
||||
classes_all = [
|
||||
['背景','肝脏','胆囊','分离钳','止血海绵','肝总管','胆总管','吸引器','剪刀','止血纱布','生物夹','无损伤钳','喷洒','胆囊管','胆囊动脉','电凝','标本袋','引流管','纱布','金属钛夹','术中超声','吻合器','乳胶管','推结器','肝带','钳夹','超声刀','脂肪','双极电凝','棉球','血管阻断夹','肿瘤','针','线','韧带','胆囊静脉'],
|
||||
]
|
||||
palette_all = [
|
||||
[[0,0,0],[255,91,0],[255,234,0],[85, 111, 181],[181, 227, 14],[72, 0, 255],[0, 155, 33],[255,0,255],[29, 32, 136],[160, 15, 95],[0,160,233],[52,184,178],[90,120,41],[255,0,0],[177,0,0],[167,24,233],[112,113,150],[0,255,0],[255,255,255],[0,255,255],[138,251,213],[136,162,196],[197,83,181],[202,202,200],[113,102,140],[66,115,82],[240,16,116],[155,132,0],[155,62,0],[146,175,236],[255,172,159],[245,161,0],[134,124,118],[0,157,142],[181,85,105],[42,8,66]],
|
||||
]
|
||||
# 一般不太会变的参数
|
||||
img_suffix = ".png"
|
||||
seg_map_suffix = "_gtFine_labelTrainIds.png"
|
||||
reduce_zero_label = False # 在第 0 类为 background 类别的数据集上,如果您最终是需要将背景和您的其余类别分开时,是不需要使用reduce_zero_label的 【reduce_zero_label = False】
|
||||
|
||||
########### 1.3.定义训练相关参数 ###########
|
||||
# 一般不太会变的参数
|
||||
crop_size=(512, 512) # 分割大小
|
||||
train_batch_size=4 # 训练batch
|
||||
train_num_workers=4 # 训练并行运行数量
|
||||
val_and_test_batch_size=1 # 验证集和测试集batch
|
||||
val_and_test_num_workers=4 # 验证集和测试集并行运行数量
|
||||
|
||||
########### 2.文件存储位置 ###########
|
||||
output_configs_base_datasets_my_dataset=f'./configs/_base_/datasets/{dataset_file_name}.py'
|
||||
output_mmseg_datasets_dataset_file_name = os.path.join(f'./mmseg/datasets/{dataset_file_name}.py')
|
||||
output_mmseg_datasets_init = os.path.join('./mmseg/datasets/__init__.py')
|
||||
output_mmseg_utils_class_names = f'./mmseg/utils/class_names.py'
|
||||
|
||||
########### 3.运行程序生成配置文件 ###########
|
||||
success = generate_configs_base_datasets_my_dataset_file(output_file=output_configs_base_datasets_my_dataset, dataset_class_name=dataset_class_name , data_root=data_root, img_scale=img_scale, crop_size=crop_size, train_batch_size=train_batch_size, train_num_workers=train_num_workers, val_and_test_batch_size=val_and_test_batch_size, val_and_test_num_workers=val_and_test_num_workers, train_img_path=train_img_path, train_seg_map_path=train_seg_map_path, val_img_path=val_img_path, val_seg_map_path=val_seg_map_path, test_img_path=test_img_path, test_seg_map_path=test_seg_map_path)
|
||||
success, classes, palette = generate_mmseg_datasets_my_dataset_file(output_file=output_mmseg_datasets_dataset_file_name, dataset_class_name=dataset_class_name, classes=classes, palette=palette, img_suffix=img_suffix, seg_map_suffix=seg_map_suffix, reduce_zero_label=reduce_zero_label)
|
||||
# 需要用到上一步的classes和palette
|
||||
success = generate_mmseg_datasets_init_file(output_file=output_mmseg_datasets_init, dataset_file_names=dataset_file_names, dataset_class_names=dataset_class_names)
|
||||
success = generate_mmseg_utils_class_names_file(output_file=output_mmseg_utils_class_names, dataset_file_names=dataset_file_names, classes_all=classes_all, palette_all=palette_all)
|
||||
|
||||
@@ -0,0 +1,209 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: UTF-8 -*-
|
||||
import os, json
|
||||
from Initial_Data_Program.Initial_Data_Calculate_std_and_mean import calculate_pic_std_and_mean
|
||||
from Initial_Data_Program.Initial_Data_Gen_configs_base_datasets_my_dataset import generate_configs_base_datasets_my_dataset_file
|
||||
from Initial_Data_Program.Initial_Data_Gen_mmseg_utils_class_names import generate_mmseg_utils_class_names_file
|
||||
from Initial_Data_Program.Initial_Data_Gen_mmseg_datasets_my_dataset import generate_mmseg_datasets_my_dataset_file
|
||||
from Initial_Data_Program.Initial_Data_Gen_mmseg_datasets_init_ import generate_mmseg_datasets_init_file
|
||||
|
||||
def load_json_files(directory, not_check_list=None):
|
||||
"""
|
||||
读取指定文件夹下的所有 JSON 文件,排除指定的文件。
|
||||
|
||||
:param directory: 要读取的目录
|
||||
:param not_check_list: 要排除的文件名列表,不包含路径。默认为空列表
|
||||
:return: 包含所有有效 JSON 数据的列表,每个元素为一个字典
|
||||
"""
|
||||
if not_check_list is None:
|
||||
not_check_list = [''] # 默认排除文件
|
||||
|
||||
# 获取所有 .json 文件,并排除在 not_check_list 中的文件
|
||||
json_files = [f for f in os.listdir(directory) if f.endswith('.json') and f not in not_check_list]
|
||||
|
||||
data_list = []
|
||||
|
||||
# 遍历每个 JSON 文件
|
||||
for json_file in json_files:
|
||||
json_path = os.path.join(directory, json_file)
|
||||
|
||||
try:
|
||||
# 打开并加载 JSON 文件
|
||||
with open(json_path, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
# 将文件名(去掉 .json)添加到数据中
|
||||
data['file_name_json'] = json_file.rstrip(".json")
|
||||
data_list.append(data)
|
||||
|
||||
except json.JSONDecodeError:
|
||||
print(f"\033[91mError decoding JSON file: {json_path}\033[0m")
|
||||
except Exception as e:
|
||||
print(f"\033[91mError reading file {json_path}: {str(e)}\033[0m")
|
||||
|
||||
return data_list
|
||||
|
||||
|
||||
def process_json_data(json_data):
|
||||
"""处理每个 JSON 数据,生成相应参数"""
|
||||
# 提取 dataset 信息
|
||||
dataset_file_name = json_data["dataset_info"]["dataset_file_name"]
|
||||
dataset_class_name = json_data["dataset_info"]["dataset_class_name"]
|
||||
data_root = json_data["dataset_info"]["data_root"]
|
||||
|
||||
# 转换 img_scale 为元组 (img_scale_width, img_scale_height)
|
||||
img_scale = (json_data["dataset_info"]["img_scale_width"], json_data["dataset_info"]["img_scale_height"])
|
||||
|
||||
# 提取其他必要信息
|
||||
train_img_path = json_data["dataset_info"]["paths"]["train_img_path"]
|
||||
train_seg_map_path = json_data["dataset_info"]["paths"]["train_seg_map_path"]
|
||||
val_img_path = json_data["dataset_info"]["paths"]["val_img_path"]
|
||||
val_seg_map_path = json_data["dataset_info"]["paths"]["val_seg_map_path"]
|
||||
test_img_path = json_data["dataset_info"]["paths"]["test_img_path"]
|
||||
test_seg_map_path = json_data["dataset_info"]["paths"]["test_seg_map_path"]
|
||||
|
||||
# 提取 label 相关信息
|
||||
classes = json_data["label_info"]["classes"]
|
||||
palette = json_data["label_info"]["palette"]
|
||||
img_suffix = json_data["label_info"]["img_suffix"]
|
||||
seg_map_suffix = json_data["label_info"]["seg_map_suffix"]
|
||||
reduce_zero_label = json_data["label_info"]["reduce_zero_label"]
|
||||
|
||||
# 提取训练相关参数
|
||||
# 转换 crop_size 为元组 (crop_size_width, crop_size_height)
|
||||
crop_size = (json_data["training_info"]["crop_size_width"], json_data["training_info"]["crop_size_height"])
|
||||
train_batch_size = json_data["training_info"]["train_batch_size"]
|
||||
train_num_workers = json_data["training_info"]["train_num_workers"]
|
||||
val_and_test_batch_size = json_data["training_info"]["val_and_test_batch_size"]
|
||||
val_and_test_num_workers = json_data["training_info"]["val_and_test_num_workers"]
|
||||
|
||||
return (dataset_file_name, dataset_class_name, data_root, img_scale, train_img_path, train_seg_map_path, val_img_path, val_seg_map_path, test_img_path, test_seg_map_path,
|
||||
classes, palette, img_suffix, seg_map_suffix, reduce_zero_label, crop_size, train_batch_size, train_num_workers, val_and_test_batch_size, val_and_test_num_workers,)
|
||||
|
||||
|
||||
def save_all_record_to_json(output_file, dataset_file_names, classes_all, palette_all, palette_num_all, mean_all, std_all):
|
||||
"""构建一个 JSON 文件,用于存储每个数据集的信息"""
|
||||
# 构建一个字典,用于存储每个数据集的信息
|
||||
data_record = {}
|
||||
|
||||
# 假设所有列表长度相同,遍历每个 dataset_file_name
|
||||
for i in range(len(dataset_file_names)):
|
||||
data_record[dataset_file_names[i]] = {
|
||||
"classes": classes_all[i],
|
||||
"palette": palette_all[i],
|
||||
"palette_num": palette_num_all[i],
|
||||
"mean": list(mean_all[i]),
|
||||
"std": list(std_all[i])
|
||||
}
|
||||
|
||||
# 将字典写入 JSON 文件
|
||||
with open(output_file, 'w', encoding='utf-8') as json_file:
|
||||
json.dump(data_record, json_file, ensure_ascii=False, indent=6)
|
||||
|
||||
print(f"\033[93m相关数据汇总到 {output_file} successfully!\033[0m")
|
||||
return True
|
||||
|
||||
if __name__ == '__main__':
|
||||
########### 1.超参数-定义训练文件夹路径 ###########
|
||||
train_parameter_dir = './My_All_In_One/1_Data_Parameter'
|
||||
all_data_record_json = "All_Data_Record.json" # 记录所有数据
|
||||
|
||||
########### 2.1.遍历所有配置文件,并生成对应数据集 ###########
|
||||
# 定义保存信息的列表
|
||||
dataset_file_names = []
|
||||
dataset_class_names = []
|
||||
classes_all = []
|
||||
palette_all = []
|
||||
palette_num_all = []
|
||||
mean_all = []
|
||||
std_all = []
|
||||
|
||||
# 2.1.1. 从 ./1_Data_Parameter 文件夹读取所有 JSON 文件
|
||||
json_data_list = load_json_files(train_parameter_dir, not_check_list = [all_data_record_json])
|
||||
|
||||
# 2.1.2. 遍历每个 JSON 数据文件并生成对应配置文件
|
||||
for json_data in json_data_list:
|
||||
# A. 输出当前处理文件
|
||||
print(f"\033[32m正在处理{json_data['file_name_json']}.json文件\033[0m")
|
||||
|
||||
# B. 处理 JSON 数据并提取参数
|
||||
(dataset_file_name, dataset_class_name, data_root, img_scale, train_img_path, train_seg_map_path, val_img_path, val_seg_map_path, test_img_path, test_seg_map_path,
|
||||
classes, palette, img_suffix, seg_map_suffix, reduce_zero_label, crop_size, train_batch_size, train_num_workers, val_and_test_batch_size, val_and_test_num_workers,
|
||||
) = process_json_data(json_data)
|
||||
|
||||
# 保存文件名和类别名
|
||||
dataset_file_names.append(dataset_file_name)
|
||||
dataset_class_names.append(dataset_class_name)
|
||||
|
||||
# C. 文件存储位置
|
||||
output_configs_base_datasets_my_dataset = f'./configs/_base_/datasets/{dataset_file_name}.py'
|
||||
output_mmseg_datasets_dataset_file_name = os.path.join(f'./mmseg/datasets/{dataset_file_name}.py')
|
||||
output_mmseg_datasets_init = os.path.join('./mmseg/datasets/__init__.py')
|
||||
output_mmseg_utils_class_names = f'./mmseg/utils/class_names.py'
|
||||
|
||||
# D. 运行程序生成配置文件
|
||||
# 生成 ./configs/_base_/datasets/{dataset_file_name}.py
|
||||
print(" ",end='')
|
||||
success = generate_configs_base_datasets_my_dataset_file(
|
||||
output_file=output_configs_base_datasets_my_dataset,
|
||||
dataset_class_name=dataset_class_name,
|
||||
data_root=data_root,
|
||||
img_scale=img_scale,
|
||||
crop_size=crop_size,
|
||||
train_batch_size=train_batch_size,
|
||||
train_num_workers=train_num_workers,
|
||||
val_and_test_batch_size=val_and_test_batch_size,
|
||||
val_and_test_num_workers=val_and_test_num_workers,
|
||||
train_img_path=train_img_path,
|
||||
train_seg_map_path=train_seg_map_path,
|
||||
val_img_path=val_img_path,
|
||||
val_seg_map_path=val_seg_map_path,
|
||||
test_img_path=test_img_path,
|
||||
test_seg_map_path=test_seg_map_path
|
||||
)
|
||||
|
||||
# 生成 ./mmseg/datasets/{dataset_file_name}.py
|
||||
print(" ",end='')
|
||||
success, classes, palette = generate_mmseg_datasets_my_dataset_file(
|
||||
output_file=output_mmseg_datasets_dataset_file_name,
|
||||
dataset_class_name=dataset_class_name,
|
||||
classes=classes,
|
||||
palette=palette,
|
||||
img_suffix=img_suffix,
|
||||
seg_map_suffix=seg_map_suffix,
|
||||
reduce_zero_label=reduce_zero_label
|
||||
)
|
||||
|
||||
mean, std = calculate_pic_std_and_mean(
|
||||
dataset_dir = os.path.join(data_root, train_img_path)
|
||||
)
|
||||
|
||||
# 保存标注类名和颜色
|
||||
classes_all.append(classes)
|
||||
palette_all.append(palette)
|
||||
palette_num_all.append(len(classes))
|
||||
mean_all.append(mean)
|
||||
std_all.append(std)
|
||||
|
||||
########### 2.2.汇总所有信息运行生成 init 和 class_names 文件 ###########
|
||||
# 生成 ./mmseg/datasets/__init__.py
|
||||
print(" ",end='')
|
||||
success = generate_mmseg_datasets_init_file(
|
||||
output_file=output_mmseg_datasets_init,
|
||||
dataset_file_names=dataset_file_names,
|
||||
dataset_class_names=dataset_class_names
|
||||
)
|
||||
|
||||
# 生成 ./mmseg/utils/class_names.py
|
||||
print(" ",end='')
|
||||
success = generate_mmseg_utils_class_names_file(
|
||||
output_file=output_mmseg_utils_class_names,
|
||||
dataset_file_names=dataset_file_names,
|
||||
classes_all=classes_all,
|
||||
palette_all=palette_all
|
||||
)
|
||||
|
||||
########### 2.2.汇总dataset_file_names、classes_all、palette_all、palette_num_all所有信息到My_All_In_One/1_Data_Parameter/All_Data_Record.json文件 ###########
|
||||
output_all_data_record = os.path.join(train_parameter_dir, all_data_record_json)
|
||||
# 调用函数保存数据
|
||||
success = save_all_record_to_json(output_file=output_all_data_record, dataset_file_names=dataset_file_names, classes_all=classes_all, palette_all=palette_all, palette_num_all=palette_num_all, mean_all=mean_all, std_all=std_all)
|
||||
|
||||
@@ -0,0 +1,212 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: UTF-8 -*-
|
||||
import os, json
|
||||
from Initial_Data_Program.Initial_Data_Calculate_std_and_mean import calculate_pic_std_and_mean
|
||||
from Initial_Data_Program.Initial_Data_Gen_configs_base_datasets_my_dataset import generate_configs_base_datasets_my_dataset_file
|
||||
from Initial_Data_Program.Initial_Data_Gen_mmseg_utils_class_names import generate_mmseg_utils_class_names_file
|
||||
from Initial_Data_Program.Initial_Data_Gen_mmseg_datasets_my_dataset import generate_mmseg_datasets_my_dataset_file
|
||||
from Initial_Data_Program.Initial_Data_Gen_mmseg_datasets_init_ import generate_mmseg_datasets_init_file
|
||||
|
||||
def load_json_files(directory, not_check_list=None):
|
||||
"""
|
||||
读取指定文件夹下的所有 JSON 文件,排除指定的文件。
|
||||
|
||||
:param directory: 要读取的目录
|
||||
:param not_check_list: 要排除的文件名列表,不包含路径。默认为空列表
|
||||
:return: 包含所有有效 JSON 数据的列表,每个元素为一个字典
|
||||
"""
|
||||
if not_check_list is None:
|
||||
not_check_list = [''] # 默认排除文件
|
||||
|
||||
# 获取所有 .json 文件,并排除在 not_check_list 中的文件
|
||||
json_files = [f for f in os.listdir(directory) if f.endswith('.json') and f not in not_check_list]
|
||||
|
||||
data_list = []
|
||||
|
||||
# 遍历每个 JSON 文件
|
||||
for json_file in json_files:
|
||||
json_path = os.path.join(directory, json_file)
|
||||
|
||||
try:
|
||||
# 打开并加载 JSON 文件
|
||||
with open(json_path, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
# 将文件名(去掉 .json)添加到数据中
|
||||
data['file_name_json'] = json_file.rstrip(".json")
|
||||
data_list.append(data)
|
||||
|
||||
except json.JSONDecodeError:
|
||||
print(f"\033[91mError decoding JSON file: {json_path}\033[0m")
|
||||
except Exception as e:
|
||||
print(f"\033[91mError reading file {json_path}: {str(e)}\033[0m")
|
||||
|
||||
return data_list
|
||||
|
||||
|
||||
def process_json_data(json_data):
|
||||
"""处理每个 JSON 数据,生成相应参数"""
|
||||
# 提取 dataset 信息
|
||||
dataset_file_name = json_data["dataset_info"]["dataset_file_name"]
|
||||
dataset_class_name = json_data["dataset_info"]["dataset_class_name"]
|
||||
data_root = json_data["dataset_info"]["data_root"]
|
||||
|
||||
# 转换 img_scale 为元组 (img_scale_width, img_scale_height)
|
||||
img_scale = (json_data["dataset_info"]["img_scale_width"], json_data["dataset_info"]["img_scale_height"])
|
||||
|
||||
# 提取其他必要信息
|
||||
train_img_path = json_data["dataset_info"]["paths"]["train_img_path"]
|
||||
train_seg_map_path = json_data["dataset_info"]["paths"]["train_seg_map_path"]
|
||||
val_img_path = json_data["dataset_info"]["paths"]["val_img_path"]
|
||||
val_seg_map_path = json_data["dataset_info"]["paths"]["val_seg_map_path"]
|
||||
test_img_path = json_data["dataset_info"]["paths"]["test_img_path"]
|
||||
test_seg_map_path = json_data["dataset_info"]["paths"]["test_seg_map_path"]
|
||||
|
||||
# 提取 label 相关信息
|
||||
classes = json_data["label_info"]["classes"]
|
||||
palette = json_data["label_info"]["palette"]
|
||||
img_suffix = json_data["label_info"]["img_suffix"]
|
||||
seg_map_suffix = json_data["label_info"]["seg_map_suffix"]
|
||||
reduce_zero_label = json_data["label_info"]["reduce_zero_label"]
|
||||
|
||||
# 提取训练相关参数
|
||||
# 转换 crop_size 为元组 (crop_size_width, crop_size_height)
|
||||
crop_size = (json_data["training_info"]["crop_size_width"], json_data["training_info"]["crop_size_height"])
|
||||
train_batch_size = json_data["training_info"]["train_batch_size"]
|
||||
train_num_workers = json_data["training_info"]["train_num_workers"]
|
||||
val_and_test_batch_size = json_data["training_info"]["val_and_test_batch_size"]
|
||||
val_and_test_num_workers = json_data["training_info"]["val_and_test_num_workers"]
|
||||
|
||||
return (dataset_file_name, dataset_class_name, data_root, img_scale, train_img_path, train_seg_map_path, val_img_path, val_seg_map_path, test_img_path, test_seg_map_path,
|
||||
classes, palette, img_suffix, seg_map_suffix, reduce_zero_label, crop_size, train_batch_size, train_num_workers, val_and_test_batch_size, val_and_test_num_workers,)
|
||||
|
||||
|
||||
def save_all_record_to_json(output_file, dataset_file_names, classes_all, palette_all, palette_num_all, mean_all, std_all, train_imgs_num):
|
||||
"""构建一个 JSON 文件,用于存储每个数据集的信息"""
|
||||
# 构建一个字典,用于存储每个数据集的信息
|
||||
data_record = {}
|
||||
|
||||
# 假设所有列表长度相同,遍历每个 dataset_file_name
|
||||
for i in range(len(dataset_file_names)):
|
||||
data_record[dataset_file_names[i]] = {
|
||||
"classes": classes_all[i],
|
||||
"palette": palette_all[i],
|
||||
"palette_num": palette_num_all[i],
|
||||
"mean": list(mean_all[i]),
|
||||
"std": list(std_all[i]),
|
||||
"train_imgs_num": train_imgs_num[i]
|
||||
}
|
||||
|
||||
# 将字典写入 JSON 文件
|
||||
with open(output_file, 'w', encoding='utf-8') as json_file:
|
||||
json.dump(data_record, json_file, ensure_ascii=False, indent=6)
|
||||
|
||||
print(f"\033[93m相关数据汇总到 {output_file} successfully!\033[0m")
|
||||
return True
|
||||
|
||||
if __name__ == '__main__':
|
||||
########### 1.超参数-定义训练文件夹路径 ###########
|
||||
train_parameter_dir = './My_All_In_One/1_Data_Parameter'
|
||||
all_data_record_json = "All_Data_Record.json" # 记录所有数据
|
||||
|
||||
########### 2.1.遍历所有配置文件,并生成对应数据集 ###########
|
||||
# 定义保存信息的列表
|
||||
dataset_file_names = []
|
||||
dataset_class_names = []
|
||||
classes_all = []
|
||||
palette_all = []
|
||||
palette_num_all = []
|
||||
mean_all = []
|
||||
std_all = []
|
||||
train_imgs_num = []
|
||||
|
||||
# 2.1.1. 从 ./1_Data_Parameter 文件夹读取所有 JSON 文件
|
||||
json_data_list = load_json_files(train_parameter_dir, not_check_list = [all_data_record_json])
|
||||
|
||||
# 2.1.2. 遍历每个 JSON 数据文件并生成对应配置文件
|
||||
for json_data in json_data_list:
|
||||
# A. 输出当前处理文件
|
||||
print(f"\033[32m正在处理{json_data['file_name_json']}.json文件\033[0m")
|
||||
|
||||
# B. 处理 JSON 数据并提取参数
|
||||
(dataset_file_name, dataset_class_name, data_root, img_scale, train_img_path, train_seg_map_path, val_img_path, val_seg_map_path, test_img_path, test_seg_map_path,
|
||||
classes, palette, img_suffix, seg_map_suffix, reduce_zero_label, crop_size, train_batch_size, train_num_workers, val_and_test_batch_size, val_and_test_num_workers,
|
||||
) = process_json_data(json_data)
|
||||
|
||||
# 保存文件名和类别名
|
||||
dataset_file_names.append(dataset_file_name)
|
||||
dataset_class_names.append(dataset_class_name)
|
||||
|
||||
# C. 文件存储位置
|
||||
output_configs_base_datasets_my_dataset = f'./configs/_base_/datasets/{dataset_file_name}.py'
|
||||
output_mmseg_datasets_dataset_file_name = os.path.join(f'./mmseg/datasets/{dataset_file_name}.py')
|
||||
output_mmseg_datasets_init = os.path.join('./mmseg/datasets/__init__.py')
|
||||
output_mmseg_utils_class_names = f'./mmseg/utils/class_names.py'
|
||||
|
||||
# D. 运行程序生成配置文件
|
||||
# 生成 ./configs/_base_/datasets/{dataset_file_name}.py
|
||||
print(" ",end='')
|
||||
success = generate_configs_base_datasets_my_dataset_file(
|
||||
output_file=output_configs_base_datasets_my_dataset,
|
||||
dataset_class_name=dataset_class_name,
|
||||
data_root=data_root,
|
||||
img_scale=img_scale,
|
||||
crop_size=crop_size,
|
||||
train_batch_size=train_batch_size,
|
||||
train_num_workers=train_num_workers,
|
||||
val_and_test_batch_size=val_and_test_batch_size,
|
||||
val_and_test_num_workers=val_and_test_num_workers,
|
||||
train_img_path=train_img_path,
|
||||
train_seg_map_path=train_seg_map_path,
|
||||
val_img_path=val_img_path,
|
||||
val_seg_map_path=val_seg_map_path,
|
||||
test_img_path=test_img_path,
|
||||
test_seg_map_path=test_seg_map_path
|
||||
)
|
||||
|
||||
# 生成 ./mmseg/datasets/{dataset_file_name}.py
|
||||
print(" ",end='')
|
||||
success, classes, palette = generate_mmseg_datasets_my_dataset_file(
|
||||
output_file=output_mmseg_datasets_dataset_file_name,
|
||||
dataset_class_name=dataset_class_name,
|
||||
classes=classes,
|
||||
palette=palette,
|
||||
img_suffix=img_suffix,
|
||||
seg_map_suffix=seg_map_suffix,
|
||||
reduce_zero_label=reduce_zero_label
|
||||
)
|
||||
|
||||
mean, std = calculate_pic_std_and_mean(
|
||||
dataset_dir = os.path.join(data_root, train_img_path)
|
||||
)
|
||||
|
||||
# 保存标注类名和颜色
|
||||
classes_all.append(classes)
|
||||
palette_all.append(palette)
|
||||
palette_num_all.append(len(classes))
|
||||
mean_all.append(mean)
|
||||
std_all.append(std)
|
||||
train_imgs_num.append(len(os.listdir(os.path.join(data_root, train_img_path))))
|
||||
|
||||
########### 2.2.汇总所有信息运行生成 init 和 class_names 文件 ###########
|
||||
# 生成 ./mmseg/datasets/__init__.py
|
||||
print(" ",end='')
|
||||
success = generate_mmseg_datasets_init_file(
|
||||
output_file=output_mmseg_datasets_init,
|
||||
dataset_file_names=dataset_file_names,
|
||||
dataset_class_names=dataset_class_names
|
||||
)
|
||||
|
||||
# 生成 ./mmseg/utils/class_names.py
|
||||
print(" ",end='')
|
||||
success = generate_mmseg_utils_class_names_file(
|
||||
output_file=output_mmseg_utils_class_names,
|
||||
dataset_file_names=dataset_file_names,
|
||||
classes_all=classes_all,
|
||||
palette_all=palette_all
|
||||
)
|
||||
|
||||
########### 2.2.汇总dataset_file_names、classes_all、palette_all、palette_num_all所有信息到My_All_In_One/1_Data_Parameter/All_Data_Record.json文件 ###########
|
||||
output_all_data_record = os.path.join(train_parameter_dir, all_data_record_json)
|
||||
# 调用函数保存数据
|
||||
success = save_all_record_to_json(output_file=output_all_data_record, dataset_file_names=dataset_file_names, classes_all=classes_all, palette_all=palette_all, palette_num_all=palette_num_all, mean_all=mean_all, std_all=std_all, train_imgs_num=train_imgs_num)
|
||||
|
||||
101
Seg_All_In_One_MMSeg/My_All_In_One/2_Alg_Program/my_ann_r50.py
Normal file
101
Seg_All_In_One_MMSeg/My_All_In_One/2_Alg_Program/my_ann_r50.py
Normal file
@@ -0,0 +1,101 @@
|
||||
import os, sys, argparse, json
|
||||
# 添加上级目录到 sys.path
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../')))
|
||||
|
||||
from Initial_Alg_Program.Initial_Alg_Select_Tool import select_crop_size, select_pretrained_model, select_test_cfg_slide
|
||||
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_Tool import format_all_data, write_config_to_file
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_0_base_ import generate_base_config
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_1_norm_cfg import generate_norm_cfg
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_2_data_preprocessor import generate_data_preprocessor
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_3_model import generate_model_pretrained, generate_model_backbone, generate_model_data_preprocessor, generate_model_decode_head, generate_model_auxiliary_head, generate_model_train_cfg, generate_model_test_cfg
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_4_optimizer import generate_optim_wrapper, generate_param_scheduler
|
||||
|
||||
if __name__ == '__main__':
|
||||
########### 1.超参数-算法默认参数 ###########
|
||||
alg_file_name = 'ann_r50-d8'
|
||||
|
||||
########### 2.参数解析 ###########
|
||||
parser = argparse.ArgumentParser(description="Run algorithm with specified parameters.")
|
||||
|
||||
# 添加命令行参数
|
||||
parser.add_argument('--alg_name', type=str, required=True, help="Algorithm name (e.g., 'ann')")
|
||||
parser.add_argument('--dataset_file_name', type=str, required=True, help="Dataset file name (e.g., 'my_dataset_model')")
|
||||
parser.add_argument('--mean', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
|
||||
parser.add_argument('--std', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
|
||||
parser.add_argument('--dataset_num_classes', type=int, required=True, help="Dataset class num (e.g., 36)")
|
||||
parser.add_argument('--GPU_num', type=int, required=True, help="Num of GPU to train (e.g., 1)")
|
||||
parser.add_argument('--schedule_file_name', type=str, required=True, help="Schedule file name (e.g., 'schedule_4k_check_400')")
|
||||
parser.add_argument('--train_type', type=str, required=True, help="Schedule train type (e.g., Epoch or Iteration)")
|
||||
parser.add_argument('--train_time_or_epoch_k', type=int, required=True, help="Schedule train time k (e.g., 4)")
|
||||
|
||||
# 解析命令行参数
|
||||
args = parser.parse_args()
|
||||
|
||||
# 变量赋值
|
||||
alg_name = args.alg_name
|
||||
dataset_file_name = args.dataset_file_name
|
||||
mean = json.loads(args.mean) # 使用json.loads转换成list
|
||||
std = json.loads(args.std) # 使用json.loads转换成list
|
||||
num_classes = args.dataset_num_classes
|
||||
schedule_file_name = args.schedule_file_name
|
||||
train_type = args.train_type
|
||||
train_time_or_epoch_k = args.train_time_or_epoch_k
|
||||
GPU_num = args.GPU_num
|
||||
|
||||
|
||||
########### 3. 设定参数 ###########
|
||||
# 3.1. 自选参数
|
||||
crop_size = select_crop_size(predefined_options=[(512, 512), (512, 1024), (769, 769)]) # 选择切割大小
|
||||
model_list = ['openmmlab/resnet50_v1c', 'openmmlab/resnet101_v1c']
|
||||
selected_model_name, select_pretrained, pretrained_pth, selected_model_info = select_pretrained_model(model_list=model_list)
|
||||
depth = selected_model_info['depth']
|
||||
|
||||
test_cfg_mode, test_cfg_crop_div_stride = select_test_cfg_slide(crop_size)
|
||||
|
||||
# 3.2. 自动生成参数
|
||||
# decode、auxiliary损失下载方式
|
||||
decode_head_loss_decode_dict=dict(type='DiceLoss', use_sigmoid=False, loss_weight=1.0) # decode损失函数 # TODO
|
||||
auxiliary_head_loss_decode_dict=dict(type='DiceLoss', use_sigmoid=False, loss_weight=0.4) # DiceLoss损失函数 # TODO
|
||||
|
||||
# 如果使用slide模式,则开启align_corners
|
||||
if test_cfg_mode == "slide":
|
||||
align_corners = True
|
||||
test_slide = 'testslide' # 算法名进行标记
|
||||
else:
|
||||
align_corners = False
|
||||
test_slide = 'no_testslide' # 算法名进行标记
|
||||
|
||||
|
||||
########### 3.生成各个部分 ###########
|
||||
_base_ = generate_base_config(alg_file_name=alg_file_name, dataset_file_name=dataset_file_name, schedule_file_name=schedule_file_name)
|
||||
norm_cfg = generate_norm_cfg(GPU_num=GPU_num)
|
||||
data_preprocessor = generate_data_preprocessor(crop_size = crop_size, mean = mean, std = std)
|
||||
|
||||
pretrained = generate_model_pretrained(pretrained_pth = pretrained_pth)
|
||||
backbone = generate_model_backbone(depth=depth)
|
||||
model_data_preprocessor = data_preprocessor # 修改data_preprocessor
|
||||
model_data_preprocessor = generate_model_data_preprocessor(model_data_preprocessor=model_data_preprocessor)
|
||||
decode_head = generate_model_decode_head(num_classes=num_classes, decode_head_loss_decode_dict=decode_head_loss_decode_dict, align_corners=align_corners)
|
||||
auxiliary_head = generate_model_auxiliary_head(num_classes=num_classes, auxiliary_head_loss_decode_dict=auxiliary_head_loss_decode_dict, align_corners=align_corners)
|
||||
# 综合model
|
||||
model = dict(
|
||||
pretrained = pretrained,
|
||||
backbone = backbone,
|
||||
data_preprocessor = model_data_preprocessor,
|
||||
decode_head = decode_head,
|
||||
auxiliary_head = auxiliary_head,
|
||||
)
|
||||
|
||||
test_cfg = generate_model_test_cfg(test_cfg_mode=test_cfg_mode, crop_size=crop_size, test_cfg_crop_div_stride=test_cfg_crop_div_stride)
|
||||
optim_wrapper = generate_optim_wrapper()
|
||||
param_scheduler = generate_param_scheduler(train_type, train_time_or_epoch_k)
|
||||
########### 4.程序写入 ###########
|
||||
# 算法名称解析:ann【Alg】 _ r50【pretrained模型深度】 - d8【_base_/models】 _ 4【GPU数量】 x b2【Batch Size大小】 - 40k【schedule】 _ cityscapes【数据集】 - 512x1024【crop_size】.py
|
||||
alg_file_name = f"{alg_name}_r{depth}_{'Pre' if select_pretrained else 'NoPre'}_g{GPU_num}-{schedule_file_name.lstrip('schedule_')}_{dataset_file_name}-{crop_size[0]}x{crop_size[1]}-{test_slide}.py"
|
||||
output_configs_alg_my_alg = os.path.join(f'./configs/ann/', alg_file_name)
|
||||
write_config_to_file(output_configs_alg_my_alg, _base_=_base_, crop_size=crop_size, data_preprocessor=data_preprocessor, model=model, test_cfg=test_cfg, optim_wrapper=optim_wrapper, param_scheduler=param_scheduler)
|
||||
|
||||
# 将信息写入文件
|
||||
with open("_temp_.txt", "w") as file:
|
||||
json.dump({'alg_infos':{'alg_file_name':alg_file_name,'alg_file_pth':output_configs_alg_my_alg}}, file)
|
||||
@@ -0,0 +1,102 @@
|
||||
import os, sys, argparse, json
|
||||
# 添加上级目录到 sys.path
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../')))
|
||||
|
||||
from Initial_Alg_Program.Initial_Alg_Select_Tool import select_crop_size, select_pretrained_model, select_test_cfg_slide
|
||||
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_Tool import format_all_data, write_config_to_file
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_0_base_ import generate_base_config
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_1_norm_cfg import generate_norm_cfg
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_2_data_preprocessor import generate_data_preprocessor
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_3_model import generate_model_pretrained, generate_model_backbone, generate_model_data_preprocessor, generate_model_decode_head, generate_model_auxiliary_head, generate_model_train_cfg, generate_model_test_cfg
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_4_optimizer import generate_optim_wrapper, generate_param_scheduler
|
||||
|
||||
if __name__ == '__main__':
|
||||
########### 1.超参数-算法默认参数 ###########
|
||||
alg_file_name = 'apcnet_r50-d8'
|
||||
|
||||
########### 2.参数解析 ###########
|
||||
parser = argparse.ArgumentParser(description="Run algorithm with specified parameters.")
|
||||
|
||||
# 添加命令行参数
|
||||
parser.add_argument('--alg_name', type=str, required=True, help="Algorithm name (e.g., 'ann')")
|
||||
parser.add_argument('--dataset_file_name', type=str, required=True, help="Dataset file name (e.g., 'my_dataset_model')")
|
||||
parser.add_argument('--mean', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
|
||||
parser.add_argument('--std', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
|
||||
parser.add_argument('--dataset_num_classes', type=int, required=True, help="Dataset class num (e.g., 36)")
|
||||
parser.add_argument('--GPU_num', type=int, required=True, help="Num of GPU to train (e.g., 1)")
|
||||
parser.add_argument('--schedule_file_name', type=str, required=True, help="Schedule file name (e.g., 'schedule_4k_check_400')")
|
||||
parser.add_argument('--train_type', type=str, required=True, help="Schedule train type (e.g., Epoch or Iteration)")
|
||||
parser.add_argument('--train_time_or_epoch_k', type=int, required=True, help="Schedule train time k (e.g., 4)")
|
||||
|
||||
# 解析命令行参数
|
||||
args = parser.parse_args()
|
||||
|
||||
# 变量赋值
|
||||
alg_name = args.alg_name
|
||||
dataset_file_name = args.dataset_file_name
|
||||
mean = json.loads(args.mean) # 使用json.loads转换成list
|
||||
std = json.loads(args.std) # 使用json.loads转换成list
|
||||
num_classes = args.dataset_num_classes
|
||||
schedule_file_name = args.schedule_file_name
|
||||
train_type = args.train_type
|
||||
train_time_or_epoch_k = args.train_time_or_epoch_k
|
||||
GPU_num = args.GPU_num
|
||||
|
||||
|
||||
########### 3. 设定参数 ###########
|
||||
# 3.1. 自选参数
|
||||
crop_size = select_crop_size(predefined_options=[(512, 512), (512, 1024), (769, 769)]) # 选择切割大小
|
||||
model_list = ['openmmlab/resnet50_v1c', 'openmmlab/resnet101_v1c']
|
||||
selected_model_name, select_pretrained, pretrained_pth, selected_model_info = select_pretrained_model(model_list=model_list)
|
||||
depth = selected_model_info['depth']
|
||||
|
||||
test_cfg_mode, test_cfg_crop_div_stride = select_test_cfg_slide(crop_size)
|
||||
|
||||
# 3.2. 自动生成参数
|
||||
# decode、auxiliary损失下载方式
|
||||
decode_head_loss_decode_dict=dict(type='DiceLoss', use_sigmoid=False, loss_weight=1.0) # decode损失函数 # TODO
|
||||
auxiliary_head_loss_decode_dict=dict(type='DiceLoss', use_sigmoid=False, loss_weight=0.4) # DiceLoss损失函数 # TODO
|
||||
|
||||
# 如果使用slide模式,则开启align_corners
|
||||
if test_cfg_mode == "slide":
|
||||
align_corners = True
|
||||
test_slide = 'testslide' # 算法名进行标记
|
||||
else:
|
||||
align_corners = False
|
||||
test_slide = 'no_testslide' # 算法名进行标记
|
||||
|
||||
|
||||
########### 3.生成各个部分 ###########
|
||||
_base_ = generate_base_config(alg_file_name=alg_file_name, dataset_file_name=dataset_file_name, schedule_file_name=schedule_file_name)
|
||||
norm_cfg = generate_norm_cfg(GPU_num=GPU_num)
|
||||
data_preprocessor = generate_data_preprocessor(crop_size = crop_size, mean = mean, std = std)
|
||||
|
||||
pretrained = generate_model_pretrained(pretrained_pth = pretrained_pth)
|
||||
backbone = generate_model_backbone(depth=depth)
|
||||
model_data_preprocessor = data_preprocessor # 修改data_preprocessor
|
||||
model_data_preprocessor = generate_model_data_preprocessor(model_data_preprocessor=model_data_preprocessor)
|
||||
decode_head = generate_model_decode_head(num_classes=num_classes, decode_head_loss_decode_dict=decode_head_loss_decode_dict, align_corners=align_corners)
|
||||
auxiliary_head = generate_model_auxiliary_head(num_classes=num_classes, auxiliary_head_loss_decode_dict=auxiliary_head_loss_decode_dict, align_corners=align_corners)
|
||||
# 综合model
|
||||
model = dict(
|
||||
pretrained = pretrained,
|
||||
backbone = backbone,
|
||||
data_preprocessor = model_data_preprocessor,
|
||||
decode_head = decode_head,
|
||||
auxiliary_head = auxiliary_head,
|
||||
)
|
||||
|
||||
test_cfg = generate_model_test_cfg(test_cfg_mode=test_cfg_mode, crop_size=crop_size, test_cfg_crop_div_stride=test_cfg_crop_div_stride)
|
||||
optim_wrapper = generate_optim_wrapper()
|
||||
param_scheduler = generate_param_scheduler(train_type, train_time_or_epoch_k)
|
||||
|
||||
########### 4.程序写入 ###########
|
||||
# 算法名称解析:apcnet【Alg】 _ r50【pretrained模型深度】 - d8【_base_/models】 _ 4【GPU数量】 x b2【Batch Size大小】 - 40k【schedule】 _ cityscapes【数据集】 - 512x1024【crop_size】.py
|
||||
alg_file_name = f"{alg_name}_r{depth}_{'Pre' if select_pretrained else 'NoPre'}_g{GPU_num}-{schedule_file_name.lstrip('schedule_')}_{dataset_file_name}-{crop_size[0]}x{crop_size[1]}-{test_slide}.py"
|
||||
output_configs_alg_my_alg = os.path.join(f'./configs/apcnet/', alg_file_name)
|
||||
write_config_to_file(output_configs_alg_my_alg, _base_=_base_, norm_cfg=norm_cfg, crop_size=crop_size, data_preprocessor=data_preprocessor, model=model, test_cfg=test_cfg, optim_wrapper=optim_wrapper, param_scheduler=param_scheduler)
|
||||
|
||||
# 将信息写入文件
|
||||
with open("_temp_.txt", "w") as file:
|
||||
json.dump({'alg_infos':{'alg_file_name':alg_file_name,'alg_file_pth':output_configs_alg_my_alg}}, file)
|
||||
@@ -0,0 +1,176 @@
|
||||
import os, sys, argparse, json
|
||||
# 添加上级目录到 sys.path
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../')))
|
||||
|
||||
from Initial_Alg_Program.Initial_Alg_Select_Tool import select_crop_size, select_pretrained_model
|
||||
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_Tool import format_all_data, write_config_to_file, create_dict_by_kwargs, get_var_from_file, update_list_dict_var, get_var_from_py_file
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_0_base_ import generate_base_config
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_1_norm_cfg import generate_norm_cfg
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_2_data_preprocessor import generate_data_preprocessor
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_3_model import generate_model_pretrained, generate_model_backbone, generate_model_data_preprocessor, generate_model_decode_head, generate_model_auxiliary_head, generate_model_train_cfg, generate_model_test_cfg
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_4_optimizer import generate_optim_wrapper, generate_param_scheduler
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_5_train_dataloader import generate_train_dataloader
|
||||
|
||||
# 交互式选择 decode_head 的函数
|
||||
def select_decode_head(decode_head_choose):
|
||||
print("可用的 decode head 选项:")
|
||||
for i, key in enumerate(decode_head_choose.keys()):
|
||||
print(f"{i + 1}. {key}")
|
||||
|
||||
while True:
|
||||
try:
|
||||
# 提示用户输入选择
|
||||
choice = int(input("请选择需要的 decode head(输入编号):"))
|
||||
# 检查输入是否在有效范围内
|
||||
if 1 <= choice <= len(decode_head_choose):
|
||||
selected_key = list(decode_head_choose.keys())[choice - 1]
|
||||
print(f"你选择了: {selected_key}")
|
||||
return decode_head_choose[selected_key]
|
||||
else:
|
||||
print("输入的编号不正确,请重新输入。")
|
||||
except ValueError:
|
||||
print("输入无效,请输入有效的编号。")
|
||||
|
||||
if __name__ == '__main__':
|
||||
########### 1.超参数-算法默认参数 ###########
|
||||
alg_file_name = 'upernet_beit'
|
||||
alg_file_pth = f"configs/_base_/models/{alg_file_name}.py"
|
||||
|
||||
########### 2.参数解析 ###########
|
||||
parser = argparse.ArgumentParser(description="Run algorithm with specified parameters.")
|
||||
|
||||
# 添加命令行参数
|
||||
parser.add_argument('--alg_name', type=str, required=True, help="Algorithm name (e.g., 'ann')")
|
||||
parser.add_argument('--dataset_file_name', type=str, required=True, help="Dataset file name (e.g., 'my_dataset_model')")
|
||||
parser.add_argument('--mean', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
|
||||
parser.add_argument('--std', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
|
||||
parser.add_argument('--dataset_num_classes', type=int, required=True, help="Dataset class num (e.g., 36)")
|
||||
parser.add_argument('--GPU_num', type=int, required=True, help="Num of GPU to train (e.g., 1)")
|
||||
parser.add_argument('--schedule_file_name', type=str, required=True, help="Schedule file name (e.g., 'schedule_4k_check_400')")
|
||||
parser.add_argument('--train_type', type=str, required=True, help="Schedule train type (e.g., Epoch or Iteration)")
|
||||
parser.add_argument('--train_time_or_epoch_k', type=int, required=True, help="Schedule train time k (e.g., 4)")
|
||||
|
||||
# 解析命令行参数
|
||||
args = parser.parse_args()
|
||||
|
||||
# 变量赋值
|
||||
alg_name = args.alg_name
|
||||
dataset_file_name = args.dataset_file_name
|
||||
mean = json.loads(args.mean) # 使用json.loads转换成list
|
||||
std = json.loads(args.std) # 使用json.loads转换成list
|
||||
num_classes = args.dataset_num_classes
|
||||
schedule_file_name = args.schedule_file_name
|
||||
train_type = args.train_type
|
||||
train_time_or_epoch_k = args.train_time_or_epoch_k
|
||||
GPU_num = args.GPU_num
|
||||
|
||||
########### 3.1. 生成_base_部分 ###########
|
||||
_base_ = generate_base_config(alg_file_name=alg_file_name, dataset_file_name=dataset_file_name, schedule_file_name=schedule_file_name)
|
||||
|
||||
########### 3.1. 生成norm_cfg部分 ###########
|
||||
norm_cfg = generate_norm_cfg(GPU_num=GPU_num)
|
||||
|
||||
########### 3.2. 生成data_preprocessor部分 ###########
|
||||
crop_size = select_crop_size(predefined_options=[(640, 640)]) # 选择切割大小
|
||||
data_preprocessor = generate_data_preprocessor(crop_size = crop_size, mean = mean, std = std)
|
||||
|
||||
########### 3.3. 生成model部分 ###########
|
||||
model_data_preprocessor = data_preprocessor # 修改data_preprocessor
|
||||
# 获取backbone模型、是否需要预训练
|
||||
model_list = ['pretrain/beit_base_patch16_224_pt22k_ft22k.pth', 'pretrain/beit_large_patch16_224_pt22k_ft22k.pth']
|
||||
selected_model_name, select_pretrained, pretrained_pth, selected_model_info = select_pretrained_model(model_list=model_list) # 需要选择是否用预训练模型
|
||||
|
||||
backbone = create_dict_by_kwargs(type='BEiT', img_size=crop_size,)
|
||||
|
||||
decode_head_loss_decode=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0) # Way Ori
|
||||
decode_head_loss_decode=dict(type='DiceLoss', use_sigmoid=False, loss_weight=1.0) # Way 1
|
||||
|
||||
auxiliary_head_loss_decode=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4) # Way Ori
|
||||
auxiliary_head_loss_decode=dict(type='DiceLoss', use_sigmoid=False, loss_weight=0.4) # Way 1
|
||||
|
||||
if selected_model_name == 'pretrain/beit_base_patch16_224_pt22k_ft22k.pth':
|
||||
backbone_new=dict(
|
||||
# patch_size=16,
|
||||
# in_channels=3,
|
||||
embed_dims=768,
|
||||
num_layers=12,
|
||||
num_heads=12,
|
||||
# mlp_ratio=4,
|
||||
out_indices=(3, 5, 7, 11),
|
||||
# qv_bias=True,
|
||||
# attn_drop_rate=0.0,
|
||||
drop_path_rate=0.1,
|
||||
# norm_cfg=dict(type='LN', eps=1e-6),
|
||||
# act_cfg=dict(type='GELU'),
|
||||
# norm_eval=False,
|
||||
init_values=0.1)
|
||||
neck = dict(type='Feature2Pyramid', embed_dim=768, rescales=[4, 2, 1, 0.5])
|
||||
decode_head=dict(in_channels=[768, 768, 768, 768], num_classes=num_classes, channels=768, norm_cfg=norm_cfg, loss_decode=decode_head_loss_decode)
|
||||
auxiliary_head=dict(norm_cfg=norm_cfg, in_channels=768, num_classes=num_classes, loss_decode = auxiliary_head_loss_decode)
|
||||
|
||||
optim_wrapper = dict(
|
||||
_delete_=True,
|
||||
type='AmpOptimWrapper', # type='OptimWrapper', # TODO Ori
|
||||
optimizer=dict(
|
||||
type='AdamW', lr=3e-5, betas=(0.9, 0.999), weight_decay=0.05),
|
||||
constructor='LayerDecayOptimizerConstructor',
|
||||
paramwise_cfg=dict(num_layers=12, layer_decay_rate=0.9))
|
||||
|
||||
model_size = 'base'
|
||||
|
||||
train_dataloader, batch_size = generate_train_dataloader(2)
|
||||
|
||||
elif selected_model_name == 'pretrain/beit_large_patch16_224_pt22k_ft22k.pth':
|
||||
backbone_new=dict(
|
||||
embed_dims=1024,
|
||||
num_layers=24,
|
||||
num_heads=16,
|
||||
# mlp_ratio=4,
|
||||
# qv_bias=True,
|
||||
init_values=1e-6,
|
||||
drop_path_rate=0.2,
|
||||
out_indices=[7, 11, 15, 23])
|
||||
neck=dict(type='Feature2Pyramid', embed_dim=1024, rescales=[4, 2, 1, 0.5])
|
||||
decode_head=dict(in_channels=[1024, 1024, 1024, 1024], num_classes=num_classes, channels=1024, norm_cfg=norm_cfg, loss_decode=decode_head_loss_decode)
|
||||
auxiliary_head=dict(norm_cfg=norm_cfg, in_channels=1024, num_classes=num_classes, loss_decode = auxiliary_head_loss_decode)
|
||||
|
||||
optim_wrapper = dict(
|
||||
_delete_=True,
|
||||
type='AmpOptimWrapper',
|
||||
optimizer=dict(
|
||||
type='AdamW', lr=2e-5, betas=(0.9, 0.999), weight_decay=0.05),
|
||||
constructor='LayerDecayOptimizerConstructor',
|
||||
paramwise_cfg=dict(num_layers=24, layer_decay_rate=0.95),
|
||||
accumulative_counts=2)
|
||||
|
||||
model_size = 'large'
|
||||
|
||||
train_dataloader, batch_size = generate_train_dataloader(1)
|
||||
|
||||
# 更新backbone
|
||||
backbone.update(backbone_new)
|
||||
|
||||
model = dict(
|
||||
backbone = backbone,
|
||||
data_preprocessor = model_data_preprocessor,
|
||||
pretrained = pretrained_pth,
|
||||
neck = neck,
|
||||
decode_head = decode_head,
|
||||
auxiliary_head = auxiliary_head,
|
||||
)
|
||||
|
||||
|
||||
########### 3.4. 生成optim、param_scheduler部分 ###########
|
||||
optim_wrapper = optim_wrapper
|
||||
param_scheduler = generate_param_scheduler(train_type, train_time_or_epoch_k)
|
||||
|
||||
########### 4.程序写入 ###########
|
||||
# 算法名称解析:beit【Alg】 _ r50【pretrained模型深度】 - d8【_base_/models】 _ 4【GPU数量】 x b2【Batch Size大小】 - 40k【schedule】 _ cityscapes【数据集】 - 512x1024【crop_size】.py
|
||||
alg_file_name = f"{alg_name}-{model_size}_b{batch_size}_g{GPU_num}-{schedule_file_name.lstrip('schedule_')}_{dataset_file_name}-{crop_size[0]}x{crop_size[1]}.py"
|
||||
output_configs_alg_my_alg = os.path.join(f'./configs/beit/', alg_file_name)
|
||||
write_config_to_file(output_configs_alg_my_alg, _base_=_base_, norm_cfg=norm_cfg, crop_size=crop_size, data_preprocessor=data_preprocessor, model=model, optim_wrapper=optim_wrapper, param_scheduler=param_scheduler, train_dataloader = train_dataloader)
|
||||
|
||||
# 将信息写入文件
|
||||
with open("_temp_.txt", "w") as file:
|
||||
json.dump({'alg_infos':{'alg_file_name':alg_file_name,'alg_file_pth':output_configs_alg_my_alg}}, file)
|
||||
@@ -0,0 +1,133 @@
|
||||
import os, sys, argparse, json
|
||||
# 添加上级目录到 sys.path
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../')))
|
||||
|
||||
from Initial_Alg_Program.Initial_Alg_Select_Tool import select_crop_size, select_pretrained_model
|
||||
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_Tool import format_all_data, write_config_to_file, create_dict_by_kwargs, get_var_from_file, update_list_dict_var, get_var_from_py_file
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_0_base_ import generate_base_config
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_1_norm_cfg import generate_norm_cfg
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_2_data_preprocessor import generate_data_preprocessor
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_3_model import generate_model_pretrained, generate_model_backbone, generate_model_data_preprocessor, generate_model_decode_head, generate_model_auxiliary_head, generate_model_train_cfg, generate_model_test_cfg
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_4_optimizer import generate_optim_wrapper, generate_param_scheduler
|
||||
|
||||
if __name__ == '__main__':
|
||||
########### 1.超参数-算法默认参数 ###########
|
||||
alg_file_name = 'bisenetv1_r18-d32'
|
||||
alg_file_pth = f"configs/_base_/models/{alg_file_name}.py"
|
||||
|
||||
########### 2.参数解析 ###########
|
||||
parser = argparse.ArgumentParser(description="Run algorithm with specified parameters.")
|
||||
|
||||
# 添加命令行参数
|
||||
parser.add_argument('--alg_name', type=str, required=True, help="Algorithm name (e.g., 'ann')")
|
||||
parser.add_argument('--dataset_file_name', type=str, required=True, help="Dataset file name (e.g., 'my_dataset_model')")
|
||||
parser.add_argument('--mean', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
|
||||
parser.add_argument('--std', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
|
||||
parser.add_argument('--dataset_num_classes', type=int, required=True, help="Dataset class num (e.g., 36)")
|
||||
parser.add_argument('--GPU_num', type=int, required=True, help="Num of GPU to train (e.g., 1)")
|
||||
parser.add_argument('--schedule_file_name', type=str, required=True, help="Schedule file name (e.g., 'schedule_4k_check_400')")
|
||||
parser.add_argument('--train_type', type=str, required=True, help="Schedule train type (e.g., Epoch or Iteration)")
|
||||
parser.add_argument('--train_time_or_epoch_k', type=int, required=True, help="Schedule train time k (e.g., 4)")
|
||||
|
||||
# 解析命令行参数
|
||||
args = parser.parse_args()
|
||||
|
||||
# 变量赋值
|
||||
alg_name = args.alg_name
|
||||
dataset_file_name = args.dataset_file_name
|
||||
mean = json.loads(args.mean) # 使用json.loads转换成list
|
||||
std = json.loads(args.std) # 使用json.loads转换成list
|
||||
num_classes = args.dataset_num_classes
|
||||
schedule_file_name = args.schedule_file_name
|
||||
train_type = args.train_type
|
||||
train_time_or_epoch_k = args.train_time_or_epoch_k
|
||||
GPU_num = args.GPU_num
|
||||
|
||||
########### 3.1. 生成_base_部分 ###########
|
||||
_base_ = generate_base_config(alg_file_name=alg_file_name, dataset_file_name=dataset_file_name, schedule_file_name=schedule_file_name)
|
||||
|
||||
########### 3.1. 生成norm_cfg部分 ###########
|
||||
norm_cfg = generate_norm_cfg(GPU_num=GPU_num)
|
||||
|
||||
########### 3.2. 生成data_preprocessor部分 ###########
|
||||
crop_size = select_crop_size(predefined_options=[(512, 512), (1024, 1024)]) # 选择切割大小
|
||||
data_preprocessor = generate_data_preprocessor(crop_size = crop_size, mean = mean, std = std)
|
||||
|
||||
########### 3.3. 生成model部分 ###########
|
||||
model_data_preprocessor = data_preprocessor # 修改data_preprocessor
|
||||
# 获取backbone模型、是否需要预训练
|
||||
model_list = ['openmmlab/resnet18_v1c', 'openmmlab/resnet50_v1c', 'openmmlab/resnet101_v1c']
|
||||
selected_model_name, select_pretrained, pretrained_pth, selected_model_info = select_pretrained_model(model_list=model_list, need_select_pretrained = True) # 需要选择是否用预训练模型
|
||||
depth = selected_model_info['depth']
|
||||
|
||||
if select_pretrained == True:
|
||||
pretrained_txt = 'Pre'
|
||||
else:
|
||||
pretrained_txt = 'NoPre'
|
||||
|
||||
# 模型信息
|
||||
if selected_model_name == 'openmmlab/resnet18_v1c' :
|
||||
backbone_context_channels = (128, 256, 512)
|
||||
backbone_spatial_channels = (64, 64, 64, 128)
|
||||
backbone_out_channels = 256
|
||||
decode_head_in_channels = 256
|
||||
decode_head_channels = 256
|
||||
auxiliary_head_in_channels = 128
|
||||
auxiliary_head_channels = 64
|
||||
|
||||
elif selected_model_name == 'openmmlab/resnet50_v1c' or selected_model_name == 'openmmlab/resnet101_v1c' :
|
||||
backbone_context_channels = (512, 1024, 2048)
|
||||
backbone_spatial_channels = (256, 256, 256, 512)
|
||||
backbone_out_channels = 1024
|
||||
decode_head_in_channels = 1024
|
||||
decode_head_channels = 1024
|
||||
auxiliary_head_in_channels = 512
|
||||
auxiliary_head_channels = 256
|
||||
|
||||
# 需要选择预训练模型
|
||||
if select_pretrained == True:
|
||||
backbone_backbone_cfg = create_dict_by_kwargs(type='ResNet', norm_cfg=norm_cfg, depth=depth, init_cfg=dict(type='Pretrained', checkpoint=pretrained_pth))
|
||||
else:
|
||||
backbone_backbone_cfg = create_dict_by_kwargs(type='ResNet', norm_cfg=norm_cfg, depth=depth)
|
||||
|
||||
backbone = create_dict_by_kwargs(context_channels=backbone_context_channels, norm_cfg=norm_cfg, spatial_channels=backbone_spatial_channels, out_channels=backbone_out_channels, backbone_cfg=backbone_backbone_cfg)
|
||||
|
||||
# decode、auxiliary损失下载方式 TODO
|
||||
decode_head_loss_decode_dict=dict(type='DiceLoss', use_sigmoid=False, loss_weight=1.0) # decode损失函数 # TODO
|
||||
decode_head = create_dict_by_kwargs(type='FCNHead',in_channels=decode_head_in_channels, channels=decode_head_channels, num_classes=num_classes)
|
||||
|
||||
# auxiliary损失下载方式 TODO # 可更改
|
||||
# auxiliary_head_loss_decode_dict = dict(type='DiceLoss', use_sigmoid=False, loss_weight=0.4) # DiceLoss损失函数 # TODO # 可更改
|
||||
auxiliary_head_loss_decode_dict = dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)
|
||||
auxiliary_head = [create_dict_by_kwargs(type='FCNHead', norm_cfg=norm_cfg, in_channels=auxiliary_head_in_channels, channels=auxiliary_head_channels, num_classes=num_classes, align_corners=False, loss_decode=auxiliary_head_loss_decode_dict),
|
||||
create_dict_by_kwargs(type='FCNHead', norm_cfg=norm_cfg, in_channels=auxiliary_head_in_channels, channels=auxiliary_head_channels, num_classes=num_classes, align_corners=False, loss_decode=auxiliary_head_loss_decode_dict)]
|
||||
# 获取原始auxiliary_head
|
||||
auxiliary_head = get_var_from_py_file(alg_file_pth, var_name="model")["auxiliary_head"]
|
||||
# 新的auxiliary_head
|
||||
auxiliary_head_new = [create_dict_by_kwargs(type='FCNHead', norm_cfg=norm_cfg, in_channels=auxiliary_head_in_channels, channels=auxiliary_head_channels, num_classes=num_classes, align_corners=False, loss_decode=auxiliary_head_loss_decode_dict),
|
||||
create_dict_by_kwargs(type='FCNHead', norm_cfg=norm_cfg, in_channels=auxiliary_head_in_channels, channels=auxiliary_head_channels, num_classes=num_classes, align_corners=False, loss_decode=auxiliary_head_loss_decode_dict)]
|
||||
# 更新auxiliary_head
|
||||
auxiliary_head = update_list_dict_var(auxiliary_head, auxiliary_head_new)
|
||||
|
||||
# 综合model
|
||||
model = dict(
|
||||
backbone = backbone,
|
||||
data_preprocessor = model_data_preprocessor,
|
||||
decode_head = decode_head,
|
||||
auxiliary_head = auxiliary_head,
|
||||
)
|
||||
|
||||
########### 3.4. 生成optim、param_scheduler部分 ###########
|
||||
optim_wrapper = generate_optim_wrapper()
|
||||
param_scheduler = generate_param_scheduler(train_type, train_time_or_epoch_k)
|
||||
|
||||
########### 4.程序写入 ###########
|
||||
# 算法名称解析:bisenetv1【Alg】 _ r50【pretrained模型深度】 - d8【_base_/models】 _ 4【GPU数量】 x b2【Batch Size大小】 - 40k【schedule】 _ cityscapes【数据集】 - 512x1024【crop_size】.py
|
||||
alg_file_name = f"{alg_name}_r{depth}_{'Pre' if select_pretrained else 'NoPre'}_g{GPU_num}-{schedule_file_name.lstrip('schedule_')}_{dataset_file_name}-{crop_size[0]}x{crop_size[1]}.py"
|
||||
output_configs_alg_my_alg = os.path.join(f'./configs/bisenetv1/', alg_file_name)
|
||||
write_config_to_file(output_configs_alg_my_alg, _base_=_base_, norm_cfg=norm_cfg, crop_size=crop_size, data_preprocessor=data_preprocessor, model=model, optim_wrapper=optim_wrapper, param_scheduler=param_scheduler)
|
||||
|
||||
# 将信息写入文件
|
||||
with open("_temp_.txt", "w") as file:
|
||||
json.dump({'alg_infos':{'alg_file_name':alg_file_name,'alg_file_pth':output_configs_alg_my_alg}}, file)
|
||||
111
Seg_All_In_One_MMSeg/My_All_In_One/2_Alg_Program/my_bisenetv2.py
Normal file
111
Seg_All_In_One_MMSeg/My_All_In_One/2_Alg_Program/my_bisenetv2.py
Normal file
@@ -0,0 +1,111 @@
|
||||
import os, sys, argparse, json
|
||||
# 添加上级目录到 sys.path
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../')))
|
||||
|
||||
from Initial_Alg_Program.Initial_Alg_Select_Tool import select_crop_size, select_pretrained_model, select_sampler
|
||||
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_Tool import format_all_data, write_config_to_file, create_dict_by_kwargs, get_var_from_file, update_list_dict_var, get_var_from_py_file
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_0_base_ import generate_base_config
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_1_norm_cfg import generate_norm_cfg
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_2_data_preprocessor import generate_data_preprocessor
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_3_model import generate_model_pretrained, generate_model_backbone, generate_model_data_preprocessor, generate_model_decode_head, generate_model_auxiliary_head, generate_model_train_cfg, generate_model_test_cfg
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_4_optimizer import generate_optim_wrapper, generate_param_scheduler
|
||||
|
||||
if __name__ == '__main__':
|
||||
########### 1.超参数-算法默认参数 ###########
|
||||
alg_file_name = 'bisenetv2'
|
||||
alg_file_pth = f"configs/_base_/models/{alg_file_name}.py"
|
||||
|
||||
########### 2.参数解析 ###########
|
||||
parser = argparse.ArgumentParser(description="Run algorithm with specified parameters.")
|
||||
|
||||
# 添加命令行参数
|
||||
parser.add_argument('--alg_name', type=str, required=True, help="Algorithm name (e.g., 'ann')")
|
||||
parser.add_argument('--dataset_file_name', type=str, required=True, help="Dataset file name (e.g., 'my_dataset_model')")
|
||||
parser.add_argument('--mean', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
|
||||
parser.add_argument('--std', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
|
||||
parser.add_argument('--dataset_num_classes', type=int, required=True, help="Dataset class num (e.g., 36)")
|
||||
parser.add_argument('--GPU_num', type=int, required=True, help="Num of GPU to train (e.g., 1)")
|
||||
parser.add_argument('--schedule_file_name', type=str, required=True, help="Schedule file name (e.g., 'schedule_4k_check_400')")
|
||||
parser.add_argument('--train_type', type=str, required=True, help="Schedule train type (e.g., Epoch or Iteration)")
|
||||
parser.add_argument('--train_time_or_epoch_k', type=int, required=True, help="Schedule train time k (e.g., 4)")
|
||||
|
||||
# 解析命令行参数
|
||||
args = parser.parse_args()
|
||||
|
||||
# 变量赋值
|
||||
alg_name = args.alg_name
|
||||
dataset_file_name = args.dataset_file_name
|
||||
mean = json.loads(args.mean) # 使用json.loads转换成list
|
||||
std = json.loads(args.std) # 使用json.loads转换成list
|
||||
num_classes = args.dataset_num_classes
|
||||
schedule_file_name = args.schedule_file_name
|
||||
train_type = args.train_type
|
||||
train_time_or_epoch_k = args.train_time_or_epoch_k
|
||||
GPU_num = args.GPU_num
|
||||
|
||||
########### 3.1. 生成_base_部分 ###########
|
||||
_base_ = generate_base_config(alg_file_name=alg_file_name, dataset_file_name=dataset_file_name, schedule_file_name=schedule_file_name)
|
||||
|
||||
########### 3.1. 生成norm_cfg部分 ###########
|
||||
norm_cfg = generate_norm_cfg(GPU_num=GPU_num)
|
||||
|
||||
########### 3.2. 生成data_preprocessor部分 ###########
|
||||
crop_size = select_crop_size(predefined_options=[(512, 512)]) # 选择切割大小
|
||||
data_preprocessor = generate_data_preprocessor(crop_size = crop_size, mean = mean, std = std)
|
||||
|
||||
########### 3.3. 生成model部分 ###########
|
||||
# 3.3.1. 预处理data_preprocessor
|
||||
model_data_preprocessor = data_preprocessor # 修改data_preprocessor
|
||||
|
||||
# 3.3.2. 采样函数sampler
|
||||
selected_sampler_name, use_sampler, selected_sampler_info = select_sampler(sampler_list=['OHEMPixelSampler'])
|
||||
if use_sampler == True:
|
||||
decode_head_sampler=selected_sampler_info
|
||||
auxiliary_head_sampler=selected_sampler_info
|
||||
use_sampler_txt = 'ohempSampler' # 标记
|
||||
else:
|
||||
decode_head_sampler=None
|
||||
auxiliary_head_sampler=None
|
||||
use_sampler_txt = ''
|
||||
|
||||
# 3.3.3. 解码器decode_head
|
||||
# decode、auxiliary损失下载方式 TODO
|
||||
decode_head_loss_decode_dict=dict(type='DiceLoss', use_sigmoid=False, loss_weight=1.0) # decode损失函数 # TODO
|
||||
# decode_head_loss_decode_dict=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)
|
||||
decode_head = create_dict_by_kwargs(type='FCNHead', loss_decode=decode_head_loss_decode_dict, sampler=decode_head_sampler, num_classes=num_classes)
|
||||
|
||||
# 3.3.3. 辅助部分auxiliary_head
|
||||
# auxiliary损失下载方式 TODO # 可更改
|
||||
auxiliary_head_loss_decode_dict = dict(type='DiceLoss', use_sigmoid=False, loss_weight=0.4, reduction="none") # DiceLoss损失函数 # TODO # 可更改
|
||||
# auxiliary_head_loss_decode_dict = dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)
|
||||
# 获取原始auxiliary_head
|
||||
auxiliary_head = get_var_from_py_file(alg_file_pth, var_name="model")["auxiliary_head"]
|
||||
# 新的auxiliary_head
|
||||
auxiliary_head_new = [create_dict_by_kwargs(type='FCNHead', sampler=auxiliary_head_sampler, norm_cfg=norm_cfg, align_corners=False, loss_decode=auxiliary_head_loss_decode_dict, num_classes=num_classes,),
|
||||
create_dict_by_kwargs(type='FCNHead', sampler=auxiliary_head_sampler, norm_cfg=norm_cfg, align_corners=False, loss_decode=auxiliary_head_loss_decode_dict, num_classes=num_classes,),
|
||||
create_dict_by_kwargs(type='FCNHead', sampler=auxiliary_head_sampler, norm_cfg=norm_cfg, align_corners=False, loss_decode=auxiliary_head_loss_decode_dict, num_classes=num_classes,),
|
||||
create_dict_by_kwargs(type='FCNHead', sampler=auxiliary_head_sampler, norm_cfg=norm_cfg, align_corners=False, loss_decode=auxiliary_head_loss_decode_dict, num_classes=num_classes,)]
|
||||
# 更新auxiliary_head
|
||||
auxiliary_head = update_list_dict_var(auxiliary_head, auxiliary_head_new)
|
||||
|
||||
# 3.3.4. 综合model
|
||||
model = dict(
|
||||
data_preprocessor = model_data_preprocessor,
|
||||
decode_head = decode_head,
|
||||
auxiliary_head = auxiliary_head,
|
||||
)
|
||||
|
||||
########### 3.4. 生成optim、param_scheduler部分 ###########
|
||||
optim_wrapper = generate_optim_wrapper()
|
||||
param_scheduler = generate_param_scheduler(train_type, train_time_or_epoch_k)
|
||||
|
||||
########### 4.程序写入 ###########
|
||||
# 4.1. 算法名称解析:bisenetv2【Alg】 _ r50【pretrained模型深度】 - d8【_base_/models】 _ 4【GPU数量】 x b2【Batch Size大小】 - 40k【schedule】 _ cityscapes【数据集】 - 512x1024【crop_size】.py
|
||||
alg_file_name = f"{alg_name}_fcn_{use_sampler_txt+'_' if use_sampler_txt != '' else ''}g{GPU_num}-{schedule_file_name.lstrip('schedule_')}_{dataset_file_name}-{crop_size[0]}x{crop_size[1]}.py"
|
||||
output_configs_alg_my_alg = os.path.join(f'./configs/bisenetv2/', alg_file_name)
|
||||
write_config_to_file(output_configs_alg_my_alg, _base_=_base_, norm_cfg=norm_cfg, crop_size=crop_size, data_preprocessor=data_preprocessor, model=model, optim_wrapper=optim_wrapper, param_scheduler=param_scheduler)
|
||||
|
||||
# 4.2. 将信息临时写入文件
|
||||
with open("_temp_.txt", "w") as file:
|
||||
json.dump({'alg_infos':{'alg_file_name':alg_file_name,'alg_file_pth':output_configs_alg_my_alg}}, file)
|
||||
102
Seg_All_In_One_MMSeg/My_All_In_One/2_Alg_Program/my_ccnet_r50.py
Normal file
102
Seg_All_In_One_MMSeg/My_All_In_One/2_Alg_Program/my_ccnet_r50.py
Normal file
@@ -0,0 +1,102 @@
|
||||
import os, sys, argparse, json
|
||||
# 添加上级目录到 sys.path
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../')))
|
||||
|
||||
from Initial_Alg_Program.Initial_Alg_Select_Tool import select_crop_size, select_pretrained_model, select_test_cfg_slide
|
||||
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_Tool import format_all_data, write_config_to_file
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_0_base_ import generate_base_config
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_1_norm_cfg import generate_norm_cfg
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_2_data_preprocessor import generate_data_preprocessor
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_3_model import generate_model_pretrained, generate_model_backbone, generate_model_data_preprocessor, generate_model_decode_head, generate_model_auxiliary_head, generate_model_train_cfg, generate_model_test_cfg
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_4_optimizer import generate_optim_wrapper, generate_param_scheduler
|
||||
|
||||
if __name__ == '__main__':
|
||||
########### 1.超参数-算法默认参数 ###########
|
||||
alg_file_name = 'ccnet_r50-d8'
|
||||
|
||||
########### 2.参数解析 ###########
|
||||
parser = argparse.ArgumentParser(description="Run algorithm with specified parameters.")
|
||||
|
||||
# 添加命令行参数
|
||||
parser.add_argument('--alg_name', type=str, required=True, help="Algorithm name (e.g., 'ann')")
|
||||
parser.add_argument('--dataset_file_name', type=str, required=True, help="Dataset file name (e.g., 'my_dataset_model')")
|
||||
parser.add_argument('--mean', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
|
||||
parser.add_argument('--std', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
|
||||
parser.add_argument('--dataset_num_classes', type=int, required=True, help="Dataset class num (e.g., 36)")
|
||||
parser.add_argument('--GPU_num', type=int, required=True, help="Num of GPU to train (e.g., 1)")
|
||||
parser.add_argument('--schedule_file_name', type=str, required=True, help="Schedule file name (e.g., 'schedule_4k_check_400')")
|
||||
parser.add_argument('--train_type', type=str, required=True, help="Schedule train type (e.g., Epoch or Iteration)")
|
||||
parser.add_argument('--train_time_or_epoch_k', type=int, required=True, help="Schedule train time k (e.g., 4)")
|
||||
|
||||
# 解析命令行参数
|
||||
args = parser.parse_args()
|
||||
|
||||
# 变量赋值
|
||||
alg_name = args.alg_name
|
||||
dataset_file_name = args.dataset_file_name
|
||||
mean = json.loads(args.mean) # 使用json.loads转换成list
|
||||
std = json.loads(args.std) # 使用json.loads转换成list
|
||||
num_classes = args.dataset_num_classes
|
||||
schedule_file_name = args.schedule_file_name
|
||||
train_type = args.train_type
|
||||
train_time_or_epoch_k = args.train_time_or_epoch_k
|
||||
GPU_num = args.GPU_num
|
||||
|
||||
|
||||
########### 3. 设定参数 ###########
|
||||
# 3.1. 自选参数
|
||||
crop_size = select_crop_size(predefined_options=[(512, 512), (512, 1024), (769, 769)]) # 选择切割大小
|
||||
model_list = ['openmmlab/resnet50_v1c', 'openmmlab/resnet101_v1c']
|
||||
selected_model_name, select_pretrained, pretrained_pth, selected_model_info = select_pretrained_model(model_list=model_list)
|
||||
depth = selected_model_info['depth']
|
||||
|
||||
test_cfg_mode, test_cfg_crop_div_stride = select_test_cfg_slide(crop_size)
|
||||
|
||||
# 3.2. 自动生成参数
|
||||
# decode、auxiliary损失下载方式
|
||||
decode_head_loss_decode_dict=dict(type='DiceLoss', use_sigmoid=False, loss_weight=1.0) # decode损失函数 # TODO
|
||||
auxiliary_head_loss_decode_dict=dict(type='DiceLoss', use_sigmoid=False, loss_weight=0.4) # DiceLoss损失函数 # TODO
|
||||
|
||||
# 如果使用slide模式,则开启align_corners
|
||||
if test_cfg_mode == "slide":
|
||||
align_corners = True
|
||||
test_slide = 'testslide' # 算法名进行标记
|
||||
else:
|
||||
align_corners = False
|
||||
test_slide = 'no_testslide' # 算法名进行标记
|
||||
|
||||
|
||||
########### 3.生成各个部分 ###########
|
||||
_base_ = generate_base_config(alg_file_name=alg_file_name, dataset_file_name=dataset_file_name, schedule_file_name=schedule_file_name)
|
||||
norm_cfg = generate_norm_cfg(GPU_num=GPU_num)
|
||||
data_preprocessor = generate_data_preprocessor(crop_size = crop_size, mean = mean, std = std)
|
||||
|
||||
pretrained = generate_model_pretrained(pretrained_pth = pretrained_pth)
|
||||
backbone = generate_model_backbone(depth=depth)
|
||||
model_data_preprocessor = data_preprocessor # 修改data_preprocessor
|
||||
model_data_preprocessor = generate_model_data_preprocessor(model_data_preprocessor=model_data_preprocessor)
|
||||
decode_head = generate_model_decode_head(num_classes=num_classes, decode_head_loss_decode_dict=decode_head_loss_decode_dict, align_corners=align_corners)
|
||||
auxiliary_head = generate_model_auxiliary_head(num_classes=num_classes, auxiliary_head_loss_decode_dict=auxiliary_head_loss_decode_dict, align_corners=align_corners)
|
||||
# 综合model
|
||||
model = dict(
|
||||
pretrained = pretrained,
|
||||
backbone = backbone,
|
||||
data_preprocessor = model_data_preprocessor,
|
||||
decode_head = decode_head,
|
||||
auxiliary_head = auxiliary_head,
|
||||
)
|
||||
|
||||
test_cfg = generate_model_test_cfg(test_cfg_mode=test_cfg_mode, crop_size=crop_size, test_cfg_crop_div_stride=test_cfg_crop_div_stride)
|
||||
optim_wrapper = generate_optim_wrapper()
|
||||
param_scheduler = generate_param_scheduler(train_type, train_time_or_epoch_k)
|
||||
|
||||
########### 4.程序写入 ###########
|
||||
# 算法名称解析:ccnet【Alg】 _ r50【pretrained模型深度】 - d8【_base_/models】 _ 4【GPU数量】 x b2【Batch Size大小】 - 40k【schedule】 _ cityscapes【数据集】 - 512x1024【crop_size】.py
|
||||
alg_file_name = f"{alg_name}_r{depth}_{'Pre' if select_pretrained else 'NoPre'}_g{GPU_num}-{schedule_file_name.lstrip('schedule_')}_{dataset_file_name}-{crop_size[0]}x{crop_size[1]}-{test_slide}.py"
|
||||
output_configs_alg_my_alg = os.path.join(f'./configs/ccnet/', alg_file_name)
|
||||
write_config_to_file(output_configs_alg_my_alg, _base_=_base_, norm_cfg=norm_cfg, crop_size=crop_size, data_preprocessor=data_preprocessor, model=model, test_cfg=test_cfg, optim_wrapper=optim_wrapper, param_scheduler=param_scheduler)
|
||||
|
||||
# 将信息写入文件
|
||||
with open("_temp_.txt", "w") as file:
|
||||
json.dump({'alg_infos':{'alg_file_name':alg_file_name,'alg_file_pth':output_configs_alg_my_alg}}, file)
|
||||
@@ -0,0 +1,95 @@
|
||||
import os, sys, argparse, json
|
||||
# 添加上级目录到 sys.path
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../')))
|
||||
|
||||
from Initial_Alg_Program.Initial_Alg_Select_Tool import select_crop_size, select_pretrained_model, select_test_cfg_slide
|
||||
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_Tool import format_all_data, write_config_to_file, get_var_from_py_file
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_0_base_ import generate_base_config
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_1_norm_cfg import generate_norm_cfg
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_2_data_preprocessor import generate_data_preprocessor
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_3_model import generate_model_pretrained, generate_model_backbone, generate_model_data_preprocessor, generate_model_decode_head, generate_model_auxiliary_head, generate_model_train_cfg, generate_model_test_cfg
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_4_optimizer import generate_optim_wrapper, generate_param_scheduler
|
||||
|
||||
if __name__ == '__main__':
|
||||
########### 1.超参数-算法默认参数 ###########
|
||||
alg_file_name = 'cgnet'
|
||||
|
||||
########### 2.参数解析 ###########
|
||||
parser = argparse.ArgumentParser(description="Run algorithm with specified parameters.")
|
||||
|
||||
# 添加命令行参数
|
||||
parser.add_argument('--alg_name', type=str, required=True, help="Algorithm name (e.g., 'ann')")
|
||||
parser.add_argument('--dataset_file_name', type=str, required=True, help="Dataset file name (e.g., 'my_dataset_model')")
|
||||
parser.add_argument('--mean', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
|
||||
parser.add_argument('--std', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
|
||||
parser.add_argument('--dataset_num_classes', type=int, required=True, help="Dataset class num (e.g., 36)")
|
||||
parser.add_argument('--GPU_num', type=int, required=True, help="Num of GPU to train (e.g., 1)")
|
||||
parser.add_argument('--schedule_file_name', type=str, required=True, help="Schedule file name (e.g., 'schedule_4k_check_400')")
|
||||
parser.add_argument('--train_type', type=str, required=True, help="Schedule train type (e.g., Epoch or Iteration)")
|
||||
parser.add_argument('--train_time_or_epoch_k', type=int, required=True, help="Schedule train time k (e.g., 4)")
|
||||
|
||||
# 解析命令行参数
|
||||
args = parser.parse_args()
|
||||
|
||||
# 变量赋值
|
||||
alg_name = args.alg_name
|
||||
dataset_file_name = args.dataset_file_name
|
||||
mean = json.loads(args.mean) # 使用json.loads转换成list
|
||||
std = json.loads(args.std) # 使用json.loads转换成list
|
||||
num_classes = args.dataset_num_classes
|
||||
schedule_file_name = args.schedule_file_name
|
||||
train_type = args.train_type
|
||||
train_time_or_epoch_k = args.train_time_or_epoch_k
|
||||
GPU_num = args.GPU_num
|
||||
|
||||
|
||||
########### 3. 设定参数 ###########
|
||||
# 3.1. 自选参数
|
||||
crop_size = select_crop_size(predefined_options=[(680, 680), (512, 1024)]) # 选择切割大小
|
||||
|
||||
# 3.2. 自动生成参数
|
||||
# decode、auxiliary损失下载方式
|
||||
decode_head = get_var_from_py_file(os.path.join('./configs/_base_/models', alg_file_name+'.py'), 'model')['decode_head']
|
||||
|
||||
# Way Ori: loss_decode
|
||||
decode_head_loss_decode_dict = dict(
|
||||
type='CrossEntropyLoss',
|
||||
use_sigmoid=False,
|
||||
loss_weight=1.0,
|
||||
class_weight=[
|
||||
2.5959933, 6.7415504, 3.5354059, 9.8663225, 9.690899, 9.369352,
|
||||
10.289121, 9.953208, 4.3097677, 9.490387, 7.674431, 9.396905,
|
||||
10.347791, 6.3927646, 10.226669, 10.241062, 10.280587,
|
||||
10.396974, 10.055647
|
||||
]
|
||||
)
|
||||
# Way 1: loss_decode
|
||||
decode_head_loss_decode_dict = dict(_delete_=True, type='DiceLoss', use_sigmoid=False, loss_weight=1.0) # decode损失函数 # TODO
|
||||
decode_head['_delete_'] = True
|
||||
decode_head['loss_decode'] = decode_head_loss_decode_dict
|
||||
|
||||
########### 3.生成各个部分 ###########
|
||||
_base_ = generate_base_config(alg_file_name=alg_file_name, dataset_file_name=dataset_file_name, schedule_file_name=schedule_file_name)
|
||||
norm_cfg = generate_norm_cfg(GPU_num=GPU_num)
|
||||
data_preprocessor = generate_data_preprocessor(crop_size = crop_size, mean = mean, std = std)
|
||||
|
||||
model_data_preprocessor = data_preprocessor # 修改data_preprocessor
|
||||
model_data_preprocessor = generate_model_data_preprocessor(model_data_preprocessor=model_data_preprocessor)
|
||||
|
||||
# 综合model
|
||||
model = dict(data_preprocessor = model_data_preprocessor, decode_head=decode_head)
|
||||
|
||||
# test_cfg = generate_model_test_cfg(test_cfg_mode=test_cfg_mode, crop_size=crop_size, test_cfg_crop_div_stride=test_cfg_crop_div_stride)
|
||||
optim_wrapper = generate_optim_wrapper()
|
||||
param_scheduler = generate_param_scheduler(train_type, train_time_or_epoch_k)
|
||||
|
||||
########### 4.程序写入 ###########
|
||||
# 算法名称解析:cgnet【Alg】 _ r50【pretrained模型深度】 - d8【_base_/models】 _ 4【GPU数量】 x b2【Batch Size大小】 - 40k【schedule】 _ cityscapes【数据集】 - 512x1024【crop_size】.py
|
||||
alg_file_name = f"{alg_name}_fcn_g{GPU_num}-{schedule_file_name.lstrip('schedule_')}_{dataset_file_name}-{crop_size[0]}x{crop_size[1]}.py"
|
||||
output_configs_alg_my_alg = os.path.join(f'./configs/cgnet/', alg_file_name)
|
||||
write_config_to_file(output_configs_alg_my_alg, _base_=_base_, norm_cfg=norm_cfg, crop_size=crop_size, data_preprocessor=data_preprocessor, model=model, optim_wrapper=optim_wrapper, param_scheduler=param_scheduler)
|
||||
|
||||
# 将信息写入文件
|
||||
with open("_temp_.txt", "w") as file:
|
||||
json.dump({'alg_infos':{'alg_file_name':alg_file_name,'alg_file_pth':output_configs_alg_my_alg}}, file)
|
||||
102
Seg_All_In_One_MMSeg/My_All_In_One/2_Alg_Program/my_danet_r50.py
Normal file
102
Seg_All_In_One_MMSeg/My_All_In_One/2_Alg_Program/my_danet_r50.py
Normal file
@@ -0,0 +1,102 @@
|
||||
import os, sys, argparse, json
|
||||
# 添加上级目录到 sys.path
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../')))
|
||||
|
||||
from Initial_Alg_Program.Initial_Alg_Select_Tool import select_crop_size, select_pretrained_model, select_test_cfg_slide
|
||||
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_Tool import format_all_data, write_config_to_file
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_0_base_ import generate_base_config
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_1_norm_cfg import generate_norm_cfg
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_2_data_preprocessor import generate_data_preprocessor
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_3_model import generate_model_pretrained, generate_model_backbone, generate_model_data_preprocessor, generate_model_decode_head, generate_model_auxiliary_head, generate_model_train_cfg, generate_model_test_cfg
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_4_optimizer import generate_optim_wrapper, generate_param_scheduler
|
||||
|
||||
if __name__ == '__main__':
|
||||
########### 1.超参数-算法默认参数 ###########
|
||||
alg_file_name = 'danet_r50-d8'
|
||||
|
||||
########### 2.参数解析 ###########
|
||||
parser = argparse.ArgumentParser(description="Run algorithm with specified parameters.")
|
||||
|
||||
# 添加命令行参数
|
||||
parser.add_argument('--alg_name', type=str, required=True, help="Algorithm name (e.g., 'ann')")
|
||||
parser.add_argument('--dataset_file_name', type=str, required=True, help="Dataset file name (e.g., 'my_dataset_model')")
|
||||
parser.add_argument('--mean', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
|
||||
parser.add_argument('--std', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
|
||||
parser.add_argument('--dataset_num_classes', type=int, required=True, help="Dataset class num (e.g., 36)")
|
||||
parser.add_argument('--GPU_num', type=int, required=True, help="Num of GPU to train (e.g., 1)")
|
||||
parser.add_argument('--schedule_file_name', type=str, required=True, help="Schedule file name (e.g., 'schedule_4k_check_400')")
|
||||
parser.add_argument('--train_type', type=str, required=True, help="Schedule train type (e.g., Epoch or Iteration)")
|
||||
parser.add_argument('--train_time_or_epoch_k', type=int, required=True, help="Schedule train time k (e.g., 4)")
|
||||
|
||||
# 解析命令行参数
|
||||
args = parser.parse_args()
|
||||
|
||||
# 变量赋值
|
||||
alg_name = args.alg_name
|
||||
dataset_file_name = args.dataset_file_name
|
||||
mean = json.loads(args.mean) # 使用json.loads转换成list
|
||||
std = json.loads(args.std) # 使用json.loads转换成list
|
||||
num_classes = args.dataset_num_classes
|
||||
schedule_file_name = args.schedule_file_name
|
||||
train_type = args.train_type
|
||||
train_time_or_epoch_k = args.train_time_or_epoch_k
|
||||
GPU_num = args.GPU_num
|
||||
|
||||
|
||||
########### 3. 设定参数 ###########
|
||||
# 3.1. 自选参数
|
||||
crop_size = select_crop_size(predefined_options=[(512, 512), (512, 1024), (769, 769)]) # 选择切割大小
|
||||
model_list = ['openmmlab/resnet50_v1c', 'openmmlab/resnet101_v1c']
|
||||
selected_model_name, select_pretrained, pretrained_pth, selected_model_info = select_pretrained_model(model_list=model_list)
|
||||
depth = selected_model_info['depth']
|
||||
|
||||
test_cfg_mode, test_cfg_crop_div_stride = select_test_cfg_slide(crop_size)
|
||||
|
||||
# 3.2. 自动生成参数
|
||||
# decode、auxiliary损失下载方式
|
||||
decode_head_loss_decode_dict=dict(type='DiceLoss', use_sigmoid=False, loss_weight=1.0) # decode损失函数 # TODO
|
||||
auxiliary_head_loss_decode_dict=dict(type='DiceLoss', use_sigmoid=False, loss_weight=0.4) # DiceLoss损失函数 # TODO
|
||||
|
||||
# 如果使用slide模式,则开启align_corners
|
||||
if test_cfg_mode == "slide":
|
||||
align_corners = True
|
||||
test_slide = 'testslide' # 算法名进行标记
|
||||
else:
|
||||
align_corners = False
|
||||
test_slide = 'no_testslide' # 算法名进行标记
|
||||
|
||||
|
||||
########### 3.生成各个部分 ###########
|
||||
_base_ = generate_base_config(alg_file_name=alg_file_name, dataset_file_name=dataset_file_name, schedule_file_name=schedule_file_name)
|
||||
norm_cfg = generate_norm_cfg(GPU_num=GPU_num)
|
||||
data_preprocessor = generate_data_preprocessor(crop_size = crop_size, mean = mean, std = std)
|
||||
|
||||
pretrained = generate_model_pretrained(pretrained_pth = pretrained_pth)
|
||||
backbone = generate_model_backbone(depth=depth)
|
||||
model_data_preprocessor = data_preprocessor # 修改data_preprocessor
|
||||
model_data_preprocessor = generate_model_data_preprocessor(model_data_preprocessor=model_data_preprocessor)
|
||||
decode_head = generate_model_decode_head(num_classes=num_classes, decode_head_loss_decode_dict=decode_head_loss_decode_dict, align_corners=align_corners)
|
||||
auxiliary_head = generate_model_auxiliary_head(num_classes=num_classes, auxiliary_head_loss_decode_dict=auxiliary_head_loss_decode_dict, align_corners=align_corners)
|
||||
# 综合model
|
||||
model = dict(
|
||||
pretrained = pretrained,
|
||||
backbone = backbone,
|
||||
data_preprocessor = model_data_preprocessor,
|
||||
decode_head = decode_head,
|
||||
auxiliary_head = auxiliary_head,
|
||||
)
|
||||
|
||||
test_cfg = generate_model_test_cfg(test_cfg_mode=test_cfg_mode, crop_size=crop_size, test_cfg_crop_div_stride=test_cfg_crop_div_stride)
|
||||
optim_wrapper = generate_optim_wrapper()
|
||||
param_scheduler = generate_param_scheduler(train_type, train_time_or_epoch_k)
|
||||
|
||||
########### 4.程序写入 ###########
|
||||
# 算法名称解析:danet【Alg】 _ r50【pretrained模型深度】 - d8【_base_/models】 _ 4【GPU数量】 x b2【Batch Size大小】 - 40k【schedule】 _ cityscapes【数据集】 - 512x1024【crop_size】.py
|
||||
alg_file_name = f"{alg_name}_r{depth}_{'Pre' if select_pretrained else 'NoPre'}_g{GPU_num}-{schedule_file_name.lstrip('schedule_')}_{dataset_file_name}-{crop_size[0]}x{crop_size[1]}-{test_slide}.py"
|
||||
output_configs_alg_my_alg = os.path.join(f'./configs/danet/', alg_file_name)
|
||||
write_config_to_file(output_configs_alg_my_alg, _base_=_base_, norm_cfg=norm_cfg, crop_size=crop_size, data_preprocessor=data_preprocessor, model=model, test_cfg=test_cfg, optim_wrapper=optim_wrapper, param_scheduler=param_scheduler)
|
||||
|
||||
# 将信息写入文件
|
||||
with open("_temp_.txt", "w") as file:
|
||||
json.dump({'alg_infos':{'alg_file_name':alg_file_name,'alg_file_pth':output_configs_alg_my_alg}}, file)
|
||||
@@ -0,0 +1,95 @@
|
||||
import os, sys, argparse, json
|
||||
import importlib.util
|
||||
# 添加上级目录到 sys.path
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../')))
|
||||
|
||||
from Initial_Alg_Program.Initial_Alg_Select_Tool import select_crop_size, select_pretrained_model, select_sampler
|
||||
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_Tool import format_all_data, write_config_to_file, create_dict_by_kwargs, get_var_from_file, update_list_dict_var, get_var_from_py_file
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_0_base_ import generate_base_config
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_1_norm_cfg import generate_norm_cfg
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_2_data_preprocessor import generate_data_preprocessor
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_3_model import generate_model_pretrained, generate_model_backbone, generate_model_data_preprocessor, generate_model_decode_head, generate_model_auxiliary_head, generate_model_train_cfg, generate_model_test_cfg
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_4_optimizer import generate_optim_wrapper, generate_param_scheduler
|
||||
|
||||
if __name__ == '__main__':
|
||||
########### 1.超参数-算法默认参数 ###########
|
||||
alg_file_name = 'ddrnet'
|
||||
alg_file_pth = f"configs/_base_/models/{alg_file_name}.py"
|
||||
|
||||
########### 2.参数解析 ###########
|
||||
parser = argparse.ArgumentParser(description="Run algorithm with specified parameters.")
|
||||
|
||||
# 添加命令行参数
|
||||
parser.add_argument('--alg_name', type=str, required=True, help="Algorithm name (e.g., 'ann')")
|
||||
parser.add_argument('--dataset_file_name', type=str, required=True, help="Dataset file name (e.g., 'my_dataset_model')")
|
||||
parser.add_argument('--mean', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
|
||||
parser.add_argument('--std', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
|
||||
parser.add_argument('--dataset_num_classes', type=int, required=True, help="Dataset class num (e.g., 36)")
|
||||
parser.add_argument('--GPU_num', type=int, required=True, help="Num of GPU to train (e.g., 1)")
|
||||
parser.add_argument('--schedule_file_name', type=str, required=True, help="Schedule file name (e.g., 'schedule_4k_check_400')")
|
||||
parser.add_argument('--train_type', type=str, required=True, help="Schedule train type (e.g., Epoch or Iteration)")
|
||||
parser.add_argument('--train_time_or_epoch_k', type=int, required=True, help="Schedule train time k (e.g., 4)")
|
||||
|
||||
# 解析命令行参数
|
||||
args = parser.parse_args()
|
||||
|
||||
# 变量赋值
|
||||
alg_name = args.alg_name
|
||||
dataset_file_name = args.dataset_file_name
|
||||
mean = json.loads(args.mean) # 使用json.loads转换成list
|
||||
std = json.loads(args.std) # 使用json.loads转换成list
|
||||
num_classes = args.dataset_num_classes
|
||||
schedule_file_name = args.schedule_file_name
|
||||
train_type = args.train_type
|
||||
train_time_or_epoch_k = args.train_time_or_epoch_k
|
||||
GPU_num = args.GPU_num
|
||||
|
||||
########### 3.1. 生成_base_部分 ###########
|
||||
_base_ = generate_base_config(alg_file_name=alg_file_name, dataset_file_name=dataset_file_name, schedule_file_name=schedule_file_name)
|
||||
|
||||
########### 3.1. 生成norm_cfg部分 ###########
|
||||
norm_cfg = generate_norm_cfg(GPU_num=GPU_num)
|
||||
|
||||
########### 3.2. 生成data_preprocessor部分 ###########
|
||||
crop_size = select_crop_size(predefined_options=[(512, 512)]) # 选择切割大小
|
||||
data_preprocessor = generate_data_preprocessor(crop_size = crop_size, mean = mean, std = std)
|
||||
|
||||
########### 3.3. 生成model部分 ###########
|
||||
# 3.3.1. 预处理data_preprocessor
|
||||
model_data_preprocessor = data_preprocessor # 修改data_preprocessor
|
||||
|
||||
# # 3.3.2. 骨架backbone、解码器decode_head
|
||||
model_list = ["openmmlab/ddrnet23-s", "openmmlab/ddrnet23"]
|
||||
selected_model_name, select_pretrained, pretrained_pth, selected_model_info = select_pretrained_model(model_list=model_list)
|
||||
if selected_model_name == "openmmlab/ddrnet23-s":
|
||||
backbone = dict(channels=32, init_cfg=dict(type='Pretrained', checkpoint=pretrained_pth))
|
||||
decode_head = dict(in_channels=32 * 4, channels=64, num_classes=num_classes)
|
||||
model_size = 'small'
|
||||
elif selected_model_name == "openmmlab/ddrnet23":
|
||||
backbone = dict(channels=64, init_cfg=dict(type='Pretrained', checkpoint=pretrained_pth))
|
||||
decode_head = dict(in_channels=64 * 4, channels=128, num_classes=num_classes)
|
||||
model_size = 'normal'
|
||||
else:
|
||||
quit("Error: 未知的模型名称")
|
||||
|
||||
# 3.3.4. 综合model
|
||||
model = dict(
|
||||
data_preprocessor = model_data_preprocessor,
|
||||
backbone = backbone,
|
||||
decode_head = decode_head,
|
||||
)
|
||||
|
||||
########### 3.4. 生成optim、param_scheduler部分 ###########
|
||||
optim_wrapper = generate_optim_wrapper()
|
||||
param_scheduler = generate_param_scheduler(train_type, train_time_or_epoch_k)
|
||||
|
||||
########### 4.程序写入 ###########
|
||||
# 4.1. 算法名称解析:bisenetv2【Alg】 _ r50【pretrained模型深度】 - d8【_base_/models】 _ 4【GPU数量】 x b2【Batch Size大小】 - 40k【schedule】 _ cityscapes【数据集】 - 512x1024【crop_size】.py
|
||||
alg_file_name = f"{alg_name}_{model_size}_g{GPU_num}-{schedule_file_name.lstrip('schedule_')}_{dataset_file_name}-{crop_size[0]}x{crop_size[1]}.py"
|
||||
output_configs_alg_my_alg = os.path.join(f'./configs/ddrnet/', alg_file_name)
|
||||
write_config_to_file(output_configs_alg_my_alg, _base_=_base_, norm_cfg=norm_cfg, crop_size=crop_size, data_preprocessor=data_preprocessor, model=model, optim_wrapper=optim_wrapper, param_scheduler=param_scheduler)
|
||||
|
||||
# 4.2. 将信息临时写入文件
|
||||
with open("_temp_.txt", "w") as file:
|
||||
json.dump({'alg_infos':{'alg_file_name':alg_file_name,'alg_file_pth':output_configs_alg_my_alg}}, file)
|
||||
@@ -0,0 +1,135 @@
|
||||
import os, sys, argparse, json
|
||||
# 添加上级目录到 sys.path
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../')))
|
||||
|
||||
from Initial_Alg_Program.Initial_Alg_Select_Tool import select_crop_size, select_pretrained_model, select_test_cfg_slide, select_sampler
|
||||
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_Tool import format_all_data, write_config_to_file, create_dict_by_kwargs
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_0_base_ import generate_base_config
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_1_norm_cfg import generate_norm_cfg
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_2_data_preprocessor import generate_data_preprocessor
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_3_model import generate_model_pretrained, generate_model_backbone, generate_model_data_preprocessor, generate_model_decode_head, generate_model_auxiliary_head, generate_model_train_cfg, generate_model_test_cfg
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_4_optimizer import generate_optim_wrapper, generate_param_scheduler
|
||||
|
||||
if __name__ == '__main__':
|
||||
########### 1.超参数-算法默认参数 ###########
|
||||
alg_file_name = 'deeplabv3_r50-d8'
|
||||
|
||||
########### 2.参数解析 ###########
|
||||
parser = argparse.ArgumentParser(description="Run algorithm with specified parameters.")
|
||||
|
||||
# 添加命令行参数
|
||||
parser.add_argument('--alg_name', type=str, required=True, help="Algorithm name (e.g., 'ann')")
|
||||
parser.add_argument('--dataset_file_name', type=str, required=True, help="Dataset file name (e.g., 'my_dataset_model')")
|
||||
parser.add_argument('--mean', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
|
||||
parser.add_argument('--std', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
|
||||
parser.add_argument('--dataset_num_classes', type=int, required=True, help="Dataset class num (e.g., 36)")
|
||||
parser.add_argument('--GPU_num', type=int, required=True, help="Num of GPU to train (e.g., 1)")
|
||||
parser.add_argument('--schedule_file_name', type=str, required=True, help="Schedule file name (e.g., 'schedule_4k_check_400')")
|
||||
parser.add_argument('--train_type', type=str, required=True, help="Schedule train type (e.g., Epoch or Iteration)")
|
||||
parser.add_argument('--train_time_or_epoch_k', type=int, required=True, help="Schedule train time k (e.g., 4)")
|
||||
|
||||
# 解析命令行参数
|
||||
args = parser.parse_args()
|
||||
|
||||
# 变量赋值
|
||||
alg_name = args.alg_name
|
||||
dataset_file_name = args.dataset_file_name
|
||||
mean = json.loads(args.mean) # 使用json.loads转换成list
|
||||
std = json.loads(args.std) # 使用json.loads转换成list
|
||||
num_classes = args.dataset_num_classes
|
||||
schedule_file_name = args.schedule_file_name
|
||||
train_type = args.train_type
|
||||
train_time_or_epoch_k = args.train_time_or_epoch_k
|
||||
GPU_num = args.GPU_num
|
||||
|
||||
|
||||
########### 3. 设定参数 ###########
|
||||
# 3.1. 自选参数
|
||||
crop_size = select_crop_size(predefined_options=[(512, 512), (512, 1024), (769, 769),(1280,1280)]) # 选择切割大小
|
||||
model_list = ['openmmlab/resnet18_v1c', 'openmmlab/resnet50_v1c', 'openmmlab/resnet101_v1c', 'torchvision://resnet18', 'torchvision://resnet50', 'torchvision://resnet101', ]
|
||||
selected_model_name, select_pretrained, pretrained_pth, selected_model_info = select_pretrained_model(model_list=model_list)
|
||||
depth = selected_model_info['depth']
|
||||
type_model = selected_model_info['type']
|
||||
|
||||
test_cfg_mode, test_cfg_crop_div_stride = select_test_cfg_slide(crop_size)
|
||||
|
||||
# 3.2. 自动生成参数
|
||||
# decode、auxiliary损失下载方式
|
||||
decode_head_loss_decode_dict=dict(type='DiceLoss', use_sigmoid=False, loss_weight=1.0) # decode损失函数 # TODO
|
||||
auxiliary_head_loss_decode_dict=dict(type='DiceLoss', use_sigmoid=False, loss_weight=0.4) # DiceLoss损失函数 # TODO
|
||||
|
||||
# 如果使用slide模式,则开启align_corners
|
||||
if test_cfg_mode == "slide":
|
||||
align_corners = True
|
||||
test_slide = 'testslide' # 算法名进行标记
|
||||
else:
|
||||
align_corners = False
|
||||
test_slide = 'no_testslide' # 算法名进行标记
|
||||
|
||||
# 3.3.2. 采样函数sampler
|
||||
selected_sampler_name, use_sampler, selected_sampler_info = select_sampler(sampler_list=['OHEMPixelSampler'])
|
||||
if use_sampler == True:
|
||||
decode_head_sampler=selected_sampler_info
|
||||
auxiliary_head_sampler=selected_sampler_info
|
||||
use_sampler_txt = 'ohempSampler' # 标记
|
||||
backbone_dilations=(1, 1, 1, 2)
|
||||
backbone_strides=(1, 2, 2, 1)
|
||||
backbone_multi_grid=(1, 2, 4)
|
||||
decode_head_dilations=(1, 6, 12, 18)
|
||||
else:
|
||||
decode_head_sampler=None
|
||||
auxiliary_head_sampler=None
|
||||
use_sampler_txt = ''
|
||||
backbone_dilations=None # (1, 1, 2, 4),
|
||||
backbone_strides=None # (1, 2, 1, 1)
|
||||
backbone_multi_grid=None # 没有
|
||||
decode_head_dilations=None # (1, 12, 24, 36),
|
||||
|
||||
|
||||
########### 3.生成各个部分 ###########
|
||||
_base_ = generate_base_config(alg_file_name=alg_file_name, dataset_file_name=dataset_file_name, schedule_file_name=schedule_file_name)
|
||||
|
||||
norm_cfg = generate_norm_cfg(GPU_num=GPU_num)
|
||||
|
||||
data_preprocessor = generate_data_preprocessor(crop_size = crop_size, mean = mean, std = std)
|
||||
|
||||
pretrained = generate_model_pretrained(pretrained_pth = pretrained_pth)
|
||||
|
||||
backbone = create_dict_by_kwargs(depth=depth, type=type_model, dilations=backbone_dilations, strides=backbone_strides, multi_grid=backbone_multi_grid) # generate_model_backbone(depth=depth,)
|
||||
|
||||
model_data_preprocessor = data_preprocessor # 修改data_preprocessor
|
||||
model_data_preprocessor = generate_model_data_preprocessor(model_data_preprocessor=model_data_preprocessor)
|
||||
|
||||
if depth == 18:
|
||||
decode_head = create_dict_by_kwargs(sampler=decode_head_sampler, dilations=decode_head_dilations, in_channels=512, channels=128, num_classes=num_classes, loss_decode=decode_head_loss_decode_dict, align_corners=align_corners, norm_cfg=norm_cfg,)
|
||||
elif depth == 101 or depth == 50:
|
||||
decode_head = create_dict_by_kwargs(sampler=decode_head_sampler, dilations=decode_head_dilations, in_channels=2048, channels=512, num_classes=num_classes, loss_decode=decode_head_loss_decode_dict, align_corners=align_corners, norm_cfg=norm_cfg,)
|
||||
|
||||
if depth == 18:
|
||||
auxiliary_head = create_dict_by_kwargs(in_channels=256, channels=64, num_classes=num_classes, loss_decode=decode_head_loss_decode_dict, align_corners=align_corners, norm_cfg=norm_cfg,)
|
||||
elif depth == 101 or depth == 50:
|
||||
auxiliary_head = create_dict_by_kwargs(in_channels=1024, channels=256, num_classes=num_classes, loss_decode=decode_head_loss_decode_dict, align_corners=align_corners, norm_cfg=norm_cfg,)
|
||||
|
||||
# 综合model
|
||||
model = dict(
|
||||
pretrained = pretrained,
|
||||
backbone = backbone,
|
||||
data_preprocessor = model_data_preprocessor,
|
||||
decode_head = decode_head,
|
||||
auxiliary_head = auxiliary_head,
|
||||
)
|
||||
|
||||
test_cfg = generate_model_test_cfg(test_cfg_mode=test_cfg_mode, crop_size=crop_size, test_cfg_crop_div_stride=test_cfg_crop_div_stride)
|
||||
optim_wrapper = generate_optim_wrapper()
|
||||
param_scheduler = generate_param_scheduler(train_type, train_time_or_epoch_k)
|
||||
|
||||
########### 4.程序写入 ###########
|
||||
# 算法名称解析:deeplabv3【Alg】 _ r50【pretrained模型深度】 - d8【_base_/models】 _ 4【GPU数量】 x b2【Batch Size大小】 - 40k【schedule】 _ cityscapes【数据集】 - 512x1024【crop_size】.py
|
||||
alg_file_name = f"{alg_name}_r{depth}_{'Pre' if select_pretrained else 'NoPre'}_{use_sampler_txt+'_' if use_sampler_txt != '' else ''}g{GPU_num}-{schedule_file_name.lstrip('schedule_')}_{dataset_file_name}-{crop_size[0]}x{crop_size[1]}-{test_slide}.py"
|
||||
output_configs_alg_my_alg = os.path.join(f'./configs/deeplabv3/', alg_file_name)
|
||||
write_config_to_file(output_configs_alg_my_alg, _base_=_base_, norm_cfg=norm_cfg, crop_size=crop_size, data_preprocessor=data_preprocessor, model=model, test_cfg=test_cfg, optim_wrapper=optim_wrapper, param_scheduler=param_scheduler)
|
||||
|
||||
# 将信息写入文件
|
||||
with open("_temp_.txt", "w") as file:
|
||||
json.dump({'alg_infos':{'alg_file_name':alg_file_name,'alg_file_pth':output_configs_alg_my_alg}}, file)
|
||||
@@ -0,0 +1,135 @@
|
||||
import os, sys, argparse, json
|
||||
# 添加上级目录到 sys.path
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../')))
|
||||
|
||||
from Initial_Alg_Program.Initial_Alg_Select_Tool import select_crop_size, select_pretrained_model, select_test_cfg_slide, select_sampler
|
||||
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_Tool import format_all_data, write_config_to_file, create_dict_by_kwargs
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_0_base_ import generate_base_config
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_1_norm_cfg import generate_norm_cfg
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_2_data_preprocessor import generate_data_preprocessor
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_3_model import generate_model_pretrained, generate_model_backbone, generate_model_data_preprocessor, generate_model_decode_head, generate_model_auxiliary_head, generate_model_train_cfg, generate_model_test_cfg
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_4_optimizer import generate_optim_wrapper, generate_param_scheduler
|
||||
|
||||
if __name__ == '__main__':
|
||||
########### 1.超参数-算法默认参数 ###########
|
||||
alg_file_name = 'deeplabv3plus_r50-d8'
|
||||
|
||||
########### 2.参数解析 ###########
|
||||
parser = argparse.ArgumentParser(description="Run algorithm with specified parameters.")
|
||||
|
||||
# 添加命令行参数
|
||||
parser.add_argument('--alg_name', type=str, required=True, help="Algorithm name (e.g., 'ann')")
|
||||
parser.add_argument('--dataset_file_name', type=str, required=True, help="Dataset file name (e.g., 'my_dataset_model')")
|
||||
parser.add_argument('--mean', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
|
||||
parser.add_argument('--std', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
|
||||
parser.add_argument('--dataset_num_classes', type=int, required=True, help="Dataset class num (e.g., 36)")
|
||||
parser.add_argument('--GPU_num', type=int, required=True, help="Num of GPU to train (e.g., 1)")
|
||||
parser.add_argument('--schedule_file_name', type=str, required=True, help="Schedule file name (e.g., 'schedule_4k_check_400')")
|
||||
parser.add_argument('--train_type', type=str, required=True, help="Schedule train type (e.g., Epoch or Iteration)")
|
||||
parser.add_argument('--train_time_or_epoch_k', type=int, required=True, help="Schedule train time k (e.g., 4)")
|
||||
|
||||
# 解析命令行参数
|
||||
args = parser.parse_args()
|
||||
|
||||
# 变量赋值
|
||||
alg_name = args.alg_name
|
||||
dataset_file_name = args.dataset_file_name
|
||||
mean = json.loads(args.mean) # 使用json.loads转换成list
|
||||
std = json.loads(args.std) # 使用json.loads转换成list
|
||||
num_classes = args.dataset_num_classes
|
||||
schedule_file_name = args.schedule_file_name
|
||||
train_type = args.train_type
|
||||
train_time_or_epoch_k = args.train_time_or_epoch_k
|
||||
GPU_num = args.GPU_num
|
||||
|
||||
|
||||
########### 3. 设定参数 ###########
|
||||
# 3.1. 自选参数
|
||||
crop_size = select_crop_size(predefined_options=[(512, 512), (512, 1024), (769, 769),(1280,1280)]) # 选择切割大小
|
||||
model_list = ['openmmlab/resnet18_v1c', 'openmmlab/resnet50_v1c', 'openmmlab/resnet101_v1c', 'torchvision://resnet18', 'torchvision://resnet50', 'torchvision://resnet101', ]
|
||||
selected_model_name, select_pretrained, pretrained_pth, selected_model_info = select_pretrained_model(model_list=model_list)
|
||||
depth = selected_model_info['depth']
|
||||
type_model = selected_model_info['type']
|
||||
|
||||
test_cfg_mode, test_cfg_crop_div_stride = select_test_cfg_slide(crop_size)
|
||||
|
||||
# 3.2. 自动生成参数
|
||||
# decode、auxiliary损失下载方式
|
||||
decode_head_loss_decode_dict=dict(type='DiceLoss', use_sigmoid=False, loss_weight=1.0) # decode损失函数 # TODO
|
||||
auxiliary_head_loss_decode_dict=dict(type='DiceLoss', use_sigmoid=False, loss_weight=0.4) # DiceLoss损失函数 # TODO
|
||||
|
||||
# 如果使用slide模式,则开启align_corners
|
||||
if test_cfg_mode == "slide":
|
||||
align_corners = True
|
||||
test_slide = 'testslide' # 算法名进行标记
|
||||
else:
|
||||
align_corners = False
|
||||
test_slide = 'no_testslide' # 算法名进行标记
|
||||
|
||||
# 3.3.2. 采样函数sampler
|
||||
selected_sampler_name, use_sampler, selected_sampler_info = select_sampler(sampler_list=['OHEMPixelSampler'])
|
||||
if use_sampler == True:
|
||||
decode_head_sampler=selected_sampler_info
|
||||
auxiliary_head_sampler=selected_sampler_info
|
||||
use_sampler_txt = 'ohempSampler' # 标记
|
||||
backbone_dilations=(1, 1, 1, 2)
|
||||
backbone_strides=(1, 2, 2, 1)
|
||||
backbone_multi_grid=(1, 2, 4)
|
||||
decode_head_dilations=(1, 6, 12, 18)
|
||||
else:
|
||||
decode_head_sampler=None
|
||||
auxiliary_head_sampler=None
|
||||
use_sampler_txt = ''
|
||||
backbone_dilations=None # (1, 1, 2, 4),
|
||||
backbone_strides=None # (1, 2, 1, 1)
|
||||
backbone_multi_grid=None # 没有
|
||||
decode_head_dilations=None # (1, 12, 24, 36),
|
||||
|
||||
|
||||
########### 3.生成各个部分 ###########
|
||||
_base_ = generate_base_config(alg_file_name=alg_file_name, dataset_file_name=dataset_file_name, schedule_file_name=schedule_file_name)
|
||||
|
||||
norm_cfg = generate_norm_cfg(GPU_num=GPU_num)
|
||||
|
||||
data_preprocessor = generate_data_preprocessor(crop_size = crop_size, mean = mean, std = std)
|
||||
|
||||
pretrained = generate_model_pretrained(pretrained_pth = pretrained_pth)
|
||||
|
||||
backbone = create_dict_by_kwargs(depth=depth, type=type_model, dilations=backbone_dilations, strides=backbone_strides, multi_grid=backbone_multi_grid) # generate_model_backbone(depth=depth,)
|
||||
|
||||
model_data_preprocessor = data_preprocessor # 修改data_preprocessor
|
||||
model_data_preprocessor = generate_model_data_preprocessor(model_data_preprocessor=model_data_preprocessor)
|
||||
|
||||
if depth == 18:
|
||||
decode_head = create_dict_by_kwargs(sampler=decode_head_sampler, dilations=decode_head_dilations, c1_in_channels=64, c1_channels=12, in_channels=512, channels=128, num_classes=num_classes, loss_decode=decode_head_loss_decode_dict, align_corners=align_corners, norm_cfg=norm_cfg,)
|
||||
elif depth == 101 or depth == 50:
|
||||
decode_head = create_dict_by_kwargs(sampler=decode_head_sampler, dilations=decode_head_dilations, c1_in_channels=256, c1_channels=48, in_channels=2048, channels=512, num_classes=num_classes, loss_decode=decode_head_loss_decode_dict, align_corners=align_corners, norm_cfg=norm_cfg,)
|
||||
|
||||
if depth == 18:
|
||||
auxiliary_head = create_dict_by_kwargs(in_channels=256, channels=64, num_classes=num_classes, loss_decode=decode_head_loss_decode_dict, align_corners=align_corners, norm_cfg=norm_cfg,)
|
||||
elif depth == 101 or depth == 50:
|
||||
auxiliary_head = create_dict_by_kwargs(in_channels=1024, channels=256, num_classes=num_classes, loss_decode=decode_head_loss_decode_dict, align_corners=align_corners, norm_cfg=norm_cfg,)
|
||||
|
||||
# 综合model
|
||||
model = dict(
|
||||
pretrained = pretrained,
|
||||
backbone = backbone,
|
||||
data_preprocessor = model_data_preprocessor,
|
||||
decode_head = decode_head,
|
||||
auxiliary_head = auxiliary_head,
|
||||
)
|
||||
|
||||
test_cfg = generate_model_test_cfg(test_cfg_mode=test_cfg_mode, crop_size=crop_size, test_cfg_crop_div_stride=test_cfg_crop_div_stride)
|
||||
optim_wrapper = generate_optim_wrapper()
|
||||
param_scheduler = generate_param_scheduler(train_type, train_time_or_epoch_k)
|
||||
|
||||
########### 4.程序写入 ###########
|
||||
# 算法名称解析:deeplabv3plus【Alg】 _ r50【pretrained模型深度】 - d8【_base_/models】 _ 4【GPU数量】 x b2【Batch Size大小】 - 40k【schedule】 _ cityscapes【数据集】 - 512x1024【crop_size】.py
|
||||
alg_file_name = f"{alg_name}_r{depth}_{'Pre' if select_pretrained else 'NoPre'}_{use_sampler_txt+'_' if use_sampler_txt != '' else ''}g{GPU_num}-{schedule_file_name.lstrip('schedule_')}_{dataset_file_name}-{crop_size[0]}x{crop_size[1]}-{test_slide}.py"
|
||||
output_configs_alg_my_alg = os.path.join(f'./configs/deeplabv3plus/', alg_file_name)
|
||||
write_config_to_file(output_configs_alg_my_alg, _base_=_base_, norm_cfg=norm_cfg, crop_size=crop_size, data_preprocessor=data_preprocessor, model=model, test_cfg=test_cfg, optim_wrapper=optim_wrapper, param_scheduler=param_scheduler)
|
||||
|
||||
# 将信息写入文件
|
||||
with open("_temp_.txt", "w") as file:
|
||||
json.dump({'alg_infos':{'alg_file_name':alg_file_name,'alg_file_pth':output_configs_alg_my_alg}}, file)
|
||||
@@ -0,0 +1,102 @@
|
||||
import os, sys, argparse, json
|
||||
# 添加上级目录到 sys.path
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../')))
|
||||
|
||||
from Initial_Alg_Program.Initial_Alg_Select_Tool import select_crop_size, select_pretrained_model, select_test_cfg_slide
|
||||
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_Tool import format_all_data, write_config_to_file
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_0_base_ import generate_base_config
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_1_norm_cfg import generate_norm_cfg
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_2_data_preprocessor import generate_data_preprocessor
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_3_model import generate_model_pretrained, generate_model_backbone, generate_model_data_preprocessor, generate_model_decode_head, generate_model_auxiliary_head, generate_model_train_cfg, generate_model_test_cfg
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_4_optimizer import generate_optim_wrapper, generate_param_scheduler
|
||||
|
||||
if __name__ == '__main__':
|
||||
########### 1.超参数-算法默认参数 ###########
|
||||
alg_file_name = 'dnl_r50-d8'
|
||||
|
||||
########### 2.参数解析 ###########
|
||||
parser = argparse.ArgumentParser(description="Run algorithm with specified parameters.")
|
||||
|
||||
# 添加命令行参数
|
||||
parser.add_argument('--alg_name', type=str, required=True, help="Algorithm name (e.g., 'ann')")
|
||||
parser.add_argument('--dataset_file_name', type=str, required=True, help="Dataset file name (e.g., 'my_dataset_model')")
|
||||
parser.add_argument('--mean', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
|
||||
parser.add_argument('--std', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
|
||||
parser.add_argument('--dataset_num_classes', type=int, required=True, help="Dataset class num (e.g., 36)")
|
||||
parser.add_argument('--GPU_num', type=int, required=True, help="Num of GPU to train (e.g., 1)")
|
||||
parser.add_argument('--schedule_file_name', type=str, required=True, help="Schedule file name (e.g., 'schedule_4k_check_400')")
|
||||
parser.add_argument('--train_type', type=str, required=True, help="Schedule train type (e.g., Epoch or Iteration)")
|
||||
parser.add_argument('--train_time_or_epoch_k', type=int, required=True, help="Schedule train time k (e.g., 4)")
|
||||
|
||||
# 解析命令行参数
|
||||
args = parser.parse_args()
|
||||
|
||||
# 变量赋值
|
||||
alg_name = args.alg_name
|
||||
dataset_file_name = args.dataset_file_name
|
||||
mean = json.loads(args.mean) # 使用json.loads转换成list
|
||||
std = json.loads(args.std) # 使用json.loads转换成list
|
||||
num_classes = args.dataset_num_classes
|
||||
schedule_file_name = args.schedule_file_name
|
||||
train_type = args.train_type
|
||||
train_time_or_epoch_k = args.train_time_or_epoch_k
|
||||
GPU_num = args.GPU_num
|
||||
|
||||
|
||||
########### 3. 设定参数 ###########
|
||||
# 3.1. 自选参数
|
||||
crop_size = select_crop_size(predefined_options=[(512, 512), (512, 1024), (769, 769)]) # 选择切割大小
|
||||
model_list = ['openmmlab/resnet50_v1c', 'openmmlab/resnet101_v1c']
|
||||
selected_model_name, select_pretrained, pretrained_pth, selected_model_info = select_pretrained_model(model_list=model_list)
|
||||
depth = selected_model_info['depth']
|
||||
|
||||
test_cfg_mode, test_cfg_crop_div_stride = select_test_cfg_slide(crop_size)
|
||||
|
||||
# 3.2. 自动生成参数
|
||||
# decode、auxiliary损失下载方式
|
||||
decode_head_loss_decode_dict=dict(type='DiceLoss', use_sigmoid=False, loss_weight=1.0) # decode损失函数 # TODO
|
||||
auxiliary_head_loss_decode_dict=dict(type='DiceLoss', use_sigmoid=False, loss_weight=0.4) # DiceLoss损失函数 # TODO
|
||||
|
||||
# 如果使用slide模式,则开启align_corners
|
||||
if test_cfg_mode == "slide":
|
||||
align_corners = True
|
||||
test_slide = 'testslide' # 算法名进行标记
|
||||
else:
|
||||
align_corners = False
|
||||
test_slide = 'no_testslide' # 算法名进行标记
|
||||
|
||||
|
||||
########### 3.生成各个部分 ###########
|
||||
_base_ = generate_base_config(alg_file_name=alg_file_name, dataset_file_name=dataset_file_name, schedule_file_name=schedule_file_name)
|
||||
norm_cfg = generate_norm_cfg(GPU_num=GPU_num)
|
||||
data_preprocessor = generate_data_preprocessor(crop_size = crop_size, mean = mean, std = std)
|
||||
|
||||
pretrained = generate_model_pretrained(pretrained_pth = pretrained_pth)
|
||||
backbone = generate_model_backbone(depth=depth)
|
||||
model_data_preprocessor = data_preprocessor # 修改data_preprocessor
|
||||
model_data_preprocessor = generate_model_data_preprocessor(model_data_preprocessor=model_data_preprocessor)
|
||||
decode_head = generate_model_decode_head(num_classes=num_classes, decode_head_loss_decode_dict=decode_head_loss_decode_dict, align_corners=align_corners)
|
||||
auxiliary_head = generate_model_auxiliary_head(num_classes=num_classes, auxiliary_head_loss_decode_dict=auxiliary_head_loss_decode_dict, align_corners=align_corners)
|
||||
# 综合model
|
||||
model = dict(
|
||||
pretrained = pretrained,
|
||||
backbone = backbone,
|
||||
data_preprocessor = model_data_preprocessor,
|
||||
decode_head = decode_head,
|
||||
auxiliary_head = auxiliary_head,
|
||||
)
|
||||
|
||||
test_cfg = generate_model_test_cfg(test_cfg_mode=test_cfg_mode, crop_size=crop_size, test_cfg_crop_div_stride=test_cfg_crop_div_stride)
|
||||
optim_wrapper = generate_optim_wrapper()
|
||||
param_scheduler = generate_param_scheduler(train_type, train_time_or_epoch_k)
|
||||
|
||||
########### 4.程序写入 ###########
|
||||
# 算法名称解析:dnlnet【Alg】 _ r50【pretrained模型深度】 - d8【_base_/models】 _ 4【GPU数量】 x b2【Batch Size大小】 - 40k【schedule】 _ cityscapes【数据集】 - 512x1024【crop_size】.py
|
||||
alg_file_name = f"{alg_name}_r{depth}_{'Pre' if select_pretrained else 'NoPre'}_g{GPU_num}-{schedule_file_name.lstrip('schedule_')}_{dataset_file_name}-{crop_size[0]}x{crop_size[1]}-{test_slide}.py"
|
||||
output_configs_alg_my_alg = os.path.join(f'./configs/dnlnet/', alg_file_name)
|
||||
write_config_to_file(output_configs_alg_my_alg, _base_=_base_, norm_cfg=norm_cfg, crop_size=crop_size, data_preprocessor=data_preprocessor, model=model, test_cfg=test_cfg, optim_wrapper=optim_wrapper, param_scheduler=param_scheduler)
|
||||
|
||||
# 将信息写入文件
|
||||
with open("_temp_.txt", "w") as file:
|
||||
json.dump({'alg_infos':{'alg_file_name':alg_file_name,'alg_file_pth':output_configs_alg_my_alg}}, file)
|
||||
100
Seg_All_In_One_MMSeg/My_All_In_One/2_Alg_Program/my_dpt_vit.py
Normal file
100
Seg_All_In_One_MMSeg/My_All_In_One/2_Alg_Program/my_dpt_vit.py
Normal file
@@ -0,0 +1,100 @@
|
||||
import os, sys, argparse, json
|
||||
# 添加上级目录到 sys.path
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../')))
|
||||
|
||||
from Initial_Alg_Program.Initial_Alg_Select_Tool import select_crop_size, select_pretrained_model, select_sampler
|
||||
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_Tool import format_all_data, write_config_to_file, create_dict_by_kwargs, get_var_from_file, update_list_dict_var, get_var_from_py_file
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_0_base_ import generate_base_config
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_1_norm_cfg import generate_norm_cfg
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_2_data_preprocessor import generate_data_preprocessor
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_3_model import generate_model_pretrained, generate_model_backbone, generate_model_data_preprocessor, generate_model_decode_head, generate_model_auxiliary_head, generate_model_train_cfg, generate_model_test_cfg
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_4_optimizer import generate_optim_wrapper, generate_param_scheduler
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_5_train_dataloader import generate_train_dataloader
|
||||
|
||||
if __name__ == '__main__':
|
||||
########### 1.超参数-算法默认参数 ###########
|
||||
alg_file_name = 'dpt_vit-b16'
|
||||
alg_file_pth = f"configs/_base_/models/{alg_file_name}.py"
|
||||
|
||||
########### 2.参数解析 ###########
|
||||
parser = argparse.ArgumentParser(description="Run algorithm with specified parameters.")
|
||||
|
||||
# 添加命令行参数
|
||||
parser.add_argument('--alg_name', type=str, required=True, help="Algorithm name (e.g., 'ann')")
|
||||
parser.add_argument('--dataset_file_name', type=str, required=True, help="Dataset file name (e.g., 'my_dataset_model')")
|
||||
parser.add_argument('--mean', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
|
||||
parser.add_argument('--std', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
|
||||
parser.add_argument('--dataset_num_classes', type=int, required=True, help="Dataset class num (e.g., 36)")
|
||||
parser.add_argument('--GPU_num', type=int, required=True, help="Num of GPU to train (e.g., 1)")
|
||||
parser.add_argument('--schedule_file_name', type=str, required=True, help="Schedule file name (e.g., 'schedule_4k_check_400')")
|
||||
parser.add_argument('--train_type', type=str, required=True, help="Schedule train type (e.g., Epoch or Iteration)")
|
||||
parser.add_argument('--train_time_or_epoch_k', type=int, required=True, help="Schedule train time k (e.g., 4)")
|
||||
|
||||
# 解析命令行参数
|
||||
args = parser.parse_args()
|
||||
|
||||
# 变量赋值
|
||||
alg_name = args.alg_name
|
||||
dataset_file_name = args.dataset_file_name
|
||||
mean = json.loads(args.mean) # 使用json.loads转换成list
|
||||
std = json.loads(args.std) # 使用json.loads转换成list
|
||||
num_classes = args.dataset_num_classes
|
||||
schedule_file_name = args.schedule_file_name
|
||||
train_type = args.train_type
|
||||
train_time_or_epoch_k = args.train_time_or_epoch_k
|
||||
GPU_num = args.GPU_num
|
||||
|
||||
########### 3.1. 生成_base_部分 ###########
|
||||
_base_ = generate_base_config(alg_file_name=alg_file_name, dataset_file_name=dataset_file_name, schedule_file_name=schedule_file_name)
|
||||
|
||||
########### 3.1. 生成norm_cfg部分 ###########
|
||||
norm_cfg = generate_norm_cfg(GPU_num=GPU_num)
|
||||
|
||||
########### 3.2. 生成data_preprocessor部分 ###########
|
||||
crop_size = select_crop_size(predefined_options=[(1024, 1024)]) # 选择切割大小
|
||||
data_preprocessor = generate_data_preprocessor(crop_size = crop_size, mean = mean, std = std)
|
||||
|
||||
########### 3.3. 生成model部分 ###########
|
||||
# 3.3.1. 预处理data_preprocessor
|
||||
model_data_preprocessor = data_preprocessor # 修改data_preprocessor
|
||||
|
||||
model_list = ['pretrain/vit-b16_p16_224-80ecf9dd.pth']
|
||||
selected_model_name, select_pretrained, pretrained_pth, selected_model_info = select_pretrained_model(model_list=model_list)
|
||||
|
||||
pretrained = generate_model_pretrained(pretrained_pth = pretrained_pth)
|
||||
model_data_preprocessor = data_preprocessor
|
||||
|
||||
# 3.3.2. 解码器decode_head
|
||||
# decode、auxiliary损失下载方式 TODO
|
||||
decode_head_loss_decode_dict=dict(type='DiceLoss', use_sigmoid=False, loss_weight=1.0) # decode损失函数 # TODO
|
||||
# decode_head_loss_decode_dict=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0) # 默认 TODO
|
||||
decode_head = create_dict_by_kwargs(loss_decode=decode_head_loss_decode_dict, num_classes=num_classes)
|
||||
|
||||
# 3.3.3. 辅助部分auxiliary_head [空]
|
||||
auxiliary_head = None
|
||||
|
||||
# 3.3.4. 综合model
|
||||
model = dict(
|
||||
data_preprocessor = model_data_preprocessor,
|
||||
pretrained = pretrained,
|
||||
decode_head = decode_head,
|
||||
# auxiliary_head = auxiliary_head,
|
||||
)
|
||||
|
||||
########### 3.4. 生成optim、param_scheduler部分 ###########
|
||||
optim_wrapper = generate_optim_wrapper(type_of_back_bone = "Vit")
|
||||
param_scheduler = generate_param_scheduler(train_type, train_time_or_epoch_k)
|
||||
|
||||
########### 3.5. 生成train_dataloader部分[1个batch_size就要20G] ###########
|
||||
train_dataloader, batch_size = generate_train_dataloader(batch_size_default=1)
|
||||
|
||||
########### 4.程序写入 ###########
|
||||
# 4.1. 算法名称解析:dpt_vit【Alg】 _ r50【pretrained模型深度】 - d8【_base_/models】 _ 4【GPU数量】 x b2【Batch Size大小】 - 40k【schedule】 _ cityscapes【数据集】 - 512x1024【crop_size】.py
|
||||
alg_file_name = f"{alg_name}_g{GPU_num}_b{batch_size}-{schedule_file_name.lstrip('schedule_')}_{dataset_file_name}-{crop_size[0]}x{crop_size[1]}.py"
|
||||
output_configs_alg_my_alg = os.path.join(f'./configs/dpt/', alg_file_name)
|
||||
write_config_to_file(output_configs_alg_my_alg, _base_=_base_, norm_cfg=norm_cfg, crop_size=crop_size, data_preprocessor=data_preprocessor, model=model, optim_wrapper=optim_wrapper, param_scheduler=param_scheduler, train_dataloader=train_dataloader)
|
||||
|
||||
# 4.2. 将信息临时写入文件
|
||||
with open("_temp_.txt", "w") as file:
|
||||
json.dump({'alg_infos':{'alg_file_name':alg_file_name,'alg_file_pth':output_configs_alg_my_alg}}, file)
|
||||
@@ -0,0 +1,106 @@
|
||||
import os, sys, argparse, json
|
||||
# 添加上级目录到 sys.path
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../')))
|
||||
|
||||
from Initial_Alg_Program.Initial_Alg_Select_Tool import select_crop_size, select_pretrained_model, select_test_cfg_slide
|
||||
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_Tool import format_all_data, write_config_to_file
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_0_base_ import generate_base_config
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_1_norm_cfg import generate_norm_cfg
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_2_data_preprocessor import generate_data_preprocessor
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_3_model import generate_model_pretrained, generate_model_backbone, generate_model_data_preprocessor, generate_model_decode_head, generate_model_auxiliary_head, generate_model_train_cfg, generate_model_test_cfg
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_4_optimizer import generate_optim_wrapper, generate_param_scheduler
|
||||
|
||||
if __name__ == '__main__':
|
||||
########### 1.超参数-算法默认参数 ###########
|
||||
alg_file_name = 'emanet_r50-d8'
|
||||
|
||||
########### 2.参数解析 ###########
|
||||
parser = argparse.ArgumentParser(description="Run algorithm with specified parameters.")
|
||||
|
||||
# 添加命令行参数
|
||||
parser.add_argument('--alg_name', type=str, required=True, help="Algorithm name (e.g., 'ann')")
|
||||
parser.add_argument('--dataset_file_name', type=str, required=True, help="Dataset file name (e.g., 'my_dataset_model')")
|
||||
parser.add_argument('--mean', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
|
||||
parser.add_argument('--std', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
|
||||
parser.add_argument('--dataset_num_classes', type=int, required=True, help="Dataset class num (e.g., 36)")
|
||||
parser.add_argument('--GPU_num', type=int, required=True, help="Num of GPU to train (e.g., 1)")
|
||||
parser.add_argument('--schedule_file_name', type=str, required=True, help="Schedule file name (e.g., 'schedule_4k_check_400')")
|
||||
parser.add_argument('--train_type', type=str, required=True, help="Schedule train type (e.g., Epoch or Iteration)")
|
||||
parser.add_argument('--train_time_or_epoch_k', type=int, required=True, help="Schedule train time k (e.g., 4)")
|
||||
|
||||
# 解析命令行参数
|
||||
args = parser.parse_args()
|
||||
|
||||
# 变量赋值
|
||||
alg_name = args.alg_name
|
||||
dataset_file_name = args.dataset_file_name
|
||||
mean = json.loads(args.mean) # 使用json.loads转换成list
|
||||
std = json.loads(args.std) # 使用json.loads转换成list
|
||||
num_classes = args.dataset_num_classes
|
||||
schedule_file_name = args.schedule_file_name
|
||||
train_type = args.train_type
|
||||
train_time_or_epoch_k = args.train_time_or_epoch_k
|
||||
GPU_num = args.GPU_num
|
||||
|
||||
|
||||
########### 3. 设定参数 ###########
|
||||
# 3.1. 自选参数
|
||||
crop_size = select_crop_size(predefined_options=[(512, 512), (512, 1024), (769, 769)]) # 选择切割大小
|
||||
model_list = ['openmmlab/resnet50_v1c', 'openmmlab/resnet101_v1c']
|
||||
selected_model_name, select_pretrained, pretrained_pth, selected_model_info = select_pretrained_model(model_list=model_list)
|
||||
depth = selected_model_info['depth']
|
||||
|
||||
test_cfg_mode, test_cfg_crop_div_stride = select_test_cfg_slide(crop_size)
|
||||
# test_cfg_mode = None
|
||||
# test_cfg_crop_div_stride = crop_size
|
||||
|
||||
# 3.2. 自动生成参数
|
||||
# decode、auxiliary损失下载方式
|
||||
decode_head_loss_decode_dict=dict(type='DiceLoss', use_sigmoid=False, loss_weight=1.0) # decode损失函数 # TODO
|
||||
# decode_head_loss_decode_dict=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0) # decode损失函数 # TODO
|
||||
auxiliary_head_loss_decode_dict=dict(type='DiceLoss', use_sigmoid=False, loss_weight=0.4) # DiceLoss损失函数 # TODO
|
||||
# auxiliary_head_loss_decode_dict=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4) # DiceLoss损失函数 # TODO
|
||||
|
||||
# 如果使用slide模式,则开启align_corners
|
||||
if test_cfg_mode == "slide":
|
||||
align_corners = True
|
||||
test_slide = 'testslide' # 算法名进行标记
|
||||
else:
|
||||
align_corners = False
|
||||
test_slide = 'no_testslide' # 算法名进行标记
|
||||
|
||||
|
||||
########### 3.生成各个部分 ###########
|
||||
_base_ = generate_base_config(alg_file_name=alg_file_name, dataset_file_name=dataset_file_name, schedule_file_name=schedule_file_name)
|
||||
norm_cfg = generate_norm_cfg(GPU_num=GPU_num)
|
||||
data_preprocessor = generate_data_preprocessor(crop_size = crop_size, mean = mean, std = std)
|
||||
|
||||
pretrained = generate_model_pretrained(pretrained_pth = pretrained_pth)
|
||||
backbone = generate_model_backbone(depth=depth)
|
||||
model_data_preprocessor = data_preprocessor # 修改data_preprocessor
|
||||
model_data_preprocessor = generate_model_data_preprocessor(model_data_preprocessor=model_data_preprocessor)
|
||||
decode_head = generate_model_decode_head(num_classes=num_classes, decode_head_loss_decode_dict=decode_head_loss_decode_dict, align_corners=align_corners)
|
||||
auxiliary_head = generate_model_auxiliary_head(num_classes=num_classes, auxiliary_head_loss_decode_dict=auxiliary_head_loss_decode_dict, align_corners=align_corners)
|
||||
# 综合model
|
||||
model = dict(
|
||||
pretrained = pretrained,
|
||||
backbone = backbone,
|
||||
data_preprocessor = model_data_preprocessor,
|
||||
decode_head = decode_head,
|
||||
auxiliary_head = auxiliary_head,
|
||||
)
|
||||
|
||||
test_cfg = generate_model_test_cfg(test_cfg_mode=test_cfg_mode, crop_size=crop_size, test_cfg_crop_div_stride=test_cfg_crop_div_stride)
|
||||
optim_wrapper = generate_optim_wrapper()
|
||||
param_scheduler = generate_param_scheduler(train_type, train_time_or_epoch_k)
|
||||
|
||||
########### 4.程序写入 ###########
|
||||
# 算法名称解析:emanet【Alg】 _ r50【pretrained模型深度】 - d8【_base_/models】 _ 4【GPU数量】 x b2【Batch Size大小】 - 40k【schedule】 _ cityscapes【数据集】 - 512x1024【crop_size】.py
|
||||
alg_file_name = f"{alg_name}_r{depth}_{'Pre' if select_pretrained else 'NoPre'}_g{GPU_num}-{schedule_file_name.lstrip('schedule_')}_{dataset_file_name}-{crop_size[0]}x{crop_size[1]}-{test_slide}.py"
|
||||
output_configs_alg_my_alg = os.path.join(f'./configs/emanet/', alg_file_name)
|
||||
write_config_to_file(output_configs_alg_my_alg, _base_=_base_, norm_cfg=norm_cfg, crop_size=crop_size, data_preprocessor=data_preprocessor, model=model, test_cfg=test_cfg, optim_wrapper=optim_wrapper, param_scheduler=param_scheduler)
|
||||
|
||||
# 将信息写入文件
|
||||
with open("_temp_.txt", "w") as file:
|
||||
json.dump({'alg_infos':{'alg_file_name':alg_file_name,'alg_file_pth':output_configs_alg_my_alg}}, file)
|
||||
@@ -0,0 +1,106 @@
|
||||
import os, sys, argparse, json
|
||||
# 添加上级目录到 sys.path
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../')))
|
||||
|
||||
from Initial_Alg_Program.Initial_Alg_Select_Tool import select_crop_size, select_pretrained_model, select_test_cfg_slide
|
||||
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_Tool import format_all_data, write_config_to_file
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_0_base_ import generate_base_config
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_1_norm_cfg import generate_norm_cfg
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_2_data_preprocessor import generate_data_preprocessor
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_3_model import generate_model_pretrained, generate_model_backbone, generate_model_data_preprocessor, generate_model_decode_head, generate_model_auxiliary_head, generate_model_train_cfg, generate_model_test_cfg
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_4_optimizer import generate_optim_wrapper, generate_param_scheduler
|
||||
|
||||
if __name__ == '__main__':
|
||||
########### 1.超参数-算法默认参数 ###########
|
||||
alg_file_name = 'encnet_r50-d8'
|
||||
|
||||
########### 2.参数解析 ###########
|
||||
parser = argparse.ArgumentParser(description="Run algorithm with specified parameters.")
|
||||
|
||||
# 添加命令行参数
|
||||
parser.add_argument('--alg_name', type=str, required=True, help="Algorithm name (e.g., 'ann')")
|
||||
parser.add_argument('--dataset_file_name', type=str, required=True, help="Dataset file name (e.g., 'my_dataset_model')")
|
||||
parser.add_argument('--mean', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
|
||||
parser.add_argument('--std', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
|
||||
parser.add_argument('--dataset_num_classes', type=int, required=True, help="Dataset class num (e.g., 36)")
|
||||
parser.add_argument('--GPU_num', type=int, required=True, help="Num of GPU to train (e.g., 1)")
|
||||
parser.add_argument('--schedule_file_name', type=str, required=True, help="Schedule file name (e.g., 'schedule_4k_check_400')")
|
||||
parser.add_argument('--train_type', type=str, required=True, help="Schedule train type (e.g., Epoch or Iteration)")
|
||||
parser.add_argument('--train_time_or_epoch_k', type=int, required=True, help="Schedule train time k (e.g., 4)")
|
||||
|
||||
# 解析命令行参数
|
||||
args = parser.parse_args()
|
||||
|
||||
# 变量赋值
|
||||
alg_name = args.alg_name
|
||||
dataset_file_name = args.dataset_file_name
|
||||
mean = json.loads(args.mean) # 使用json.loads转换成list
|
||||
std = json.loads(args.std) # 使用json.loads转换成list
|
||||
num_classes = args.dataset_num_classes
|
||||
schedule_file_name = args.schedule_file_name
|
||||
train_type = args.train_type
|
||||
train_time_or_epoch_k = args.train_time_or_epoch_k
|
||||
GPU_num = args.GPU_num
|
||||
|
||||
|
||||
########### 3. 设定参数 ###########
|
||||
# 3.1. 自选参数
|
||||
crop_size = select_crop_size(predefined_options=[(512, 512), (512, 1024), (769, 769)]) # 选择切割大小
|
||||
model_list = ['openmmlab/resnet50_v1c', 'openmmlab/resnet101_v1c']
|
||||
selected_model_name, select_pretrained, pretrained_pth, selected_model_info = select_pretrained_model(model_list=model_list)
|
||||
depth = selected_model_info['depth']
|
||||
|
||||
test_cfg_mode, test_cfg_crop_div_stride = select_test_cfg_slide(crop_size)
|
||||
# test_cfg_mode = None
|
||||
# test_cfg_crop_div_stride = crop_size
|
||||
|
||||
# 3.2. 自动生成参数
|
||||
# decode、auxiliary损失下载方式
|
||||
decode_head_loss_decode_dict=dict(type='DiceLoss', use_sigmoid=False, loss_weight=1.0) # decode损失函数 # TODO
|
||||
# decode_head_loss_decode_dict=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0) # decode损失函数 # TODO
|
||||
auxiliary_head_loss_decode_dict=dict(type='DiceLoss', use_sigmoid=False, loss_weight=0.4) # DiceLoss损失函数 # TODO
|
||||
# auxiliary_head_loss_decode_dict=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4) # DiceLoss损失函数 # TODO
|
||||
|
||||
# 如果使用slide模式,则开启align_corners
|
||||
if test_cfg_mode == "slide":
|
||||
align_corners = True
|
||||
test_slide = 'testslide' # 算法名进行标记
|
||||
else:
|
||||
align_corners = False
|
||||
test_slide = 'no_testslide' # 算法名进行标记
|
||||
|
||||
|
||||
########### 3.生成各个部分 ###########
|
||||
_base_ = generate_base_config(alg_file_name=alg_file_name, dataset_file_name=dataset_file_name, schedule_file_name=schedule_file_name)
|
||||
norm_cfg = generate_norm_cfg(GPU_num=GPU_num)
|
||||
data_preprocessor = generate_data_preprocessor(crop_size = crop_size, mean = mean, std = std)
|
||||
|
||||
pretrained = generate_model_pretrained(pretrained_pth = pretrained_pth)
|
||||
backbone = generate_model_backbone(depth=depth)
|
||||
model_data_preprocessor = data_preprocessor # 修改data_preprocessor
|
||||
model_data_preprocessor = generate_model_data_preprocessor(model_data_preprocessor=model_data_preprocessor)
|
||||
decode_head = generate_model_decode_head(num_classes=num_classes, decode_head_loss_decode_dict=decode_head_loss_decode_dict, align_corners=align_corners)
|
||||
auxiliary_head = generate_model_auxiliary_head(num_classes=num_classes, auxiliary_head_loss_decode_dict=auxiliary_head_loss_decode_dict, align_corners=align_corners)
|
||||
# 综合model
|
||||
model = dict(
|
||||
pretrained = pretrained,
|
||||
backbone = backbone,
|
||||
data_preprocessor = model_data_preprocessor,
|
||||
decode_head = decode_head,
|
||||
auxiliary_head = auxiliary_head,
|
||||
)
|
||||
|
||||
test_cfg = generate_model_test_cfg(test_cfg_mode=test_cfg_mode, crop_size=crop_size, test_cfg_crop_div_stride=test_cfg_crop_div_stride)
|
||||
optim_wrapper = generate_optim_wrapper()
|
||||
param_scheduler = generate_param_scheduler(train_type, train_time_or_epoch_k)
|
||||
|
||||
########### 4.程序写入 ###########
|
||||
# 算法名称解析:encnet【Alg】 _ r50【pretrained模型深度】 - d8【_base_/models】 _ 4【GPU数量】 x b2【Batch Size大小】 - 40k【schedule】 _ cityscapes【数据集】 - 512x1024【crop_size】.py
|
||||
alg_file_name = f"{alg_name}_r{depth}_{'Pre' if select_pretrained else 'NoPre'}_g{GPU_num}-{schedule_file_name.lstrip('schedule_')}_{dataset_file_name}-{crop_size[0]}x{crop_size[1]}-{test_slide}.py"
|
||||
output_configs_alg_my_alg = os.path.join(f'./configs/encnet/', alg_file_name)
|
||||
write_config_to_file(output_configs_alg_my_alg, _base_=_base_, norm_cfg=norm_cfg, crop_size=crop_size, data_preprocessor=data_preprocessor, model=model, test_cfg=test_cfg, optim_wrapper=optim_wrapper, param_scheduler=param_scheduler)
|
||||
|
||||
# 将信息写入文件
|
||||
with open("_temp_.txt", "w") as file:
|
||||
json.dump({'alg_infos':{'alg_file_name':alg_file_name,'alg_file_pth':output_configs_alg_my_alg}}, file)
|
||||
@@ -0,0 +1,83 @@
|
||||
import os, sys, argparse, json
|
||||
# 添加上级目录到 sys.path
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../')))
|
||||
|
||||
from Initial_Alg_Program.Initial_Alg_Select_Tool import select_crop_size, select_pretrained_model, select_test_cfg_slide
|
||||
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_Tool import format_all_data, write_config_to_file
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_0_base_ import generate_base_config
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_1_norm_cfg import generate_norm_cfg
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_2_data_preprocessor import generate_data_preprocessor
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_3_model import generate_model_pretrained, generate_model_backbone, generate_model_data_preprocessor, generate_model_decode_head, generate_model_auxiliary_head, generate_model_train_cfg, generate_model_test_cfg
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_4_optimizer import generate_optim_wrapper, generate_param_scheduler
|
||||
|
||||
if __name__ == '__main__':
|
||||
########### 1.超参数-算法默认参数 ###########
|
||||
alg_file_name = 'erfnet_fcn'
|
||||
|
||||
########### 2.参数解析 ###########
|
||||
parser = argparse.ArgumentParser(description="Run algorithm with specified parameters.")
|
||||
|
||||
# 添加命令行参数
|
||||
parser.add_argument('--alg_name', type=str, required=True, help="Algorithm name (e.g., 'ann')")
|
||||
parser.add_argument('--dataset_file_name', type=str, required=True, help="Dataset file name (e.g., 'my_dataset_model')")
|
||||
parser.add_argument('--mean', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
|
||||
parser.add_argument('--std', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
|
||||
parser.add_argument('--dataset_num_classes', type=int, required=True, help="Dataset class num (e.g., 36)")
|
||||
parser.add_argument('--GPU_num', type=int, required=True, help="Num of GPU to train (e.g., 1)")
|
||||
parser.add_argument('--schedule_file_name', type=str, required=True, help="Schedule file name (e.g., 'schedule_4k_check_400')")
|
||||
parser.add_argument('--train_type', type=str, required=True, help="Schedule train type (e.g., Epoch or Iteration)")
|
||||
parser.add_argument('--train_time_or_epoch_k', type=int, required=True, help="Schedule train time k (e.g., 4)")
|
||||
|
||||
# 解析命令行参数
|
||||
args = parser.parse_args()
|
||||
|
||||
# 变量赋值
|
||||
alg_name = args.alg_name
|
||||
dataset_file_name = args.dataset_file_name
|
||||
mean = json.loads(args.mean) # 使用json.loads转换成list
|
||||
std = json.loads(args.std) # 使用json.loads转换成list
|
||||
num_classes = args.dataset_num_classes
|
||||
schedule_file_name = args.schedule_file_name
|
||||
train_type = args.train_type
|
||||
train_time_or_epoch_k = args.train_time_or_epoch_k
|
||||
GPU_num = args.GPU_num
|
||||
|
||||
|
||||
########### 3. 设定参数 ###########
|
||||
# 3.1. 自选参数
|
||||
crop_size = select_crop_size(predefined_options=[(512,1024), (512, 512)]) # 选择切割大小
|
||||
|
||||
# 3.2. 自动生成参数
|
||||
# decode、auxiliary损失下载方式
|
||||
# decode_head_loss_decode_dict = dict(type='CrossEntropyLoss',
|
||||
# use_sigmoid=False,
|
||||
# loss_weight=1.0)
|
||||
decode_head_loss_decode_dict=dict(type='DiceLoss', use_sigmoid=False, loss_weight=1.0) # decode损失函数 # TODO
|
||||
decode_head=dict(loss_decode=decode_head_loss_decode_dict)
|
||||
|
||||
|
||||
########### 3.生成各个部分 ###########
|
||||
_base_ = generate_base_config(alg_file_name=alg_file_name, dataset_file_name=dataset_file_name, schedule_file_name=schedule_file_name)
|
||||
norm_cfg = generate_norm_cfg(GPU_num=GPU_num)
|
||||
data_preprocessor = generate_data_preprocessor(crop_size = crop_size, mean = mean, std = std)
|
||||
|
||||
model_data_preprocessor = data_preprocessor # 修改data_preprocessor
|
||||
model_data_preprocessor = generate_model_data_preprocessor(model_data_preprocessor=model_data_preprocessor)
|
||||
|
||||
# 综合model
|
||||
model = dict(data_preprocessor = model_data_preprocessor, decode_head=decode_head)
|
||||
|
||||
# test_cfg = generate_model_test_cfg(test_cfg_mode=test_cfg_mode, crop_size=crop_size, test_cfg_crop_div_stride=test_cfg_crop_div_stride)
|
||||
optim_wrapper = generate_optim_wrapper()
|
||||
param_scheduler = generate_param_scheduler(train_type, train_time_or_epoch_k)
|
||||
|
||||
########### 4.程序写入 ###########
|
||||
# 算法名称解析:erfnet【Alg】 _ r50【pretrained模型深度】 - d8【_base_/models】 _ 4【GPU数量】 x b2【Batch Size大小】 - 40k【schedule】 _ cityscapes【数据集】 - 512x1024【crop_size】.py
|
||||
alg_file_name = f"{alg_name}_fcn_g{GPU_num}-{schedule_file_name.lstrip('schedule_')}_{dataset_file_name}-{crop_size[0]}x{crop_size[1]}.py"
|
||||
output_configs_alg_my_alg = os.path.join(f'./configs/erfnet/', alg_file_name)
|
||||
write_config_to_file(output_configs_alg_my_alg, _base_=_base_, norm_cfg=norm_cfg, crop_size=crop_size, data_preprocessor=data_preprocessor, model=model, optim_wrapper=optim_wrapper, param_scheduler=param_scheduler)
|
||||
|
||||
# 将信息写入文件
|
||||
with open("_temp_.txt", "w") as file:
|
||||
json.dump({'alg_infos':{'alg_file_name':alg_file_name,'alg_file_pth':output_configs_alg_my_alg}}, file)
|
||||
@@ -0,0 +1,89 @@
|
||||
import os, sys, argparse, json
|
||||
# 添加上级目录到 sys.path
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../')))
|
||||
|
||||
from Initial_Alg_Program.Initial_Alg_Select_Tool import select_crop_size, select_pretrained_model, select_test_cfg_slide
|
||||
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_Tool import format_all_data, write_config_to_file, get_var_from_py_file
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_0_base_ import generate_base_config
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_1_norm_cfg import generate_norm_cfg
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_2_data_preprocessor import generate_data_preprocessor
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_3_model import generate_model_pretrained, generate_model_backbone, generate_model_data_preprocessor, generate_model_decode_head, generate_model_auxiliary_head, generate_model_train_cfg, generate_model_test_cfg
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_4_optimizer import generate_optim_wrapper, generate_param_scheduler
|
||||
|
||||
if __name__ == '__main__':
|
||||
########### 1.超参数-算法默认参数 ###########
|
||||
alg_file_name = 'fast_scnn'
|
||||
|
||||
########### 2.参数解析 ###########
|
||||
parser = argparse.ArgumentParser(description="Run algorithm with specified parameters.")
|
||||
|
||||
# 添加命令行参数
|
||||
parser.add_argument('--alg_name', type=str, required=True, help="Algorithm name (e.g., 'ann')")
|
||||
parser.add_argument('--dataset_file_name', type=str, required=True, help="Dataset file name (e.g., 'my_dataset_model')")
|
||||
parser.add_argument('--mean', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
|
||||
parser.add_argument('--std', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
|
||||
parser.add_argument('--dataset_num_classes', type=int, required=True, help="Dataset class num (e.g., 36)")
|
||||
parser.add_argument('--GPU_num', type=int, required=True, help="Num of GPU to train (e.g., 1)")
|
||||
parser.add_argument('--schedule_file_name', type=str, required=True, help="Schedule file name (e.g., 'schedule_4k_check_400')")
|
||||
parser.add_argument('--train_type', type=str, required=True, help="Schedule train type (e.g., Epoch or Iteration)")
|
||||
parser.add_argument('--train_time_or_epoch_k', type=int, required=True, help="Schedule train time k (e.g., 4)")
|
||||
|
||||
# 解析命令行参数
|
||||
args = parser.parse_args()
|
||||
|
||||
# 变量赋值
|
||||
alg_name = args.alg_name
|
||||
dataset_file_name = args.dataset_file_name
|
||||
mean = json.loads(args.mean) # 使用json.loads转换成list
|
||||
std = json.loads(args.std) # 使用json.loads转换成list
|
||||
num_classes = args.dataset_num_classes
|
||||
schedule_file_name = args.schedule_file_name
|
||||
train_type = args.train_type
|
||||
train_time_or_epoch_k = args.train_time_or_epoch_k
|
||||
GPU_num = args.GPU_num
|
||||
|
||||
|
||||
########### 3. 设定参数 ###########
|
||||
# 3.1. 自选参数
|
||||
crop_size = select_crop_size(predefined_options=[(512,1024), (512, 512)]) # 选择切割大小
|
||||
|
||||
########### 3.生成各个部分 ###########
|
||||
# 3.1. base、norm_cfg、data_preprocessor部分
|
||||
_base_ = generate_base_config(alg_file_name=alg_file_name, dataset_file_name=dataset_file_name, schedule_file_name=schedule_file_name)
|
||||
norm_cfg = generate_norm_cfg(GPU_num=GPU_num)
|
||||
data_preprocessor = generate_data_preprocessor(crop_size = crop_size, mean = mean, std = std)
|
||||
|
||||
# 3.2. model部分
|
||||
model_data_preprocessor = data_preprocessor # 修改data_preprocessor
|
||||
model_data_preprocessor = generate_model_data_preprocessor(model_data_preprocessor=model_data_preprocessor)
|
||||
|
||||
# decode损失
|
||||
decode_head_loss_decode_dict = dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)
|
||||
decode_head_loss_decode_dict=dict(type='DiceLoss', use_sigmoid=False, loss_weight=1.0) # decode损失函数 # TODO
|
||||
decode_head=dict(loss_decode=decode_head_loss_decode_dict)
|
||||
|
||||
# auxiliary损失
|
||||
auxiliary_head = get_var_from_py_file(os.path.join('./configs/_base_/models', alg_file_name+'.py'), 'model')['auxiliary_head']
|
||||
for i in range(len(auxiliary_head)):
|
||||
auxiliary_head[i]['loss_decode']['use_sigmoid']=False
|
||||
auxiliary_head[i]['loss_decode']['type']='DiceLoss'
|
||||
auxiliary_head[i]['num_classes']=num_classes
|
||||
auxiliary_head[i]['norm_cfg']=norm_cfg
|
||||
|
||||
# 综合model
|
||||
model = dict(data_preprocessor = model_data_preprocessor, decode_head = decode_head, auxiliary_head=auxiliary_head)
|
||||
|
||||
# 3.3. optim_wrapper、param_scheduler
|
||||
optim_wrapper = generate_optim_wrapper()
|
||||
param_scheduler = generate_param_scheduler(train_type, train_time_or_epoch_k)
|
||||
|
||||
########### 4.程序写入 ###########
|
||||
# 算法名称解析:fastscnn【Alg】 _ r50【pretrained模型深度】 - d8【_base_/models】 _ 4【GPU数量】 x b2【Batch Size大小】 - 40k【schedule】 _ cityscapes【数据集】 - 512x1024【crop_size】.py
|
||||
alg_file_name = f"{alg_name}_g{GPU_num}-{schedule_file_name.lstrip('schedule_')}_{dataset_file_name}-{crop_size[0]}x{crop_size[1]}.py"
|
||||
output_configs_alg_my_alg = os.path.join(f'./configs/fastscnn/', alg_file_name)
|
||||
write_config_to_file(output_configs_alg_my_alg, _base_=_base_, norm_cfg=norm_cfg, crop_size=crop_size, data_preprocessor=data_preprocessor, model=model, optim_wrapper=optim_wrapper, param_scheduler=param_scheduler)
|
||||
|
||||
# 将信息写入文件
|
||||
with open("_temp_.txt", "w") as file:
|
||||
json.dump({'alg_infos':{'alg_file_name':alg_file_name,'alg_file_pth':output_configs_alg_my_alg}}, file)
|
||||
@@ -0,0 +1,189 @@
|
||||
import os, sys, argparse, json
|
||||
# 添加上级目录到 sys.path
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../')))
|
||||
|
||||
from Initial_Alg_Program.Initial_Alg_Select_Tool import select_crop_size, select_pretrained_model
|
||||
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_Tool import format_all_data, write_config_to_file, create_dict_by_kwargs, get_var_from_file, update_list_dict_var, get_var_from_py_file
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_0_base_ import generate_base_config
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_1_norm_cfg import generate_norm_cfg
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_2_data_preprocessor import generate_data_preprocessor
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_3_model import generate_model_pretrained, generate_model_backbone, generate_model_data_preprocessor, generate_model_decode_head, generate_model_auxiliary_head, generate_model_train_cfg, generate_model_test_cfg
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_4_optimizer import generate_optim_wrapper, generate_param_scheduler
|
||||
|
||||
# 交互式选择 decode_head 的函数
|
||||
def select_decode_head(decode_head_choose):
|
||||
print("可用的 decode head 选项:")
|
||||
for i, key in enumerate(decode_head_choose.keys()):
|
||||
print(f"{i + 1}. {key}")
|
||||
|
||||
while True:
|
||||
try:
|
||||
# 提示用户输入选择
|
||||
choice = int(input("请选择需要的 decode head(输入编号):"))
|
||||
# 检查输入是否在有效范围内
|
||||
if 1 <= choice <= len(decode_head_choose):
|
||||
selected_key = list(decode_head_choose.keys())[choice - 1]
|
||||
print(f"你选择了: {selected_key}")
|
||||
return decode_head_choose[selected_key]
|
||||
else:
|
||||
print("输入的编号不正确,请重新输入。")
|
||||
except ValueError:
|
||||
print("输入无效,请输入有效的编号。")
|
||||
|
||||
if __name__ == '__main__':
|
||||
########### 1.超参数-算法默认参数 ###########
|
||||
alg_file_name = 'fastfcn_r50-d32_jpu_psp'
|
||||
alg_file_pth = f"configs/_base_/models/{alg_file_name}.py"
|
||||
|
||||
########### 2.参数解析 ###########
|
||||
parser = argparse.ArgumentParser(description="Run algorithm with specified parameters.")
|
||||
|
||||
# 添加命令行参数
|
||||
parser.add_argument('--alg_name', type=str, required=True, help="Algorithm name (e.g., 'ann')")
|
||||
parser.add_argument('--dataset_file_name', type=str, required=True, help="Dataset file name (e.g., 'my_dataset_model')")
|
||||
parser.add_argument('--mean', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
|
||||
parser.add_argument('--std', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
|
||||
parser.add_argument('--dataset_num_classes', type=int, required=True, help="Dataset class num (e.g., 36)")
|
||||
parser.add_argument('--GPU_num', type=int, required=True, help="Num of GPU to train (e.g., 1)")
|
||||
parser.add_argument('--schedule_file_name', type=str, required=True, help="Schedule file name (e.g., 'schedule_4k_check_400')")
|
||||
parser.add_argument('--train_type', type=str, required=True, help="Schedule train type (e.g., Epoch or Iteration)")
|
||||
parser.add_argument('--train_time_or_epoch_k', type=int, required=True, help="Schedule train time k (e.g., 4)")
|
||||
|
||||
# 解析命令行参数
|
||||
args = parser.parse_args()
|
||||
|
||||
# 变量赋值
|
||||
alg_name = args.alg_name
|
||||
dataset_file_name = args.dataset_file_name
|
||||
mean = json.loads(args.mean) # 使用json.loads转换成list
|
||||
std = json.loads(args.std) # 使用json.loads转换成list
|
||||
num_classes = args.dataset_num_classes
|
||||
schedule_file_name = args.schedule_file_name
|
||||
train_type = args.train_type
|
||||
train_time_or_epoch_k = args.train_time_or_epoch_k
|
||||
GPU_num = args.GPU_num
|
||||
|
||||
########### 3.1. 生成_base_部分 ###########
|
||||
_base_ = generate_base_config(alg_file_name=alg_file_name, dataset_file_name=dataset_file_name, schedule_file_name=schedule_file_name)
|
||||
|
||||
########### 3.1. 生成norm_cfg部分 ###########
|
||||
norm_cfg = generate_norm_cfg(GPU_num=GPU_num)
|
||||
|
||||
########### 3.2. 生成data_preprocessor部分 ###########
|
||||
crop_size = select_crop_size(predefined_options=[(512, 512), (1024, 1024)]) # 选择切割大小
|
||||
data_preprocessor = generate_data_preprocessor(crop_size = crop_size, mean = mean, std = std)
|
||||
|
||||
########### 3.3. 生成model部分 ###########
|
||||
model_data_preprocessor = data_preprocessor # 修改data_preprocessor
|
||||
# 获取backbone模型、是否需要预训练
|
||||
model_list = ['openmmlab/resnet50_v1c', 'openmmlab/resnet101_v1c']
|
||||
selected_model_name, select_pretrained, pretrained_pth, selected_model_info = select_pretrained_model(model_list=model_list) # 需要选择是否用预训练模型
|
||||
depth = selected_model_info['depth']
|
||||
|
||||
backbone = create_dict_by_kwargs(norm_cfg=norm_cfg, depth=depth)
|
||||
|
||||
# 定义 decode_head 的选项字典
|
||||
decode_head_choose = {
|
||||
'psp-PSPHead': {
|
||||
'decode_head':dict(
|
||||
type='PSPHead',
|
||||
in_channels=2048,
|
||||
in_index=2,
|
||||
channels=512,
|
||||
pool_scales=(1, 2, 3, 6),
|
||||
dropout_ratio=0.1,
|
||||
num_classes=19,
|
||||
norm_cfg=norm_cfg, # 假设 norm_cfg 预定义
|
||||
align_corners=False,
|
||||
loss_decode=dict(
|
||||
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)
|
||||
),
|
||||
'decode_head_name':'aspp'
|
||||
},
|
||||
'enc-EncHead': {
|
||||
'decode_head':dict(
|
||||
_delete_=True,
|
||||
type='EncHead',
|
||||
in_channels=[512, 1024, 2048],
|
||||
in_index=(0, 1, 2),
|
||||
channels=512,
|
||||
num_codes=32,
|
||||
use_se_loss=True,
|
||||
add_lateral=False,
|
||||
dropout_ratio=0.1,
|
||||
num_classes=num_classes,
|
||||
norm_cfg=norm_cfg,
|
||||
align_corners=False,
|
||||
loss_decode=dict(
|
||||
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
|
||||
loss_se_decode=dict(
|
||||
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=0.2)
|
||||
),
|
||||
'decode_head_name':'aspp'
|
||||
},
|
||||
'aspp-ASPPHead': {
|
||||
'decode_head':dict(
|
||||
_delete_=True,
|
||||
type='ASPPHead',
|
||||
in_channels=2048,
|
||||
in_index=2,
|
||||
channels=512,
|
||||
dilations=(1, 12, 24, 36),
|
||||
dropout_ratio=0.1,
|
||||
num_classes=num_classes,
|
||||
norm_cfg=norm_cfg,
|
||||
align_corners=False,
|
||||
loss_decode=dict(
|
||||
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)
|
||||
),
|
||||
'decode_head_name':'aspp'
|
||||
}
|
||||
}
|
||||
|
||||
decode_head_dict = select_decode_head(decode_head_choose=decode_head_choose)
|
||||
decode_head = decode_head_dict['decode_head']
|
||||
decode_head_name = decode_head_dict['decode_head_name']
|
||||
|
||||
decode_head_loss_decode_type = 'DiceLoss' # Way 1: 更改 Loss
|
||||
# decode_head_loss_decode_type = 'CrossEntropyLoss' # Way ori: 不更改
|
||||
|
||||
# 修改decode_head中loss_decode type
|
||||
if 'loss_se_decode' in decode_head.keys():
|
||||
decode_head['loss_se_decode']['type'] = decode_head_loss_decode_type
|
||||
if 'loss_decode' in decode_head.keys():
|
||||
decode_head['loss_decode']['type'] = decode_head_loss_decode_type
|
||||
|
||||
# auxiliary 为 None
|
||||
auxiliary_head_loss_decode_dict = dict(type='DiceLoss', use_sigmoid=False, loss_weight=0.4)
|
||||
auxiliary_head = dict(loss_decode=auxiliary_head_loss_decode_dict, norm_cfg=norm_cfg, num_classes=num_classes) # Way 1: 更改 Loss
|
||||
# auxiliary_head = dict(norm_cfg=norm_cfg, num_classes=num_classes) # Way ori: 不更改
|
||||
|
||||
# 综合model
|
||||
if auxiliary_head == None:
|
||||
model = dict(
|
||||
backbone = backbone,
|
||||
data_preprocessor = model_data_preprocessor,
|
||||
decode_head = decode_head
|
||||
)
|
||||
else:
|
||||
model = dict(
|
||||
backbone = backbone,
|
||||
data_preprocessor = model_data_preprocessor,
|
||||
decode_head = decode_head,
|
||||
auxiliary_head = auxiliary_head
|
||||
)
|
||||
|
||||
########### 3.4. 生成optim、param_scheduler部分 ###########
|
||||
optim_wrapper = generate_optim_wrapper()
|
||||
param_scheduler = generate_param_scheduler(train_type, train_time_or_epoch_k)
|
||||
|
||||
########### 4.程序写入 ###########
|
||||
# 算法名称解析:fastfcn【Alg】 _ r50【pretrained模型深度】 - d8【_base_/models】 _ 4【GPU数量】 x b2【Batch Size大小】 - 40k【schedule】 _ cityscapes【数据集】 - 512x1024【crop_size】.py
|
||||
alg_file_name = f"{alg_name}_r{depth}_{'Pre' if select_pretrained else 'NoPre'}_jpu_{decode_head_name}_g{GPU_num}-{schedule_file_name.lstrip('schedule_')}_{dataset_file_name}-{crop_size[0]}x{crop_size[1]}.py"
|
||||
output_configs_alg_my_alg = os.path.join(f'./configs/fastfcn/', alg_file_name)
|
||||
write_config_to_file(output_configs_alg_my_alg, _base_=_base_, norm_cfg=norm_cfg, crop_size=crop_size, data_preprocessor=data_preprocessor, model=model, optim_wrapper=optim_wrapper, param_scheduler=param_scheduler)
|
||||
|
||||
# 将信息写入文件
|
||||
with open("_temp_.txt", "w") as file:
|
||||
json.dump({'alg_infos':{'alg_file_name':alg_file_name,'alg_file_pth':output_configs_alg_my_alg}}, file)
|
||||
132
Seg_All_In_One_MMSeg/My_All_In_One/2_Alg_Program/my_fcn_r18.py
Normal file
132
Seg_All_In_One_MMSeg/My_All_In_One/2_Alg_Program/my_fcn_r18.py
Normal file
@@ -0,0 +1,132 @@
|
||||
import os, sys, argparse, json
|
||||
# 添加上级目录到 sys.path
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../')))
|
||||
|
||||
from Initial_Alg_Program.Initial_Alg_Select_Tool import select_crop_size, select_pretrained_model, select_test_cfg_slide
|
||||
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_Tool import format_all_data, write_config_to_file, create_dict_by_kwargs
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_0_base_ import generate_base_config
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_1_norm_cfg import generate_norm_cfg
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_2_data_preprocessor import generate_data_preprocessor
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_3_model import generate_model_pretrained, generate_model_backbone, generate_model_data_preprocessor, generate_model_decode_head, generate_model_auxiliary_head, generate_model_train_cfg, generate_model_test_cfg
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_4_optimizer import generate_optim_wrapper, generate_param_scheduler
|
||||
|
||||
if __name__ == '__main__':
|
||||
########### 1.超参数-算法默认参数 ###########
|
||||
alg_file_name = 'fcn_r50-d8'
|
||||
|
||||
########### 2.参数解析 ###########
|
||||
parser = argparse.ArgumentParser(description="Run algorithm with specified parameters.")
|
||||
|
||||
# 添加命令行参数
|
||||
parser.add_argument('--alg_name', type=str, required=True, help="Algorithm name (e.g., 'ann')")
|
||||
parser.add_argument('--dataset_file_name', type=str, required=True, help="Dataset file name (e.g., 'my_dataset_model')")
|
||||
parser.add_argument('--mean', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
|
||||
parser.add_argument('--std', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
|
||||
parser.add_argument('--dataset_num_classes', type=int, required=True, help="Dataset class num (e.g., 36)")
|
||||
parser.add_argument('--GPU_num', type=int, required=True, help="Num of GPU to train (e.g., 1)")
|
||||
parser.add_argument('--schedule_file_name', type=str, required=True, help="Schedule file name (e.g., 'schedule_4k_check_400')")
|
||||
parser.add_argument('--train_type', type=str, required=True, help="Schedule train type (e.g., Epoch or Iteration)")
|
||||
parser.add_argument('--train_time_or_epoch_k', type=int, required=True, help="Schedule train time k (e.g., 4)")
|
||||
|
||||
# 解析命令行参数
|
||||
args = parser.parse_args()
|
||||
|
||||
# 变量赋值
|
||||
alg_name = args.alg_name
|
||||
dataset_file_name = args.dataset_file_name
|
||||
mean = json.loads(args.mean) # 使用json.loads转换成list
|
||||
std = json.loads(args.std) # 使用json.loads转换成list
|
||||
num_classes = args.dataset_num_classes
|
||||
schedule_file_name = args.schedule_file_name
|
||||
train_type = args.train_type
|
||||
train_time_or_epoch_k = args.train_time_or_epoch_k
|
||||
GPU_num = args.GPU_num
|
||||
|
||||
|
||||
########### 3. 设定参数 ###########
|
||||
# 3.1. 自选参数
|
||||
crop_size = select_crop_size(predefined_options=[(512, 512), (512, 1024), (769, 769)]) # 选择切割大小
|
||||
# 如果crop_size更大则调整dilations和strides
|
||||
if crop_size[0]*crop_size[1] > 512*512:
|
||||
backbone_dilations = (1, 1, 1, 2) # 每一步更精细信息
|
||||
backbone_strides = (1, 2, 2, 1) # 步子走的大
|
||||
decode_head_dilation = 6 # 每一步更多信息
|
||||
auxiliary_head_dilation = 6 # 每一步更多信息
|
||||
else:
|
||||
backbone_dilations = (1, 1, 2, 4) # 每一步更多信息
|
||||
backbone_strides = (1, 2, 1, 1) # 步子走的小
|
||||
decode_head_dilation = 1 # 每一步更精细信息
|
||||
auxiliary_head_dilation = 1 # 每一步更精细信息
|
||||
|
||||
model_list = ['torchvision://resnet18', 'torchvision://resnet50', 'torchvision://resnet101', 'open-mmlab/resnet18_v1c', 'openmmlab/resnet50_v1c', 'openmmlab/resnet101_v1c']
|
||||
selected_model_name, select_pretrained, pretrained_pth, selected_model_info = select_pretrained_model(model_list=model_list)
|
||||
depth = selected_model_info['depth']
|
||||
model_type = selected_model_info['type']
|
||||
|
||||
test_cfg_mode, test_cfg_crop_div_stride = select_test_cfg_slide(crop_size)
|
||||
|
||||
# 模型信息
|
||||
if str(depth) == '18' :
|
||||
backbone_out_channels = 256
|
||||
decode_head_in_channels = 512 #
|
||||
decode_head_channels = 128 #
|
||||
auxiliary_head_in_channels = 256 #
|
||||
auxiliary_head_channels = 64 #
|
||||
|
||||
elif str(depth) == '50' or str(depth) == '101':
|
||||
backbone_out_channels = 1024
|
||||
decode_head_in_channels = 2048 #
|
||||
decode_head_channels = 512 #
|
||||
auxiliary_head_in_channels = 1024 #
|
||||
auxiliary_head_channels = 256 #
|
||||
|
||||
# 3.2. 自动生成参数
|
||||
# decode、auxiliary损失下载方式
|
||||
decode_head_loss_decode_dict=dict(type='DiceLoss', use_sigmoid=False, loss_weight=1.0) # decode损失函数 # TODO
|
||||
# decode_head_loss_decode_dict=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0) # decode损失函数 # TODO
|
||||
auxiliary_head_loss_decode_dict=dict(type='DiceLoss', use_sigmoid=False, loss_weight=0.4) # DiceLoss损失函数 # TODO
|
||||
# auxiliary_head_loss_decode_dict=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4) # DiceLoss损失函数 # TODO
|
||||
|
||||
# 如果使用slide模式,则开启align_corners
|
||||
if test_cfg_mode == "slide":
|
||||
align_corners = True
|
||||
test_slide = 'testslide' # 算法名进行标记
|
||||
else:
|
||||
align_corners = False
|
||||
test_slide = 'no_testslide' # 算法名进行标记
|
||||
|
||||
|
||||
########### 3.生成各个部分 ###########
|
||||
_base_ = generate_base_config(alg_file_name=alg_file_name, dataset_file_name=dataset_file_name, schedule_file_name=schedule_file_name)
|
||||
norm_cfg = generate_norm_cfg(GPU_num=GPU_num)
|
||||
data_preprocessor = generate_data_preprocessor(crop_size = crop_size, mean = mean, std = std)
|
||||
|
||||
pretrained = generate_model_pretrained(pretrained_pth = pretrained_pth)
|
||||
backbone = create_dict_by_kwargs(type=model_type, depth=depth, strides = backbone_strides, dilations=backbone_dilations)
|
||||
model_data_preprocessor = data_preprocessor # 修改data_preprocessor
|
||||
model_data_preprocessor = generate_model_data_preprocessor(model_data_preprocessor=model_data_preprocessor)
|
||||
decode_head = create_dict_by_kwargs(dilation=decode_head_dilation, channels=decode_head_channels, in_channels=decode_head_in_channels, num_classes=num_classes, loss_decode=decode_head_loss_decode_dict, align_corners=align_corners)
|
||||
auxiliary_head = create_dict_by_kwargs(dilation=auxiliary_head_dilation, channels=auxiliary_head_channels, in_channels=auxiliary_head_in_channels, num_classes=num_classes, loss_decode=auxiliary_head_loss_decode_dict, align_corners=align_corners)
|
||||
# 综合model
|
||||
model = dict(
|
||||
pretrained = pretrained,
|
||||
backbone = backbone,
|
||||
data_preprocessor = model_data_preprocessor,
|
||||
decode_head = decode_head,
|
||||
auxiliary_head = auxiliary_head,
|
||||
)
|
||||
|
||||
test_cfg = generate_model_test_cfg(test_cfg_mode=test_cfg_mode, crop_size=crop_size, test_cfg_crop_div_stride=test_cfg_crop_div_stride)
|
||||
optim_wrapper = generate_optim_wrapper()
|
||||
param_scheduler = generate_param_scheduler(train_type, train_time_or_epoch_k)
|
||||
|
||||
########### 4.程序写入 ###########
|
||||
# 算法名称解析:fcn【Alg】 _ r50【pretrained模型深度】 - d8【_base_/models】 _ 4【GPU数量】 x b2【Batch Size大小】 - 40k【schedule】 _ cityscapes【数据集】 - 512x1024【crop_size】.py
|
||||
alg_file_name = f"{alg_name}_r{depth}_{'Pre' if select_pretrained else 'NoPre'}_g{GPU_num}-{schedule_file_name.lstrip('schedule_')}_{dataset_file_name}-{crop_size[0]}x{crop_size[1]}-{test_slide}.py"
|
||||
output_configs_alg_my_alg = os.path.join(f'./configs/fcn/', alg_file_name)
|
||||
write_config_to_file(output_configs_alg_my_alg, _base_=_base_, norm_cfg=norm_cfg, crop_size=crop_size, data_preprocessor=data_preprocessor, model=model, test_cfg=test_cfg, optim_wrapper=optim_wrapper, param_scheduler=param_scheduler)
|
||||
|
||||
# 将信息写入文件
|
||||
with open("_temp_.txt", "w") as file:
|
||||
json.dump({'alg_infos':{'alg_file_name':alg_file_name,'alg_file_pth':output_configs_alg_my_alg}}, file)
|
||||
106
Seg_All_In_One_MMSeg/My_All_In_One/2_Alg_Program/my_gcnet_r50.py
Normal file
106
Seg_All_In_One_MMSeg/My_All_In_One/2_Alg_Program/my_gcnet_r50.py
Normal file
@@ -0,0 +1,106 @@
|
||||
import os, sys, argparse, json
|
||||
# 添加上级目录到 sys.path
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../')))
|
||||
|
||||
from Initial_Alg_Program.Initial_Alg_Select_Tool import select_crop_size, select_pretrained_model, select_test_cfg_slide
|
||||
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_Tool import format_all_data, write_config_to_file
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_0_base_ import generate_base_config
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_1_norm_cfg import generate_norm_cfg
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_2_data_preprocessor import generate_data_preprocessor
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_3_model import generate_model_pretrained, generate_model_backbone, generate_model_data_preprocessor, generate_model_decode_head, generate_model_auxiliary_head, generate_model_train_cfg, generate_model_test_cfg
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_4_optimizer import generate_optim_wrapper, generate_param_scheduler
|
||||
|
||||
if __name__ == '__main__':
|
||||
########### 1.超参数-算法默认参数 ###########
|
||||
alg_file_name = 'gcnet_r50-d8'
|
||||
|
||||
########### 2.参数解析 ###########
|
||||
parser = argparse.ArgumentParser(description="Run algorithm with specified parameters.")
|
||||
|
||||
# 添加命令行参数
|
||||
parser.add_argument('--alg_name', type=str, required=True, help="Algorithm name (e.g., 'ann')")
|
||||
parser.add_argument('--dataset_file_name', type=str, required=True, help="Dataset file name (e.g., 'my_dataset_model')")
|
||||
parser.add_argument('--mean', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
|
||||
parser.add_argument('--std', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
|
||||
parser.add_argument('--dataset_num_classes', type=int, required=True, help="Dataset class num (e.g., 36)")
|
||||
parser.add_argument('--GPU_num', type=int, required=True, help="Num of GPU to train (e.g., 1)")
|
||||
parser.add_argument('--schedule_file_name', type=str, required=True, help="Schedule file name (e.g., 'schedule_4k_check_400')")
|
||||
parser.add_argument('--train_type', type=str, required=True, help="Schedule train type (e.g., Epoch or Iteration)")
|
||||
parser.add_argument('--train_time_or_epoch_k', type=int, required=True, help="Schedule train time k (e.g., 4)")
|
||||
|
||||
# 解析命令行参数
|
||||
args = parser.parse_args()
|
||||
|
||||
# 变量赋值
|
||||
alg_name = args.alg_name
|
||||
dataset_file_name = args.dataset_file_name
|
||||
mean = json.loads(args.mean) # 使用json.loads转换成list
|
||||
std = json.loads(args.std) # 使用json.loads转换成list
|
||||
num_classes = args.dataset_num_classes
|
||||
schedule_file_name = args.schedule_file_name
|
||||
train_type = args.train_type
|
||||
train_time_or_epoch_k = args.train_time_or_epoch_k
|
||||
GPU_num = args.GPU_num
|
||||
|
||||
|
||||
########### 3. 设定参数 ###########
|
||||
# 3.1. 自选参数
|
||||
crop_size = select_crop_size(predefined_options=[(512, 512), (512, 1024), (769, 769)]) # 选择切割大小
|
||||
model_list = ['openmmlab/resnet50_v1c', 'openmmlab/resnet101_v1c']
|
||||
selected_model_name, select_pretrained, pretrained_pth, selected_model_info = select_pretrained_model(model_list=model_list)
|
||||
depth = selected_model_info['depth']
|
||||
|
||||
test_cfg_mode, test_cfg_crop_div_stride = select_test_cfg_slide(crop_size)
|
||||
# test_cfg_mode = None
|
||||
# test_cfg_crop_div_stride = crop_size
|
||||
|
||||
# 3.2. 自动生成参数
|
||||
# decode、auxiliary损失下载方式
|
||||
decode_head_loss_decode_dict=dict(type='DiceLoss', use_sigmoid=False, loss_weight=1.0) # decode损失函数 # TODO
|
||||
# decode_head_loss_decode_dict=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0) # decode损失函数 # TODO
|
||||
auxiliary_head_loss_decode_dict=dict(type='DiceLoss', use_sigmoid=False, loss_weight=0.4) # DiceLoss损失函数 # TODO
|
||||
# auxiliary_head_loss_decode_dict=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4) # DiceLoss损失函数 # TODO
|
||||
|
||||
# 如果使用slide模式,则开启align_corners
|
||||
if test_cfg_mode == "slide":
|
||||
align_corners = True
|
||||
test_slide = 'testslide' # 算法名进行标记
|
||||
else:
|
||||
align_corners = False
|
||||
test_slide = 'no_testslide' # 算法名进行标记
|
||||
|
||||
|
||||
########### 3.生成各个部分 ###########
|
||||
_base_ = generate_base_config(alg_file_name=alg_file_name, dataset_file_name=dataset_file_name, schedule_file_name=schedule_file_name)
|
||||
norm_cfg = generate_norm_cfg(GPU_num=GPU_num)
|
||||
data_preprocessor = generate_data_preprocessor(crop_size = crop_size, mean = mean, std = std)
|
||||
|
||||
pretrained = generate_model_pretrained(pretrained_pth = pretrained_pth)
|
||||
backbone = generate_model_backbone(depth=depth)
|
||||
model_data_preprocessor = data_preprocessor # 修改data_preprocessor
|
||||
model_data_preprocessor = generate_model_data_preprocessor(model_data_preprocessor=model_data_preprocessor)
|
||||
decode_head = generate_model_decode_head(num_classes=num_classes, decode_head_loss_decode_dict=decode_head_loss_decode_dict, align_corners=align_corners)
|
||||
auxiliary_head = generate_model_auxiliary_head(num_classes=num_classes, auxiliary_head_loss_decode_dict=auxiliary_head_loss_decode_dict, align_corners=align_corners)
|
||||
# 综合model
|
||||
model = dict(
|
||||
pretrained = pretrained,
|
||||
backbone = backbone,
|
||||
data_preprocessor = model_data_preprocessor,
|
||||
decode_head = decode_head,
|
||||
auxiliary_head = auxiliary_head,
|
||||
)
|
||||
|
||||
test_cfg = generate_model_test_cfg(test_cfg_mode=test_cfg_mode, crop_size=crop_size, test_cfg_crop_div_stride=test_cfg_crop_div_stride)
|
||||
optim_wrapper = generate_optim_wrapper()
|
||||
param_scheduler = generate_param_scheduler(train_type, train_time_or_epoch_k)
|
||||
|
||||
########### 4.程序写入 ###########
|
||||
# 算法名称解析:gcnet【Alg】 _ r50【pretrained模型深度】 - d8【_base_/models】 _ 4【GPU数量】 x b2【Batch Size大小】 - 40k【schedule】 _ cityscapes【数据集】 - 512x1024【crop_size】.py
|
||||
alg_file_name = f"{alg_name}_r{depth}_{'Pre' if select_pretrained else 'NoPre'}_g{GPU_num}-{schedule_file_name.lstrip('schedule_')}_{dataset_file_name}-{crop_size[0]}x{crop_size[1]}-{test_slide}.py"
|
||||
output_configs_alg_my_alg = os.path.join(f'./configs/gcnet/', alg_file_name)
|
||||
write_config_to_file(output_configs_alg_my_alg, _base_=_base_, norm_cfg=norm_cfg, crop_size=crop_size, data_preprocessor=data_preprocessor, model=model, test_cfg=test_cfg, optim_wrapper=optim_wrapper, param_scheduler=param_scheduler)
|
||||
|
||||
# 将信息写入文件
|
||||
with open("_temp_.txt", "w") as file:
|
||||
json.dump({'alg_infos':{'alg_file_name':alg_file_name,'alg_file_pth':output_configs_alg_my_alg}}, file)
|
||||
144
Seg_All_In_One_MMSeg/My_All_In_One/2_Alg_Program/my_hrnet_fcn.py
Normal file
144
Seg_All_In_One_MMSeg/My_All_In_One/2_Alg_Program/my_hrnet_fcn.py
Normal file
@@ -0,0 +1,144 @@
|
||||
import os, sys, argparse, json
|
||||
# 添加上级目录到 sys.path
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../')))
|
||||
|
||||
from Initial_Alg_Program.Initial_Alg_Select_Tool import select_crop_size, select_pretrained_model, select_test_cfg_slide
|
||||
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_Tool import format_all_data, write_config_to_file, create_dict_by_kwargs
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_0_base_ import generate_base_config
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_1_norm_cfg import generate_norm_cfg
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_2_data_preprocessor import generate_data_preprocessor
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_3_model import generate_model_pretrained, generate_model_backbone, generate_model_data_preprocessor, generate_model_decode_head, generate_model_auxiliary_head, generate_model_train_cfg, generate_model_test_cfg
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_4_optimizer import generate_optim_wrapper, generate_param_scheduler
|
||||
|
||||
if __name__ == '__main__':
|
||||
########### 1.超参数-算法默认参数 ###########
|
||||
alg_file_name = 'fcn_hr18'
|
||||
|
||||
########### 2.参数解析 ###########
|
||||
parser = argparse.ArgumentParser(description="Run algorithm with specified parameters.")
|
||||
|
||||
# 添加命令行参数
|
||||
parser.add_argument('--alg_name', type=str, required=True, help="Algorithm name (e.g., 'ann')")
|
||||
parser.add_argument('--dataset_file_name', type=str, required=True, help="Dataset file name (e.g., 'my_dataset_model')")
|
||||
parser.add_argument('--mean', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
|
||||
parser.add_argument('--std', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
|
||||
parser.add_argument('--dataset_num_classes', type=int, required=True, help="Dataset class num (e.g., 36)")
|
||||
parser.add_argument('--GPU_num', type=int, required=True, help="Num of GPU to train (e.g., 1)")
|
||||
parser.add_argument('--schedule_file_name', type=str, required=True, help="Schedule file name (e.g., 'schedule_4k_check_400')")
|
||||
parser.add_argument('--train_type', type=str, required=True, help="Schedule train type (e.g., Epoch or Iteration)")
|
||||
parser.add_argument('--train_time_or_epoch_k', type=int, required=True, help="Schedule train time k (e.g., 4)")
|
||||
|
||||
# 解析命令行参数
|
||||
args = parser.parse_args()
|
||||
|
||||
# 变量赋值
|
||||
alg_name = args.alg_name
|
||||
dataset_file_name = args.dataset_file_name
|
||||
mean = json.loads(args.mean) # 使用json.loads转换成list
|
||||
std = json.loads(args.std) # 使用json.loads转换成list
|
||||
num_classes = args.dataset_num_classes
|
||||
schedule_file_name = args.schedule_file_name
|
||||
train_type = args.train_type
|
||||
train_time_or_epoch_k = args.train_time_or_epoch_k
|
||||
GPU_num = args.GPU_num
|
||||
|
||||
|
||||
########### 3. 设定参数 ###########
|
||||
# 3.1. 自选参数
|
||||
crop_size = select_crop_size(predefined_options=[(512, 512), (512, 1024), (769, 769)]) # 选择切割大小
|
||||
# # 如果crop_size更大则调整dilations和strides
|
||||
# if crop_size[0]*crop_size[1] > 512*512:
|
||||
# backbone_dilations = (1, 1, 1, 2) # 每一步更精细信息
|
||||
# backbone_strides = (1, 2, 2, 1) # 步子走的大
|
||||
# decode_head_dilation = 6 # 每一步更多信息
|
||||
# auxiliary_head_dilation = 6 # 每一步更多信息
|
||||
# else:
|
||||
# backbone_dilations = (1, 1, 2, 4) # 每一步更多信息
|
||||
# backbone_strides = (1, 2, 1, 1) # 步子走的小
|
||||
# decode_head_dilation = 1 # 每一步更精细信息
|
||||
# auxiliary_head_dilation = 1 # 每一步更精细信息
|
||||
|
||||
model_list = ['open-mmlab://msra/hrnetv2_w18', 'open-mmlab://msra/hrnetv2_w18_small', 'open-mmlab://msra/hrnetv2_w48']
|
||||
selected_model_name, select_pretrained, pretrained_pth, selected_model_info = select_pretrained_model(model_list=model_list)
|
||||
|
||||
test_cfg_mode, test_cfg_crop_div_stride = select_test_cfg_slide(crop_size)
|
||||
|
||||
# 模型信息
|
||||
if selected_model_name == 'open-mmlab://msra/hrnetv2_w18':
|
||||
backbone_extra=dict(
|
||||
stage1=dict(num_blocks=(4, ), num_channels=(64, )),
|
||||
stage2=dict(num_blocks=(4, 4), num_channels=(18, 36)),
|
||||
stage3=dict(num_modules=4, num_blocks=(4, 4, 4), num_channels=(18, 36, 72)),
|
||||
stage4=dict(num_modules=3, num_blocks=(4, 4, 4, 4), num_channels=(18, 36, 72, 144)))
|
||||
|
||||
decode_head_channels = dict(in_channels=[18, 36, 72, 144], channels=sum([18, 36, 72, 144]))
|
||||
alg_text = 'w18'
|
||||
elif selected_model_name == 'open-mmlab://msra/hrnetv2_w18_small':
|
||||
backbone_extra=dict(
|
||||
stage1=dict(num_blocks=(2, )),
|
||||
stage2=dict(num_blocks=(2, 2)),
|
||||
stage3=dict(num_modules=3, num_blocks=(2, 2, 2)),
|
||||
stage4=dict(num_modules=2, num_blocks=(2, 2, 2, 2)))
|
||||
|
||||
decode_head_channels = dict(in_channels=[18, 36, 72, 144], channels=sum([18, 36, 72, 144])) # 同上
|
||||
alg_text = 'w18-small'
|
||||
elif selected_model_name == 'open-mmlab://msra/hrnetv2_w48':
|
||||
# 改变channel
|
||||
backbone_extra=dict(
|
||||
stage2=dict(num_channels=(48, 96)),
|
||||
stage3=dict(num_channels=(48, 96, 192)),
|
||||
stage4=dict(num_channels=(48, 96, 192, 384)))
|
||||
|
||||
decode_head_channels = dict(in_channels=[48, 96, 192, 384], channels=sum([48, 96, 192, 384]))
|
||||
alg_text = 'w48'
|
||||
|
||||
# 3.2. 自动生成参数
|
||||
# decode、auxiliary损失下载方式
|
||||
# decode_head_loss_decode_dict=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0) # decode损失函数 # Way: Ori TODO
|
||||
decode_head_loss_decode_dict=dict(type='DiceLoss', use_sigmoid=False, loss_weight=1.0) # decode损失函数 # Way: 1 TODO
|
||||
|
||||
# 如果使用slide模式,则开启align_corners
|
||||
if test_cfg_mode == "slide":
|
||||
align_corners = True
|
||||
test_slide = 'testslide' # 算法名进行标记
|
||||
else:
|
||||
align_corners = False
|
||||
test_slide = 'no_testslide' # 算法名进行标记
|
||||
|
||||
########### 3.生成各个部分 ###########
|
||||
_base_ = generate_base_config(alg_file_name=alg_file_name, dataset_file_name=dataset_file_name, schedule_file_name=schedule_file_name)
|
||||
norm_cfg = generate_norm_cfg(GPU_num=GPU_num)
|
||||
data_preprocessor = generate_data_preprocessor(crop_size = crop_size, mean = mean, std = std)
|
||||
|
||||
pretrained = generate_model_pretrained(pretrained_pth = pretrained_pth)
|
||||
backbone = create_dict_by_kwargs(extra = backbone_extra)
|
||||
model_data_preprocessor = data_preprocessor # 修改data_preprocessor
|
||||
model_data_preprocessor = generate_model_data_preprocessor(model_data_preprocessor=model_data_preprocessor)
|
||||
decode_head = create_dict_by_kwargs(num_classes=num_classes, loss_decode=decode_head_loss_decode_dict, align_corners=align_corners)
|
||||
decode_head.update(decode_head_channels) # 加入channels相关信息
|
||||
|
||||
auxiliary_head = None
|
||||
|
||||
# 综合model
|
||||
model = dict(
|
||||
pretrained = pretrained,
|
||||
backbone = backbone,
|
||||
data_preprocessor = model_data_preprocessor,
|
||||
decode_head = decode_head,
|
||||
# auxiliary_head = auxiliary_head,
|
||||
)
|
||||
|
||||
test_cfg = generate_model_test_cfg(test_cfg_mode=test_cfg_mode, crop_size=crop_size, test_cfg_crop_div_stride=test_cfg_crop_div_stride)
|
||||
optim_wrapper = generate_optim_wrapper()
|
||||
param_scheduler = generate_param_scheduler(train_type, train_time_or_epoch_k)
|
||||
|
||||
########### 4.程序写入 ###########
|
||||
# 算法名称解析:hrnet【Alg】 _ r50【pretrained模型深度】 - d8【_base_/models】 _ 4【GPU数量】 x b2【Batch Size大小】 - 40k【schedule】 _ cityscapes【数据集】 - 512x1024【crop_size】.py
|
||||
alg_file_name = f"{alg_name}_{alg_text}_{'Pre' if select_pretrained else 'NoPre'}_g{GPU_num}-{schedule_file_name.lstrip('schedule_')}_{dataset_file_name}-{crop_size[0]}x{crop_size[1]}-{test_slide}.py"
|
||||
output_configs_alg_my_alg = os.path.join(f'./configs/hrnet/', alg_file_name)
|
||||
write_config_to_file(output_configs_alg_my_alg, _base_=_base_, norm_cfg=norm_cfg, crop_size=crop_size, data_preprocessor=data_preprocessor, model=model, test_cfg=test_cfg, optim_wrapper=optim_wrapper, param_scheduler=param_scheduler)
|
||||
|
||||
# 将信息写入文件
|
||||
with open("_temp_.txt", "w") as file:
|
||||
json.dump({'alg_infos':{'alg_file_name':alg_file_name,'alg_file_pth':output_configs_alg_my_alg}}, file)
|
||||
124
Seg_All_In_One_MMSeg/My_All_In_One/2_Alg_Program/my_icnet_r18.py
Normal file
124
Seg_All_In_One_MMSeg/My_All_In_One/2_Alg_Program/my_icnet_r18.py
Normal file
@@ -0,0 +1,124 @@
|
||||
import os, sys, argparse, json
|
||||
# 添加上级目录到 sys.path
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../')))
|
||||
|
||||
from Initial_Alg_Program.Initial_Alg_Select_Tool import select_crop_size, select_pretrained_model, select_test_cfg_slide
|
||||
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_Tool import format_all_data, write_config_to_file, create_dict_by_kwargs, get_var_from_py_file
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_0_base_ import generate_base_config
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_1_norm_cfg import generate_norm_cfg
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_2_data_preprocessor import generate_data_preprocessor
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_3_model import generate_model_pretrained, generate_model_backbone, generate_model_data_preprocessor, generate_model_decode_head, generate_model_auxiliary_head, generate_model_train_cfg, generate_model_test_cfg
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_4_optimizer import generate_optim_wrapper, generate_param_scheduler
|
||||
|
||||
if __name__ == '__main__':
|
||||
########### 1.超参数-算法默认参数 ###########
|
||||
alg_file_name = 'icnet_r50-d8'
|
||||
|
||||
########### 2.参数解析 ###########
|
||||
parser = argparse.ArgumentParser(description="Run algorithm with specified parameters.")
|
||||
|
||||
# 添加命令行参数
|
||||
parser.add_argument('--alg_name', type=str, required=True, help="Algorithm name (e.g., 'ann')")
|
||||
parser.add_argument('--dataset_file_name', type=str, required=True, help="Dataset file name (e.g., 'my_dataset_model')")
|
||||
parser.add_argument('--mean', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
|
||||
parser.add_argument('--std', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
|
||||
parser.add_argument('--dataset_num_classes', type=int, required=True, help="Dataset class num (e.g., 36)")
|
||||
parser.add_argument('--GPU_num', type=int, required=True, help="Num of GPU to train (e.g., 1)")
|
||||
parser.add_argument('--schedule_file_name', type=str, required=True, help="Schedule file name (e.g., 'schedule_4k_check_400')")
|
||||
parser.add_argument('--train_type', type=str, required=True, help="Schedule train type (e.g., Epoch or Iteration)")
|
||||
parser.add_argument('--train_time_or_epoch_k', type=int, required=True, help="Schedule train time k (e.g., 4)")
|
||||
|
||||
# 解析命令行参数
|
||||
args = parser.parse_args()
|
||||
|
||||
# 变量赋值
|
||||
alg_name = args.alg_name
|
||||
dataset_file_name = args.dataset_file_name
|
||||
mean = json.loads(args.mean) # 使用json.loads转换成list
|
||||
std = json.loads(args.std) # 使用json.loads转换成list
|
||||
num_classes = args.dataset_num_classes
|
||||
schedule_file_name = args.schedule_file_name
|
||||
train_type = args.train_type
|
||||
train_time_or_epoch_k = args.train_time_or_epoch_k
|
||||
GPU_num = args.GPU_num
|
||||
|
||||
|
||||
########### 3. 设定参数 ###########
|
||||
# 3.1. 自选参数
|
||||
crop_size = select_crop_size(predefined_options=[(832, 832), (512, 512)]) # 选择切割大小
|
||||
|
||||
model_list = ['openmmlab/resnet18_v1c', 'openmmlab/resnet50_v1c', 'openmmlab/resnet101_v1c']
|
||||
selected_model_name, select_pretrained, pretrained_pth, selected_model_info = select_pretrained_model(model_list=model_list, need_select_pretrained=True)
|
||||
depth = selected_model_info['depth']
|
||||
|
||||
test_cfg_mode, test_cfg_crop_div_stride = select_test_cfg_slide(crop_size)
|
||||
|
||||
# 需要预训练模型
|
||||
if select_pretrained == True:
|
||||
backbone_backbone_cfg_init_cfg = dict(type='Pretrained', checkpoint=pretrained_pth)
|
||||
else:
|
||||
backbone_cfg_init_cfg = None
|
||||
backbone_backbone_cfg = create_dict_by_kwargs(depth=depth, init_cfg=backbone_backbone_cfg_init_cfg)
|
||||
|
||||
# 模型信息
|
||||
if str(depth) == '18' :
|
||||
backbone_layer_channels = (128, 512)
|
||||
elif str(depth) == '50' or str(depth) == '101':
|
||||
backbone_layer_channels = (512, 2048)
|
||||
|
||||
# 3.2. 自动生成参数
|
||||
# decode、auxiliary损失下载方式
|
||||
decode_head_loss_decode_dict=dict(type='DiceLoss', use_sigmoid=False, loss_weight=1.0) # Way 1: decode损失函数 # TODO
|
||||
# decode_head_loss_decode_dict=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0) # Way Ori: decode损失函数 # TODO
|
||||
auxiliary_head_loss_decode_dict=dict(type='DiceLoss', use_sigmoid=False, loss_weight=0.4) # Way 1: DiceLoss损失函数 # TODO
|
||||
# auxiliary_head_loss_decode_dict=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4) # Way Ori: DiceLoss损失函数 # TODO
|
||||
|
||||
# 如果使用slide模式,则开启align_corners
|
||||
if test_cfg_mode == "slide":
|
||||
align_corners = True
|
||||
test_slide = 'testslide' # 算法名进行标记
|
||||
else:
|
||||
align_corners = False
|
||||
test_slide = 'no_testslide' # 算法名进行标记
|
||||
|
||||
|
||||
########### 3.生成各个部分 ###########
|
||||
_base_ = generate_base_config(alg_file_name=alg_file_name, dataset_file_name=dataset_file_name, schedule_file_name=schedule_file_name)
|
||||
norm_cfg = generate_norm_cfg(GPU_num=GPU_num)
|
||||
data_preprocessor = generate_data_preprocessor(crop_size = crop_size, mean = mean, std = std)
|
||||
|
||||
model_data_preprocessor = data_preprocessor # 修改data_preprocessor
|
||||
model_data_preprocessor = generate_model_data_preprocessor(model_data_preprocessor=model_data_preprocessor)
|
||||
|
||||
backbone = create_dict_by_kwargs(backbone_cfg = backbone_backbone_cfg, layer_channels=backbone_layer_channels, norm_cfg=norm_cfg, align_corners=align_corners)
|
||||
|
||||
neck = create_dict_by_kwargs(norm_cfg=norm_cfg, align_corners=align_corners)
|
||||
|
||||
decode_head = create_dict_by_kwargs(num_classes=num_classes, loss_decode=decode_head_loss_decode_dict, align_corners=align_corners)
|
||||
|
||||
auxiliary_head = get_var_from_py_file(os.path.join('./configs/_base_/models', alg_file_name+'.py'), 'model')['auxiliary_head']
|
||||
for i in range(len(auxiliary_head)):
|
||||
auxiliary_head[i].update(num_classes=num_classes, loss_decode=auxiliary_head_loss_decode_dict, align_corners=align_corners)
|
||||
|
||||
# 综合model
|
||||
model = dict(
|
||||
backbone = backbone,
|
||||
data_preprocessor = model_data_preprocessor,
|
||||
decode_head = decode_head,
|
||||
auxiliary_head = auxiliary_head,
|
||||
)
|
||||
|
||||
test_cfg = generate_model_test_cfg(test_cfg_mode=test_cfg_mode, crop_size=crop_size, test_cfg_crop_div_stride=test_cfg_crop_div_stride)
|
||||
optim_wrapper = generate_optim_wrapper()
|
||||
param_scheduler = generate_param_scheduler(train_type, train_time_or_epoch_k)
|
||||
|
||||
########### 4.程序写入 ###########
|
||||
# 算法名称解析:icnet【Alg】 _ r50【pretrained模型深度】 - d8【_base_/models】 _ 4【GPU数量】 x b2【Batch Size大小】 - 40k【schedule】 _ cityscapes【数据集】 - 512x1024【crop_size】.py
|
||||
alg_file_name = f"{alg_name}_r{depth}_{'Pre' if select_pretrained else 'NoPre'}_g{GPU_num}-{schedule_file_name.lstrip('schedule_')}_{dataset_file_name}-{crop_size[0]}x{crop_size[1]}-{test_slide}.py"
|
||||
output_configs_alg_my_alg = os.path.join(f'./configs/icnet/', alg_file_name)
|
||||
write_config_to_file(output_configs_alg_my_alg, _base_=_base_, norm_cfg=norm_cfg, crop_size=crop_size, data_preprocessor=data_preprocessor, model=model, test_cfg=test_cfg, optim_wrapper=optim_wrapper, param_scheduler=param_scheduler)
|
||||
|
||||
# 将信息写入文件
|
||||
with open("_temp_.txt", "w") as file:
|
||||
json.dump({'alg_infos':{'alg_file_name':alg_file_name,'alg_file_pth':output_configs_alg_my_alg}}, file)
|
||||
@@ -0,0 +1,106 @@
|
||||
import os, sys, argparse, json
|
||||
# 添加上级目录到 sys.path
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../')))
|
||||
|
||||
from Initial_Alg_Program.Initial_Alg_Select_Tool import select_crop_size, select_pretrained_model, select_test_cfg_slide
|
||||
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_Tool import format_all_data, write_config_to_file
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_0_base_ import generate_base_config
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_1_norm_cfg import generate_norm_cfg
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_2_data_preprocessor import generate_data_preprocessor
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_3_model import generate_model_pretrained, generate_model_backbone, generate_model_data_preprocessor, generate_model_decode_head, generate_model_auxiliary_head, generate_model_train_cfg, generate_model_test_cfg
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_4_optimizer import generate_optim_wrapper, generate_param_scheduler
|
||||
|
||||
if __name__ == '__main__':
|
||||
########### 1.超参数-算法默认参数 ###########
|
||||
alg_file_name = 'isanet_r50-d8'
|
||||
|
||||
########### 2.参数解析 ###########
|
||||
parser = argparse.ArgumentParser(description="Run algorithm with specified parameters.")
|
||||
|
||||
# 添加命令行参数
|
||||
parser.add_argument('--alg_name', type=str, required=True, help="Algorithm name (e.g., 'ann')")
|
||||
parser.add_argument('--dataset_file_name', type=str, required=True, help="Dataset file name (e.g., 'my_dataset_model')")
|
||||
parser.add_argument('--mean', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
|
||||
parser.add_argument('--std', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
|
||||
parser.add_argument('--dataset_num_classes', type=int, required=True, help="Dataset class num (e.g., 36)")
|
||||
parser.add_argument('--GPU_num', type=int, required=True, help="Num of GPU to train (e.g., 1)")
|
||||
parser.add_argument('--schedule_file_name', type=str, required=True, help="Schedule file name (e.g., 'schedule_4k_check_400')")
|
||||
parser.add_argument('--train_type', type=str, required=True, help="Schedule train type (e.g., Epoch or Iteration)")
|
||||
parser.add_argument('--train_time_or_epoch_k', type=int, required=True, help="Schedule train time k (e.g., 4)")
|
||||
|
||||
# 解析命令行参数
|
||||
args = parser.parse_args()
|
||||
|
||||
# 变量赋值
|
||||
alg_name = args.alg_name
|
||||
dataset_file_name = args.dataset_file_name
|
||||
mean = json.loads(args.mean) # 使用json.loads转换成list
|
||||
std = json.loads(args.std) # 使用json.loads转换成list
|
||||
num_classes = args.dataset_num_classes
|
||||
schedule_file_name = args.schedule_file_name
|
||||
train_type = args.train_type
|
||||
train_time_or_epoch_k = args.train_time_or_epoch_k
|
||||
GPU_num = args.GPU_num
|
||||
|
||||
|
||||
########### 3. 设定参数 ###########
|
||||
# 3.1. 自选参数
|
||||
crop_size = select_crop_size(predefined_options=[(512, 512), (512, 1024), (769, 769)]) # 选择切割大小
|
||||
model_list = ['openmmlab/resnet50_v1c', 'openmmlab/resnet101_v1c']
|
||||
selected_model_name, select_pretrained, pretrained_pth, selected_model_info = select_pretrained_model(model_list=model_list)
|
||||
depth = selected_model_info['depth']
|
||||
|
||||
test_cfg_mode, test_cfg_crop_div_stride = select_test_cfg_slide(crop_size)
|
||||
# test_cfg_mode = None
|
||||
# test_cfg_crop_div_stride = crop_size
|
||||
|
||||
# 3.2. 自动生成参数
|
||||
# decode、auxiliary损失下载方式
|
||||
decode_head_loss_decode_dict=dict(type='DiceLoss', use_sigmoid=False, loss_weight=1.0) # decode损失函数 # TODO
|
||||
# decode_head_loss_decode_dict=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0) # decode损失函数 # TODO
|
||||
auxiliary_head_loss_decode_dict=dict(type='DiceLoss', use_sigmoid=False, loss_weight=0.4) # DiceLoss损失函数 # TODO
|
||||
# auxiliary_head_loss_decode_dict=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4) # DiceLoss损失函数 # TODO
|
||||
|
||||
# 如果使用slide模式,则开启align_corners
|
||||
if test_cfg_mode == "slide":
|
||||
align_corners = True
|
||||
test_slide = 'testslide' # 算法名进行标记
|
||||
else:
|
||||
align_corners = False
|
||||
test_slide = 'no_testslide' # 算法名进行标记
|
||||
|
||||
|
||||
########### 3.生成各个部分 ###########
|
||||
_base_ = generate_base_config(alg_file_name=alg_file_name, dataset_file_name=dataset_file_name, schedule_file_name=schedule_file_name)
|
||||
norm_cfg = generate_norm_cfg(GPU_num=GPU_num)
|
||||
data_preprocessor = generate_data_preprocessor(crop_size = crop_size, mean = mean, std = std)
|
||||
|
||||
pretrained = generate_model_pretrained(pretrained_pth = pretrained_pth)
|
||||
backbone = generate_model_backbone(depth=depth)
|
||||
model_data_preprocessor = data_preprocessor # 修改data_preprocessor
|
||||
model_data_preprocessor = generate_model_data_preprocessor(model_data_preprocessor=model_data_preprocessor)
|
||||
decode_head = generate_model_decode_head(num_classes=num_classes, decode_head_loss_decode_dict=decode_head_loss_decode_dict, align_corners=align_corners)
|
||||
auxiliary_head = generate_model_auxiliary_head(num_classes=num_classes, auxiliary_head_loss_decode_dict=auxiliary_head_loss_decode_dict, align_corners=align_corners)
|
||||
# 综合model
|
||||
model = dict(
|
||||
pretrained = pretrained,
|
||||
backbone = backbone,
|
||||
data_preprocessor = model_data_preprocessor,
|
||||
decode_head = decode_head,
|
||||
auxiliary_head = auxiliary_head,
|
||||
)
|
||||
|
||||
test_cfg = generate_model_test_cfg(test_cfg_mode=test_cfg_mode, crop_size=crop_size, test_cfg_crop_div_stride=test_cfg_crop_div_stride)
|
||||
optim_wrapper = generate_optim_wrapper()
|
||||
param_scheduler = generate_param_scheduler(train_type, train_time_or_epoch_k)
|
||||
|
||||
########### 4.程序写入 ###########
|
||||
# 算法名称解析:isanet【Alg】 _ r50【pretrained模型深度】 - d8【_base_/models】 _ 4【GPU数量】 x b2【Batch Size大小】 - 40k【schedule】 _ cityscapes【数据集】 - 512x1024【crop_size】.py
|
||||
alg_file_name = f"{alg_name}_r{depth}_{'Pre' if select_pretrained else 'NoPre'}_g{GPU_num}-{schedule_file_name.lstrip('schedule_')}_{dataset_file_name}-{crop_size[0]}x{crop_size[1]}-{test_slide}.py"
|
||||
output_configs_alg_my_alg = os.path.join(f'./configs/isanet/', alg_file_name)
|
||||
write_config_to_file(output_configs_alg_my_alg, _base_=_base_, norm_cfg=norm_cfg, crop_size=crop_size, data_preprocessor=data_preprocessor, model=model, test_cfg=test_cfg, optim_wrapper=optim_wrapper, param_scheduler=param_scheduler)
|
||||
|
||||
# 将信息写入文件
|
||||
with open("_temp_.txt", "w") as file:
|
||||
json.dump({'alg_infos':{'alg_file_name':alg_file_name,'alg_file_pth':output_configs_alg_my_alg}}, file)
|
||||
276
Seg_All_In_One_MMSeg/My_All_In_One/2_Alg_Program/my_knet.py
Normal file
276
Seg_All_In_One_MMSeg/My_All_In_One/2_Alg_Program/my_knet.py
Normal file
@@ -0,0 +1,276 @@
|
||||
import os, sys, argparse, json
|
||||
# 添加上级目录到 sys.path
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../')))
|
||||
|
||||
from Initial_Alg_Program.Initial_Alg_Select_Tool import select_crop_size, select_pretrained_model, select_test_cfg_slide
|
||||
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_Tool import format_all_data, write_config_to_file, create_dict_by_kwargs, get_var_from_py_file
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_0_base_ import generate_base_config
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_1_norm_cfg import generate_norm_cfg
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_2_data_preprocessor import generate_data_preprocessor
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_3_model import generate_model_pretrained, generate_model_backbone, generate_model_data_preprocessor, generate_model_decode_head, generate_model_auxiliary_head, generate_model_train_cfg, generate_model_test_cfg
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_4_optimizer import generate_optim_wrapper, generate_param_scheduler
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_5_train_dataloader import generate_train_dataloader
|
||||
|
||||
if __name__ == '__main__':
|
||||
########### 1.超参数-算法默认参数 ###########
|
||||
alg_file_name = 'knet_r50-d8_my'
|
||||
|
||||
########### 2.参数解析 ###########
|
||||
parser = argparse.ArgumentParser(description="Run algorithm with specified parameters.")
|
||||
|
||||
# 添加命令行参数
|
||||
parser.add_argument('--alg_name', type=str, required=True, help="Algorithm name (e.g., 'ann')")
|
||||
parser.add_argument('--dataset_file_name', type=str, required=True, help="Dataset file name (e.g., 'my_dataset_model')")
|
||||
parser.add_argument('--mean', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
|
||||
parser.add_argument('--std', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
|
||||
parser.add_argument('--dataset_num_classes', type=int, required=True, help="Dataset class num (e.g., 36)")
|
||||
parser.add_argument('--GPU_num', type=int, required=True, help="Num of GPU to train (e.g., 1)")
|
||||
parser.add_argument('--schedule_file_name', type=str, required=True, help="Schedule file name (e.g., 'schedule_4k_check_400')")
|
||||
parser.add_argument('--train_type', type=str, required=True, help="Schedule train type (e.g., Epoch or Iteration)")
|
||||
parser.add_argument('--train_time_or_epoch_k', type=int, required=True, help="Schedule train time k (e.g., 4)")
|
||||
|
||||
# 解析命令行参数
|
||||
args = parser.parse_args()
|
||||
|
||||
# 变量赋值
|
||||
alg_name = args.alg_name
|
||||
dataset_file_name = args.dataset_file_name
|
||||
mean = json.loads(args.mean) # 使用json.loads转换成list
|
||||
std = json.loads(args.std) # 使用json.loads转换成list
|
||||
num_classes = args.dataset_num_classes
|
||||
schedule_file_name = args.schedule_file_name
|
||||
train_type = args.train_type
|
||||
train_time_or_epoch_k = args.train_time_or_epoch_k
|
||||
GPU_num = args.GPU_num
|
||||
|
||||
|
||||
########### 3. 设定参数 ###########
|
||||
|
||||
_base_ = generate_base_config(alg_file_name=alg_file_name, dataset_file_name=dataset_file_name, schedule_file_name=schedule_file_name) # _base_无算法
|
||||
|
||||
crop_size = select_crop_size(predefined_options=[(512, 512), (640,640)]) # 选择切割大小
|
||||
|
||||
data_preprocessor = generate_data_preprocessor(crop_size = crop_size, mean = mean, std = std)
|
||||
|
||||
model_data_preprocessor = data_preprocessor # 修改data_preprocessor
|
||||
model_data_preprocessor = generate_model_data_preprocessor(model_data_preprocessor=model_data_preprocessor)
|
||||
|
||||
norm_cfg = generate_norm_cfg(GPU_num=GPU_num)
|
||||
|
||||
# 选择对应的模型
|
||||
config_dict = {
|
||||
'DeepLabV3': {
|
||||
'decode_head_kernel_generate_head_dict': dict(
|
||||
_delete_ = True,
|
||||
type='ASPPHead',
|
||||
in_channels=2048,
|
||||
in_index=3,
|
||||
channels=512,
|
||||
dilations=(1, 12, 24, 36),
|
||||
dropout_ratio=0.1,
|
||||
num_classes=150,
|
||||
norm_cfg=norm_cfg,
|
||||
align_corners=False),
|
||||
'backbone_dilations': (1, 1, 2, 4),
|
||||
'backbone_strides': (1, 2, 1, 1),
|
||||
'auxiliary_head_in_channels':1024
|
||||
},
|
||||
'PSPNet': {
|
||||
'decode_head_kernel_generate_head_dict': dict(
|
||||
_delete_ = True,
|
||||
type='PSPHead',
|
||||
in_channels=2048,
|
||||
in_index=3,
|
||||
channels=512,
|
||||
pool_scales=(1, 2, 3, 6),
|
||||
dropout_ratio=0.1,
|
||||
num_classes=150,
|
||||
norm_cfg=norm_cfg,
|
||||
align_corners=False),
|
||||
'backbone_dilations': (1, 1, 2, 4),
|
||||
'backbone_strides': (1, 2, 1, 1),
|
||||
'auxiliary_head_in_channels':1024
|
||||
},
|
||||
'FCN': {
|
||||
'decode_head_kernel_generate_head_dict': dict(
|
||||
_delete_ = True,
|
||||
type='FCNHead',
|
||||
in_channels=2048,
|
||||
in_index=3,
|
||||
channels=512,
|
||||
num_convs=2,
|
||||
concat_input=True,
|
||||
dropout_ratio=0.1,
|
||||
num_classes=150,
|
||||
norm_cfg=norm_cfg,
|
||||
align_corners=False),
|
||||
'backbone_dilations': (1, 1, 2, 4),
|
||||
'backbone_strides': (1, 2, 1, 1),
|
||||
'auxiliary_head_in_channels':1024
|
||||
},
|
||||
'UPerNet': {
|
||||
'decode_head_kernel_generate_head_dict': dict(
|
||||
type='UPerHead',
|
||||
in_channels=[256, 512, 1024, 2048],
|
||||
in_index=[0, 1, 2, 3],
|
||||
pool_scales=(1, 2, 3, 6),
|
||||
channels=512,
|
||||
dropout_ratio=0.1,
|
||||
num_classes=150,
|
||||
norm_cfg=norm_cfg,
|
||||
align_corners=False),
|
||||
'backbone_dilations': (1, 1, 1, 1),
|
||||
'backbone_strides': (1, 2, 2, 2),
|
||||
'auxiliary_head_in_channels':1024
|
||||
}
|
||||
}
|
||||
|
||||
models = ['DeepLabV3', 'PSPNet', 'FCN', 'UPerNet']
|
||||
print("请你选择对应的模型:")
|
||||
while True:
|
||||
for index, model in enumerate(models, 1):
|
||||
print(f"{index}. {model}")
|
||||
try:
|
||||
user_input = int(input("Enter your choice (1-4): "))
|
||||
if 1 <= user_input <= 4:
|
||||
model_choice = models[user_input - 1]
|
||||
break
|
||||
else:
|
||||
print("Invalid input, please enter a number between 1 and 4.")
|
||||
except ValueError:
|
||||
print("Invalid input, please enter a valid number.")
|
||||
|
||||
decode_head_kernel_generate_head_dict = config_dict[model_choice]['decode_head_kernel_generate_head_dict']
|
||||
backbone_dilations = config_dict[model_choice]['backbone_dilations']
|
||||
backbone_strides = config_dict[model_choice]['backbone_strides']
|
||||
auxiliary_head_in_channels = config_dict[model_choice]['auxiliary_head_in_channels']
|
||||
|
||||
model_list = ['openmmlab/resnet50_v1c', 'openmmlab/resnet101_v1c', "pretrain/swin_tiny-f41b89d3.pth", "pretrain/swin_large-d5bdebaf.pth"]
|
||||
selected_model_name, select_pretrained, pretrained_pth, selected_model_info = select_pretrained_model(model_list=model_list)
|
||||
if selected_model_info['type'] == 'ResNetV1c':
|
||||
depth = selected_model_info['depth']
|
||||
backbone = create_dict_by_kwargs(depth = depth, norm_cfg=norm_cfg, dilations=backbone_dilations, strides=backbone_strides)
|
||||
elif selected_model_info['type'] == 'swin':
|
||||
if selected_model_name == 'pretrain/swin_tiny-f41b89d3.pth':
|
||||
backbone = dict(
|
||||
_delete_=True,
|
||||
type='SwinTransformer',
|
||||
embed_dims=96, #
|
||||
depths=[2, 2, 6, 2], #
|
||||
num_heads=[3, 6, 12, 24], #
|
||||
window_size=7,
|
||||
mlp_ratio=4,
|
||||
qkv_bias=True,
|
||||
qk_scale=None,
|
||||
drop_rate=0.,
|
||||
attn_drop_rate=0.,
|
||||
drop_path_rate=0.3, #
|
||||
use_abs_pos_embed=False,
|
||||
patch_norm=True,
|
||||
out_indices=(0, 1, 2, 3))
|
||||
if selected_model_name == 'pretrain/swin_large-d5bdebaf.pth':
|
||||
backbone = dict(
|
||||
_delete_=True,
|
||||
type='SwinTransformer',
|
||||
embed_dims=192, #
|
||||
depths=[2, 2, 18, 2], #
|
||||
num_heads=[6, 12, 24, 48], #
|
||||
window_size=7,
|
||||
mlp_ratio=4,
|
||||
qkv_bias=True,
|
||||
qk_scale=None,
|
||||
drop_rate=0.,
|
||||
attn_drop_rate=0.,
|
||||
drop_path_rate=0.4, #
|
||||
use_abs_pos_embed=False,
|
||||
patch_norm=True,
|
||||
out_indices=(0, 1, 2, 3))
|
||||
backbone.update(create_dict_by_kwargs(norm_cfg=dict(type='LN'))) # TODO Transformer一类的内容norm_cfg都设置为LN TODO # 没有 dilations 和 strides
|
||||
|
||||
test_cfg_mode, test_cfg_crop_div_stride = select_test_cfg_slide(crop_size)
|
||||
|
||||
# 3.2. 自动生成参数
|
||||
# 如果使用slide模式,则开启align_corners
|
||||
if test_cfg_mode == "slide":
|
||||
align_corners = True
|
||||
test_slide = 'testslide' # 算法名进行标记
|
||||
else:
|
||||
align_corners = False
|
||||
test_slide = 'no_testslide' # 算法名进行标记
|
||||
|
||||
|
||||
########### 3.生成各个部分 ###########
|
||||
|
||||
# 3.2. decode_head
|
||||
decode_head = get_var_from_py_file(os.path.join('./configs/_base_/models', alg_file_name+'.py'), 'model')['decode_head']
|
||||
|
||||
decode_head_kernel_update_head_list = decode_head['kernel_update_head']
|
||||
for i in range(len(decode_head_kernel_update_head_list)):
|
||||
decode_head_kernel_update_head_list[i]['num_classes'] = num_classes
|
||||
decode_head['kernel_update_head'] = decode_head_kernel_update_head_list
|
||||
|
||||
decode_head_kernel_generate_head_dict['num_classes'] = num_classes
|
||||
if model_choice == 'UPerNet':
|
||||
if selected_model_name == 'pretrain/swin_tiny-f41b89d3.pth':
|
||||
decode_head_kernel_generate_head_dict['in_channels'] = [96, 192, 384, 768]
|
||||
auxiliary_head_in_channels = 384
|
||||
elif selected_model_name == 'pretrain/swin_large-d5bdebaf.pth':
|
||||
decode_head_kernel_generate_head_dict['in_channels'] = [192, 384, 768, 1536]
|
||||
auxiliary_head_in_channels = 768
|
||||
elif model_choice in ['DeepLabV3', 'PSPNet', 'FCN']:
|
||||
if selected_model_name == 'pretrain/swin_tiny-f41b89d3.pth':
|
||||
decode_head_kernel_generate_head_dict['in_channels'] = 768
|
||||
auxiliary_head_in_channels = 384
|
||||
elif selected_model_name == 'pretrain/swin_large-d5bdebaf.pth':
|
||||
decode_head_kernel_generate_head_dict['in_channels'] = 1536
|
||||
auxiliary_head_in_channels = 768
|
||||
|
||||
|
||||
decode_head_kernel_generate_head_dict_loss_decode = dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0) # Way Ori
|
||||
decode_head_kernel_generate_head_dict_loss_decode = dict(type='DiceLoss', use_sigmoid=False, loss_weight=1.0) # Way 1
|
||||
decode_head_kernel_generate_head_dict['loss_decode'] = decode_head_kernel_generate_head_dict_loss_decode
|
||||
|
||||
decode_head['kernel_generate_head'] = decode_head_kernel_generate_head_dict
|
||||
|
||||
decode_head['kernel_generate_head']['align_corners'] = align_corners
|
||||
|
||||
# 3.2. auxiliary_head
|
||||
auxiliary_head = get_var_from_py_file(os.path.join('./configs/_base_/models', alg_file_name+'.py'), 'model')['auxiliary_head']
|
||||
auxiliary_head_loss_decode = dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4) # Way Ori
|
||||
auxiliary_head_loss_decode = dict(type='DiceLoss', use_sigmoid=False, loss_weight=0.4) # Way 1
|
||||
auxiliary_head['loss_decode'] = auxiliary_head_loss_decode
|
||||
auxiliary_head['align_corners'] = align_corners
|
||||
auxiliary_head['in_channels'] = auxiliary_head_in_channels
|
||||
|
||||
# 综合model
|
||||
model = dict(
|
||||
backbone = backbone,
|
||||
pretrained = pretrained_pth,
|
||||
data_preprocessor = model_data_preprocessor,
|
||||
decode_head = decode_head,
|
||||
auxiliary_head = auxiliary_head,
|
||||
)
|
||||
|
||||
# 4. train_dataloader
|
||||
train_dataloader, batch_size, num_workers = generate_train_dataloader(batch_size_default=2, num_workers_default=2)
|
||||
if selected_model_info['type'] == 'ResNetV1c':
|
||||
optim_wrapper = generate_optim_wrapper()
|
||||
elif selected_model_info['type'] == 'swin':
|
||||
optim_wrapper = generate_optim_wrapper('swin')
|
||||
|
||||
param_scheduler = generate_param_scheduler(train_type, train_time_or_epoch_k)
|
||||
|
||||
|
||||
########### 4.程序写入 ###########
|
||||
# 算法名称解析:knet【Alg】 _ r50【pretrained模型深度】 - d8【_base_/models】 _ 4【GPU数量】 x b2【Batch Size大小】 - 40k【schedule】 _ cityscapes【数据集】 - 512x1024【crop_size】.py
|
||||
if selected_model_info['type'] == 'ResNetV1c':
|
||||
alg_file_name = f"{alg_name}_r{depth}_{model_choice}_{'Pre' if select_pretrained else 'NoPre'}_g{GPU_num}-{schedule_file_name.lstrip('schedule_')}_{dataset_file_name}-{crop_size[0]}x{crop_size[1]}-{test_slide}.py"
|
||||
elif selected_model_info['type'] == 'swin':
|
||||
alg_file_name = f"{alg_name}_swin_{selected_model_info['size']}_{model_choice}_{'Pre' if select_pretrained else 'NoPre'}_g{GPU_num}-{schedule_file_name.lstrip('schedule_')}_{dataset_file_name}-{crop_size[0]}x{crop_size[1]}-{test_slide}.py"
|
||||
output_configs_alg_my_alg = os.path.join(f'./configs/knet/', alg_file_name)
|
||||
write_config_to_file(output_configs_alg_my_alg, _base_=_base_, norm_cfg=norm_cfg, crop_size=crop_size, data_preprocessor=data_preprocessor, model=model, train_dataloader=train_dataloader, optim_wrapper=optim_wrapper, param_scheduler=param_scheduler)
|
||||
|
||||
# 将信息写入文件
|
||||
with open("_temp_.txt", "w") as file:
|
||||
json.dump({'alg_infos':{'alg_file_name':alg_file_name,'alg_file_pth':output_configs_alg_my_alg}}, file)
|
||||
@@ -0,0 +1,136 @@
|
||||
import os, sys, argparse, json
|
||||
# 添加上级目录到 sys.path
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../')))
|
||||
|
||||
from Initial_Alg_Program.Initial_Alg_Select_Tool import select_crop_size, select_pretrained_model, select_test_cfg_slide
|
||||
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_Tool import format_all_data, write_config_to_file, create_dict_by_kwargs, get_var_from_file, update_list_dict_var, get_var_from_py_file
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_0_base_ import generate_base_config
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_1_norm_cfg import generate_norm_cfg
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_2_data_preprocessor import generate_data_preprocessor
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_3_model import generate_model_pretrained, generate_model_backbone, generate_model_data_preprocessor, generate_model_decode_head, generate_model_auxiliary_head, generate_model_train_cfg, generate_model_test_cfg
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_4_optimizer import generate_optim_wrapper, generate_param_scheduler
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_5_train_dataloader import generate_train_dataloader
|
||||
|
||||
# 交互式选择 decode_head 的函数
|
||||
def select_decode_head(decode_head_choose):
|
||||
print("可用的 decode head 选项:")
|
||||
for i, key in enumerate(decode_head_choose.keys()):
|
||||
print(f"{i + 1}. {key}")
|
||||
|
||||
while True:
|
||||
try:
|
||||
# 提示用户输入选择
|
||||
choice = int(input("请选择需要的 decode head(输入编号):"))
|
||||
# 检查输入是否在有效范围内
|
||||
if 1 <= choice <= len(decode_head_choose):
|
||||
selected_key = list(decode_head_choose.keys())[choice - 1]
|
||||
print(f"你选择了: {selected_key}")
|
||||
return decode_head_choose[selected_key]
|
||||
else:
|
||||
print("输入的编号不正确,请重新输入。")
|
||||
except ValueError:
|
||||
print("输入无效,请输入有效的编号。")
|
||||
|
||||
if __name__ == '__main__':
|
||||
########### 1.超参数-算法默认参数 ###########
|
||||
alg_file_name = 'upernet_mae'
|
||||
alg_file_pth = f"configs/_base_/models/{alg_file_name}.py"
|
||||
|
||||
########### 2.参数解析 ###########
|
||||
parser = argparse.ArgumentParser(description="Run algorithm with specified parameters.")
|
||||
|
||||
# 添加命令行参数
|
||||
parser.add_argument('--alg_name', type=str, required=True, help="Algorithm name (e.g., 'ann')")
|
||||
parser.add_argument('--dataset_file_name', type=str, required=True, help="Dataset file name (e.g., 'my_dataset_model')")
|
||||
parser.add_argument('--mean', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
|
||||
parser.add_argument('--std', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
|
||||
parser.add_argument('--dataset_num_classes', type=int, required=True, help="Dataset class num (e.g., 36)")
|
||||
parser.add_argument('--GPU_num', type=int, required=True, help="Num of GPU to train (e.g., 1)")
|
||||
parser.add_argument('--schedule_file_name', type=str, required=True, help="Schedule file name (e.g., 'schedule_4k_check_400')")
|
||||
parser.add_argument('--train_type', type=str, required=True, help="Schedule train type (e.g., Epoch or Iteration)")
|
||||
parser.add_argument('--train_time_or_epoch_k', type=int, required=True, help="Schedule train time k (e.g., 4)")
|
||||
|
||||
# 解析命令行参数
|
||||
args = parser.parse_args()
|
||||
|
||||
# 变量赋值
|
||||
alg_name = args.alg_name
|
||||
dataset_file_name = args.dataset_file_name
|
||||
mean = json.loads(args.mean) # 使用json.loads转换成list
|
||||
std = json.loads(args.std) # 使用json.loads转换成list
|
||||
num_classes = args.dataset_num_classes
|
||||
schedule_file_name = args.schedule_file_name
|
||||
train_type = args.train_type
|
||||
train_time_or_epoch_k = args.train_time_or_epoch_k
|
||||
GPU_num = args.GPU_num
|
||||
|
||||
########### 3.1. 生成_base_部分 ###########
|
||||
_base_ = generate_base_config(alg_file_name=alg_file_name, dataset_file_name=dataset_file_name, schedule_file_name=schedule_file_name)
|
||||
|
||||
########### 3.1. 生成norm_cfg部分 ###########
|
||||
norm_cfg = generate_norm_cfg(GPU_num=GPU_num)
|
||||
|
||||
########### 3.2. 生成data_preprocessor部分 ###########
|
||||
crop_size = select_crop_size(predefined_options=[(512, 512)]) # 选择切割大小
|
||||
data_preprocessor = generate_data_preprocessor(crop_size = crop_size, mean = mean, std = std)
|
||||
|
||||
########### 3.3. 生成model部分 ###########
|
||||
model_data_preprocessor = data_preprocessor # 修改data_preprocessor
|
||||
# 获取backbone模型、是否需要预训练
|
||||
model_list = ["mae_pretrain_vit_base_mmcls.pth"]
|
||||
selected_model_name, select_pretrained, pretrained_pth, selected_model_info = select_pretrained_model(model_list=model_list) # 需要选择是否用预训练模型
|
||||
|
||||
backbone = create_dict_by_kwargs(type='MAE', img_size=crop_size, init_values=1.0)
|
||||
|
||||
decode_head_loss_decode=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0) # Way Ori
|
||||
decode_head_loss_decode=dict(type='DiceLoss', use_sigmoid=False, loss_weight=1.0) # Way 1
|
||||
|
||||
auxiliary_head_loss_decode=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4) # Way Ori
|
||||
auxiliary_head_loss_decode=dict(type='DiceLoss', use_sigmoid=False, loss_weight=0.4) # Way 1
|
||||
|
||||
neck=dict(embed_dim=768, rescales=[4, 2, 1, 0.5])
|
||||
decode_head=dict(in_channels=[768, 768, 768, 768], channels=768, num_classes=num_classes, norm_cfg=norm_cfg, loss_decode=decode_head_loss_decode)
|
||||
auxiliary_head=dict(in_channels=768, norm_cfg=norm_cfg, num_classes=num_classes, loss_decode = auxiliary_head_loss_decode)
|
||||
|
||||
test_cfg_mode, test_cfg_crop_div_stride = select_test_cfg_slide(crop_size, select_slide=True)
|
||||
test_cfg = generate_model_test_cfg(test_cfg_mode='slide', crop_size=crop_size, test_cfg_crop_div_stride=test_cfg_crop_div_stride)
|
||||
|
||||
optim_wrapper = dict(
|
||||
_delete_=True,
|
||||
type='AmpOptimWrapper', # type='OptimWrapper', # TODO Ori
|
||||
optimizer=dict(
|
||||
type='AdamW', lr=3e-5, betas=(0.9, 0.999), weight_decay=0.05),
|
||||
constructor='LayerDecayOptimizerConstructor',
|
||||
paramwise_cfg=dict(num_layers=12, layer_decay_rate=0.9))
|
||||
|
||||
model_size = 'base'
|
||||
|
||||
train_dataloader, batch_size = generate_train_dataloader(4)
|
||||
|
||||
model = dict(
|
||||
backbone = backbone,
|
||||
data_preprocessor = model_data_preprocessor,
|
||||
pretrained = pretrained_pth,
|
||||
decode_head = decode_head,
|
||||
neck = neck,
|
||||
auxiliary_head = auxiliary_head,
|
||||
test_cfg = test_cfg
|
||||
)
|
||||
|
||||
########### 3.4. 生成fp16部分 ###########
|
||||
fp16 = dict(loss_scale='dynamic') # mixed precision
|
||||
|
||||
########### 3.4. 生成optim、param_scheduler部分 ###########
|
||||
optim_wrapper = optim_wrapper
|
||||
param_scheduler = generate_param_scheduler(train_type, train_time_or_epoch_k)
|
||||
|
||||
########### 4.程序写入 ###########
|
||||
# 算法名称解析:beit【Alg】 _ r50【pretrained模型深度】 - d8【_base_/models】 _ 4【GPU数量】 x b2【Batch Size大小】 - 40k【schedule】 _ cityscapes【数据集】 - 512x1024【crop_size】.py
|
||||
alg_file_name = f"{alg_name}-{model_size}_b{batch_size}_g{GPU_num}-{schedule_file_name.lstrip('schedule_')}_{dataset_file_name}-{crop_size[0]}x{crop_size[1]}-testslide.py"
|
||||
output_configs_alg_my_alg = os.path.join(f'./configs/mae/', alg_file_name)
|
||||
write_config_to_file(output_configs_alg_my_alg, _base_=_base_, norm_cfg=norm_cfg, crop_size=crop_size, data_preprocessor=data_preprocessor, model=model, fp16=fp16,optim_wrapper=optim_wrapper, param_scheduler=param_scheduler, train_dataloader = train_dataloader)
|
||||
|
||||
# 将信息写入文件
|
||||
with open("_temp_.txt", "w") as file:
|
||||
json.dump({'alg_infos':{'alg_file_name':alg_file_name,'alg_file_pth':output_configs_alg_my_alg}}, file)
|
||||
@@ -0,0 +1,284 @@
|
||||
import os, sys, argparse, json
|
||||
# 添加上级目录到 sys.path
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../')))
|
||||
|
||||
from Initial_Alg_Program.Initial_Alg_Select_Tool import select_crop_size, select_pretrained_model, select_test_cfg_slide
|
||||
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_Tool import format_all_data, write_config_to_file, create_dict_by_kwargs, get_var_from_py_file
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_0_base_ import generate_base_config
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_1_norm_cfg import generate_norm_cfg
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_2_data_preprocessor import generate_data_preprocessor
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_3_model import generate_model_pretrained, generate_model_backbone, generate_model_data_preprocessor, generate_model_decode_head, generate_model_auxiliary_head, generate_model_train_cfg, generate_model_test_cfg
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_4_optimizer import generate_optim_wrapper, generate_param_scheduler
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_5_train_dataloader import generate_train_dataloader
|
||||
|
||||
if __name__ == '__main__':
|
||||
########### 1.超参数-算法默认参数 ###########
|
||||
alg_file_name = 'mask2former_my'
|
||||
|
||||
########### 2.参数解析 ###########
|
||||
parser = argparse.ArgumentParser(description="Run algorithm with specified parameters.")
|
||||
|
||||
# 添加命令行参数
|
||||
parser.add_argument('--alg_name', type=str, required=True, help="Algorithm name (e.g., 'ann')")
|
||||
parser.add_argument('--dataset_file_name', type=str, required=True, help="Dataset file name (e.g., 'my_dataset_model')")
|
||||
parser.add_argument('--mean', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
|
||||
parser.add_argument('--std', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
|
||||
parser.add_argument('--dataset_num_classes', type=int, required=True, help="Dataset class num (e.g., 36)")
|
||||
parser.add_argument('--GPU_num', type=int, required=True, help="Num of GPU to train (e.g., 1)")
|
||||
parser.add_argument('--schedule_file_name', type=str, required=True, help="Schedule file name (e.g., 'schedule_4k_check_400')")
|
||||
parser.add_argument('--train_type', type=str, required=True, help="Schedule train type (e.g., Epoch or Iteration)")
|
||||
parser.add_argument('--train_time_or_epoch_k', type=int, required=True, help="Schedule train time k (e.g., 4)")
|
||||
|
||||
# 解析命令行参数
|
||||
args = parser.parse_args()
|
||||
|
||||
# 变量赋值
|
||||
alg_name = args.alg_name
|
||||
dataset_file_name = args.dataset_file_name
|
||||
mean = json.loads(args.mean) # 使用json.loads转换成list
|
||||
std = json.loads(args.std) # 使用json.loads转换成list
|
||||
num_classes = args.dataset_num_classes
|
||||
schedule_file_name = args.schedule_file_name
|
||||
train_type = args.train_type
|
||||
train_time_or_epoch_k = args.train_time_or_epoch_k
|
||||
GPU_num = args.GPU_num
|
||||
|
||||
|
||||
########### 3. 设定参数 ###########
|
||||
|
||||
_base_ = generate_base_config(alg_file_name=alg_file_name, dataset_file_name=dataset_file_name, schedule_file_name=schedule_file_name) # _base_无算法
|
||||
|
||||
crop_size = select_crop_size(predefined_options=[(512, 512), (640, 640), (512, 1024)]) # 选择切割大小
|
||||
|
||||
data_preprocessor = generate_data_preprocessor(crop_size = crop_size, mean = mean, std = std)
|
||||
|
||||
model_data_preprocessor = data_preprocessor # 修改data_preprocessor
|
||||
model_data_preprocessor = generate_model_data_preprocessor(model_data_preprocessor=model_data_preprocessor)
|
||||
|
||||
norm_cfg = generate_norm_cfg(GPU_num=GPU_num)
|
||||
|
||||
# 3.1. 选择backbone
|
||||
model_list = ['torchvision://resnet50', 'torchvision://resnet101', 'pretrain/swin_large-6580f57d.pth', 'pretrain/swin_base-e5c09f74.pth', 'pretrain/swin_small-7ba6d6dd.pth', 'pretrain/swin_tiny-1cdeb081.pth']
|
||||
selected_model_name, select_pretrained, pretrained_pth, selected_model_info = select_pretrained_model(model_list=model_list)
|
||||
|
||||
if selected_model_info['type'] == 'ResNet':
|
||||
depth = selected_model_info['depth']
|
||||
backbone = create_dict_by_kwargs(depth = depth, type=selected_model_info['type'], norm_cfg=norm_cfg, num_stages=4, frozen_stages=-1, style='pytorch', init_cfg=dict(type='Pretrained', checkpoint=pretrained_pth))
|
||||
in_channels = [256, 512, 1024, 2048]
|
||||
elif selected_model_info['type'] == 'swin':
|
||||
if selected_model_name == 'pretrain/swin_tiny-1cdeb081.pth':
|
||||
in_channels = [96, 192, 384, 768]
|
||||
depths=[2, 2, 6, 2]
|
||||
backbone = dict(
|
||||
_delete_=True,
|
||||
type='SwinTransformer',
|
||||
pretrain_img_size=224,
|
||||
embed_dims=96,
|
||||
depths=depths,
|
||||
num_heads=[3, 6, 12, 24],
|
||||
window_size=7,
|
||||
mlp_ratio=4,
|
||||
qkv_bias=True,
|
||||
qk_scale=None,
|
||||
drop_rate=0.,
|
||||
attn_drop_rate=0.,
|
||||
drop_path_rate=0.3,
|
||||
patch_norm=True,
|
||||
# out_indices=(0, 1, 2, 3),
|
||||
with_cp=False,
|
||||
# frozen_stages=-1,
|
||||
init_cfg=dict(type='Pretrained', checkpoint=pretrained_pth))
|
||||
if selected_model_name == 'pretrain/swin_small-7ba6d6dd.pth':
|
||||
depths=[2, 2, 18, 2]
|
||||
in_channels = [96, 192, 384, 768]
|
||||
backbone = dict(
|
||||
_delete_=True,
|
||||
type='SwinTransformer',
|
||||
pretrain_img_size=224,
|
||||
embed_dims=96,
|
||||
depths=depths,
|
||||
num_heads=[3, 6, 12, 24],
|
||||
window_size=7,
|
||||
mlp_ratio=4,
|
||||
qkv_bias=True,
|
||||
qk_scale=None,
|
||||
drop_rate=0.,
|
||||
attn_drop_rate=0.,
|
||||
drop_path_rate=0.3,
|
||||
patch_norm=True,
|
||||
# out_indices=(0, 1, 2, 3),
|
||||
with_cp=False,
|
||||
# frozen_stages=-1,
|
||||
init_cfg=dict(type='Pretrained', checkpoint=pretrained_pth))
|
||||
if selected_model_name == 'pretrain/swin_base-e5c09f74.pth':
|
||||
depths=[2, 2, 18, 2]
|
||||
in_channels = [128, 256, 512, 1024]
|
||||
backbone = dict(
|
||||
type='SwinTransformer',
|
||||
pretrain_img_size=384,
|
||||
embed_dims=128,
|
||||
depths=depths,
|
||||
num_heads=[4, 8, 16, 32],
|
||||
window_size=12,
|
||||
mlp_ratio=4,
|
||||
qkv_bias=True,
|
||||
qk_scale=None,
|
||||
drop_rate=0.,
|
||||
attn_drop_rate=0.,
|
||||
drop_path_rate=0.3,
|
||||
patch_norm=True,
|
||||
# out_indices=(0, 1, 2, 3),
|
||||
with_cp=False,
|
||||
# frozen_stages=-1,
|
||||
init_cfg=dict(type='Pretrained', checkpoint=pretrained_pth))
|
||||
if selected_model_name == 'pretrain/swin_large-6580f57d.pth':
|
||||
in_channels = [192, 384, 768, 1536]
|
||||
depths=[2, 2, 18, 2]
|
||||
backbone = dict(
|
||||
type='SwinTransformer',
|
||||
pretrain_img_size=384,
|
||||
embed_dims=192,
|
||||
depths=depths,
|
||||
num_heads=[6, 12, 24, 48],
|
||||
window_size=12,
|
||||
mlp_ratio=4,
|
||||
qkv_bias=True,
|
||||
qk_scale=None,
|
||||
drop_rate=0.,
|
||||
attn_drop_rate=0.,
|
||||
drop_path_rate=0.3,
|
||||
patch_norm=True,
|
||||
# out_indices=(0, 1, 2, 3),
|
||||
with_cp=False,
|
||||
# frozen_stages=-1,
|
||||
init_cfg=dict(type='Pretrained', checkpoint=pretrained_pth))
|
||||
# set all layers in backbone to lr_mult=0.1
|
||||
# set all norm layers, position_embeding,
|
||||
# query_embeding, level_embeding to decay_multi=0.0
|
||||
backbone_norm_multi = dict(lr_mult=0.1, decay_mult=0.0)
|
||||
backbone_embed_multi = dict(lr_mult=0.1, decay_mult=0.0)
|
||||
embed_multi = dict(lr_mult=1.0, decay_mult=0.0)
|
||||
custom_keys = {
|
||||
'backbone': dict(lr_mult=0.1, decay_mult=1.0),
|
||||
'backbone.patch_embed.norm': backbone_norm_multi,
|
||||
'backbone.norm': backbone_norm_multi,
|
||||
'absolute_pos_embed': backbone_embed_multi,
|
||||
'relative_position_bias_table': backbone_embed_multi,
|
||||
'query_embed': embed_multi,
|
||||
'query_feat': embed_multi,
|
||||
'level_embed': embed_multi
|
||||
}
|
||||
custom_keys.update({
|
||||
f'backbone.stages.{stage_id}.blocks.{block_id}.norm': backbone_norm_multi
|
||||
for stage_id, num_blocks in enumerate(depths)
|
||||
for block_id in range(num_blocks)
|
||||
})
|
||||
custom_keys.update({
|
||||
f'backbone.stages.{stage_id}.downsample.norm': backbone_norm_multi
|
||||
for stage_id in range(len(depths) - 1)
|
||||
})
|
||||
|
||||
# 3.2. decode_head
|
||||
decode_head = dict(num_classes=num_classes, in_channels=in_channels,
|
||||
loss_cls=dict(class_weight=[1.0] * num_classes + [0.1]))
|
||||
|
||||
# 3.3. optimizer部分
|
||||
if selected_model_info['type'] == 'ResNet':
|
||||
# optimizer
|
||||
embed_multi = dict(lr_mult=1.0, decay_mult=0.0)
|
||||
optimizer = dict(
|
||||
type='AdamW', lr=0.0001, weight_decay=0.05, eps=1e-8, betas=(0.9, 0.999))
|
||||
optim_wrapper = dict(
|
||||
_delete_=True,
|
||||
type='OptimWrapper',
|
||||
optimizer=optimizer,
|
||||
clip_grad=dict(max_norm=0.01, norm_type=2),
|
||||
paramwise_cfg=dict(
|
||||
custom_keys={
|
||||
'backbone': dict(lr_mult=0.1, decay_mult=1.0),
|
||||
'query_embed': embed_multi,
|
||||
'query_feat': embed_multi,
|
||||
'level_embed': embed_multi,
|
||||
},
|
||||
norm_decay_mult=0.0))
|
||||
train_dataloader, batch_size, num_workers = generate_train_dataloader(batch_size_default=4, num_workers_default=4)
|
||||
elif selected_model_info['type'] == 'swin':
|
||||
# set all layers in backbone to lr_mult=0.1
|
||||
# set all norm layers, position_embeding,
|
||||
# query_embeding, level_embeding to decay_multi=0.0
|
||||
backbone_norm_multi = dict(lr_mult=0.1, decay_mult=0.0)
|
||||
backbone_embed_multi = dict(lr_mult=0.1, decay_mult=0.0)
|
||||
embed_multi = dict(lr_mult=1.0, decay_mult=0.0)
|
||||
custom_keys = {
|
||||
'backbone': dict(lr_mult=0.1, decay_mult=1.0),
|
||||
'backbone.patch_embed.norm': backbone_norm_multi,
|
||||
'backbone.norm': backbone_norm_multi,
|
||||
'absolute_pos_embed': backbone_embed_multi,
|
||||
'relative_position_bias_table': backbone_embed_multi,
|
||||
'query_embed': embed_multi,
|
||||
'query_feat': embed_multi,
|
||||
'level_embed': embed_multi
|
||||
}
|
||||
custom_keys.update({
|
||||
f'backbone.stages.{stage_id}.blocks.{block_id}.norm': backbone_norm_multi
|
||||
for stage_id, num_blocks in enumerate(depths)
|
||||
for block_id in range(num_blocks)
|
||||
})
|
||||
custom_keys.update({
|
||||
f'backbone.stages.{stage_id}.downsample.norm': backbone_norm_multi
|
||||
for stage_id in range(len(depths) - 1)
|
||||
})
|
||||
# optimizer
|
||||
optimizer = dict(
|
||||
type='AdamW', lr=0.0001, weight_decay=0.05, eps=1e-8, betas=(0.9, 0.999))
|
||||
optim_wrapper = dict(
|
||||
_delete_=True,
|
||||
type='OptimWrapper',
|
||||
optimizer=optimizer,
|
||||
clip_grad=dict(max_norm=0.01, norm_type=2),
|
||||
paramwise_cfg={'custom_keys':custom_keys, 'norm_decay_mult':0.0})
|
||||
|
||||
# 更新train_dataloader
|
||||
train_dataloader, batch_size, num_workers = generate_train_dataloader(batch_size_default=2, num_workers_default=2)
|
||||
|
||||
# 3.4. train_dataloader部分
|
||||
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='LoadAnnotations'),
|
||||
dict(
|
||||
type='RandomChoiceResize',
|
||||
scales=[int(max(crop_size[0], crop_size[1]) * x * 0.1) for x in range(5, 21)],
|
||||
resize_type='ResizeShortestEdge',
|
||||
max_size=max(crop_size[0], crop_size[1])*4),
|
||||
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
|
||||
dict(type='RandomFlip', prob=0.5),
|
||||
dict(type='PhotoMetricDistortion'),
|
||||
dict(type='PackSegInputs')
|
||||
]
|
||||
train_dataloader.update(dict(dataset=dict(pipeline=train_pipeline)))
|
||||
|
||||
# 3.5. param_scheduler部分
|
||||
param_scheduler = generate_param_scheduler(train_type, train_time_or_epoch_k)
|
||||
|
||||
########### 3.生成各个部分 ###########
|
||||
# 综合model
|
||||
model = dict(
|
||||
backbone = backbone,
|
||||
data_preprocessor = model_data_preprocessor,
|
||||
decode_head = decode_head
|
||||
)
|
||||
|
||||
########### 4.程序写入 ###########
|
||||
# 算法名称解析:mask2former【Alg】 _ r50【pretrained模型深度】 - d8【_base_/models】 _ 4【GPU数量】 x b2【Batch Size大小】 - 40k【schedule】 _ cityscapes【数据集】 - 512x1024【crop_size】.py
|
||||
if selected_model_info['type'] == 'ResNet':
|
||||
alg_file_name = f"{alg_name}_r{depth}_{'Pre' if select_pretrained else 'NoPre'}_g{GPU_num}-{schedule_file_name.lstrip('schedule_')}_{dataset_file_name}-{crop_size[0]}x{crop_size[1]}.py"
|
||||
elif selected_model_info['type'] == 'swin':
|
||||
alg_file_name = f"{alg_name}_swin_{selected_model_info['size']}_{'Pre' if select_pretrained else 'NoPre'}_g{GPU_num}-{schedule_file_name.lstrip('schedule_')}_{dataset_file_name}-{crop_size[0]}x{crop_size[1]}.py"
|
||||
output_configs_alg_my_alg = os.path.join(f'./configs/mask2former/', alg_file_name)
|
||||
write_config_to_file(output_configs_alg_my_alg, _base_=_base_, norm_cfg=norm_cfg, crop_size=crop_size, data_preprocessor=data_preprocessor, model=model, train_dataloader=train_dataloader, optim_wrapper=optim_wrapper, param_scheduler=param_scheduler)
|
||||
|
||||
# 将信息写入文件
|
||||
with open("_temp_.txt", "w") as file:
|
||||
json.dump({'alg_infos':{'alg_file_name':alg_file_name,'alg_file_pth':output_configs_alg_my_alg}}, file)
|
||||
@@ -0,0 +1,246 @@
|
||||
import os, sys, argparse, json
|
||||
# 添加上级目录到 sys.path
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../')))
|
||||
|
||||
from Initial_Alg_Program.Initial_Alg_Select_Tool import select_crop_size, select_pretrained_model, select_test_cfg_slide
|
||||
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_Tool import format_all_data, write_config_to_file, create_dict_by_kwargs, get_var_from_py_file
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_0_base_ import generate_base_config
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_1_norm_cfg import generate_norm_cfg
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_2_data_preprocessor import generate_data_preprocessor
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_3_model import generate_model_pretrained, generate_model_backbone, generate_model_data_preprocessor, generate_model_decode_head, generate_model_auxiliary_head, generate_model_train_cfg, generate_model_test_cfg
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_4_optimizer import generate_optim_wrapper, generate_param_scheduler
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_5_train_dataloader import generate_train_dataloader
|
||||
|
||||
if __name__ == '__main__':
|
||||
########### 1.超参数-算法默认参数 ###########
|
||||
alg_file_name = 'maskformer_my'
|
||||
|
||||
########### 2.参数解析 ###########
|
||||
parser = argparse.ArgumentParser(description="Run algorithm with specified parameters.")
|
||||
|
||||
# 添加命令行参数
|
||||
parser.add_argument('--alg_name', type=str, required=True, help="Algorithm name (e.g., 'ann')")
|
||||
parser.add_argument('--dataset_file_name', type=str, required=True, help="Dataset file name (e.g., 'my_dataset_model')")
|
||||
parser.add_argument('--mean', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
|
||||
parser.add_argument('--std', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
|
||||
parser.add_argument('--dataset_num_classes', type=int, required=True, help="Dataset class num (e.g., 36)")
|
||||
parser.add_argument('--GPU_num', type=int, required=True, help="Num of GPU to train (e.g., 1)")
|
||||
parser.add_argument('--schedule_file_name', type=str, required=True, help="Schedule file name (e.g., 'schedule_4k_check_400')")
|
||||
parser.add_argument('--train_type', type=str, required=True, help="Schedule train type (e.g., Epoch or Iteration)")
|
||||
parser.add_argument('--train_time_or_epoch_k', type=int, required=True, help="Schedule train time k (e.g., 4)")
|
||||
|
||||
# 解析命令行参数
|
||||
args = parser.parse_args()
|
||||
|
||||
# 变量赋值
|
||||
alg_name = args.alg_name
|
||||
dataset_file_name = args.dataset_file_name
|
||||
mean = json.loads(args.mean) # 使用json.loads转换成list
|
||||
std = json.loads(args.std) # 使用json.loads转换成list
|
||||
num_classes = args.dataset_num_classes
|
||||
schedule_file_name = args.schedule_file_name
|
||||
train_type = args.train_type
|
||||
train_time_or_epoch_k = args.train_time_or_epoch_k
|
||||
GPU_num = args.GPU_num
|
||||
|
||||
|
||||
########### 3. 设定参数 ###########
|
||||
|
||||
_base_ = generate_base_config(alg_file_name=alg_file_name, dataset_file_name=dataset_file_name, schedule_file_name=schedule_file_name) # _base_无算法
|
||||
|
||||
crop_size = select_crop_size(predefined_options=[(512, 512), (640, 640), (512, 1024)]) # 选择切割大小
|
||||
|
||||
data_preprocessor = generate_data_preprocessor(crop_size = crop_size, mean = mean, std = std)
|
||||
|
||||
model_data_preprocessor = data_preprocessor # 修改data_preprocessor
|
||||
model_data_preprocessor = generate_model_data_preprocessor(model_data_preprocessor=model_data_preprocessor)
|
||||
|
||||
norm_cfg = generate_norm_cfg(GPU_num=GPU_num)
|
||||
|
||||
# 3.1. 选择backbone
|
||||
model_list = ['torchvision://resnet50', 'torchvision://resnet101', 'pretrain/swin_large-6580f57d.pth', 'pretrain/swin_base-e5c09f74.pth', 'pretrain/swin_small-7ba6d6dd.pth', 'pretrain/swin_tiny-1cdeb081.pth']
|
||||
selected_model_name, select_pretrained, pretrained_pth, selected_model_info = select_pretrained_model(model_list=model_list)
|
||||
|
||||
if selected_model_info['type'] == 'ResNet':
|
||||
depth = selected_model_info['depth']
|
||||
backbone = create_dict_by_kwargs(depth = depth, type=selected_model_info['type'], norm_cfg=norm_cfg, num_stages=4, frozen_stages=-1, style='pytorch', init_cfg=dict(type='Pretrained', checkpoint=pretrained_pth))
|
||||
in_channels = [256, 512, 1024, 2048]
|
||||
elif selected_model_info['type'] == 'swin':
|
||||
if selected_model_name == 'pretrain/swin_tiny-1cdeb081.pth':
|
||||
in_channels = [96, 192, 384, 768]
|
||||
depths=[2, 2, 6, 2]
|
||||
backbone = dict(
|
||||
_delete_=True,
|
||||
type='SwinTransformer',
|
||||
pretrain_img_size=224,
|
||||
embed_dims=96,
|
||||
depths=depths,
|
||||
num_heads=[3, 6, 12, 24],
|
||||
window_size=7,
|
||||
mlp_ratio=4,
|
||||
qkv_bias=True,
|
||||
qk_scale=None,
|
||||
drop_rate=0.,
|
||||
attn_drop_rate=0.,
|
||||
drop_path_rate=0.3,
|
||||
patch_norm=True,
|
||||
act_cfg=dict(type='GELU'), # ADD
|
||||
use_abs_pos_embed=False, # ADD
|
||||
# out_indices=(0, 1, 2, 3),
|
||||
with_cp=False,
|
||||
init_cfg=dict(type='Pretrained', checkpoint=pretrained_pth))
|
||||
if selected_model_name == 'pretrain/swin_small-7ba6d6dd.pth':
|
||||
depths=[2, 2, 18, 2]
|
||||
in_channels = [96, 192, 384, 768]
|
||||
backbone = dict(
|
||||
_delete_=True,
|
||||
type='SwinTransformer',
|
||||
pretrain_img_size=224,
|
||||
embed_dims=96,
|
||||
depths=depths,
|
||||
num_heads=[3, 6, 12, 24],
|
||||
window_size=7,
|
||||
mlp_ratio=4,
|
||||
qkv_bias=True,
|
||||
qk_scale=None,
|
||||
drop_rate=0.,
|
||||
attn_drop_rate=0.,
|
||||
drop_path_rate=0.3,
|
||||
patch_norm=True,
|
||||
act_cfg=dict(type='GELU'), # ADD
|
||||
use_abs_pos_embed=False, # ADD
|
||||
# out_indices=(0, 1, 2, 3),
|
||||
with_cp=False,
|
||||
init_cfg=dict(type='Pretrained', checkpoint=pretrained_pth))
|
||||
if selected_model_name == 'pretrain/swin_base-e5c09f74.pth':
|
||||
depths=[2, 2, 18, 2]
|
||||
in_channels = [128, 256, 512, 1024]
|
||||
backbone = dict(
|
||||
type='SwinTransformer',
|
||||
pretrain_img_size=384,
|
||||
embed_dims=128,
|
||||
depths=depths,
|
||||
num_heads=[4, 8, 16, 32],
|
||||
window_size=12,
|
||||
mlp_ratio=4,
|
||||
qkv_bias=True,
|
||||
qk_scale=None,
|
||||
drop_rate=0.,
|
||||
attn_drop_rate=0.,
|
||||
drop_path_rate=0.3,
|
||||
patch_norm=True,
|
||||
act_cfg=dict(type='GELU'), # ADD
|
||||
use_abs_pos_embed=False, # ADD
|
||||
# out_indices=(0, 1, 2, 3),
|
||||
with_cp=False,
|
||||
init_cfg=dict(type='Pretrained', checkpoint=pretrained_pth))
|
||||
if selected_model_name == 'pretrain/swin_large-6580f57d.pth':
|
||||
in_channels = [192, 384, 768, 1536]
|
||||
depths=[2, 2, 18, 2]
|
||||
backbone = dict(
|
||||
type='SwinTransformer',
|
||||
pretrain_img_size=384,
|
||||
embed_dims=192,
|
||||
depths=depths,
|
||||
num_heads=[6, 12, 24, 48],
|
||||
window_size=12,
|
||||
mlp_ratio=4,
|
||||
qkv_bias=True,
|
||||
qk_scale=None,
|
||||
drop_rate=0.,
|
||||
attn_drop_rate=0.,
|
||||
drop_path_rate=0.3,
|
||||
patch_norm=True,
|
||||
act_cfg=dict(type='GELU'), # ADD
|
||||
use_abs_pos_embed=False, # ADD
|
||||
# out_indices=(0, 1, 2, 3),
|
||||
with_cp=False,
|
||||
init_cfg=dict(type='Pretrained', checkpoint=pretrained_pth))
|
||||
|
||||
# 3.2. decode_head
|
||||
decode_head = dict(num_classes=num_classes, in_channels=in_channels,
|
||||
loss_cls=dict(class_weight=[1.0] * num_classes + [0.1]))
|
||||
|
||||
# 3.3. optimizer部分
|
||||
if selected_model_info['type'] == 'ResNet':
|
||||
# optimizer
|
||||
embed_multi = dict(lr_mult=1.0, decay_mult=0.0)
|
||||
optimizer = dict(
|
||||
type='AdamW', lr=0.0001, weight_decay=0.05, eps=1e-8, betas=(0.9, 0.999))
|
||||
optim_wrapper = dict(
|
||||
_delete_=True,
|
||||
type='OptimWrapper',
|
||||
optimizer=optimizer,
|
||||
clip_grad=dict(max_norm=0.01, norm_type=2),
|
||||
paramwise_cfg=dict(
|
||||
custom_keys={
|
||||
'backbone': dict(lr_mult=0.1, decay_mult=1.0),
|
||||
'query_embed': embed_multi,
|
||||
'query_feat': embed_multi,
|
||||
'level_embed': embed_multi,
|
||||
},
|
||||
norm_decay_mult=0.0))
|
||||
train_dataloader, batch_size, num_workers = generate_train_dataloader(batch_size_default=4, num_workers_default=4)
|
||||
elif selected_model_info['type'] == 'swin':
|
||||
# set all layers in backbone to lr_mult=0.1
|
||||
# set all norm layers, position_embeding,
|
||||
# query_embeding, level_embeding to decay_multi=0.0
|
||||
backbone_norm_multi = dict(lr_mult=0.1, decay_mult=0.0)
|
||||
backbone_embed_multi = dict(lr_mult=0.1, decay_mult=0.0)
|
||||
embed_multi = dict(lr_mult=1.0, decay_mult=0.0)
|
||||
custom_keys = {
|
||||
'backbone': dict(lr_mult=0.1, decay_mult=1.0),
|
||||
'backbone.patch_embed.norm': backbone_norm_multi,
|
||||
'backbone.norm': backbone_norm_multi,
|
||||
'absolute_pos_embed': backbone_embed_multi,
|
||||
'relative_position_bias_table': backbone_embed_multi,
|
||||
'query_embed': embed_multi,
|
||||
'query_feat': embed_multi,
|
||||
'level_embed': embed_multi
|
||||
}
|
||||
custom_keys.update({
|
||||
f'backbone.stages.{stage_id}.blocks.{block_id}.norm': backbone_norm_multi
|
||||
for stage_id, num_blocks in enumerate(depths)
|
||||
for block_id in range(num_blocks)
|
||||
})
|
||||
custom_keys.update({
|
||||
f'backbone.stages.{stage_id}.downsample.norm': backbone_norm_multi
|
||||
for stage_id in range(len(depths) - 1)
|
||||
})
|
||||
# optimizer
|
||||
optimizer = dict(
|
||||
type='AdamW', lr=0.0001, weight_decay=0.05, eps=1e-8, betas=(0.9, 0.999))
|
||||
optim_wrapper = dict(
|
||||
_delete_=True,
|
||||
type='OptimWrapper',
|
||||
optimizer=optimizer,
|
||||
clip_grad=dict(max_norm=0.01, norm_type=2),
|
||||
paramwise_cfg={'custom_keys':custom_keys, 'norm_decay_mult':0.0})
|
||||
|
||||
# 更新train_dataloader
|
||||
train_dataloader, batch_size, num_workers = generate_train_dataloader(batch_size_default=2, num_workers_default=2)
|
||||
|
||||
# 3.5. param_scheduler部分
|
||||
param_scheduler = generate_param_scheduler(train_time_or_epoch_k)
|
||||
|
||||
########### 3.生成各个部分 ###########
|
||||
# 综合model
|
||||
model = dict(
|
||||
backbone = backbone,
|
||||
data_preprocessor = model_data_preprocessor,
|
||||
decode_head = decode_head
|
||||
)
|
||||
|
||||
########### 4.程序写入 ###########
|
||||
# 算法名称解析:maskformer【Alg】 _ r50【pretrained模型深度】 - d8【_base_/models】 _ 4【GPU数量】 x b2【Batch Size大小】 - 40k【schedule】 _ cityscapes【数据集】 - 512x1024【crop_size】.py
|
||||
if selected_model_info['type'] == 'ResNet':
|
||||
alg_file_name = f"{alg_name}_r{depth}_{'Pre' if select_pretrained else 'NoPre'}_g{GPU_num}-{schedule_file_name.lstrip('schedule_')}_{dataset_file_name}-{crop_size[0]}x{crop_size[1]}.py"
|
||||
elif selected_model_info['type'] == 'swin':
|
||||
alg_file_name = f"{alg_name}_swin_{selected_model_info['size']}_{'Pre' if select_pretrained else 'NoPre'}_g{GPU_num}-{schedule_file_name.lstrip('schedule_')}_{dataset_file_name}-{crop_size[0]}x{crop_size[1]}.py"
|
||||
output_configs_alg_my_alg = os.path.join(f'./configs/maskformer/', alg_file_name)
|
||||
write_config_to_file(output_configs_alg_my_alg, _base_=_base_, norm_cfg=norm_cfg, crop_size=crop_size, data_preprocessor=data_preprocessor, model=model, train_dataloader=train_dataloader, optim_wrapper=optim_wrapper, param_scheduler=param_scheduler)
|
||||
|
||||
# 将信息写入文件
|
||||
with open("_temp_.txt", "w") as file:
|
||||
json.dump({'alg_infos':{'alg_file_name':alg_file_name,'alg_file_pth':output_configs_alg_my_alg}}, file)
|
||||
152
Seg_All_In_One_MMSeg/My_All_In_One/2_Alg_Program/my_pidnet.py
Normal file
152
Seg_All_In_One_MMSeg/My_All_In_One/2_Alg_Program/my_pidnet.py
Normal file
@@ -0,0 +1,152 @@
|
||||
import os, sys, argparse, json
|
||||
import importlib.util
|
||||
# 添加上级目录到 sys.path
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../')))
|
||||
|
||||
from Initial_Alg_Program.Initial_Alg_Select_Tool import select_crop_size, select_pretrained_model, select_sampler
|
||||
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_Tool import format_all_data, write_config_to_file, create_dict_by_kwargs, get_var_from_file, update_list_dict_var, get_var_from_py_file
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_0_base_ import generate_base_config
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_1_norm_cfg import generate_norm_cfg
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_2_data_preprocessor import generate_data_preprocessor
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_3_model import generate_model_pretrained, generate_model_backbone, generate_model_data_preprocessor, generate_model_decode_head, generate_model_auxiliary_head, generate_model_train_cfg, generate_model_test_cfg
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_4_optimizer import generate_optim_wrapper, generate_param_scheduler
|
||||
|
||||
def get_train_pipeline_from_config(dataset_file_name: str):
|
||||
"""
|
||||
动态加载指定的配置文件并读取其中的 'train_pipeline' 列表。
|
||||
|
||||
Args:
|
||||
dataset_file_name (str): 数据集配置文件的名字 (不包含.py后缀)。
|
||||
|
||||
Returns:
|
||||
list | None: 如果成功找到,则返回 train_pipeline 列表;
|
||||
如果文件不存在或文件中没有 train_pipeline 变量,则返回 None。
|
||||
"""
|
||||
# 1. 构建配置文件的相对路径
|
||||
# 使用 os.path.join 来确保路径在不同操作系统上都是正确的
|
||||
config_path = os.path.join('configs/_base_/datasets/', f'{dataset_file_name}.py')
|
||||
|
||||
# 2. 检查文件是否存在
|
||||
if not os.path.exists(config_path):
|
||||
print(f"\033[30错误:配置文件不存在于 '{config_path}'\033[0m")
|
||||
return None
|
||||
|
||||
try:
|
||||
# 3. 动态加载 .py 文件作为一个模块
|
||||
# 创建一个模块规范 (spec)
|
||||
# 模块名可以是任意的,这里用文件名以防冲突
|
||||
spec = importlib.util.spec_from_file_location(dataset_file_name, config_path)
|
||||
|
||||
# 根据规范创建一个模块对象
|
||||
config_module = importlib.util.module_from_spec(spec)
|
||||
|
||||
# 执行模块代码,使其所有变量(如 train_pipeline)都加载到模块对象中
|
||||
spec.loader.exec_module(config_module)
|
||||
|
||||
# 4. 从加载的模块中获取 train_pipeline 变量
|
||||
# 使用 getattr 来安全地获取,如果不存在,可以设置一个默认值
|
||||
train_pipeline = getattr(config_module, 'train_pipeline', None)
|
||||
|
||||
if train_pipeline is None:
|
||||
print(f"错误:在 '{config_path}' 文件中未找到 'train_pipeline' 变量。")
|
||||
return None
|
||||
|
||||
return train_pipeline
|
||||
|
||||
except Exception as e:
|
||||
print(f"读取配置文件时发生未知错误: {e}")
|
||||
return None
|
||||
|
||||
if __name__ == '__main__':
|
||||
########### 1.超参数-算法默认参数 ###########
|
||||
alg_file_name = 'pidnet'
|
||||
alg_file_pth = f"configs/_base_/models/{alg_file_name}.py"
|
||||
|
||||
########### 2.参数解析 ###########
|
||||
parser = argparse.ArgumentParser(description="Run algorithm with specified parameters.")
|
||||
|
||||
# 添加命令行参数
|
||||
parser.add_argument('--alg_name', type=str, required=True, help="Algorithm name (e.g., 'ann')")
|
||||
parser.add_argument('--dataset_file_name', type=str, required=True, help="Dataset file name (e.g., 'my_dataset_model')")
|
||||
parser.add_argument('--mean', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
|
||||
parser.add_argument('--std', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
|
||||
parser.add_argument('--dataset_num_classes', type=int, required=True, help="Dataset class num (e.g., 36)")
|
||||
parser.add_argument('--GPU_num', type=int, required=True, help="Num of GPU to train (e.g., 1)")
|
||||
parser.add_argument('--schedule_file_name', type=str, required=True, help="Schedule file name (e.g., 'schedule_4k_check_400')")
|
||||
parser.add_argument('--train_type', type=str, required=True, help="Schedule train type (e.g., Epoch or Iteration)")
|
||||
parser.add_argument('--train_time_or_epoch_k', type=int, required=True, help="Schedule train time k (e.g., 4)")
|
||||
|
||||
# 解析命令行参数
|
||||
args = parser.parse_args()
|
||||
|
||||
# 变量赋值
|
||||
alg_name = args.alg_name
|
||||
dataset_file_name = args.dataset_file_name
|
||||
mean = json.loads(args.mean) # 使用json.loads转换成list
|
||||
std = json.loads(args.std) # 使用json.loads转换成list
|
||||
num_classes = args.dataset_num_classes
|
||||
schedule_file_name = args.schedule_file_name
|
||||
train_type = args.train_type
|
||||
train_time_or_epoch_k = args.train_time_or_epoch_k
|
||||
GPU_num = args.GPU_num
|
||||
|
||||
########### 3.1. 生成_base_部分 ###########
|
||||
_base_ = generate_base_config(alg_file_name=alg_file_name, dataset_file_name=dataset_file_name, schedule_file_name=schedule_file_name)
|
||||
|
||||
########### 3.1. 生成norm_cfg部分 ###########
|
||||
norm_cfg = generate_norm_cfg(GPU_num=GPU_num)
|
||||
|
||||
########### 3.1. 生成train_pipeline部分 ###########
|
||||
config_path = os.path.join('../_base_/datasets/', f'{dataset_file_name}.py')
|
||||
|
||||
train_pipeline = get_train_pipeline_from_config(dataset_file_name)
|
||||
train_pipeline.insert(-1, dict(type='GenerateEdge', edge_width=4)) # For pidnet
|
||||
train_dataloader = dict(dataset=dict(pipeline=train_pipeline))
|
||||
|
||||
########### 3.2. 生成data_preprocessor部分 ###########
|
||||
crop_size = select_crop_size(predefined_options=[(512, 512)]) # 选择切割大小
|
||||
data_preprocessor = generate_data_preprocessor(crop_size = crop_size, mean = mean, std = std)
|
||||
|
||||
########### 3.3. 生成model部分 ###########
|
||||
# 3.3.1. 预处理data_preprocessor
|
||||
model_data_preprocessor = data_preprocessor # 修改data_preprocessor
|
||||
|
||||
# # 3.3.2. 骨架backbone、解码器decode_head
|
||||
model_list = ['openmmlab/pidnet-s', 'openmmlab/pidnet-m', 'openmmlab/pidnet-l']
|
||||
selected_model_name, select_pretrained, pretrained_pth, selected_model_info = select_pretrained_model(model_list=model_list)
|
||||
if selected_model_name == 'openmmlab/pidnet-s':
|
||||
backbone = dict(channels=32, ppm_channels=96, num_stem_blocks=2, num_branch_blocks=3, init_cfg=dict(type='Pretrained', checkpoint=pretrained_pth))
|
||||
decode_head = dict(num_classes=num_classes, in_channels=128, channels=128)
|
||||
model_size = 'small'
|
||||
elif selected_model_name == 'openmmlab/pidnet-m':
|
||||
backbone = dict(channels=64, ppm_channels=96, num_stem_blocks=2, num_branch_blocks=3, init_cfg=dict(type='Pretrained', checkpoint=pretrained_pth))
|
||||
decode_head = dict(num_classes=num_classes, in_channels=256, channels=128)
|
||||
model_size = 'middle'
|
||||
elif selected_model_name == 'openmmlab/pidnet-l':
|
||||
backbone = dict(channels=64, ppm_channels=112, num_stem_blocks=3, num_branch_blocks=4, init_cfg=dict(type='Pretrained', checkpoint=pretrained_pth))
|
||||
decode_head = dict(num_classes=num_classes, in_channels=256, channels=256)
|
||||
model_size = 'large'
|
||||
else:
|
||||
quit("Error: 未知的模型名称")
|
||||
|
||||
# 3.3.4. 综合model
|
||||
model = dict(
|
||||
data_preprocessor = model_data_preprocessor,
|
||||
backbone = backbone,
|
||||
decode_head = decode_head,
|
||||
)
|
||||
|
||||
########### 3.4. 生成optim、param_scheduler部分 ###########
|
||||
optim_wrapper = generate_optim_wrapper()
|
||||
param_scheduler = generate_param_scheduler(train_type, train_time_or_epoch_k)
|
||||
|
||||
########### 4.程序写入 ###########
|
||||
# 4.1. 算法名称解析:bisenetv2【Alg】 _ r50【pretrained模型深度】 - d8【_base_/models】 _ 4【GPU数量】 x b2【Batch Size大小】 - 40k【schedule】 _ cityscapes【数据集】 - 512x1024【crop_size】.py
|
||||
alg_file_name = f"{alg_name}_{model_size}_g{GPU_num}-{schedule_file_name.lstrip('schedule_')}_{dataset_file_name}-{crop_size[0]}x{crop_size[1]}.py"
|
||||
output_configs_alg_my_alg = os.path.join(f'./configs/pidnet/', alg_file_name)
|
||||
write_config_to_file(output_configs_alg_my_alg, _base_=_base_, norm_cfg=norm_cfg, crop_size=crop_size, data_preprocessor=data_preprocessor, train_pipeline = train_pipeline, train_dataloader = train_dataloader, model=model, optim_wrapper=optim_wrapper, param_scheduler=param_scheduler)
|
||||
|
||||
# 4.2. 将信息临时写入文件
|
||||
with open("_temp_.txt", "w") as file:
|
||||
json.dump({'alg_infos':{'alg_file_name':alg_file_name,'alg_file_pth':output_configs_alg_my_alg}}, file)
|
||||
111
Seg_All_In_One_MMSeg/My_All_In_One/2_Alg_Program/my_pspnet.py
Normal file
111
Seg_All_In_One_MMSeg/My_All_In_One/2_Alg_Program/my_pspnet.py
Normal file
@@ -0,0 +1,111 @@
|
||||
import os, sys, argparse, json
|
||||
# 添加上级目录到 sys.path
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../')))
|
||||
|
||||
from Initial_Alg_Program.Initial_Alg_Select_Tool import select_crop_size, select_pretrained_model, select_test_cfg_slide, select_sampler
|
||||
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_Tool import format_all_data, write_config_to_file, create_dict_by_kwargs
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_0_base_ import generate_base_config
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_1_norm_cfg import generate_norm_cfg
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_2_data_preprocessor import generate_data_preprocessor
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_3_model import generate_model_pretrained, generate_model_backbone, generate_model_data_preprocessor, generate_model_decode_head, generate_model_auxiliary_head, generate_model_train_cfg, generate_model_test_cfg
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_4_optimizer import generate_optim_wrapper, generate_param_scheduler
|
||||
|
||||
if __name__ == '__main__':
|
||||
########### 1.超参数-算法默认参数 ###########
|
||||
alg_file_name = 'pspnet_r50-d8'
|
||||
|
||||
########### 2.参数解析 ###########
|
||||
parser = argparse.ArgumentParser(description="Run algorithm with specified parameters.")
|
||||
|
||||
# 添加命令行参数
|
||||
parser.add_argument('--alg_name', type=str, required=True, help="Algorithm name (e.g., 'ann')")
|
||||
parser.add_argument('--dataset_file_name', type=str, required=True, help="Dataset file name (e.g., 'my_dataset_model')")
|
||||
parser.add_argument('--mean', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
|
||||
parser.add_argument('--std', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
|
||||
parser.add_argument('--dataset_num_classes', type=int, required=True, help="Dataset class num (e.g., 36)")
|
||||
parser.add_argument('--GPU_num', type=int, required=True, help="Num of GPU to train (e.g., 1)")
|
||||
parser.add_argument('--schedule_file_name', type=str, required=True, help="Schedule file name (e.g., 'schedule_4k_check_400')")
|
||||
parser.add_argument('--train_type', type=str, required=True, help="Schedule train type (e.g., Epoch or Iteration)")
|
||||
parser.add_argument('--train_time_or_epoch_k', type=int, required=True, help="Schedule train time k (e.g., 4)")
|
||||
|
||||
# 解析命令行参数
|
||||
args = parser.parse_args()
|
||||
|
||||
# 变量赋值
|
||||
alg_name = args.alg_name
|
||||
dataset_file_name = args.dataset_file_name
|
||||
mean = json.loads(args.mean) # 使用json.loads转换成list
|
||||
std = json.loads(args.std) # 使用json.loads转换成list
|
||||
num_classes = args.dataset_num_classes
|
||||
schedule_file_name = args.schedule_file_name
|
||||
train_type = args.train_type
|
||||
train_time_or_epoch_k = args.train_time_or_epoch_k
|
||||
GPU_num = args.GPU_num
|
||||
|
||||
|
||||
########### 3. 设定参数 ###########
|
||||
# 3.1. 自选参数
|
||||
crop_size = select_crop_size(predefined_options=[(512, 512), (512, 1024), (769, 769),(1280,1280)]) # 选择切割大小
|
||||
model_list = ['openmmlab/resnet18_v1c', 'openmmlab/resnet50_v1c', 'openmmlab/resnet101_v1c', 'torchvision://resnet18', 'torchvision://resnet50', 'torchvision://resnet101', ]
|
||||
selected_model_name, select_pretrained, pretrained_pth, selected_model_info = select_pretrained_model(model_list=model_list)
|
||||
depth = selected_model_info['depth']
|
||||
type_model = selected_model_info['type']
|
||||
|
||||
test_cfg_mode, test_cfg_crop_div_stride = select_test_cfg_slide(crop_size)
|
||||
|
||||
# 3.2. 自动生成参数
|
||||
# decode、auxiliary损失下载方式
|
||||
decode_head_loss_decode_dict=dict(type='DiceLoss', use_sigmoid=False, loss_weight=1.0) # decode损失函数 # TODO
|
||||
auxiliary_head_loss_decode_dict=dict(type='DiceLoss', use_sigmoid=False, loss_weight=0.4) # DiceLoss损失函数 # TODO
|
||||
|
||||
# decode_head_loss_decode_dict=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0) # decode损失函数 # TODO
|
||||
# auxiliary_head_loss_decode_dict=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4) # DiceLoss损失函数 # TODO
|
||||
|
||||
|
||||
########### 3.生成各个部分 ###########
|
||||
_base_ = generate_base_config(alg_file_name=alg_file_name, dataset_file_name=dataset_file_name, schedule_file_name=schedule_file_name)
|
||||
|
||||
norm_cfg = generate_norm_cfg(GPU_num=GPU_num)
|
||||
|
||||
data_preprocessor = generate_data_preprocessor(crop_size = crop_size, mean = mean, std = std)
|
||||
|
||||
pretrained = generate_model_pretrained(pretrained_pth = pretrained_pth)
|
||||
|
||||
backbone = create_dict_by_kwargs(depth=depth, type=type_model) # generate_model_backbone(depth=depth,)
|
||||
|
||||
model_data_preprocessor = data_preprocessor # 修改data_preprocessor
|
||||
model_data_preprocessor = generate_model_data_preprocessor(model_data_preprocessor=model_data_preprocessor)
|
||||
|
||||
if depth == 18:
|
||||
decode_head = create_dict_by_kwargs(in_channels=512, channels=128, num_classes=num_classes, loss_decode=decode_head_loss_decode_dict, norm_cfg=norm_cfg,)
|
||||
elif depth == 101 or depth == 50:
|
||||
decode_head = create_dict_by_kwargs(in_channels=2048, channels=512, num_classes=num_classes, loss_decode=decode_head_loss_decode_dict, norm_cfg=norm_cfg,)
|
||||
|
||||
if depth == 18:
|
||||
auxiliary_head = create_dict_by_kwargs(in_channels=256, channels=64, num_classes=num_classes, loss_decode=decode_head_loss_decode_dict, norm_cfg=norm_cfg,)
|
||||
elif depth == 101 or depth == 50:
|
||||
auxiliary_head = create_dict_by_kwargs(in_channels=1024, channels=256, num_classes=num_classes, loss_decode=decode_head_loss_decode_dict, norm_cfg=norm_cfg,)
|
||||
|
||||
# 综合model
|
||||
model = dict(
|
||||
pretrained = pretrained,
|
||||
backbone = backbone,
|
||||
data_preprocessor = model_data_preprocessor,
|
||||
decode_head = decode_head,
|
||||
auxiliary_head = auxiliary_head,
|
||||
)
|
||||
|
||||
test_cfg = generate_model_test_cfg(test_cfg_mode=test_cfg_mode, crop_size=crop_size, test_cfg_crop_div_stride=test_cfg_crop_div_stride)
|
||||
optim_wrapper = generate_optim_wrapper()
|
||||
param_scheduler = generate_param_scheduler(train_type, train_time_or_epoch_k)
|
||||
|
||||
########### 4.程序写入 ###########
|
||||
# 算法名称解析:pspnet【Alg】 _ r50【pretrained模型深度】 - d8【_base_/models】 _ 4【GPU数量】 x b2【Batch Size大小】 - 40k【schedule】 _ cityscapes【数据集】 - 512x1024【crop_size】.py
|
||||
alg_file_name = f"{alg_name}_r{depth}_{'Pre' if select_pretrained else 'NoPre'}_g{GPU_num}-{schedule_file_name.lstrip('schedule_')}_{dataset_file_name}-{crop_size[0]}x{crop_size[1]}.py"
|
||||
output_configs_alg_my_alg = os.path.join(f'./configs/pspnet/', alg_file_name)
|
||||
write_config_to_file(output_configs_alg_my_alg, _base_=_base_, norm_cfg=norm_cfg, crop_size=crop_size, data_preprocessor=data_preprocessor, model=model, test_cfg=test_cfg, optim_wrapper=optim_wrapper, param_scheduler=param_scheduler)
|
||||
|
||||
# 将信息写入文件
|
||||
with open("_temp_.txt", "w") as file:
|
||||
json.dump({'alg_infos':{'alg_file_name':alg_file_name,'alg_file_pth':output_configs_alg_my_alg}}, file)
|
||||
139
Seg_All_In_One_MMSeg/My_All_In_One/2_Alg_Program/my_stdc.py
Normal file
139
Seg_All_In_One_MMSeg/My_All_In_One/2_Alg_Program/my_stdc.py
Normal file
@@ -0,0 +1,139 @@
|
||||
import os, sys, argparse, json
|
||||
import importlib.util
|
||||
# 添加上级目录到 sys.path
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../')))
|
||||
|
||||
from Initial_Alg_Program.Initial_Alg_Select_Tool import select_crop_size, select_pretrained_model, select_sampler
|
||||
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_Tool import format_all_data, write_config_to_file, create_dict_by_kwargs, get_var_from_file, update_list_dict_var, get_var_from_py_file
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_0_base_ import generate_base_config
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_1_norm_cfg import generate_norm_cfg
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_2_data_preprocessor import generate_data_preprocessor
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_3_model import generate_model_pretrained, generate_model_backbone, generate_model_data_preprocessor, generate_model_decode_head, generate_model_auxiliary_head, generate_model_train_cfg, generate_model_test_cfg
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_4_optimizer import generate_optim_wrapper, generate_param_scheduler
|
||||
|
||||
def get_pretrained_model_choice():
|
||||
"""
|
||||
提示用户选择是否使用预训练模型,并返回他们的选择。
|
||||
|
||||
返回:
|
||||
str: 如果用户选择 '是' 或默认,则返回 '是'。
|
||||
如果用户选择 '否',则返回 '否'。
|
||||
"""
|
||||
while True:
|
||||
# 如果用户直接按回车,input()返回空字符串,or "1" 使其默认值为 "1"
|
||||
choice = input("可用的预训练模型选项:\n1. 是\n2. 否\n请选择是否使用预训练模型(默认1): ") or "1"
|
||||
|
||||
if choice == "1":
|
||||
print("您已选择:1. 是")
|
||||
return True
|
||||
elif choice == "2":
|
||||
print("您已选择:2. 否")
|
||||
return False
|
||||
else:
|
||||
# 如果输入了其他无效内容(如"3"),则提示重新输入
|
||||
print("\n无效输入,请输入 1 或 2。\n")
|
||||
|
||||
if __name__ == '__main__':
|
||||
########### 1.超参数-算法默认参数 ###########
|
||||
alg_file_name = 'stdc'
|
||||
alg_file_pth = f"configs/_base_/models/{alg_file_name}.py"
|
||||
|
||||
########### 2.参数解析 ###########
|
||||
parser = argparse.ArgumentParser(description="Run algorithm with specified parameters.")
|
||||
|
||||
# 添加命令行参数
|
||||
parser.add_argument('--alg_name', type=str, required=True, help="Algorithm name (e.g., 'ann')")
|
||||
parser.add_argument('--dataset_file_name', type=str, required=True, help="Dataset file name (e.g., 'my_dataset_model')")
|
||||
parser.add_argument('--mean', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
|
||||
parser.add_argument('--std', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
|
||||
parser.add_argument('--dataset_num_classes', type=int, required=True, help="Dataset class num (e.g., 36)")
|
||||
parser.add_argument('--GPU_num', type=int, required=True, help="Num of GPU to train (e.g., 1)")
|
||||
parser.add_argument('--schedule_file_name', type=str, required=True, help="Schedule file name (e.g., 'schedule_4k_check_400')")
|
||||
parser.add_argument('--train_type', type=str, required=True, help="Schedule train type (e.g., Epoch or Iteration)")
|
||||
parser.add_argument('--train_time_or_epoch_k', type=int, required=True, help="Schedule train time k (e.g., 4)")
|
||||
|
||||
# 解析命令行参数
|
||||
args = parser.parse_args()
|
||||
|
||||
# 变量赋值
|
||||
alg_name = args.alg_name
|
||||
dataset_file_name = args.dataset_file_name
|
||||
mean = json.loads(args.mean) # 使用json.loads转换成list
|
||||
std = json.loads(args.std) # 使用json.loads转换成list
|
||||
num_classes = args.dataset_num_classes
|
||||
schedule_file_name = args.schedule_file_name
|
||||
train_type = args.train_type
|
||||
train_time_or_epoch_k = args.train_time_or_epoch_k
|
||||
GPU_num = args.GPU_num
|
||||
|
||||
########### 3.1. 生成_base_部分 ###########
|
||||
_base_ = generate_base_config(alg_file_name=alg_file_name, dataset_file_name=dataset_file_name, schedule_file_name=schedule_file_name)
|
||||
|
||||
########### 3.1. 生成norm_cfg部分 ###########
|
||||
norm_cfg = generate_norm_cfg(GPU_num=GPU_num)
|
||||
|
||||
########### 3.2. 生成data_preprocessor部分 ###########
|
||||
crop_size = select_crop_size(predefined_options=[(512, 512)]) # 选择切割大小
|
||||
data_preprocessor = generate_data_preprocessor(crop_size = crop_size, mean = mean, std = std)
|
||||
|
||||
########### 3.3. 生成model部分 ###########
|
||||
# 3.3.1. 预处理data_preprocessor
|
||||
model_data_preprocessor = data_preprocessor # 修改data_preprocessor
|
||||
|
||||
# # 3.3.2. 骨架backbone、解码器decode_head
|
||||
decode_head_loss_decode=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0) # Way Ori
|
||||
decode_head_loss_decode=dict(type='DiceLoss', use_sigmoid=False, loss_weight=1.0) # Way 1
|
||||
|
||||
# auxiliary_head_loss_decode_dict=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0) # Way Ori: DiceLoss损失函数 # TODO
|
||||
auxiliary_head_loss_decode_dict=dict(type='DiceLoss', use_sigmoid=False, loss_weight=1.0) # Way 1: DiceLoss损失函数 # TODO
|
||||
|
||||
auxiliary_head = get_var_from_py_file(alg_file_pth, 'model')['auxiliary_head']
|
||||
for i in range(len(auxiliary_head)):
|
||||
if auxiliary_head[i]['type'] == 'FCNHead':
|
||||
auxiliary_head[i].update(num_classes=num_classes, loss_decode=auxiliary_head_loss_decode_dict)
|
||||
|
||||
model_list = ["openmmlab/stdc1", "openmmlab/stdc2"]
|
||||
selected_model_name, select_pretrained, pretrained_pth, selected_model_info = select_pretrained_model(model_list=model_list)
|
||||
use_pretrained = get_pretrained_model_choice()
|
||||
if selected_model_name == "openmmlab/stdc1":
|
||||
if use_pretrained:
|
||||
backbone = dict(backbone_cfg=dict(stdc_type='STDCNet1'), init_cfg=dict(type='Pretrained', checkpoint=pretrained_pth))
|
||||
Pre = "Pre"
|
||||
else:
|
||||
backbone = dict(backbone_cfg=dict(stdc_type='STDCNet1'))
|
||||
Pre = "NoPre"
|
||||
decode_head = dict(num_classes=num_classes, loss_decode=decode_head_loss_decode)
|
||||
model_size = 'V1_'+Pre
|
||||
elif selected_model_name == "openmmlab/stdc2":
|
||||
if use_pretrained:
|
||||
backbone = dict(backbone_cfg=dict(stdc_type='STDCNet2'), init_cfg=dict(type='Pretrained', checkpoint=pretrained_pth))
|
||||
Pre = "Pre"
|
||||
else:
|
||||
backbone = dict(backbone_cfg=dict(stdc_type='STDCNet2'))
|
||||
Pre = "NoPre"
|
||||
decode_head = dict(num_classes=num_classes, loss_decode=decode_head_loss_decode)
|
||||
model_size = 'V2_'+Pre
|
||||
else:
|
||||
quit("Error: 未知的模型名称")
|
||||
|
||||
# 3.3.4. 综合model
|
||||
model = dict(
|
||||
data_preprocessor = model_data_preprocessor,
|
||||
backbone = backbone,
|
||||
decode_head = decode_head,
|
||||
)
|
||||
|
||||
########### 3.4. 生成optim、param_scheduler部分 ###########
|
||||
optim_wrapper = generate_optim_wrapper()
|
||||
param_scheduler = generate_param_scheduler(train_type, train_time_or_epoch_k)
|
||||
|
||||
########### 4.程序写入 ###########
|
||||
# 4.1. 算法名称解析:bisenetv2【Alg】 _ r50【pretrained模型深度】 - d8【_base_/models】 _ 4【GPU数量】 x b2【Batch Size大小】 - 40k【schedule】 _ cityscapes【数据集】 - 512x1024【crop_size】.py
|
||||
alg_file_name = f"{alg_name}_{model_size}_g{GPU_num}-{schedule_file_name.lstrip('schedule_')}_{dataset_file_name}-{crop_size[0]}x{crop_size[1]}.py"
|
||||
output_configs_alg_my_alg = os.path.join(f'./configs/stdc/', alg_file_name)
|
||||
write_config_to_file(output_configs_alg_my_alg, _base_=_base_, norm_cfg=norm_cfg, crop_size=crop_size, data_preprocessor=data_preprocessor, model=model, optim_wrapper=optim_wrapper, param_scheduler=param_scheduler)
|
||||
|
||||
# 4.2. 将信息临时写入文件
|
||||
with open("_temp_.txt", "w") as file:
|
||||
json.dump({'alg_infos':{'alg_file_name':alg_file_name,'alg_file_pth':output_configs_alg_my_alg}}, file)
|
||||
@@ -0,0 +1,105 @@
|
||||
import os, sys, argparse, json
|
||||
# 添加上级目录到 sys.path
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../')))
|
||||
|
||||
from Initial_Alg_Program.Initial_Alg_Select_Tool import select_crop_size, select_pretrained_model, select_test_cfg_slide, select_sampler
|
||||
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_Tool import format_all_data, write_config_to_file, create_dict_by_kwargs
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_0_base_ import generate_base_config
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_1_norm_cfg import generate_norm_cfg
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_2_data_preprocessor import generate_data_preprocessor
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_3_model import generate_model_pretrained, generate_model_backbone, generate_model_data_preprocessor, generate_model_decode_head, generate_model_auxiliary_head, generate_model_train_cfg, generate_model_test_cfg
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_4_optimizer import generate_optim_wrapper, generate_param_scheduler
|
||||
|
||||
if __name__ == '__main__':
|
||||
########### 1.超参数-算法默认参数 ###########
|
||||
alg_file_name = 'deeplabv3_unet_s5-d16'
|
||||
|
||||
########### 2.参数解析 ###########
|
||||
parser = argparse.ArgumentParser(description="Run algorithm with specified parameters.")
|
||||
|
||||
# 添加命令行参数
|
||||
parser.add_argument('--alg_name', type=str, required=True, help="Algorithm name (e.g., 'ann')")
|
||||
parser.add_argument('--dataset_file_name', type=str, required=True, help="Dataset file name (e.g., 'my_dataset_model')")
|
||||
parser.add_argument('--mean', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
|
||||
parser.add_argument('--std', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
|
||||
parser.add_argument('--dataset_num_classes', type=int, required=True, help="Dataset class num (e.g., 36)")
|
||||
parser.add_argument('--GPU_num', type=int, required=True, help="Num of GPU to train (e.g., 1)")
|
||||
parser.add_argument('--schedule_file_name', type=str, required=True, help="Schedule file name (e.g., 'schedule_4k_check_400')")
|
||||
parser.add_argument('--train_type', type=str, required=True, help="Schedule train type (e.g., Epoch or Iteration)")
|
||||
parser.add_argument('--train_time_or_epoch_k', type=int, required=True, help="Schedule train time k (e.g., 4)")
|
||||
|
||||
# 解析命令行参数
|
||||
args = parser.parse_args()
|
||||
|
||||
# 变量赋值
|
||||
alg_name = args.alg_name
|
||||
dataset_file_name = args.dataset_file_name
|
||||
mean = json.loads(args.mean) # 使用json.loads转换成list
|
||||
std = json.loads(args.std) # 使用json.loads转换成list
|
||||
num_classes = args.dataset_num_classes
|
||||
schedule_file_name = args.schedule_file_name
|
||||
train_type = args.train_type
|
||||
train_time_or_epoch_k = args.train_time_or_epoch_k
|
||||
GPU_num = args.GPU_num
|
||||
|
||||
|
||||
########### 3. 设定参数 ###########
|
||||
# 3.1. 自选参数
|
||||
crop_size = select_crop_size(predefined_options=[(128, 128), (256, 256)]) # 选择切割大小
|
||||
test_cfg_mode, test_cfg_crop_div_stride = select_test_cfg_slide(crop_size, select_slide=True) # 默认选择滑动模式
|
||||
|
||||
# 3.2. 自动生成参数
|
||||
# decode、auxiliary损失下载方式
|
||||
decode_head_loss_decode_dict=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0) # Way Ori: decode损失函数 # TODO
|
||||
decode_head_loss_decode_dict=dict(type='DiceLoss', use_sigmoid=False, loss_weight=1.0) # Way 1: decode损失函数 # TODO
|
||||
decode_head_loss_decode_dict = [
|
||||
dict(type='CrossEntropyLoss', loss_name='loss_ce', loss_weight=1.0),
|
||||
dict(type='DiceLoss', loss_name='loss_dice', loss_weight=3.0)]
|
||||
auxiliary_head_loss_decode_dict=dict(type='DiceLoss', use_sigmoid=False, loss_weight=0.4) # Way 1: DiceLoss损失函数 # TODO
|
||||
|
||||
# 开启align_corners
|
||||
if test_cfg_mode == "slide":
|
||||
align_corners = True
|
||||
test_slide = 'testslide' # 算法名进行标记
|
||||
|
||||
|
||||
########### 3.生成各个部分 ###########
|
||||
_base_ = generate_base_config(alg_file_name=alg_file_name, dataset_file_name=dataset_file_name, schedule_file_name=schedule_file_name)
|
||||
|
||||
norm_cfg = generate_norm_cfg(GPU_num=GPU_num)
|
||||
|
||||
data_preprocessor = generate_data_preprocessor(crop_size = crop_size, mean = mean, std = std)
|
||||
|
||||
model_data_preprocessor = data_preprocessor # 修改data_preprocessor
|
||||
model_data_preprocessor = generate_model_data_preprocessor(model_data_preprocessor=model_data_preprocessor)
|
||||
|
||||
# TODO TODO decode_head_loss_decode_dict
|
||||
decode_head = create_dict_by_kwargs(num_classes=num_classes, loss_decode=decode_head_loss_decode_dict, align_corners=align_corners, norm_cfg=norm_cfg,)
|
||||
|
||||
# auxiliary_head_loss_decode=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)
|
||||
auxiliary_head = create_dict_by_kwargs(num_classes=num_classes, norm_cfg=norm_cfg, loss_decode=auxiliary_head_loss_decode_dict)
|
||||
|
||||
test_cfg = generate_model_test_cfg(test_cfg_mode=test_cfg_mode, crop_size=crop_size, test_cfg_crop_div_stride=test_cfg_crop_div_stride)
|
||||
|
||||
# 综合model
|
||||
model = dict(
|
||||
data_preprocessor = model_data_preprocessor,
|
||||
decode_head = decode_head,
|
||||
test_cfg = test_cfg,
|
||||
auxiliary_head = auxiliary_head,
|
||||
)
|
||||
|
||||
|
||||
optim_wrapper = generate_optim_wrapper()
|
||||
param_scheduler = generate_param_scheduler(train_type, train_time_or_epoch_k)
|
||||
|
||||
########### 4.程序写入 ###########
|
||||
# 算法名称解析:unet【Alg】 _ r50【pretrained模型深度】 - d8【_base_/models】 _ 4【GPU数量】 x b2【Batch Size大小】 - 40k【schedule】 _ cityscapes【数据集】 - 512x1024【crop_size】.py
|
||||
alg_file_name = f"{alg_name}_g{GPU_num}-{schedule_file_name.lstrip('schedule_')}_{dataset_file_name}-{crop_size[0]}x{crop_size[1]}-{test_slide}.py"
|
||||
output_configs_alg_my_alg = os.path.join(f'./configs/unet/', alg_file_name)
|
||||
write_config_to_file(output_configs_alg_my_alg, _base_=_base_, norm_cfg=norm_cfg, crop_size=crop_size, data_preprocessor=data_preprocessor, model=model, optim_wrapper=optim_wrapper, param_scheduler=param_scheduler)
|
||||
|
||||
# 将信息写入文件
|
||||
with open("_temp_.txt", "w") as file:
|
||||
json.dump({'alg_infos':{'alg_file_name':alg_file_name,'alg_file_pth':output_configs_alg_my_alg}}, file)
|
||||
@@ -0,0 +1,70 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: UTF-8 -*-
|
||||
import os
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_Tool import format_all_data, write_config_to_file
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_0_base_ import generate_base_config
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_1_norm_cfg import generate_norm_cfg
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_2_data_preprocessor import generate_data_preprocessor
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_3_model import generate_model_pretrained, generate_model_backbone, generate_model_data_preprocessor, generate_model_decode_head, generate_model_auxiliary_head, generate_model_train_cfg, generate_model_test_cfg
|
||||
|
||||
if __name__ == '__main__':
|
||||
########### 1.1.定义算法基本参数 ###########
|
||||
# A. generate_base_config
|
||||
alg_name = 'ann' # ./configs中算法简称
|
||||
alg_file_name = 'ann_r50-d8' # 算法根文件
|
||||
dataset_file_name = 'my_dataset_model' # 数据文件
|
||||
schedule_file_name = 'schedule_4k_check_400' # schedule文件
|
||||
_base_ = generate_base_config(alg_file_name=alg_file_name, dataset_file_name=dataset_file_name, schedule_file_name=schedule_file_name)
|
||||
|
||||
# B. generate_norm_cfg
|
||||
GPU_num = 1 # GPU数量
|
||||
norm_cfg = generate_norm_cfg(GPU_num=GPU_num)
|
||||
|
||||
# C. generate_data_preprocessor
|
||||
crop_size = (512,512) # 分割大小
|
||||
data_preprocessor = generate_data_preprocessor(crop_size = crop_size)
|
||||
|
||||
# D. generate_model
|
||||
# D.1. pretrained
|
||||
pretrained_pth = './My_Local_Model/open_mmlab/resnet50_v1c.pth' # 预训练模型位置
|
||||
pretrained = generate_model_pretrained(pretrained_pth = pretrained_pth)
|
||||
|
||||
# D.2. backbone
|
||||
depth = 50 # 模型深度
|
||||
backbone = generate_model_backbone(depth=depth)
|
||||
|
||||
# D.3. data_preprocessor
|
||||
model_data_preprocessor = data_preprocessor # 修改data_preprocessor
|
||||
model_data_preprocessor = generate_model_data_preprocessor(model_data_preprocessor=model_data_preprocessor)
|
||||
|
||||
# D.4. decode_head、auxiliary_head
|
||||
num_classes=36 # 分类数
|
||||
decode_head_loss_decode_dict=dict(type='DiceLoss', use_sigmoid=False, loss_weight=1.0) # decode损失函数
|
||||
align_corners=False # 是否需要角对齐
|
||||
auxiliary_head_loss_decode_dict=dict(type='DiceLoss', use_sigmoid=False, loss_weight=0.4) # DiceLoss损失函数
|
||||
decode_head = generate_model_decode_head(num_classes=num_classes, decode_head_loss_decode_dict=decode_head_loss_decode_dict, align_corners=align_corners)
|
||||
auxiliary_head = generate_model_auxiliary_head(num_classes=num_classes, auxiliary_head_loss_decode_dict=auxiliary_head_loss_decode_dict, align_corners=align_corners)
|
||||
|
||||
# D.5. train_cfg
|
||||
# train_cfg = generate_model_train_cfg()
|
||||
|
||||
# D.6. test_cfg
|
||||
# test_cfg = generate_model_test_cfg()
|
||||
|
||||
# E. 综合model
|
||||
model = dict(
|
||||
pretrained = pretrained,
|
||||
backbone = backbone,
|
||||
data_preprocessor = model_data_preprocessor,
|
||||
decode_head = decode_head,
|
||||
auxiliary_head = auxiliary_head,
|
||||
# train_cfg = train_cfg,
|
||||
# test_cfg = test_cfg,
|
||||
)
|
||||
|
||||
########### 2.文件存储 ###########
|
||||
# output_configs_alg_my_alg = os.path.join(f'my_{alg_name}.py')
|
||||
output_configs_alg_my_alg = os.path.join(f'./configs/{alg_name}/', f'my_{alg_name}.py')
|
||||
|
||||
write_config_to_file(output_configs_alg_my_alg, _base_=_base_, crop_size=crop_size, data_preprocessor=data_preprocessor, model=model)
|
||||
|
||||
@@ -0,0 +1,301 @@
|
||||
import os, json, subprocess
|
||||
from Initial_Schedule_Program.Initial_Train_Gen_configs_base_schedules_schedule_XXk import generate_times_configs_base_schedules_schedule_file
|
||||
from Initial_Schedule_Program.Initial_Train_Gen_configs_base_schedules_schedule_XXe import generate_epochs_configs_base_schedules_schedule_file
|
||||
from datetime import datetime
|
||||
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_Tool import format_all_data, write_config_to_file, get_gpu_info
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_0_base_ import generate_base_config
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_1_norm_cfg import generate_norm_cfg
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_2_data_preprocessor import generate_data_preprocessor
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_3_model import generate_model_pretrained, generate_model_backbone, generate_model_data_preprocessor, generate_model_decode_head, generate_model_auxiliary_head, generate_model_train_cfg, generate_model_test_cfg
|
||||
|
||||
def load_json_file(file_name):
|
||||
"""
|
||||
读取并返回指定 JSON 文件的内容。
|
||||
|
||||
:param file_name: 要读取的 JSON 文件路径
|
||||
:return: JSON 文件的内容作为字典或列表
|
||||
"""
|
||||
try:
|
||||
# 打开并读取 JSON 文件
|
||||
with open(file_name, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
return data
|
||||
except FileNotFoundError:
|
||||
print(f"\033[91mError: File {file_name} not found.\033[0m")
|
||||
return None
|
||||
except json.JSONDecodeError:
|
||||
print(f"\033[91mError: Failed to decode JSON from {file_name}.\033[0m")
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f"\033[91mAn error occurred while reading {file_name}: {str(e)}\033[0m")
|
||||
return None
|
||||
|
||||
def process_data_record_json_data(json_data):
|
||||
"""
|
||||
将 JSON 数据直接存入一个大的字典 all_data_record。
|
||||
:param json_data: 从 JSON 文件加载的数据
|
||||
:return: all_data_record 字典
|
||||
"""
|
||||
all_data_record = {}
|
||||
|
||||
# 遍历 json_data 中的每个键(即每个数据集)
|
||||
for dataset_name, dataset_info in json_data.items():
|
||||
all_data_record[dataset_name] = {
|
||||
"classes": dataset_info['classes'],
|
||||
"palette": dataset_info['palette'],
|
||||
"palette_num": dataset_info['palette_num'],
|
||||
"mean": dataset_info['mean'],
|
||||
"std": dataset_info['std'],
|
||||
}
|
||||
|
||||
return all_data_record
|
||||
|
||||
# 选择要处理的数据集
|
||||
def select_dataset(all_data_record):
|
||||
"""
|
||||
让用户从 all_data_record 中选择数据集,并返回 dataset_file_name 和对应的 palette_num。
|
||||
|
||||
:param all_data_record: 包含所有数据集信息的字典
|
||||
:return: 选定的 dataset_file_name 和 palette_num
|
||||
"""
|
||||
# 获取所有数据集的名称
|
||||
dataset_names = list(all_data_record.keys())
|
||||
|
||||
# 显示可用的数据集,并让用户选择
|
||||
print("选择可用数据集:")
|
||||
for i, name in enumerate(dataset_names):
|
||||
print(f"{i + 1}. {name} - {all_data_record[name]['palette_num']}类")
|
||||
|
||||
# 用户输入选择的数据集编号
|
||||
while True:
|
||||
try:
|
||||
selection = int(input("请选择数据集编号(输入数字):")) - 1
|
||||
if 0 <= selection < len(dataset_names):
|
||||
dataset_file_name = dataset_names[selection]
|
||||
break
|
||||
else:
|
||||
print(f"输入的编号无效,请输入 1 到 {len(dataset_names)} 之间的数字。")
|
||||
except ValueError:
|
||||
print("无效输入,请输入数字。")
|
||||
|
||||
# 获取对应的 palette_num
|
||||
palette_num = all_data_record[dataset_file_name]['palette_num']
|
||||
mean = all_data_record[dataset_file_name]['mean']
|
||||
std = all_data_record[dataset_file_name]['std']
|
||||
|
||||
print(f" 已选择数据集: {dataset_file_name} ,其对应的分类数为: {palette_num}")
|
||||
|
||||
return dataset_file_name, palette_num, mean, std
|
||||
|
||||
# 选择对应的算法
|
||||
def select_alg(alg_directory):
|
||||
"""
|
||||
选择一个算法文件并运行它。
|
||||
|
||||
:return: 选定的算法文件名和相对路径
|
||||
"""
|
||||
# 获取 ./Alg 目录下的所有 Python 文件
|
||||
algorithms = [f for f in os.listdir(alg_directory) if f.endswith('.py')]
|
||||
|
||||
# 显示可用的算法文件
|
||||
print("选择可用算法:")
|
||||
for i, alg in enumerate(algorithms):
|
||||
print(f"{i + 1}. {alg}")
|
||||
|
||||
# 用户选择算法
|
||||
while True:
|
||||
try:
|
||||
selection = int(input("请选择算法编号(输入数字):")) - 1
|
||||
if 0 <= selection < len(algorithms):
|
||||
selected_alg = algorithms[selection]
|
||||
break
|
||||
else:
|
||||
print(f"输入的编号无效,请输入 1 到 {len(algorithms)} 之间的数字。")
|
||||
except ValueError:
|
||||
print("无效输入,请输入数字。")
|
||||
|
||||
# 生成选定算法的相对路径
|
||||
relative_alg_path = os.path.join(alg_directory, selected_alg)
|
||||
print(f" 已选择算法: {selected_alg} {relative_alg_path}")
|
||||
|
||||
return selected_alg, relative_alg_path
|
||||
|
||||
# 选择计算用的GPU
|
||||
def select_GPU():
|
||||
"""
|
||||
让用户选择要使用的 GPU 数量,并返回相应的 GPU 列表(以 0,1 格式输出)。
|
||||
如果用户没有输入任何选择,默认为选择一块 GPU,编号为 0。
|
||||
:return: 用户选择的 GPU 列表和数量
|
||||
"""
|
||||
# 获取 GPU 信息
|
||||
gpu_info = get_gpu_info()
|
||||
|
||||
if not gpu_info:
|
||||
return None, 0
|
||||
|
||||
# 显示 GPU 信息
|
||||
print("可用的 GPU 列表:")
|
||||
for idx, mem_free in gpu_info:
|
||||
print(f"GPU {idx}: 剩余显存 {mem_free} MB")
|
||||
|
||||
# 提供默认选择 GPU 个数和编号
|
||||
default_num_gpus = 1
|
||||
default_gpu_idx = 0
|
||||
|
||||
# 用户选择 GPU 个数(默认选择 1)
|
||||
try:
|
||||
num_gpus = input(f"请选择使用的 GPU 个数 (1-{len(gpu_info)}, 默认为 1): ")
|
||||
num_gpus = int(num_gpus) if num_gpus else default_num_gpus
|
||||
if not (1 <= num_gpus <= len(gpu_info)):
|
||||
print(f"输入无效,使用默认值 1 个 GPU。")
|
||||
num_gpus = default_num_gpus
|
||||
except ValueError:
|
||||
print(f"无效输入,使用默认值 1 个 GPU。")
|
||||
num_gpus = default_num_gpus
|
||||
|
||||
# 用户选择 GPU 编号(如果选择多个 GPU,逐个选择编号)
|
||||
selected_gpus = []
|
||||
for i in range(num_gpus):
|
||||
try:
|
||||
gpu_idx = input(f"请输入要使用的第 {i + 1} 个 GPU 的编号 (默认为 {default_gpu_idx}): ")
|
||||
gpu_idx = int(gpu_idx) if gpu_idx else default_gpu_idx
|
||||
if gpu_idx in [idx for idx, _ in gpu_info] and gpu_idx not in selected_gpus:
|
||||
selected_gpus.append(gpu_idx)
|
||||
else:
|
||||
print(f"无效输入,使用默认 GPU {default_gpu_idx}")
|
||||
selected_gpus.append(default_gpu_idx)
|
||||
except ValueError:
|
||||
print(f"无效输入,使用默认 GPU {default_gpu_idx}")
|
||||
selected_gpus.append(default_gpu_idx)
|
||||
|
||||
# 返回以 0,1 格式的 GPU 列表和 GPU 数量
|
||||
gpu_list_str = ','.join(map(str, selected_gpus))
|
||||
print(f"\n已选择 GPU: {gpu_list_str},共 {num_gpus} 块 GPU")
|
||||
|
||||
return gpu_list_str, num_gpus
|
||||
|
||||
# 选择训练批次相关信息
|
||||
def select_schedule():
|
||||
"""
|
||||
选择训练时间、验证次数、日志间隔,并生成训练计划配置文件。
|
||||
提供默认值:train_time_k=40, check_num=10, loggerhook_interval=50
|
||||
"""
|
||||
# 交互式输入 train_time_k,默认值 40
|
||||
try:
|
||||
train_time_k = input("请输入训练次数(k为单位,默认40k):").strip()
|
||||
train_time_k = int(train_time_k) if train_time_k else 40
|
||||
except ValueError:
|
||||
print("无效输入,使用默认训练时间 40k")
|
||||
train_time_k = 40
|
||||
|
||||
# 交互式输入 check_num,默认值 10
|
||||
try:
|
||||
check_num = input("请输入检查点数量(默认10):").strip()
|
||||
check_num = int(check_num) if check_num else 10
|
||||
except ValueError:
|
||||
print("无效输入,使用默认检查点数量 20")
|
||||
check_num = 20
|
||||
|
||||
# 交互式输入 loggerhook_interval,默认值 50
|
||||
try:
|
||||
loggerhook_interval = input("请输入日志间隔(默认50次迭代):").strip()
|
||||
loggerhook_interval = int(loggerhook_interval) if loggerhook_interval else 50
|
||||
except ValueError:
|
||||
print("无效输入,使用默认日志间隔 50 次迭代")
|
||||
loggerhook_interval = 50
|
||||
|
||||
# 计算验证比例和检查点间隔
|
||||
val_proportion = 1 / check_num # 验证比例
|
||||
checkpoint_interval = int(train_time_k * 1000 * val_proportion) # 计算保存检查点的间隔
|
||||
|
||||
# 生成文件名和路径
|
||||
output_configs_base_schedules_schedules_Timek = os.path.join('./configs/_base_/schedules/', f'schedule_{train_time_k}k_check_{checkpoint_interval}.py')
|
||||
schedule_file_name = f'schedule_{train_time_k}k_check_{checkpoint_interval}.py'
|
||||
|
||||
# 调用生成配置文件函数
|
||||
generate_times_configs_base_schedules_schedule_file(
|
||||
output_file=output_configs_base_schedules_schedules_Timek,
|
||||
train_time_k=train_time_k,
|
||||
val_proportion=val_proportion,
|
||||
loggerhook_interval=loggerhook_interval
|
||||
)
|
||||
|
||||
return schedule_file_name, train_time_k
|
||||
|
||||
# 运行对应的代码去生成算法配置文件
|
||||
def run_alg_to_gen_alg(relative_alg_path, mean, std, alg_name, dataset_file_name, num_classes, GPU_num, schedule_file_name, train_time_k):
|
||||
try:
|
||||
# 构建命令和参数
|
||||
cmd = [
|
||||
'python', relative_alg_path,
|
||||
'--alg_name', alg_name,
|
||||
'--dataset_file_name', dataset_file_name,
|
||||
'--mean', str(mean),
|
||||
'--std', str(std),
|
||||
'--dataset_num_classes', str(num_classes),
|
||||
'--GPU_num', str(GPU_num),
|
||||
'--schedule_file_name', schedule_file_name,
|
||||
'--train_time_k', str(train_time_k)
|
||||
]
|
||||
|
||||
# 打印当前执行的命令
|
||||
print(f"\n正在运行: \033[33m{' '.join(cmd)}\033[0m")
|
||||
|
||||
# 运行 Python 脚本并传递参数
|
||||
subprocess.run(cmd, check=True)
|
||||
|
||||
print(f"\033[92m算法生成器 {relative_alg_path} 运行成功。\033[0m")
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"\033[91m算法生成器 {relative_alg_path} 运行失败,错误信息: {e}\033[0m")
|
||||
|
||||
if __name__ == '__main__':
|
||||
########### 1.超参数-定义数据集、算法信息路径 ###########
|
||||
train_parameter_dir = './My_All_In_One/1_Data_Parameter'
|
||||
all_data_record_json = "All_Data_Record.json"
|
||||
all_data_record_file = os.path.join(train_parameter_dir, all_data_record_json) # 数据集信息记录路径
|
||||
|
||||
alg_directory = './My_All_In_One/2_Alg_Program' # 算法配置生成路径
|
||||
|
||||
work_dir_base = './work_dirs' # 工作路径
|
||||
|
||||
########### 2.获取现有数据集信息 ###########
|
||||
print(f"\033[36m{'='*10} 一、选择训练数据集 {'='*10}\033[0m")
|
||||
data_record_json_data = load_json_file(file_name = all_data_record_file) # 加载数据集信息
|
||||
all_data_record = process_data_record_json_data(json_data = data_record_json_data) # 分析数据集信息
|
||||
dataset_file_name, num_classes, mean, std = select_dataset(all_data_record = all_data_record) # 选择特定数据集
|
||||
|
||||
########### 3.获取现有GPU信息 ###########
|
||||
print(f"\033[36m{'='*10} 二、选择训练GPU {'='*10}\033[0m")
|
||||
gpu_list_str, GPU_num = select_GPU() # 选择GPU
|
||||
|
||||
########### 4.获取训练批次信息 ###########
|
||||
print(f"\033[36m{'='*10} 三、选择训练批次信息 {'='*10}\033[0m")
|
||||
schedule_file_name, train_time_k = select_schedule() # 选择训练批次
|
||||
schedule_file_name = schedule_file_name.rstrip('.py')
|
||||
|
||||
########### 5.获取现有算法信息 ###########
|
||||
print(f"\033[36m{'='*10} 四、选择训练算法 {'='*10}\033[0m")
|
||||
selected_alg, relative_alg_path = select_alg(alg_directory=alg_directory) # 选择算法
|
||||
alg_name = selected_alg.rstrip(".py")
|
||||
|
||||
########### 6.运行选定的算法 ###########
|
||||
print(f"\033[36m{'='*5} 生成训练算法配置 {'='*5}\033[0m")
|
||||
run_alg_to_gen_alg(relative_alg_path=relative_alg_path, alg_name=alg_name, dataset_file_name=dataset_file_name, num_classes=num_classes, GPU_num=GPU_num, mean=mean, std=std,schedule_file_name=schedule_file_name, train_time_k=train_time_k)
|
||||
|
||||
# 算法相关信息
|
||||
with open("_temp_.txt", "r") as file:
|
||||
data = json.load(file)
|
||||
alg_file_name = data['alg_infos']['alg_file_name']
|
||||
alg_file_pth = data['alg_infos']['alg_file_pth']
|
||||
os.remove("_temp_.txt")
|
||||
|
||||
########### 7.输出工作、算法目录 ###########
|
||||
# 获取并打印当前的年月日时分秒
|
||||
data_now = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
||||
work_dir_file = os.path.join(work_dir_base, f"{dataset_file_name}-Class_{num_classes}-Alg_{alg_name}-AlgName_{alg_file_name}-Card_{GPU_num}-Data_{data_now}")
|
||||
|
||||
print(f"\033[36m训练指令:python tools/train.py {alg_file_pth} --work-dir {work_dir_file}")
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,357 @@
|
||||
import os, json, subprocess
|
||||
from Initial_Schedule_Program.Initial_Train_Gen_configs_base_schedules_schedule_XXk import generate_times_configs_base_schedules_schedule_file
|
||||
from Initial_Schedule_Program.Initial_Train_Gen_configs_base_schedules_schedule_XXe import generate_epochs_configs_base_schedules_schedule_file
|
||||
from datetime import datetime
|
||||
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_Tool import format_all_data, write_config_to_file, get_gpu_info
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_0_base_ import generate_base_config
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_1_norm_cfg import generate_norm_cfg
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_2_data_preprocessor import generate_data_preprocessor
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_3_model import generate_model_pretrained, generate_model_backbone, generate_model_data_preprocessor, generate_model_decode_head, generate_model_auxiliary_head, generate_model_train_cfg, generate_model_test_cfg
|
||||
|
||||
def load_json_file(file_name):
|
||||
"""
|
||||
读取并返回指定 JSON 文件的内容。
|
||||
|
||||
:param file_name: 要读取的 JSON 文件路径
|
||||
:return: JSON 文件的内容作为字典或列表
|
||||
"""
|
||||
try:
|
||||
# 打开并读取 JSON 文件
|
||||
with open(file_name, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
return data
|
||||
except FileNotFoundError:
|
||||
print(f"\033[91mError: File {file_name} not found.\033[0m")
|
||||
return None
|
||||
except json.JSONDecodeError:
|
||||
print(f"\033[91mError: Failed to decode JSON from {file_name}.\033[0m")
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f"\033[91mAn error occurred while reading {file_name}: {str(e)}\033[0m")
|
||||
return None
|
||||
|
||||
def process_data_record_json_data(json_data):
|
||||
"""
|
||||
将 JSON 数据直接存入一个大的字典 all_data_record。
|
||||
:param json_data: 从 JSON 文件加载的数据
|
||||
:return: all_data_record 字典
|
||||
"""
|
||||
all_data_record = {}
|
||||
|
||||
# 遍历 json_data 中的每个键(即每个数据集)
|
||||
for dataset_name, dataset_info in json_data.items():
|
||||
all_data_record[dataset_name] = {
|
||||
"classes": dataset_info['classes'],
|
||||
"palette": dataset_info['palette'],
|
||||
"palette_num": dataset_info['palette_num'],
|
||||
"mean": dataset_info['mean'],
|
||||
"std": dataset_info['std'],
|
||||
"train_imgs_num": dataset_info['train_imgs_num']
|
||||
}
|
||||
|
||||
return all_data_record
|
||||
|
||||
# 选择要处理的数据集
|
||||
def select_dataset(all_data_record):
|
||||
"""
|
||||
让用户从 all_data_record 中选择数据集,并返回 dataset_file_name 和对应的 palette_num。
|
||||
|
||||
:param all_data_record: 包含所有数据集信息的字典
|
||||
:return: 选定的 dataset_file_name 和 palette_num
|
||||
"""
|
||||
# 获取所有数据集的名称
|
||||
dataset_names = list(all_data_record.keys())
|
||||
|
||||
# 显示可用的数据集,并让用户选择
|
||||
print("选择可用数据集:")
|
||||
for i, name in enumerate(dataset_names):
|
||||
print(f"{i + 1}. {name} - {all_data_record[name]['palette_num']}类")
|
||||
|
||||
# 用户输入选择的数据集编号
|
||||
while True:
|
||||
try:
|
||||
selection = int(input("请选择数据集编号(输入数字):")) - 1
|
||||
if 0 <= selection < len(dataset_names):
|
||||
dataset_file_name = dataset_names[selection]
|
||||
break
|
||||
else:
|
||||
print(f"输入的编号无效,请输入 1 到 {len(dataset_names)} 之间的数字。")
|
||||
except ValueError:
|
||||
print("无效输入,请输入数字。")
|
||||
|
||||
# 获取对应的 palette_num
|
||||
palette_num = all_data_record[dataset_file_name]['palette_num']
|
||||
mean = all_data_record[dataset_file_name]['mean']
|
||||
std = all_data_record[dataset_file_name]['std']
|
||||
train_imgs_num = all_data_record[dataset_file_name]['train_imgs_num']
|
||||
|
||||
print(f" 已选择数据集: {dataset_file_name} ,其对应的分类数为: {palette_num}")
|
||||
|
||||
return dataset_file_name, palette_num, mean, std, train_imgs_num
|
||||
|
||||
# 选择对应的算法
|
||||
def select_alg(alg_directory):
|
||||
"""
|
||||
选择一个算法文件并运行它。
|
||||
|
||||
:return: 选定的算法文件名和相对路径
|
||||
"""
|
||||
# 获取 ./Alg 目录下的所有 Python 文件
|
||||
algorithms = [f for f in os.listdir(alg_directory) if f.endswith('.py')]
|
||||
|
||||
# 显示可用的算法文件
|
||||
print("选择可用算法:")
|
||||
for i, alg in enumerate(algorithms):
|
||||
print(f"{i + 1}. {alg}")
|
||||
|
||||
# 用户选择算法
|
||||
while True:
|
||||
try:
|
||||
selection = int(input("请选择算法编号(输入数字):")) - 1
|
||||
if 0 <= selection < len(algorithms):
|
||||
selected_alg = algorithms[selection]
|
||||
break
|
||||
else:
|
||||
print(f"输入的编号无效,请输入 1 到 {len(algorithms)} 之间的数字。")
|
||||
except ValueError:
|
||||
print("无效输入,请输入数字。")
|
||||
|
||||
# 生成选定算法的相对路径
|
||||
relative_alg_path = os.path.join(alg_directory, selected_alg)
|
||||
print(f" 已选择算法: {selected_alg} {relative_alg_path}")
|
||||
|
||||
return selected_alg, relative_alg_path
|
||||
|
||||
# 选择计算用的GPU
|
||||
def select_GPU():
|
||||
"""
|
||||
让用户选择要使用的 GPU 数量,并返回相应的 GPU 列表(以 0,1 格式输出)。
|
||||
如果用户没有输入任何选择,默认为选择一块 GPU,编号为 0。
|
||||
:return: 用户选择的 GPU 列表和数量
|
||||
"""
|
||||
# 获取 GPU 信息
|
||||
gpu_info = get_gpu_info()
|
||||
|
||||
if not gpu_info:
|
||||
return None, 0
|
||||
|
||||
# 显示 GPU 信息
|
||||
print("可用的 GPU 列表:")
|
||||
for idx, mem_free in gpu_info:
|
||||
print(f"GPU {idx}: 剩余显存 {mem_free} MB")
|
||||
|
||||
# 提供默认选择 GPU 个数和编号
|
||||
default_num_gpus = 1
|
||||
default_gpu_idx = 0
|
||||
|
||||
# 用户选择 GPU 个数(默认选择 1)
|
||||
try:
|
||||
num_gpus = input(f"请选择使用的 GPU 个数 (1-{len(gpu_info)}, 默认为 1): ")
|
||||
num_gpus = int(num_gpus) if num_gpus else default_num_gpus
|
||||
if not (1 <= num_gpus <= len(gpu_info)):
|
||||
print(f"输入无效,使用默认值 1 个 GPU。")
|
||||
num_gpus = default_num_gpus
|
||||
except ValueError:
|
||||
print(f"无效输入,使用默认值 1 个 GPU。")
|
||||
num_gpus = default_num_gpus
|
||||
|
||||
# 用户选择 GPU 编号(如果选择多个 GPU,逐个选择编号)
|
||||
selected_gpus = []
|
||||
for i in range(num_gpus):
|
||||
try:
|
||||
gpu_idx = input(f"请输入要使用的第 {i + 1} 个 GPU 的编号 (默认为 {default_gpu_idx}): ")
|
||||
gpu_idx = int(gpu_idx) if gpu_idx else default_gpu_idx
|
||||
if gpu_idx in [idx for idx, _ in gpu_info] and gpu_idx not in selected_gpus:
|
||||
selected_gpus.append(gpu_idx)
|
||||
else:
|
||||
print(f"无效输入,使用默认 GPU {default_gpu_idx}")
|
||||
selected_gpus.append(default_gpu_idx)
|
||||
except ValueError:
|
||||
print(f"无效输入,使用默认 GPU {default_gpu_idx}")
|
||||
selected_gpus.append(default_gpu_idx)
|
||||
|
||||
# 返回以 0,1 格式的 GPU 列表和 GPU 数量
|
||||
gpu_list_str = ','.join(map(str, selected_gpus))
|
||||
print(f"\n已选择 GPU: {gpu_list_str},共 {num_gpus} 块 GPU")
|
||||
|
||||
return gpu_list_str, num_gpus
|
||||
|
||||
# 选择训练批次相关信息
|
||||
def select_schedule(train_imgs_num):
|
||||
"""
|
||||
让用户选择训练模式(Iteration 或 Epoch),并根据选择收集参数,
|
||||
最终生成对应的训练计划 schedule 配置文件。
|
||||
"""
|
||||
# 1. 让用户选择模式
|
||||
while True:
|
||||
mode = input("请选择训练模式 (1: Iteration, 2: Epoch) [默认: 2]: ").strip()
|
||||
if mode in ['1', '2', '']:
|
||||
mode = '2' if mode == '' else mode
|
||||
break
|
||||
else:
|
||||
print("无效输入,请输入 1 或 2。")
|
||||
|
||||
# --- 模式 1: 基于 Iteration 的配置 ---
|
||||
if mode == '1':
|
||||
print("\n--- 您已选择 Iteration 模式 ---")
|
||||
try:
|
||||
train_time_or_epoch_k = input("请输入训练次数 (k为单位, 默认40k): ").strip()
|
||||
train_time_or_epoch_k = int(train_time_or_epoch_k) if train_time_or_epoch_k else 40
|
||||
except ValueError:
|
||||
print("无效输入,使用默认训练时间 40k")
|
||||
train_time_or_epoch_k = 40
|
||||
|
||||
try:
|
||||
check_num = input("请输入总的验证/保存次数 (默认10): ").strip()
|
||||
check_num = int(check_num) if check_num else 10
|
||||
except ValueError:
|
||||
print("无效输入,使用默认次数 10")
|
||||
check_num = 10
|
||||
|
||||
try:
|
||||
loggerhook_interval = input("请输入日志间隔 (默认50次迭代): ").strip()
|
||||
loggerhook_interval = int(loggerhook_interval) if loggerhook_interval else 50
|
||||
except ValueError:
|
||||
print("无效输入,使用默认日志间隔 50")
|
||||
loggerhook_interval = 50
|
||||
|
||||
val_proportion = 1 / check_num
|
||||
# 根据验证比例计算间隔,确保至少为1
|
||||
interval = max(1, int(train_time_or_epoch_k * 1000 * val_proportion))
|
||||
|
||||
schedule_file_name = f'schedule_{train_time_or_epoch_k}k_check_{interval}.py'
|
||||
output_path = os.path.join('./configs/_base_/schedules/', schedule_file_name)
|
||||
|
||||
generate_times_configs_base_schedules_schedule_file(
|
||||
output_file=output_path,
|
||||
train_time_or_epoch_k=train_time_or_epoch_k,
|
||||
val_proportion=val_proportion,
|
||||
loggerhook_interval=loggerhook_interval
|
||||
)
|
||||
# 返回文件名和训练时长
|
||||
return schedule_file_name, "Iteration", train_time_or_epoch_k
|
||||
|
||||
# --- 模式 2: 基于 Epoch 的配置 ---
|
||||
elif mode == '2':
|
||||
print("\n--- 您已选择 Epoch 模式 ---")
|
||||
try:
|
||||
max_epochs = input("请输入训练总轮数 (Epochs, 默认300): ").strip()
|
||||
max_epochs = int(max_epochs) if max_epochs else 300
|
||||
except ValueError:
|
||||
print("无效输入,使用默认轮数 300")
|
||||
max_epochs = 300
|
||||
|
||||
try:
|
||||
val_interval = input("请输入验证间隔的轮次数 (Epochs, 默认1): ").strip()
|
||||
val_interval = int(val_interval) if val_interval else 1
|
||||
except ValueError:
|
||||
print("无效输入,使用默认验证间隔 1")
|
||||
val_interval = 1
|
||||
|
||||
try:
|
||||
checkpoint_interval = input("请输入模型保存间隔的轮次数 (Epochs, 默认10): ").strip()
|
||||
checkpoint_interval = int(checkpoint_interval) if checkpoint_interval else 10
|
||||
except ValueError:
|
||||
print("无效输入,使用默认保存间隔 10")
|
||||
checkpoint_interval = 10
|
||||
|
||||
try:
|
||||
loggerhook_interval_default = train_imgs_num // 16
|
||||
loggerhook_interval = input(f"请输入日志间隔 (默认{loggerhook_interval_default}次迭代): ").strip()
|
||||
loggerhook_interval = int(loggerhook_interval) if loggerhook_interval else loggerhook_interval_default
|
||||
except ValueError:
|
||||
print(f"无效输入,使用默认日志间隔 {loggerhook_interval_default}")
|
||||
loggerhook_interval = loggerhook_interval_default
|
||||
|
||||
schedule_file_name = f'schedule_{max_epochs}e_val{val_interval}_check{checkpoint_interval}.py'
|
||||
output_path = os.path.join('./configs/_base_/schedules/', schedule_file_name)
|
||||
|
||||
# 注意:这里调用的是基于Epoch的生成函数
|
||||
generate_epochs_configs_base_schedules_schedule_file(
|
||||
output_file=output_path,
|
||||
max_epochs=max_epochs,
|
||||
val_interval=val_interval,
|
||||
checkpoint_interval=checkpoint_interval,
|
||||
loggerhook_interval=loggerhook_interval
|
||||
)
|
||||
# 返回文件名和训练时长
|
||||
return schedule_file_name, "Epoch", max_epochs
|
||||
|
||||
# 运行对应的代码去生成算法配置文件
|
||||
def run_alg_to_gen_alg(relative_alg_path, mean, std, alg_name, dataset_file_name, num_classes, GPU_num, schedule_file_name, train_type, train_time_or_epoch_k):
|
||||
try:
|
||||
# 构建命令和参数
|
||||
cmd = [
|
||||
'python', relative_alg_path,
|
||||
'--alg_name', alg_name,
|
||||
'--dataset_file_name', dataset_file_name,
|
||||
'--mean', str(mean),
|
||||
'--std', str(std),
|
||||
'--dataset_num_classes', str(num_classes),
|
||||
'--GPU_num', str(GPU_num),
|
||||
'--schedule_file_name', schedule_file_name,
|
||||
'--train_type', train_type,
|
||||
'--train_time_or_epoch_k', str(train_time_or_epoch_k)
|
||||
]
|
||||
|
||||
# 打印当前执行的命令
|
||||
print(f"\n正在运行: \033[33m{' '.join(cmd)}\033[0m")
|
||||
|
||||
# 运行 Python 脚本并传递参数
|
||||
subprocess.run(cmd, check=True)
|
||||
|
||||
print(f"\033[92m算法生成器 {relative_alg_path} 运行成功。\033[0m")
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"\033[91m算法生成器 {relative_alg_path} 运行失败,错误信息: {e}\033[0m")
|
||||
|
||||
if __name__ == '__main__':
|
||||
########### 1.超参数-定义数据集、算法信息路径 ###########
|
||||
train_parameter_dir = './My_All_In_One/1_Data_Parameter'
|
||||
all_data_record_json = "All_Data_Record.json"
|
||||
all_data_record_file = os.path.join(train_parameter_dir, all_data_record_json) # 数据集信息记录路径
|
||||
|
||||
alg_directory = './My_All_In_One/2_Alg_Program' # 算法配置生成路径
|
||||
|
||||
work_dir_base = '../DataSet_Public_outputs' # 工作路径
|
||||
|
||||
########### 2.获取现有数据集信息 ###########
|
||||
print(f"\033[36m{'='*10} 一、选择训练数据集 {'='*10}\033[0m")
|
||||
data_record_json_data = load_json_file(file_name = all_data_record_file) # 加载数据集信息
|
||||
all_data_record = process_data_record_json_data(json_data = data_record_json_data) # 分析数据集信息
|
||||
dataset_file_name, num_classes, mean, std, train_imgs_num = select_dataset(all_data_record = all_data_record) # 选择特定数据集
|
||||
|
||||
########### 3.获取现有GPU信息 ###########
|
||||
print(f"\033[36m{'='*10} 二、选择训练GPU {'='*10}\033[0m")
|
||||
gpu_list_str, GPU_num = select_GPU() # 选择GPU
|
||||
|
||||
########### 4.获取训练批次信息 ###########
|
||||
print(f"\033[36m{'='*10} 三、选择训练批次信息 {'='*10}\033[0m")
|
||||
schedule_file_name, train_type, train_time_or_epoch_k = select_schedule(train_imgs_num) # 选择训练批次
|
||||
schedule_file_name = schedule_file_name.rstrip('.py')
|
||||
|
||||
########### 5.获取现有算法信息 ###########
|
||||
print(f"\033[36m{'='*10} 四、选择训练算法 {'='*10}\033[0m")
|
||||
selected_alg, relative_alg_path = select_alg(alg_directory=alg_directory) # 选择算法
|
||||
alg_name = selected_alg.rstrip(".py")
|
||||
|
||||
########### 6.运行选定的算法 ###########
|
||||
print(f"\033[36m{'='*5} 生成训练算法配置 {'='*5}\033[0m")
|
||||
run_alg_to_gen_alg(relative_alg_path=relative_alg_path, alg_name=alg_name, dataset_file_name=dataset_file_name, num_classes=num_classes, GPU_num=GPU_num, mean=mean, std=std,schedule_file_name=schedule_file_name, train_type = train_type, train_time_or_epoch_k=train_time_or_epoch_k)
|
||||
|
||||
# 算法相关信息
|
||||
with open("_temp_.txt", "r") as file:
|
||||
data = json.load(file)
|
||||
alg_file_name = data['alg_infos']['alg_file_name']
|
||||
alg_file_pth = data['alg_infos']['alg_file_pth']
|
||||
os.remove("_temp_.txt")
|
||||
|
||||
########### 7.输出工作、算法目录 ###########
|
||||
# 获取并打印当前的年月日时分秒
|
||||
data_now = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
||||
work_dir_file = os.path.join(work_dir_base, f"{dataset_file_name}-Class_{num_classes}-Alg_{alg_name}-AlgName_{alg_file_name}-Card_{GPU_num}-Data_{data_now}")
|
||||
|
||||
print(f"\033[36m训练指令:python tools/train.py {alg_file_pth} --work-dir {work_dir_file}")
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,78 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
def find_specific_epochs(root_folder):
|
||||
"""
|
||||
遍历指定文件夹,查找所有 epoch_XXX.pth 文件,
|
||||
其中 XXX 为整数且不能被 10 整除,并返回它们的绝对路径。
|
||||
|
||||
:param root_folder: 要搜索的根文件夹路径
|
||||
:return: 一个包含符合条件文件绝对路径的列表
|
||||
"""
|
||||
matching_files = []
|
||||
|
||||
for dirpath, _, filenames in os.walk(root_folder):
|
||||
for filename in filenames:
|
||||
if filename.startswith('epoch_') and filename.endswith('.pth'):
|
||||
try:
|
||||
number_str = filename[len('epoch_'):-len('.pth')]
|
||||
if number_str.isdigit():
|
||||
epoch_number = int(number_str)
|
||||
if epoch_number % 10 != 0:
|
||||
full_path = os.path.join(dirpath, filename)
|
||||
matching_files.append(os.path.abspath(full_path))
|
||||
except (ValueError, IndexError):
|
||||
continue
|
||||
|
||||
return matching_files
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 1. 获取当前脚本文件所在的绝对目录
|
||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
# 2. 以脚本目录为基准,构建目标路径
|
||||
# V1
|
||||
target_directory = os.path.abspath(os.path.join(script_dir, '../../Hardisk'))
|
||||
# V2
|
||||
# target_directory = os.path.abspath(os.path.join(script_dir, '../../DataSet_Public_outputs'))
|
||||
|
||||
# 1. 查找符合条件的文件
|
||||
found_paths = find_specific_epochs(target_directory)
|
||||
|
||||
if not found_paths:
|
||||
print(f"在 '{os.path.abspath(target_directory)}' 及其子目录中没有找到符合条件的文件。")
|
||||
sys.exit(0)
|
||||
|
||||
# 2. 列出所有找到的文件,并请求用户确认
|
||||
print("找到了以下符合条件的文件:")
|
||||
for path in found_paths:
|
||||
print(path)
|
||||
|
||||
print("\n" + "="*50)
|
||||
print("警告:接下来的操作将永久删除以上列出的所有文件!")
|
||||
print("="*50 + "\n")
|
||||
|
||||
# 3. 获取用户输入
|
||||
try:
|
||||
# 使用 strip() 去除首尾空格,使用 lower() 转换为小写
|
||||
confirm = input("您确定要删除这 {} 个文件吗? (请输入 'yes' 进行确认,输入其他任何内容则取消): ".format(len(found_paths)))
|
||||
except KeyboardInterrupt:
|
||||
print("\n\n操作被用户中断。")
|
||||
sys.exit(1)
|
||||
|
||||
# 4. 根据用户的输入执行操作
|
||||
if confirm.lower().strip() == 'yes':
|
||||
print("\n正在开始删除文件...")
|
||||
deleted_count = 0
|
||||
error_count = 0
|
||||
for path in found_paths:
|
||||
try:
|
||||
os.remove(path)
|
||||
print(f"已删除: {path}")
|
||||
deleted_count += 1
|
||||
except OSError as e:
|
||||
print(f"删除失败: {path} (原因: {e})")
|
||||
error_count += 1
|
||||
print(f"\n操作完成。成功删除 {deleted_count} 个文件,{error_count} 个文件删除失败。")
|
||||
else:
|
||||
print("\n操作已取消,没有文件被删除。")
|
||||
@@ -0,0 +1,99 @@
|
||||
#!/bin/bash
|
||||
|
||||
# --- 脚本说明 ---
|
||||
# 功能: 使用 rsync 将指定格式的文件夹从源目录同步到目标目录。
|
||||
# - 脚本可以从任何位置安全执行。
|
||||
# - 在同步前会检查源文件/目录是否存在,如果不存在则跳过该任务。
|
||||
#
|
||||
# rsync 参数说明:
|
||||
# -a: 归档模式,保留文件属性。
|
||||
# -v: 详细模式。
|
||||
# -h: 人性化显示大小。
|
||||
# --progress: 显示传输进度。
|
||||
# --stats: 显示任务统计。
|
||||
|
||||
# --- 路径配置 ---
|
||||
# 获取脚本文件所在的绝对目录路径,确保路径的准确性。
|
||||
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
|
||||
|
||||
# 基于脚本所在目录设置源和目标的绝对路径
|
||||
SOURCE_DIR="$SCRIPT_DIR/../../DataSet_Public_outputs"
|
||||
DEST_DIR="$SCRIPT_DIR/../../Hardisk"
|
||||
|
||||
# --- 主程序 ---
|
||||
echo "脚本执行目录已锁定为: $SCRIPT_DIR"
|
||||
echo "解析后的源目录 (Source): $SOURCE_DIR"
|
||||
echo "解析后的目标根目录 (Destination): $DEST_DIR"
|
||||
echo "========================================"
|
||||
echo "开始执行 rsync 同步任务(带安全检查)..."
|
||||
echo ""
|
||||
|
||||
# --- 任务处理函数 (简化代码) ---
|
||||
# 定义一个函数来处理每个同步任务,避免代码重复
|
||||
# 参数1: 任务名称 (例如: 1_cholecseg8k)
|
||||
# 参数2: 源文件/目录的通配符模式
|
||||
# 参数3: 目标目录
|
||||
handle_sync_task() {
|
||||
local task_name="$1"
|
||||
local src_pattern="$2"
|
||||
local dest_dir="$3"
|
||||
|
||||
echo "--> 正在检查任务: $task_name"
|
||||
|
||||
# 将通配符匹配到的文件存入数组
|
||||
local sources=($src_pattern)
|
||||
|
||||
# 检查数组的第一个元素是否存在。如果通配符没有匹配到任何文件,
|
||||
# bash会把通配符本身作为字符串返回,而这个字符串命名的文件通常不存在。
|
||||
if [ -e "${sources[0]}" ]; then
|
||||
echo " 源文件已找到,准备同步..."
|
||||
echo " 目标目录: $dest_dir"
|
||||
|
||||
# 创建目标目录(如果不存在)
|
||||
mkdir -p "$dest_dir"
|
||||
|
||||
# 执行 rsync
|
||||
rsync -avh --progress --stats "${sources[@]}" "$dest_dir/"
|
||||
echo "--> 任务 '$task_name' 同步完成。"
|
||||
else
|
||||
echo " 警告: 源路径 '$src_pattern' 未匹配到任何文件,已跳过此任务。"
|
||||
fi
|
||||
echo "----------------------------------------"
|
||||
}
|
||||
|
||||
|
||||
# --- 任务列表 ---
|
||||
|
||||
# 任务1: 同步 1_cholecseg8k
|
||||
handle_sync_task \
|
||||
"1_cholecseg8k" \
|
||||
"$SOURCE_DIR/1_cholecseg8k-Class_13-Alg*" \
|
||||
"$DEST_DIR/1_CholecSeg8k-13Type-1920x1080_outputs-MMSeg"
|
||||
|
||||
# 任务2: 同步 2_autolaparo
|
||||
handle_sync_task \
|
||||
"2_autolaparo" \
|
||||
"$SOURCE_DIR/2_autolaparo-Class_10-Alg*" \
|
||||
"$DEST_DIR/2_AutoLaparo-10Type-1920x1080_outputs-MMSeg"
|
||||
|
||||
# 任务3: 同步 3_1_endovis_2017
|
||||
handle_sync_task \
|
||||
"3_1_endovis_2017" \
|
||||
"$SOURCE_DIR/3_1_endovis_2017-Class_8-Alg*" \
|
||||
"$DEST_DIR/3_1_Endovis_2017-8Type-512x512_outputs-MMSeg"
|
||||
|
||||
# 任务4: 同步 3_2_endovis_2018
|
||||
handle_sync_task \
|
||||
"3_2_endovis_2018" \
|
||||
"$SOURCE_DIR/3_2_endovis_2018-Class_8-Alg*" \
|
||||
"$DEST_DIR/3_2_Endovis_2018-8Type-512x512_outputs-MMSeg"
|
||||
|
||||
# 任务5: 同步 4_dresden
|
||||
handle_sync_task \
|
||||
"4_dresden" \
|
||||
"$SOURCE_DIR/4_dresden-Class_11-Alg*" \
|
||||
"$DEST_DIR/4_Dresden-11Type-512x512_outputs-MMSeg"
|
||||
|
||||
|
||||
echo "========================================"
|
||||
echo "所有 rsync 同步任务已执行完毕!"
|
||||
@@ -0,0 +1,427 @@
|
||||
import os
|
||||
import glob
|
||||
import logging
|
||||
import argparse
|
||||
import re
|
||||
import subprocess
|
||||
import csv
|
||||
from typing import Dict, Optional, Tuple, List
|
||||
import time
|
||||
import numpy as np
|
||||
import torch
|
||||
from mmengine import Config
|
||||
from mmengine.model.utils import revert_sync_batchnorm
|
||||
from mmengine.registry import init_default_scope
|
||||
from mmengine.runner import Runner, load_checkpoint
|
||||
from mmseg.registry import MODELS
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
|
||||
# --- 辅助函数 ---
|
||||
def find_model_files(model_dir: str):
|
||||
"""
|
||||
在给定的模型目录中查找配置文件、最佳检查点和日志文件。
|
||||
|
||||
Args:
|
||||
model_dir (str): 模型的根目录。
|
||||
|
||||
Returns:
|
||||
Optional]: 包含 'config', 'checkpoint', 'log' 路径的字典,
|
||||
如果缺少任何必要文件,则返回 None。
|
||||
"""
|
||||
config_files = glob.glob(os.path.join(model_dir, '*.py'))
|
||||
if not config_files:
|
||||
logging.warning(f"在目录 {model_dir} 中未找到配置文件 (.py)。")
|
||||
return None
|
||||
config_path = config_files[0]
|
||||
|
||||
checkpoint_path = os.path.join(model_dir, 'best.pth')
|
||||
if not os.path.exists(checkpoint_path):
|
||||
epoch_files = glob.glob(os.path.join(model_dir, 'epoch_*.pth'))
|
||||
if not epoch_files:
|
||||
logging.warning(f"在目录 {model_dir} 中未找到 'best.pth' 或 'epoch_*.pth' 检查点文件。")
|
||||
return None
|
||||
|
||||
# 通过正则表达式从文件名中提取周期数并找到最大的
|
||||
latest_epoch = -1
|
||||
latest_file = None
|
||||
for f in epoch_files:
|
||||
match = re.search(r'epoch_(\d+)\.pth', os.path.basename(f))
|
||||
if match:
|
||||
epoch_num = int(match.group(1))
|
||||
if epoch_num > latest_epoch:
|
||||
latest_epoch = epoch_num
|
||||
latest_file = f
|
||||
|
||||
if latest_file:
|
||||
checkpoint_path = latest_file
|
||||
else:
|
||||
logging.warning(f"在目录 {model_dir} 中无法确定最新的检查点文件。")
|
||||
return None
|
||||
|
||||
return {'config': config_path, 'checkpoint': checkpoint_path}
|
||||
|
||||
def find_model_config(model_dir: str):
|
||||
"""
|
||||
在给定的模型目录中查找配置文件 (.py)。
|
||||
|
||||
Args:
|
||||
model_dir (str): 模型的根目录。
|
||||
|
||||
Returns:
|
||||
Optional[str]: 配置文件的路径,如果未找到则返回 None。
|
||||
"""
|
||||
config_files = glob.glob(os.path.join(model_dir, '*.py'))
|
||||
if not config_files:
|
||||
logging.warning(f"在目录 {model_dir} 中未找到配置文件 (.py)。")
|
||||
return None
|
||||
return config_files[0]
|
||||
|
||||
def get_shape_from_path(path: str):
|
||||
"""
|
||||
从文件夹路径中通过正则表达式提取分辨率 (宽x高)。
|
||||
|
||||
Args:
|
||||
path (str): 数据集文件夹的路径。
|
||||
|
||||
Returns:
|
||||
Optional[Tuple[int, int]]: 一个包含 (高度, 宽度) 的元组,如果未找到则返回 None。
|
||||
注意:工具需要 H W 格式。
|
||||
"""
|
||||
match = re.search(r'(\d+)x(\d+)', os.path.basename(path))
|
||||
if match:
|
||||
width, height = int(match.group(1)), int(match.group(2))
|
||||
return (height, width) # 返回 H, W
|
||||
return None
|
||||
|
||||
def get_flops_and_params(config_path: str, shape: Tuple[int, int]) -> Optional[Dict[str, str]]:
|
||||
"""
|
||||
运行 mmsegmentation 的 get_flops.py 工具并解析其输出。
|
||||
此版本适配了新版的直接输出格式 (例如 "Flops: 0.118T")。
|
||||
|
||||
Args:
|
||||
config_path (str): 模型的 .py 配置文件路径。
|
||||
shape (Tuple[int, int]): 输入图像的 (H, W) 元组。
|
||||
|
||||
Returns:
|
||||
Optional[Dict[str, str]]: 包含 'params' 和 'flops' 的字典,如果失败则返回 None。
|
||||
"""
|
||||
# 检查工具脚本是否存在
|
||||
tool_script = 'tools/analysis_tools/get_flops.py'
|
||||
if not os.path.exists(tool_script):
|
||||
logging.error(f"错误: '{tool_script}' 未找到。请确保在 MMSegmentation 项目的根目录下运行此脚本。")
|
||||
return None
|
||||
|
||||
# 构建命令行
|
||||
command = [
|
||||
'python', tool_script, config_path,
|
||||
'--shape', str(shape[0]), str(shape[1])
|
||||
]
|
||||
|
||||
logging.info(f"执行命令: {' '.join(command)}")
|
||||
|
||||
try:
|
||||
# 执行命令并捕获输出
|
||||
result = subprocess.run(command, capture_output=True, text=True, check=True, encoding='utf-8')
|
||||
output = result.stdout
|
||||
|
||||
# 使用新的正则表达式来匹配更新后的输出格式
|
||||
flops_match = re.search(r"Flops:\s*([0-9.]+\s*[TGMK]?)", output)
|
||||
params_match = re.search(r"Params:\s*([0-9.]+\s*[TGMK]?)", output)
|
||||
|
||||
if flops_match and params_match:
|
||||
raw_flops_str = flops_match.group(1).strip()
|
||||
params = params_match.group(1).strip()
|
||||
# --- 开始单位换算 ---
|
||||
value_str = raw_flops_str.rstrip('TGMKtgmk').strip()
|
||||
unit = raw_flops_str[-1].upper() if raw_flops_str[-1].isalpha() else 'G'
|
||||
try:
|
||||
value = float(value_str)
|
||||
if unit == 'T':
|
||||
value_in_g = value * 1000
|
||||
elif unit == 'M':
|
||||
value_in_g = value / 1000
|
||||
elif unit == 'K':
|
||||
value_in_g = value / 1_000_000
|
||||
else: # 默认单位是 G
|
||||
value_in_g = value
|
||||
# 使用 :g 格式化可以去除末尾多余的0
|
||||
flops = f"{value_in_g:g} G"
|
||||
except ValueError:
|
||||
flops = raw_flops_str # 如果转换失败,则使用原始值
|
||||
# --- 单位换算结束 ---
|
||||
|
||||
logging.info(f"✅ 解析成功: FLOPs={flops} (原始值: {raw_flops_str}), Params={params}")
|
||||
return {'flops': flops, 'params': params}
|
||||
else:
|
||||
logging.warning(f"❌ 无法从命令输出中解析 FLOPs 或 Params。请检查以下输出内容:\n---\n{output}\n---")
|
||||
return None
|
||||
|
||||
except subprocess.CalledProcessError as e:
|
||||
logging.error(f"执行 get_flops.py 时出错。返回码: {e.returncode}")
|
||||
logging.error(f"错误输出 (stderr):\n---\n{e.stderr}\n---")
|
||||
return None
|
||||
except FileNotFoundError:
|
||||
logging.error("错误: 'python' 命令未找到。请确保 Python 环境已正确配置。")
|
||||
return None
|
||||
|
||||
def get_benchmark_stats(config_path: str, checkpoint_path: str, repeat_times: int) -> Optional[Dict[str, float]]:
|
||||
"""
|
||||
Runs inference benchmark based on the logic from benchmark.py.
|
||||
|
||||
Args:
|
||||
config_path (str): Path to the model's .py config file.
|
||||
checkpoint_path (str): Path to the model's .pth checkpoint file.
|
||||
repeat_times (int): Number of times to run the benchmark.
|
||||
|
||||
Returns:
|
||||
Optional[Dict[str, float]]: Dict containing 'average_fps' and 'fps_variance', or None on failure.
|
||||
"""
|
||||
try:
|
||||
cfg = Config.fromfile(config_path)
|
||||
init_default_scope(cfg.get('default_scope', 'mmseg'))
|
||||
|
||||
torch.backends.cudnn.benchmark = False
|
||||
cfg.model.pretrained = None
|
||||
cfg.test_dataloader.batch_size = 1 # Crucial for FPS measurement
|
||||
|
||||
overall_fps_list = []
|
||||
for time_index in range(repeat_times):
|
||||
logging.info(f"--- Starting Benchmark Run {time_index + 1}/{repeat_times} ---")
|
||||
|
||||
data_loader = Runner.build_dataloader(cfg.test_dataloader)
|
||||
|
||||
cfg.model.train_cfg = None
|
||||
model = MODELS.build(cfg.model)
|
||||
|
||||
load_checkpoint(model, checkpoint_path, map_location='cpu')
|
||||
|
||||
if torch.cuda.is_available():
|
||||
model = model.cuda(0)
|
||||
else:
|
||||
logging.warning("CUDA is not available. Benchmarking on CPU, results may be slow.")
|
||||
|
||||
model = revert_sync_batchnorm(model)
|
||||
model.eval()
|
||||
|
||||
num_warmup = 5
|
||||
pure_inf_time = 0
|
||||
total_iters = 100 # Reduced from 200 for faster script execution
|
||||
|
||||
for i, data in enumerate(data_loader):
|
||||
data = model.data_preprocessor(data, True)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.synchronize()
|
||||
start_time = time.perf_counter()
|
||||
|
||||
with torch.no_grad():
|
||||
model(data['inputs'], data['data_samples'], mode='predict')
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.synchronize()
|
||||
elapsed = time.perf_counter() - start_time
|
||||
|
||||
if i >= num_warmup:
|
||||
pure_inf_time += elapsed
|
||||
|
||||
if (i + 1) == total_iters:
|
||||
fps = (total_iters - num_warmup) / pure_inf_time
|
||||
logging.info(f"Run {time_index + 1} Overall FPS: {fps:.2f} img/s")
|
||||
overall_fps_list.append(fps)
|
||||
break
|
||||
|
||||
if not overall_fps_list:
|
||||
logging.error("Benchmark failed to produce any results.")
|
||||
return None
|
||||
|
||||
avg_fps = round(np.mean(overall_fps_list), 2)
|
||||
fps_var = round(np.var(overall_fps_list), 4)
|
||||
|
||||
logging.info(f"✅ Benchmark Complete: Average FPS={avg_fps}, Variance={fps_var}")
|
||||
return {'average_fps': avg_fps, 'fps_variance': fps_var}
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"An exception occurred during benchmarking: {e}")
|
||||
return None
|
||||
|
||||
# --- 主函数 ---
|
||||
def main(args):
|
||||
"""
|
||||
脚本主入口,负责编排整个自动化分析流程。
|
||||
"""
|
||||
input_root = args.input_dir
|
||||
output_root = args.output_dir
|
||||
# --- 开始交互式选择修改 (V2 - 两级菜单) ---
|
||||
if not os.path.isdir(input_root):
|
||||
logging.error(f"输入目录不存在: {input_root}")
|
||||
return
|
||||
|
||||
# 1. 定义有效的数据集文件夹白名单
|
||||
VALID_DATASET_FOLDERS = [
|
||||
'1_CholecSeg8k-13Type-1920x1080_outputs-MMSeg',
|
||||
'2_AutoLaparo-10Type-1920x1080_outputs-MMSeg',
|
||||
'3_1_Endovis_2017-8Type-512x512_outputs-MMSeg',
|
||||
'3_2_Endovis_2018-8Type-512x512_outputs-MMSeg',
|
||||
'4_Dresden-11Type-512x512_outputs-MMSeg'
|
||||
]
|
||||
|
||||
# 2. 查找存在的、有效的数据集目录
|
||||
existing_dataset_dirs = [
|
||||
os.path.join(input_root, d) for d in VALID_DATASET_FOLDERS
|
||||
if os.path.isdir(os.path.join(input_root, d))
|
||||
]
|
||||
|
||||
if not existing_dataset_dirs:
|
||||
logging.error(f"在输入目录 {input_root} 中未找到任何有效的数据集文件夹。")
|
||||
return
|
||||
|
||||
# 3. 第一级菜单:选择数据集
|
||||
dataset_map = {str(i + 1): path for i, path in enumerate(existing_dataset_dirs)}
|
||||
|
||||
print("\n" + "="*50)
|
||||
print("--- 步骤 1: 请选择要处理的数据集 ---")
|
||||
for key, path in dataset_map.items():
|
||||
print(f"{key}: {os.path.basename(path)}")
|
||||
print("="*50)
|
||||
|
||||
choice1 = input("请输入数据集编号并按回车键: ").strip()
|
||||
|
||||
model_dirs = [] # 初始化最终要处理的目录列表
|
||||
|
||||
if choice1 in dataset_map:
|
||||
selected_dataset_dir = dataset_map[choice1]
|
||||
logging.info(f"您已选择数据集: [{os.path.basename(selected_dataset_dir)}]")
|
||||
|
||||
# 4. 查找选定数据集下的所有算法子目录
|
||||
alg_dirs = sorted([
|
||||
d for d in glob.glob(os.path.join(selected_dataset_dir, '*')) if os.path.isdir(d)
|
||||
])
|
||||
|
||||
if not alg_dirs:
|
||||
logging.warning(f"在 {os.path.basename(selected_dataset_dir)} 中未发现任何算法子文件夹。程序退出。")
|
||||
else:
|
||||
# 5. 第二级菜单:选择算法
|
||||
alg_map = {str(i + 1): path for i, path in enumerate(alg_dirs)}
|
||||
print("\n" + "="*50)
|
||||
print("--- 步骤 2: 请选择要处理的算法 ---")
|
||||
print("0: 批量处理当前数据集下的【全部】算法")
|
||||
for key, path in alg_map.items():
|
||||
print(f"{key}: {os.path.basename(path)}")
|
||||
print("="*50)
|
||||
|
||||
choice2 = input("请输入算法编号 (或输入 '0' 处理全部) 并按回车键: ").strip()
|
||||
|
||||
# 6. 根据第二级选择,最终确定 model_dirs
|
||||
if choice2 == '0':
|
||||
model_dirs = alg_dirs
|
||||
logging.info(f"您选择了批量处理全部 {len(model_dirs)} 个算法。")
|
||||
elif choice2 in alg_map:
|
||||
model_dirs = [alg_map[choice2]] # 将单个路径放入列表中
|
||||
logging.info(f"您选择了处理单个算法: {os.path.basename(model_dirs[0])}")
|
||||
else:
|
||||
logging.error("无效的算法选择,程序已退出。")
|
||||
else:
|
||||
logging.error("无效的数据集选择,程序已退出。")
|
||||
|
||||
# --- 交互式选择修改结束 ---
|
||||
results: List[Dict[str, str]] = []
|
||||
# 修改后的循环,将遍历经过用户筛选后的 model_dirs 列表
|
||||
for model_dir in model_dirs:
|
||||
model_name = os.path.basename(model_dir)
|
||||
logging.info(f"--- 开始处理模型: {model_name} ---")
|
||||
files = find_model_files(model_dir)
|
||||
if not files:
|
||||
logging.warning(f"跳过目录 {model_dir},因为缺少必要文件。")
|
||||
continue
|
||||
|
||||
# 构建输出目录
|
||||
# 从模型名中提取数据集标识作为Key
|
||||
dataset_key = model_name.split('-')[0]
|
||||
dataset_map = {
|
||||
'1_cholecseg8k': '1_CholecSeg8k-13Type-1920x1080_outputs-MMSeg',
|
||||
'2_autolaparo': '2_AutoLaparo-10Type-1920x1080_outputs-MMSeg',
|
||||
'3_1_endovis_2017': '3_1_Endovis_2017-8Type-512x512_outputs-MMSeg',
|
||||
'3_2_endovis_2018': '3_2_Endovis_2018-8Type-512x512_outputs-MMSeg',
|
||||
'4_dresden': '4_Dresden-11Type-512x512_outputs-MMSeg'
|
||||
}
|
||||
# 使用提取的Key(字符串)进行查询,并为默认值也使用该Key
|
||||
output_dataset_folder = dataset_map.get(dataset_key, f"{dataset_key}_outputs-MMSeg")
|
||||
|
||||
# 尝试从数据集文件夹名称中获取分辨率
|
||||
input_shape = get_shape_from_path(selected_dataset_dir)
|
||||
if not input_shape:
|
||||
# 如果无法自动提取,要求用户输入
|
||||
logging.warning(f"无法从文件夹 '{os.path.basename(selected_dataset_dir)}' 名称中自动检测分辨率。")
|
||||
try:
|
||||
h_str = input("请输入默认测试高度 (H),例如 512: ").strip()
|
||||
w_str = input("请输入默认测试宽度 (W),例如 512: ").strip()
|
||||
input_shape = (int(h_str), int(w_str))
|
||||
except ValueError:
|
||||
logging.error("输入无效,必须是整数。程序退出。")
|
||||
return
|
||||
logging.info(f"将使用输入形状 (H, W): {input_shape} 进行计算。")
|
||||
|
||||
# 加载配置
|
||||
config_file = find_model_config(model_dir)
|
||||
|
||||
# 获取 FLOPs 和 Params
|
||||
flops_and_params_stats = get_flops_and_params(config_file, input_shape)
|
||||
benchmark_stats = get_benchmark_stats(files['config'], files['checkpoint'], args.repeat_times)
|
||||
if flops_and_params_stats:
|
||||
short_model_name = model_name.split('Alg_', 1)[1]
|
||||
results.append({
|
||||
'Model': short_model_name,
|
||||
'Params': flops_and_params_stats['params'] if flops_and_params_stats else 'N/A' ,
|
||||
'FLOPs': flops_and_params_stats['flops'] if flops_and_params_stats else 'N/A' ,
|
||||
'Input_Shape (HxW)': f"{input_shape[0]}x{input_shape[1]}",
|
||||
'Average_FPS': benchmark_stats['average_fps'] if benchmark_stats else 'N/A',
|
||||
'FPS_Variance': benchmark_stats['fps_variance'] if benchmark_stats else 'N/A'
|
||||
})
|
||||
else:
|
||||
logging.warning(f"未能获取模型 {model_name} 的统计信息。")
|
||||
|
||||
# --- 将结果写入 CSV 文件 ---
|
||||
if not results:
|
||||
logging.info("没有成功获取任何模型的统计数据,不生成 CSV 文件。")
|
||||
return
|
||||
|
||||
# 新建文件夹并保存 CSV
|
||||
final_output_dir = os.path.join(output_root, output_dataset_folder)
|
||||
os.makedirs(final_output_dir, exist_ok=True)
|
||||
dataset_name = os.path.basename(selected_dataset_dir).split('_outputs-MMSeg')[0]
|
||||
output_csv_path = os.path.join(final_output_dir, f'{dataset_name}_flops_params_fps_summary.csv')
|
||||
|
||||
try:
|
||||
with open(output_csv_path, 'w', newline='', encoding='utf-8') as csvfile:
|
||||
fieldnames = ['Model', 'Params', 'FLOPs', 'Input_Shape (HxW)', 'Average_FPS', 'FPS_Variance']
|
||||
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
|
||||
|
||||
writer.writeheader()
|
||||
writer.writerows(results)
|
||||
|
||||
logging.info(f"=== 全部处理完成!结果已成功保存到: {output_csv_path} ===")
|
||||
except IOError as e:
|
||||
logging.error(f"无法写入 CSV 文件: {output_csv_path}。错误: {e}")
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description="MMSegmentation 自动化评估脚本")
|
||||
parser.add_argument(
|
||||
'--input_dir',
|
||||
type=str,
|
||||
default='../Hardisk',
|
||||
help="包含已训练模型文件夹的根目录。"
|
||||
)
|
||||
parser.add_argument(
|
||||
'--output_dir',
|
||||
type=str,
|
||||
default='../BestMode_Predict_Results_DataSet_Public',
|
||||
help="用于存储所有分析结果的根目录。"
|
||||
)
|
||||
parser.add_argument(
|
||||
'--repeat-times',
|
||||
type=int,
|
||||
default=3,
|
||||
help="Number of times to repeat the benchmark for averaging."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
@@ -0,0 +1,332 @@
|
||||
import os
|
||||
import glob
|
||||
import logging
|
||||
import argparse
|
||||
import re
|
||||
import subprocess
|
||||
import csv
|
||||
from typing import Dict, Optional, Tuple, List
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
|
||||
# --- 辅助函数 ---
|
||||
def find_model_files(model_dir: str):
|
||||
"""
|
||||
在给定的模型目录中查找配置文件、最佳检查点和日志文件。
|
||||
|
||||
Args:
|
||||
model_dir (str): 模型的根目录。
|
||||
|
||||
Returns:
|
||||
Optional]: 包含 'config', 'checkpoint', 'log' 路径的字典,
|
||||
如果缺少任何必要文件,则返回 None。
|
||||
"""
|
||||
config_files = glob.glob(os.path.join(model_dir, '*.py'))
|
||||
if not config_files:
|
||||
logging.warning(f"在目录 {model_dir} 中未找到配置文件 (.py)。")
|
||||
return None
|
||||
config_path = config_files[0]
|
||||
|
||||
checkpoint_path = os.path.join(model_dir, 'best.pth')
|
||||
if not os.path.exists(checkpoint_path):
|
||||
epoch_files = glob.glob(os.path.join(model_dir, 'epoch_*.pth'))
|
||||
if not epoch_files:
|
||||
logging.warning(f"在目录 {model_dir} 中未找到 'best.pth' 或 'epoch_*.pth' 检查点文件。")
|
||||
return None
|
||||
|
||||
# 通过正则表达式从文件名中提取周期数并找到最大的
|
||||
latest_epoch = -1
|
||||
latest_file = None
|
||||
for f in epoch_files:
|
||||
match = re.search(r'epoch_(\d+)\.pth', os.path.basename(f))
|
||||
if match:
|
||||
epoch_num = int(match.group(1))
|
||||
if epoch_num > latest_epoch:
|
||||
latest_epoch = epoch_num
|
||||
latest_file = f
|
||||
|
||||
if latest_file:
|
||||
checkpoint_path = latest_file
|
||||
else:
|
||||
logging.warning(f"在目录 {model_dir} 中无法确定最新的检查点文件。")
|
||||
return None
|
||||
|
||||
return {'config': config_path, 'checkpoint': checkpoint_path}
|
||||
|
||||
def find_model_config(model_dir: str):
|
||||
"""
|
||||
在给定的模型目录中查找配置文件 (.py)。
|
||||
|
||||
Args:
|
||||
model_dir (str): 模型的根目录。
|
||||
|
||||
Returns:
|
||||
Optional[str]: 配置文件的路径,如果未找到则返回 None。
|
||||
"""
|
||||
config_files = glob.glob(os.path.join(model_dir, '*.py'))
|
||||
if not config_files:
|
||||
logging.warning(f"在目录 {model_dir} 中未找到配置文件 (.py)。")
|
||||
return None
|
||||
return config_files[0]
|
||||
|
||||
def get_shape_from_path(path: str):
|
||||
"""
|
||||
从文件夹路径中通过正则表达式提取分辨率 (宽x高)。
|
||||
|
||||
Args:
|
||||
path (str): 数据集文件夹的路径。
|
||||
|
||||
Returns:
|
||||
Optional[Tuple[int, int]]: 一个包含 (高度, 宽度) 的元组,如果未找到则返回 None。
|
||||
注意:工具需要 H W 格式。
|
||||
"""
|
||||
match = re.search(r'(\d+)x(\d+)', os.path.basename(path))
|
||||
if match:
|
||||
width, height = int(match.group(1)), int(match.group(2))
|
||||
return (height, width) # 返回 H, W
|
||||
return None
|
||||
|
||||
def get_flops_and_params(config_path: str, shape: Tuple[int, int]) -> Optional[Dict[str, str]]:
|
||||
"""
|
||||
运行 mmsegmentation 的 get_flops.py 工具并解析其输出。
|
||||
此版本适配了新版的直接输出格式 (例如 "Flops: 0.118T")。
|
||||
|
||||
Args:
|
||||
config_path (str): 模型的 .py 配置文件路径。
|
||||
shape (Tuple[int, int]): 输入图像的 (H, W) 元组。
|
||||
|
||||
Returns:
|
||||
Optional[Dict[str, str]]: 包含 'params' 和 'flops' 的字典,如果失败则返回 None。
|
||||
"""
|
||||
# 检查工具脚本是否存在
|
||||
tool_script = 'tools/analysis_tools/get_flops.py'
|
||||
if not os.path.exists(tool_script):
|
||||
logging.error(f"错误: '{tool_script}' 未找到。请确保在 MMSegmentation 项目的根目录下运行此脚本。")
|
||||
return None
|
||||
|
||||
# 构建命令行
|
||||
command = [
|
||||
'python', tool_script, config_path,
|
||||
'--shape', str(shape[0]), str(shape[1])
|
||||
]
|
||||
|
||||
logging.info(f"执行命令: {' '.join(command)}")
|
||||
|
||||
try:
|
||||
# 执行命令并捕获输出
|
||||
result = subprocess.run(command, capture_output=True, text=True, check=True, encoding='utf-8')
|
||||
output = result.stdout
|
||||
|
||||
# 使用新的正则表达式来匹配更新后的输出格式
|
||||
flops_match = re.search(r"Flops:\s*([0-9.]+\s*[TGMK]?)", output)
|
||||
params_match = re.search(r"Params:\s*([0-9.]+\s*[TGMK]?)", output)
|
||||
|
||||
if flops_match and params_match:
|
||||
raw_flops_str = flops_match.group(1).strip()
|
||||
params = params_match.group(1).strip()
|
||||
# --- 开始单位换算 ---
|
||||
value_str = raw_flops_str.rstrip('TGMKtgmk').strip()
|
||||
unit = raw_flops_str[-1].upper() if raw_flops_str[-1].isalpha() else 'G'
|
||||
try:
|
||||
value = float(value_str)
|
||||
if unit == 'T':
|
||||
value_in_g = value * 1000
|
||||
elif unit == 'M':
|
||||
value_in_g = value / 1000
|
||||
elif unit == 'K':
|
||||
value_in_g = value / 1_000_000
|
||||
else: # 默认单位是 G
|
||||
value_in_g = value
|
||||
# 使用 :g 格式化可以去除末尾多余的0
|
||||
flops = f"{value_in_g:g} G"
|
||||
except ValueError:
|
||||
flops = raw_flops_str # 如果转换失败,则使用原始值
|
||||
# --- 单位换算结束 ---
|
||||
|
||||
logging.info(f"✅ 解析成功: FLOPs={flops} (原始值: {raw_flops_str}), Params={params}")
|
||||
return {'flops': flops, 'params': params}
|
||||
else:
|
||||
logging.warning(f"❌ 无法从命令输出中解析 FLOPs 或 Params。请检查以下输出内容:\n---\n{output}\n---")
|
||||
return None
|
||||
|
||||
except subprocess.CalledProcessError as e:
|
||||
logging.error(f"执行 get_flops.py 时出错。返回码: {e.returncode}")
|
||||
logging.error(f"错误输出 (stderr):\n---\n{e.stderr}\n---")
|
||||
return None
|
||||
except FileNotFoundError:
|
||||
logging.error("错误: 'python' 命令未找到。请确保 Python 环境已正确配置。")
|
||||
return None
|
||||
|
||||
# --- 主函数 ---
|
||||
def main(args):
|
||||
"""
|
||||
脚本主入口,负责编排整个自动化分析流程。
|
||||
"""
|
||||
input_root = args.input_dir
|
||||
output_root = args.output_dir
|
||||
# --- 开始交互式选择修改 (V2 - 两级菜单) ---
|
||||
if not os.path.isdir(input_root):
|
||||
logging.error(f"输入目录不存在: {input_root}")
|
||||
return
|
||||
|
||||
# 1. 定义有效的数据集文件夹白名单
|
||||
VALID_DATASET_FOLDERS = [
|
||||
'1_CholecSeg8k-13Type-1920x1080_outputs-MMSeg',
|
||||
'2_AutoLaparo-10Type-1920x1080_outputs-MMSeg',
|
||||
'3_1_Endovis_2017-8Type-512x512_outputs-MMSeg',
|
||||
'3_2_Endovis_2018-8Type-512x512_outputs-MMSeg',
|
||||
'4_Dresden-11Type-512x512_outputs-MMSeg'
|
||||
]
|
||||
|
||||
# 2. 查找存在的、有效的数据集目录
|
||||
existing_dataset_dirs = [
|
||||
os.path.join(input_root, d) for d in VALID_DATASET_FOLDERS
|
||||
if os.path.isdir(os.path.join(input_root, d))
|
||||
]
|
||||
|
||||
if not existing_dataset_dirs:
|
||||
logging.error(f"在输入目录 {input_root} 中未找到任何有效的数据集文件夹。")
|
||||
return
|
||||
|
||||
# 3. 第一级菜单:选择数据集
|
||||
dataset_map = {str(i + 1): path for i, path in enumerate(existing_dataset_dirs)}
|
||||
|
||||
print("\n" + "="*50)
|
||||
print("--- 步骤 1: 请选择要处理的数据集 ---")
|
||||
for key, path in dataset_map.items():
|
||||
print(f"{key}: {os.path.basename(path)}")
|
||||
print("="*50)
|
||||
|
||||
choice1 = input("请输入数据集编号并按回车键: ").strip()
|
||||
|
||||
model_dirs = [] # 初始化最终要处理的目录列表
|
||||
|
||||
if choice1 in dataset_map:
|
||||
selected_dataset_dir = dataset_map[choice1]
|
||||
logging.info(f"您已选择数据集: [{os.path.basename(selected_dataset_dir)}]")
|
||||
|
||||
# 4. 查找选定数据集下的所有算法子目录
|
||||
alg_dirs = sorted([
|
||||
d for d in glob.glob(os.path.join(selected_dataset_dir, '*')) if os.path.isdir(d)
|
||||
])
|
||||
|
||||
if not alg_dirs:
|
||||
logging.warning(f"在 {os.path.basename(selected_dataset_dir)} 中未发现任何算法子文件夹。程序退出。")
|
||||
else:
|
||||
# 5. 第二级菜单:选择算法
|
||||
alg_map = {str(i + 1): path for i, path in enumerate(alg_dirs)}
|
||||
print("\n" + "="*50)
|
||||
print("--- 步骤 2: 请选择要处理的算法 ---")
|
||||
print("0: 批量处理当前数据集下的【全部】算法")
|
||||
for key, path in alg_map.items():
|
||||
print(f"{key}: {os.path.basename(path)}")
|
||||
print("="*50)
|
||||
|
||||
choice2 = input("请输入算法编号 (或输入 '0' 处理全部) 并按回车键: ").strip()
|
||||
|
||||
# 6. 根据第二级选择,最终确定 model_dirs
|
||||
if choice2 == '0':
|
||||
model_dirs = alg_dirs
|
||||
logging.info(f"您选择了批量处理全部 {len(model_dirs)} 个算法。")
|
||||
elif choice2 in alg_map:
|
||||
model_dirs = [alg_map[choice2]] # 将单个路径放入列表中
|
||||
logging.info(f"您选择了处理单个算法: {os.path.basename(model_dirs[0])}")
|
||||
else:
|
||||
logging.error("无效的算法选择,程序已退出。")
|
||||
else:
|
||||
logging.error("无效的数据集选择,程序已退出。")
|
||||
|
||||
# --- 交互式选择修改结束 ---
|
||||
results: List[Dict[str, str]] = []
|
||||
# 修改后的循环,将遍历经过用户筛选后的 model_dirs 列表
|
||||
for model_dir in model_dirs:
|
||||
model_name = os.path.basename(model_dir)
|
||||
logging.info(f"--- 开始处理模型: {model_name} ---")
|
||||
files = find_model_files(model_dir)
|
||||
if not files:
|
||||
logging.warning(f"跳过目录 {model_dir},因为缺少必要文件。")
|
||||
continue
|
||||
|
||||
# 构建输出目录
|
||||
# 从模型名中提取数据集标识作为Key
|
||||
dataset_key = model_name.split('-')[0]
|
||||
dataset_map = {
|
||||
'1_cholecseg8k': '1_CholecSeg8k-13Type-1920x1080_outputs-MMSeg',
|
||||
'2_autolaparo': '2_AutoLaparo-10Type-1280x1024_outputs-MMSeg',
|
||||
'3_1_endovis_2017': '3_1_EndoVis_2017-7Type-1280x1024_outputs-MMSeg',
|
||||
'3_2_endovis_2018': '3_2_EndoVis_2018-11Type-1280x1024_outputs-MMSeg',
|
||||
'4_dresden': '4_Dresden-6Type-1920x1080_outputs-MMSeg'
|
||||
}
|
||||
# 使用提取的Key(字符串)进行查询,并为默认值也使用该Key
|
||||
output_dataset_folder = dataset_map.get(dataset_key, f"{dataset_key}_outputs-MMSeg")
|
||||
|
||||
# 尝试从数据集文件夹名称中获取分辨率
|
||||
input_shape = get_shape_from_path(selected_dataset_dir)
|
||||
if not input_shape:
|
||||
# 如果无法自动提取,要求用户输入
|
||||
logging.warning(f"无法从文件夹 '{os.path.basename(selected_dataset_dir)}' 名称中自动检测分辨率。")
|
||||
try:
|
||||
h_str = input("请输入默认测试高度 (H),例如 512: ").strip()
|
||||
w_str = input("请输入默认测试宽度 (W),例如 512: ").strip()
|
||||
input_shape = (int(h_str), int(w_str))
|
||||
except ValueError:
|
||||
logging.error("输入无效,必须是整数。程序退出。")
|
||||
return
|
||||
logging.info(f"将使用输入形状 (H, W): {input_shape} 进行计算。")
|
||||
|
||||
# 加载配置
|
||||
config_file = find_model_config(model_dir)
|
||||
|
||||
# 获取 FLOPs 和 Params
|
||||
stats = get_flops_and_params(config_file, input_shape)
|
||||
if stats:
|
||||
short_model_name = model_name.split('Alg_', 1)[1]
|
||||
results.append({
|
||||
'Model': short_model_name,
|
||||
'Params': stats['params'],
|
||||
'FLOPs': stats['flops'],
|
||||
'Input_Shape (HxW)': f"{input_shape[0]}x{input_shape[1]}"
|
||||
})
|
||||
else:
|
||||
logging.warning(f"未能获取模型 {model_name} 的统计信息。")
|
||||
|
||||
# --- 将结果写入 CSV 文件 ---
|
||||
if not results:
|
||||
logging.info("没有成功获取任何模型的统计数据,不生成 CSV 文件。")
|
||||
return
|
||||
|
||||
# 新建文件夹并保存 CSV
|
||||
final_output_dir = os.path.join(output_root, output_dataset_folder)
|
||||
os.makedirs(final_output_dir, exist_ok=True)
|
||||
dataset_name = os.path.basename(selected_dataset_dir).split('_outputs-MMSeg')[0]
|
||||
output_csv_path = os.path.join(final_output_dir, f'{dataset_name}_flops_params_summary.csv')
|
||||
|
||||
try:
|
||||
with open(output_csv_path, 'w', newline='', encoding='utf-8') as csvfile:
|
||||
fieldnames = ['Model', 'Params', 'FLOPs', 'Input_Shape (HxW)']
|
||||
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
|
||||
|
||||
writer.writeheader()
|
||||
writer.writerows(results)
|
||||
|
||||
logging.info(f"=== 全部处理完成!结果已成功保存到: {output_csv_path} ===")
|
||||
except IOError as e:
|
||||
logging.error(f"无法写入 CSV 文件: {output_csv_path}。错误: {e}")
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description="MMSegmentation 自动化评估脚本")
|
||||
parser.add_argument(
|
||||
'--input_dir',
|
||||
type=str,
|
||||
default='../Hardisk',
|
||||
help="包含已训练模型文件夹的根目录。"
|
||||
)
|
||||
parser.add_argument(
|
||||
'--output_dir',
|
||||
type=str,
|
||||
default='../BestMode_Predict_Results_DataSet_Public',
|
||||
help="用于存储所有分析结果的根目录。"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
@@ -0,0 +1,322 @@
|
||||
import os
|
||||
import glob
|
||||
import logging
|
||||
import argparse
|
||||
import re
|
||||
import csv
|
||||
from typing import Dict, Optional, List
|
||||
|
||||
# TODO 这个是获取最后一次结果的 TODO
|
||||
|
||||
# --- 配置日志记录 ---
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
|
||||
# --- 辅助函数 ---
|
||||
def find_all_log_files_sorted(algorithm_dir: str) -> List[str]:
|
||||
"""
|
||||
查找给定算法目录中所有的.log文件,并按从新到旧的顺序排列。
|
||||
|
||||
Args:
|
||||
algorithm_dir (str): 算法的根目录。
|
||||
|
||||
Returns:
|
||||
List[str]: 按时间倒序排列的日志文件路径列表。
|
||||
"""
|
||||
try:
|
||||
subdirs = [d for d in os.listdir(algorithm_dir) if os.path.isdir(os.path.join(algorithm_dir, d))]
|
||||
except FileNotFoundError:
|
||||
logging.error(f"算法目录不存在: {algorithm_dir}")
|
||||
return []
|
||||
|
||||
if not subdirs:
|
||||
logging.warning(f"在目录 {algorithm_dir} 中未找到任何时间戳子目录。")
|
||||
return []
|
||||
|
||||
# 按名称倒序排序,最新的目录会排在最前面
|
||||
sorted_subdirs = sorted(subdirs, reverse=True)
|
||||
|
||||
log_files = []
|
||||
for subdir_name in sorted_subdirs:
|
||||
subdir_path = os.path.join(algorithm_dir, subdir_name)
|
||||
logs_in_subdir = glob.glob(os.path.join(subdir_path, '*.log'))
|
||||
if logs_in_subdir:
|
||||
# 假设每个子目录只有一个log文件
|
||||
log_files.append(logs_in_subdir[0])
|
||||
|
||||
return log_files
|
||||
|
||||
def parse_log_metrics(log_path: str) -> Optional[Dict]:
|
||||
"""
|
||||
解析日志文件,提取最后一次完整验证(validation)的结果及其对应的Epoch。
|
||||
|
||||
Args:
|
||||
log_path (str): 日志文件的路径。
|
||||
|
||||
Returns:
|
||||
Optional[Dict]: 包含 'epoch', 'summary', 'class_wise' 指标的字典,如果解析失败则返回 None。
|
||||
"""
|
||||
try:
|
||||
with open(log_path, 'r', encoding='utf-8') as f:
|
||||
content = f.read()
|
||||
except IOError as e:
|
||||
logging.error(f"无法读取日志文件: {log_path}。错误: {e}")
|
||||
return None
|
||||
# 定义正则表达式
|
||||
summary_pattern = re.compile(
|
||||
r"Iter\(val\) \[\d+/\d+\]\s+aAcc:\s*([\d.]+)\s+mIoU:\s*([\d.]+)\s+mAcc:\s*([\d.]+)"
|
||||
)
|
||||
class_table_pattern = re.compile(
|
||||
# --- 使用新的模式匹配可变长度的顶部边框 ---
|
||||
r"\+(?:-+\+)+\s*\n"
|
||||
r"\|.*?Class.*?\|.*?IoU.*?\|.*?Acc.*?\|\n"
|
||||
# --- 匹配中间边框 ---
|
||||
r"\+(?:-+\+)+\s*\n"
|
||||
# --- 捕获表格主体 ---
|
||||
r"((?:\|.*?\|.*?\|.*?\|\n)+)"
|
||||
# --- 匹配底部边框 ---
|
||||
r"\+(?:-+\+)+\s*\n",
|
||||
re.MULTILINE
|
||||
)
|
||||
epoch_pattern = re.compile(r"Saving checkpoint at (\d+) epochs|resumed epoch: (\d+)")
|
||||
|
||||
# 查找所有匹配项
|
||||
summary_matches = list(re.finditer(summary_pattern, content))
|
||||
table_matches = list(re.finditer(class_table_pattern, content))
|
||||
epoch_matches = list(re.finditer(epoch_pattern, content))
|
||||
|
||||
if not summary_matches or not table_matches:
|
||||
logging.warning(f"❌ 在日志 {os.path.basename(log_path)} 中未能找到完整的验证结果。")
|
||||
return None
|
||||
|
||||
last_summary_match = summary_matches[-1]
|
||||
last_table_match = None
|
||||
for table in reversed(table_matches):
|
||||
if table.end() < last_summary_match.start():
|
||||
last_table_match = table
|
||||
break
|
||||
|
||||
if not last_table_match:
|
||||
logging.warning(f"❌ 在日志 {os.path.basename(log_path)} 中找到总结行但未能匹配到对应的类别表格。")
|
||||
return None
|
||||
|
||||
# 寻找关联的最新Epoch
|
||||
last_epoch = "N/A"
|
||||
latest_epoch_num = -1
|
||||
for epoch_match in epoch_matches:
|
||||
if epoch_match.end() < last_table_match.start():
|
||||
epoch_str = epoch_match.group(1) or epoch_match.group(2)
|
||||
if epoch_str:
|
||||
epoch_num = int(epoch_str)
|
||||
if epoch_num > latest_epoch_num:
|
||||
latest_epoch_num = epoch_num
|
||||
last_epoch = f"epoch_{epoch_num}"
|
||||
|
||||
# 解析数据
|
||||
summary_groups = last_summary_match.groups()
|
||||
results = {
|
||||
'epoch': last_epoch,
|
||||
'summary': {
|
||||
'aAcc': summary_groups[0],
|
||||
'mIoU': summary_groups[1],
|
||||
'mAcc': summary_groups[2]
|
||||
},
|
||||
'class_wise': []
|
||||
}
|
||||
|
||||
table_content = last_table_match.group(0)
|
||||
row_pattern = re.compile(r"\|\s*([\w\s]+?)\s*\|\s*([\d.]+)\s*\|\s*([\d.]+)\s*\|")
|
||||
for line in table_content.strip().split('\n'):
|
||||
row_match = row_pattern.match(line)
|
||||
if row_match:
|
||||
class_name, iou, acc = row_match.groups()
|
||||
results['class_wise'].append({
|
||||
'Class': class_name.strip(),
|
||||
'IoU': iou,
|
||||
'Acc': acc
|
||||
})
|
||||
|
||||
if results['class_wise']:
|
||||
logging.info(f"✅ 成功从 {os.path.basename(log_path)} 中解析出 Epoch '{last_epoch}' 的指标。")
|
||||
return results
|
||||
else:
|
||||
logging.warning(f"❌ 在 {os.path.basename(log_path)} 中未能解析出任何类别行。")
|
||||
return None
|
||||
|
||||
# --- 主函数 ---
|
||||
def main(args):
|
||||
"""
|
||||
脚本主入口,负责编排整个自动化分析流程。
|
||||
"""
|
||||
input_root = args.input_dir
|
||||
output_root = args.output_dir
|
||||
|
||||
if not os.path.isdir(input_root):
|
||||
logging.error(f"输入目录不存在: {input_root}")
|
||||
return
|
||||
|
||||
# --- 交互式菜单 ---
|
||||
all_dataset_dirs = sorted([
|
||||
d for d in glob.glob(os.path.join(input_root, '*_outputs-MMSeg')) if os.path.isdir(d)
|
||||
])
|
||||
|
||||
if not all_dataset_dirs:
|
||||
logging.error(f"在输入目录 {input_root} 中未找到任何有效的数据集文件夹。")
|
||||
return
|
||||
|
||||
dataset_map = {str(i + 1): path for i, path in enumerate(all_dataset_dirs)}
|
||||
|
||||
print("\n" + "="*50)
|
||||
print("--- 步骤 1: 请选择要处理的数据集 ---")
|
||||
for key, path in dataset_map.items():
|
||||
print(f"{key}: {os.path.basename(path)}")
|
||||
print("="*50)
|
||||
|
||||
choice1 = input("请输入数据集编号并按回车键: ").strip()
|
||||
|
||||
model_dirs = []
|
||||
selected_dataset_dir = None
|
||||
|
||||
if choice1 in dataset_map:
|
||||
selected_dataset_dir = dataset_map[choice1]
|
||||
logging.info(f"您已选择数据集: [{os.path.basename(selected_dataset_dir)}]")
|
||||
|
||||
alg_dirs = sorted([
|
||||
d for d in glob.glob(os.path.join(selected_dataset_dir, '*')) if os.path.isdir(d)
|
||||
])
|
||||
|
||||
if not alg_dirs:
|
||||
logging.warning(f"在 {os.path.basename(selected_dataset_dir)} 中未发现任何算法子文件夹。")
|
||||
return
|
||||
|
||||
alg_map = {str(i + 1): path for i, path in enumerate(alg_dirs)}
|
||||
print("\n" + "="*50)
|
||||
print("--- 步骤 2: 请选择要处理的算法 ---")
|
||||
print("0: 批量处理当前数据集下的【全部】算法")
|
||||
for key, path in alg_map.items():
|
||||
print(f"{key}: {os.path.basename(path)}")
|
||||
print("="*50)
|
||||
|
||||
choice2 = input("请输入算法编号 (或输入 '0' 处理全部) 并按回车键: ").strip()
|
||||
|
||||
if choice2 == '0':
|
||||
model_dirs = alg_dirs
|
||||
logging.info(f"您选择了批量处理全部 {len(model_dirs)} 个算法。")
|
||||
elif choice2 in alg_map:
|
||||
model_dirs = [alg_map[choice2]]
|
||||
logging.info(f"您选择了处理单个算法: {os.path.basename(model_dirs[0])}")
|
||||
else:
|
||||
logging.error("无效的算法选择,程序已退出。")
|
||||
return
|
||||
else:
|
||||
logging.error("无效的数据集选择,程序已退出。")
|
||||
return
|
||||
|
||||
# --- 开始处理选定的算法 (逻辑已修改) ---
|
||||
csv_rows = []
|
||||
output_dataset_folder = ""
|
||||
|
||||
for model_dir in model_dirs:
|
||||
model_name = os.path.basename(model_dir)
|
||||
logging.info(f"\n--- 开始处理算法: {model_name} ---")
|
||||
|
||||
# (路径构建代码保持不变)
|
||||
if not output_dataset_folder:
|
||||
dataset_key = model_name.split('-')[0]
|
||||
dataset_folder_map = {
|
||||
'1_cholecseg8k': '1_CholecSeg8k-13Type-1920x1080_outputs-MMSeg',
|
||||
'2_autolaparo': '2_AutoLaparo-10Type-1920x1080_outputs-MMSeg',
|
||||
'3_1_endovis_2017': '3_1_Endovis_2017-8Type-512x512_outputs-MMSeg',
|
||||
'3_2_endovis_2018': '3_2_Endovis_2018-8Type-512x512_outputs-MMSeg',
|
||||
'4_dresden': '4_Dresden-11Type-512x512_outputs-MMSeg'
|
||||
}
|
||||
output_dataset_folder = dataset_folder_map.get(dataset_key, f"{dataset_key}_outputs-MMSeg")
|
||||
|
||||
# --- 新的循环查找逻辑 ---
|
||||
# 1. 获取所有按时间倒序排列的日志文件
|
||||
all_logs_sorted = find_all_log_files_sorted(model_dir)
|
||||
|
||||
if not all_logs_sorted:
|
||||
logging.warning(f"跳过算法 {model_name},因为未找到任何日志文件。")
|
||||
continue
|
||||
|
||||
# 2. 循环尝试解析,直到成功或全部失败
|
||||
metrics = None
|
||||
for log_file_path in all_logs_sorted:
|
||||
logging.info(f"正在尝试解析: {os.path.relpath(log_file_path)}")
|
||||
metrics = parse_log_metrics(log_file_path)
|
||||
if metrics:
|
||||
logging.info(f"在 {os.path.basename(log_file_path)} 中成功找到并解析了指标。")
|
||||
break # 找到后立即跳出循环
|
||||
|
||||
# 3. 如果所有日志都尝试失败,则跳过此算法
|
||||
if not metrics:
|
||||
logging.warning(f"❌❌❌跳过算法 {model_name},因为在其所有日志文件中都未能找到有效的指标。❌❌❌")
|
||||
continue
|
||||
|
||||
# --- 创建一个 "宽" 格式的行 ---
|
||||
summary = metrics['summary']
|
||||
short_model_name = model_name.split('Alg_', 1)[1]
|
||||
row_data = {
|
||||
'Algorithm': short_model_name,
|
||||
'Epoch': metrics['epoch'],
|
||||
'mIoU': summary['mIoU'],
|
||||
'mAcc': summary['mAcc'],
|
||||
'aAcc': summary['aAcc']
|
||||
}
|
||||
|
||||
# 将每个类别的IoU和Acc作为新列添加到行数据中
|
||||
for class_data in metrics['class_wise']:
|
||||
class_name = class_data['Class'].replace(' ', '_') # 清理类名以用作表头
|
||||
row_data[f'{class_name}_IoU'] = class_data['IoU']
|
||||
row_data[f'{class_name}_Acc'] = class_data['Acc']
|
||||
|
||||
csv_rows.append(row_data)
|
||||
|
||||
# --- 将结果写入 CSV 文件 (逻辑已修改) ---
|
||||
if not csv_rows:
|
||||
logging.info("没有成功获取任何模型的统计数据,不生成 CSV 文件。")
|
||||
return
|
||||
|
||||
# --- 动态生成并排序表头 ---
|
||||
# 基础列保持固定顺序
|
||||
base_fieldnames = ['Algorithm', 'Epoch', 'mIoU', 'mAcc', 'aAcc']
|
||||
# 从第一个结果中获取所有与类别相关的列名,并按字母排序
|
||||
first_row_keys = csv_rows[0].keys()
|
||||
class_fieldnames = sorted([key for key in first_row_keys if key not in base_fieldnames])
|
||||
# 最终的完整表头
|
||||
final_fieldnames = base_fieldnames + class_fieldnames
|
||||
|
||||
# --- 构建输出路径并写入文件 ---
|
||||
final_output_dir = os.path.join(output_root, output_dataset_folder)
|
||||
os.makedirs(final_output_dir, exist_ok=True)
|
||||
dataset_name = os.path.basename(selected_dataset_dir).split('_outputs-MMSeg')[0]
|
||||
# 在文件名中加入 "_wide" 以区分格式
|
||||
output_csv_path = os.path.join(final_output_dir, f'{dataset_name}_metrics_summary_wide.csv')
|
||||
|
||||
try:
|
||||
with open(output_csv_path, 'w', newline='', encoding='utf-8-sig') as csvfile:
|
||||
writer = csv.DictWriter(csvfile, fieldnames=final_fieldnames)
|
||||
writer.writeheader()
|
||||
writer.writerows(csv_rows)
|
||||
|
||||
logging.info(f"=== 全部处理完成!结果已成功保存到: {output_csv_path} ===")
|
||||
except IOError as e:
|
||||
logging.error(f"无法写入 CSV 文件: {output_csv_path}。错误: {e}")
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description="MMSegmentation 最终指标提取脚本 (V2)")
|
||||
parser.add_argument(
|
||||
'--input_dir',
|
||||
type=str,
|
||||
default='../Hardisk',
|
||||
help="包含数据集输出文件夹 (例如 '..._outputs-MMSeg') 的根目录。"
|
||||
)
|
||||
parser.add_argument(
|
||||
'--output_dir',
|
||||
type=str,
|
||||
default='../BestMode_Predict_Results_DataSet_Public',
|
||||
help="用于存储所有分析结果的根目录。"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
@@ -0,0 +1,366 @@
|
||||
import os
|
||||
import glob
|
||||
import logging
|
||||
import argparse
|
||||
import re
|
||||
import csv
|
||||
from typing import Dict, Optional, List
|
||||
|
||||
# TODO 这个是获取最后一次结果的 TODO
|
||||
|
||||
# --- 配置日志记录 ---
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
|
||||
# --- 辅助函数 ---
|
||||
def find_all_log_files_sorted(algorithm_dir: str) -> List[str]:
|
||||
"""
|
||||
查找给定算法目录中所有的.log文件,并按从新到旧的顺序排列。
|
||||
|
||||
Args:
|
||||
algorithm_dir (str): 算法的根目录。
|
||||
|
||||
Returns:
|
||||
List[str]: 按时间倒序排列的日志文件路径列表。
|
||||
"""
|
||||
try:
|
||||
subdirs = [d for d in os.listdir(algorithm_dir) if os.path.isdir(os.path.join(algorithm_dir, d))]
|
||||
except FileNotFoundError:
|
||||
logging.error(f"算法目录不存在: {algorithm_dir}")
|
||||
return []
|
||||
|
||||
if not subdirs:
|
||||
logging.warning(f"在目录 {algorithm_dir} 中未找到任何时间戳子目录。")
|
||||
return []
|
||||
|
||||
# 按名称倒序排序,最新的目录会排在最前面
|
||||
sorted_subdirs = sorted(subdirs, reverse=True)
|
||||
|
||||
log_files = []
|
||||
for subdir_name in sorted_subdirs:
|
||||
subdir_path = os.path.join(algorithm_dir, subdir_name)
|
||||
logs_in_subdir = glob.glob(os.path.join(subdir_path, '*.log'))
|
||||
if logs_in_subdir:
|
||||
# 假设每个子目录只有一个log文件
|
||||
log_files.append(logs_in_subdir[0])
|
||||
|
||||
return log_files
|
||||
|
||||
def get_max_epochs(config_path: str) -> Optional[int]:
|
||||
"""
|
||||
从config.py文件中解析train_cfg字典以获取max_epochs的值。
|
||||
|
||||
Args:
|
||||
config_path (str): config.py文件的路径。
|
||||
|
||||
Returns:
|
||||
Optional[int]: max_epochs的值,如果找不到则返回None。
|
||||
"""
|
||||
if not os.path.exists(config_path):
|
||||
logging.error(f"配置文件不存在: {config_path}")
|
||||
return None
|
||||
|
||||
try:
|
||||
with open(config_path, 'r', encoding='utf-8') as f:
|
||||
content = f.read()
|
||||
|
||||
# 使用正则表达式查找 max_epochs
|
||||
match = re.search(r"train_cfg\s*=\s*dict\(.*?max_epochs\s*=\s*(\d+),.*?\)", content, re.DOTALL)
|
||||
|
||||
if match:
|
||||
max_epochs = int(match.group(1))
|
||||
logging.info(f"从 {os.path.basename(config_path)} 中成功读取 max_epochs: {max_epochs}")
|
||||
return max_epochs
|
||||
else:
|
||||
logging.warning(f"在 {config_path} 中未找到 'max_epochs'。")
|
||||
return None
|
||||
except Exception as e:
|
||||
logging.error(f"解析 {config_path} 时出错: {e}")
|
||||
return None
|
||||
|
||||
def parse_log_metrics(log_path: str, max_epochs: int) -> List[Dict]:
|
||||
"""
|
||||
解析日志文件,提取所有完整验证(validation)的结果,并计算其对应的Epoch。
|
||||
|
||||
Args:
|
||||
log_path (str): 日志文件的路径。
|
||||
max_epochs (int): 从config.py中读取的最大epoch数。
|
||||
|
||||
Returns:
|
||||
List[Dict]: 包含每次验证的 'epoch', 'summary', 'class_wise' 指标的字典列表。
|
||||
"""
|
||||
try:
|
||||
with open(log_path, 'r', encoding='utf-8') as f:
|
||||
content = f.read()
|
||||
except IOError as e:
|
||||
logging.error(f"无法读取日志文件: {log_path}。错误: {e}")
|
||||
return []
|
||||
|
||||
# 定义所有需要的正则表达式
|
||||
summary_pattern = re.compile(
|
||||
r"Iter\(val\) \[\d+/\d+\]\s+aAcc:\s*([\d.]+)\s+mIoU:\s*([\d.]+)\s+mAcc:\s*([\d.]+)"
|
||||
)
|
||||
class_table_pattern = re.compile(
|
||||
r"\+(?:-+\+)+\s*\n"
|
||||
r"\|.*?Class.*?\|.*?IoU.*?\|.*?Acc.*?\|\n"
|
||||
r"\+(?:-+\+)+\s*\n"
|
||||
r"((?:\|.*?\|.*?\|.*?\|\n)+)"
|
||||
r"\+(?:-+\+)+\s*\n",
|
||||
re.MULTILINE
|
||||
)
|
||||
train_iter_pattern = re.compile(r"Iter\(train\)\s*\[\s*(\d+)\s*/\s*(\d+)\]")
|
||||
|
||||
# 查找所有匹配项
|
||||
summary_matches = list(re.finditer(summary_pattern, content))
|
||||
table_matches = list(re.finditer(class_table_pattern, content))
|
||||
train_iter_matches = list(re.finditer(train_iter_pattern, content))
|
||||
|
||||
all_metrics = []
|
||||
|
||||
# 遍历每一次的总结行 (summary)
|
||||
for i, summary_match in enumerate(summary_matches):
|
||||
# 寻找与总结行对应的类别表格
|
||||
# 表格应该出现在总结行之前
|
||||
last_table_match = None
|
||||
for table in reversed(table_matches):
|
||||
if table.end() < summary_match.start():
|
||||
# 确保这个表格没有被上一个总结行用过
|
||||
is_already_used = False
|
||||
if i > 0:
|
||||
if table.end() < summary_matches[i-1].start():
|
||||
is_already_used = True
|
||||
if not is_already_used:
|
||||
last_table_match = table
|
||||
break
|
||||
|
||||
if not last_table_match:
|
||||
continue
|
||||
|
||||
# 寻找表格前最近的 Iter(train) 行来计算epoch
|
||||
last_train_iter_match = None
|
||||
for train_iter in reversed(train_iter_matches):
|
||||
if train_iter.end() < last_table_match.start():
|
||||
last_train_iter_match = train_iter
|
||||
break
|
||||
|
||||
epoch = "N/A"
|
||||
if last_train_iter_match and max_epochs:
|
||||
current_iter, total_iters = last_train_iter_match.groups()
|
||||
try:
|
||||
# 根据公式计算epoch
|
||||
epoch = int(int(current_iter) / int(total_iters) * max_epochs)
|
||||
except (ValueError, ZeroDivisionError) as e:
|
||||
logging.warning(f"Epoch 计算失败: {e}")
|
||||
|
||||
# 解析总结指标
|
||||
summary_groups = summary_match.groups()
|
||||
results = {
|
||||
'epoch': epoch,
|
||||
'summary': {
|
||||
'aAcc': summary_groups[0],
|
||||
'mIoU': summary_groups[1],
|
||||
'mAcc': summary_groups[2]
|
||||
},
|
||||
'class_wise': []
|
||||
}
|
||||
|
||||
# 解析每个类别的数据
|
||||
table_content = last_table_match.group(1)
|
||||
row_pattern = re.compile(r"\|\s*([\w\s.-]+?)\s*\|\s*([\d.]+)\s*\|\s*([\d.]+)\s*\|")
|
||||
for line in table_content.strip().split('\n'):
|
||||
row_match = row_pattern.match(line)
|
||||
if row_match:
|
||||
class_name, iou, acc = row_match.groups()
|
||||
results['class_wise'].append({
|
||||
'Class': class_name.strip(),
|
||||
'IoU': iou,
|
||||
'Acc': acc
|
||||
})
|
||||
|
||||
if results['class_wise']:
|
||||
all_metrics.append(results)
|
||||
|
||||
if all_metrics:
|
||||
logging.info(f"✅ 成功从 {os.path.basename(log_path)} 中解析出 {len(all_metrics)} 组指标。")
|
||||
else:
|
||||
logging.warning(f"❌ 在日志 {os.path.basename(log_path)} 中未能找到完整的验证结果。")
|
||||
|
||||
return all_metrics
|
||||
|
||||
# --- 主函数 ---
|
||||
def main(args):
|
||||
"""
|
||||
脚本主入口,负责编排整个自动化分析流程。
|
||||
"""
|
||||
input_root = args.input_dir
|
||||
output_root = args.output_dir
|
||||
|
||||
if not os.path.isdir(input_root):
|
||||
logging.error(f"输入目录不存在: {input_root}")
|
||||
return
|
||||
|
||||
# --- 交互式菜单 (这部分保持不变) ---
|
||||
all_dataset_dirs = sorted([
|
||||
d for d in glob.glob(os.path.join(input_root, '*_outputs-MMSeg')) if os.path.isdir(d)
|
||||
])
|
||||
|
||||
if not all_dataset_dirs:
|
||||
logging.error(f"在输入目录 {input_root} 中未找到任何有效的数据集文件夹。")
|
||||
return
|
||||
|
||||
dataset_map = {str(i + 1): path for i, path in enumerate(all_dataset_dirs)}
|
||||
|
||||
print("\n" + "="*50)
|
||||
print("--- 步骤 1: 请选择要处理的数据集 ---")
|
||||
for key, path in dataset_map.items():
|
||||
print(f"{key}: {os.path.basename(path)}")
|
||||
print("="*50)
|
||||
|
||||
choice1 = input("请输入数据集编号并按回车键: ").strip()
|
||||
|
||||
model_dirs = []
|
||||
selected_dataset_dir = None
|
||||
|
||||
if choice1 in dataset_map:
|
||||
selected_dataset_dir = dataset_map[choice1]
|
||||
logging.info(f"您已选择数据集: [{os.path.basename(selected_dataset_dir)}]")
|
||||
|
||||
alg_dirs = sorted([
|
||||
d for d in glob.glob(os.path.join(selected_dataset_dir, '*')) if os.path.isdir(d)
|
||||
])
|
||||
|
||||
if not alg_dirs:
|
||||
logging.warning(f"在 {os.path.basename(selected_dataset_dir)} 中未发现任何算法子文件夹。")
|
||||
return
|
||||
|
||||
alg_map = {str(i + 1): path for i, path in enumerate(alg_dirs)}
|
||||
print("\n" + "="*50)
|
||||
print("--- 步骤 2: 请选择要处理的算法 ---")
|
||||
print("0: 批量处理当前数据集下的【全部】算法")
|
||||
for key, path in alg_map.items():
|
||||
print(f"{key}: {os.path.basename(path)}")
|
||||
print("="*50)
|
||||
|
||||
choice2 = input("请输入算法编号 (或输入 '0' 处理全部) 并按回车键: ").strip()
|
||||
|
||||
if choice2 == '0':
|
||||
model_dirs = alg_dirs
|
||||
logging.info(f"您选择了批量处理全部 {len(model_dirs)} 个算法。")
|
||||
elif choice2 in alg_map:
|
||||
model_dirs = [alg_map[choice2]]
|
||||
logging.info(f"您选择了处理单个算法: {os.path.basename(model_dirs[0])}")
|
||||
else:
|
||||
logging.error("无效的算法选择,程序已退出。")
|
||||
return
|
||||
else:
|
||||
logging.error("无效的数据集选择,程序已退出。")
|
||||
return
|
||||
|
||||
# --- 开始处理选定的算法 ---
|
||||
csv_rows = []
|
||||
output_dataset_folder = ""
|
||||
|
||||
for model_dir in model_dirs:
|
||||
model_name = os.path.basename(model_dir)
|
||||
logging.info(f"\n--- 开始处理算法: {model_name} ---")
|
||||
|
||||
if not output_dataset_folder:
|
||||
dataset_key = model_name.split('-')[0]
|
||||
# ... (这部分路径映射逻辑保持不变)
|
||||
output_dataset_folder = f"{dataset_key}_outputs-MMSeg"
|
||||
|
||||
all_logs_sorted = find_all_log_files_sorted(model_dir)
|
||||
|
||||
if not all_logs_sorted:
|
||||
logging.warning(f"跳过算法 {model_name},因为未找到任何日志文件。")
|
||||
continue
|
||||
|
||||
# --- 新逻辑:获取max_epochs ---
|
||||
# 假设同一算法下所有训练的config是相同的,因此我们从最新的log对应的config读取
|
||||
latest_log_path = all_logs_sorted[0]
|
||||
config_path = os.path.join(os.path.dirname(latest_log_path), 'vis_data', 'config.py')
|
||||
max_epochs = get_max_epochs(config_path)
|
||||
if max_epochs is None:
|
||||
logging.error(f"无法为算法 {model_name} 找到 max_epochs,将跳过。")
|
||||
continue
|
||||
|
||||
# --- 新逻辑:聚合所有日志的所有指标 ---
|
||||
all_metrics_for_model = []
|
||||
for log_file_path in all_logs_sorted:
|
||||
logging.info(f"正在解析: {os.path.relpath(log_file_path)}")
|
||||
# 传递max_epochs
|
||||
metrics_from_log = parse_log_metrics(log_file_path, max_epochs)
|
||||
if metrics_from_log:
|
||||
all_metrics_for_model.extend(metrics_from_log)
|
||||
|
||||
if not all_metrics_for_model:
|
||||
logging.warning(f"❌❌❌跳过算法 {model_name},因为在其所有日志文件中都未能找到有效的指标。❌❌❌")
|
||||
continue
|
||||
|
||||
# --- 新逻辑:选择mIoU最高的记录 ---
|
||||
try:
|
||||
best_metric = max(all_metrics_for_model, key=lambda x: float(x['summary']['mIoU']))
|
||||
logging.info(f"找到了最佳指标: Epoch '{best_metric['epoch']}', mIoU: {best_metric['summary']['mIoU']}")
|
||||
except (ValueError, TypeError) as e:
|
||||
logging.error(f"为算法 {model_name} 寻找最佳mIoU时出错: {e}")
|
||||
continue
|
||||
|
||||
# --- 创建一个 "宽" 格式的行 ---
|
||||
summary = best_metric['summary']
|
||||
short_model_name = model_name.split('Alg_', 1)[1] if 'Alg_' in model_name else model_name
|
||||
row_data = {
|
||||
'Algorithm': short_model_name,
|
||||
'Epoch': best_metric['epoch'],
|
||||
'mIoU': summary['mIoU'],
|
||||
'mAcc': summary['mAcc'],
|
||||
'aAcc': summary['aAcc']
|
||||
}
|
||||
|
||||
for class_data in best_metric['class_wise']:
|
||||
class_name = class_data['Class'].replace(' ', '_')
|
||||
row_data[f'{class_name}_IoU'] = class_data['IoU']
|
||||
row_data[f'{class_name}_Acc'] = class_data['Acc']
|
||||
|
||||
csv_rows.append(row_data)
|
||||
|
||||
# --- 将结果写入 CSV 文件 (这部分保持不变) ---
|
||||
if not csv_rows:
|
||||
logging.info("没有成功获取任何模型的统计数据,不生成 CSV 文件。")
|
||||
return
|
||||
|
||||
base_fieldnames = ['Algorithm', 'Epoch', 'mIoU', 'mAcc', 'aAcc']
|
||||
first_row_keys = csv_rows[0].keys()
|
||||
class_fieldnames = sorted([key for key in first_row_keys if key not in base_fieldnames])
|
||||
final_fieldnames = base_fieldnames + class_fieldnames
|
||||
|
||||
final_output_dir = os.path.join(output_root, os.path.basename(selected_dataset_dir))
|
||||
os.makedirs(final_output_dir, exist_ok=True)
|
||||
dataset_name = os.path.basename(selected_dataset_dir).split('_outputs-MMSeg')[0]
|
||||
output_csv_path = os.path.join(final_output_dir, f'{dataset_name}_metrics_summary_wide.csv')
|
||||
|
||||
try:
|
||||
with open(output_csv_path, 'w', newline='', encoding='utf-8-sig') as csvfile:
|
||||
writer = csv.DictWriter(csvfile, fieldnames=final_fieldnames)
|
||||
writer.writeheader()
|
||||
writer.writerows(csv_rows)
|
||||
|
||||
logging.info(f"\n=== 全部处理完成!最佳结果已成功保存到: {output_csv_path} ===")
|
||||
except IOError as e:
|
||||
logging.error(f"无法写入 CSV 文件: {output_csv_path}。错误: {e}")
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description="MMSegmentation 最终指标提取脚本 (V2)")
|
||||
parser.add_argument(
|
||||
'--input_dir',
|
||||
type=str,
|
||||
default='../Hardisk',
|
||||
help="包含数据集输出文件夹 (例如 '..._outputs-MMSeg') 的根目录。"
|
||||
)
|
||||
parser.add_argument(
|
||||
'--output_dir',
|
||||
type=str,
|
||||
default='../BestMode_Predict_Results_DataSet_Public',
|
||||
help="用于存储所有分析结果的根目录。"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
@@ -0,0 +1,249 @@
|
||||
import os
|
||||
import glob
|
||||
import pandas as pd
|
||||
import matplotlib.pyplot as plt
|
||||
import seaborn as sns
|
||||
import re
|
||||
|
||||
def get_model_family(model_name):
|
||||
"""
|
||||
根据模型名称提取模型族。
|
||||
例如: 'my_bisenetv1_r50' -> 'my_bisenetv1'
|
||||
'my_fast_scnn' -> 'my_fast_scnn'
|
||||
"""
|
||||
# 使用正则表达式匹配,将 _rXX 或 _dXX 等后缀去掉
|
||||
match = re.match(r'^(.*?)_r\d+$', model_name)
|
||||
if match:
|
||||
return match.group(1)
|
||||
return model_name
|
||||
|
||||
def select_dataset(results_dir):
|
||||
"""
|
||||
扫描目录,让用户交互式选择一个数据集。
|
||||
"""
|
||||
print("正在扫描可用的数据集...")
|
||||
try:
|
||||
all_dataset_dirs = sorted([
|
||||
d for d in glob.glob(os.path.join(results_dir, '*_outputs-MMSeg')) if os.path.isdir(d)
|
||||
])
|
||||
except Exception as e:
|
||||
print(f"扫描目录 '{results_dir}' 时出错: {e}")
|
||||
return None, None
|
||||
|
||||
if not all_dataset_dirs:
|
||||
print(f"在 '{results_dir}' 中未找到任何数据集目录 (以 '_outputs-MMSeg' 结尾)。")
|
||||
print("请确保脚本与 'BestMode_Predict_Results_DataSet_Public' 文件夹在同一级目录下。")
|
||||
return None, None
|
||||
|
||||
print("\n请选择要可视化的数据集:")
|
||||
for i, dir_path in enumerate(all_dataset_dirs):
|
||||
dataset_name = os.path.basename(dir_path).replace('_outputs-MMSeg', '')
|
||||
print(f" [{i+1}] {dataset_name}")
|
||||
|
||||
while True:
|
||||
try:
|
||||
choice = input(f"\n请输入选项编号 (1-{len(all_dataset_dirs)}): ")
|
||||
choice_idx = int(choice) - 1
|
||||
if 0 <= choice_idx < len(all_dataset_dirs):
|
||||
selected_dir = all_dataset_dirs[choice_idx]
|
||||
dataset_name = os.path.basename(selected_dir).replace('_outputs-MMSeg', '')
|
||||
return selected_dir, dataset_name
|
||||
else:
|
||||
print("无效的选项,请输入列表中的编号。")
|
||||
except (ValueError, IndexError):
|
||||
print("无效的输入,请输入一个数字编号。")
|
||||
except (KeyboardInterrupt, EOFError):
|
||||
print("\n操作已取消。")
|
||||
return None, None
|
||||
|
||||
def plot_performance_speed(selected_dir, dataset_name):
|
||||
"""
|
||||
根据选定的数据集目录,加载数据并生成图表。
|
||||
"""
|
||||
print(f"\n正在为数据集 '{dataset_name}' 生成图表...")
|
||||
|
||||
# 构建文件路径
|
||||
metrics_file = os.path.join(selected_dir, f"{dataset_name}_metrics_summary_wide.csv")
|
||||
fps_file = os.path.join(selected_dir, f"{dataset_name}_flops_params_fps_summary.csv")
|
||||
|
||||
# 检查文件是否存在
|
||||
if not os.path.exists(metrics_file) or not os.path.exists(fps_file):
|
||||
print(f"错误: 在目录 '{selected_dir}' 中缺少所需的数据文件。")
|
||||
print(f" - 检查是否存在: {os.path.basename(metrics_file)}")
|
||||
print(f" - 检查是否存在: {os.path.basename(fps_file)}")
|
||||
return
|
||||
|
||||
# 加载数据
|
||||
try:
|
||||
metrics_df = pd.read_csv(metrics_file)
|
||||
# 只保留最新的epoch结果,避免重复
|
||||
metrics_df = metrics_df.sort_values('Epoch', ascending=False).drop_duplicates('Algorithm')
|
||||
|
||||
fps_df = pd.read_csv(fps_file)
|
||||
except FileNotFoundError as e:
|
||||
print(f"错误: 无法找到文件 {e.filename}")
|
||||
return
|
||||
except Exception as e:
|
||||
print(f"读取CSV文件时出错: {e}")
|
||||
return
|
||||
|
||||
# 合并两个DataFrame
|
||||
# metrics_df中的'Algorithm'列对应fps_df中的'Model'列
|
||||
merged_df = pd.merge(metrics_df, fps_df, left_on='Algorithm', right_on='Model')
|
||||
|
||||
if merged_df.empty:
|
||||
print("错误: 数据合并失败。请检查 'Algorithm' 和 'Model' 列中的模型名称是否匹配。")
|
||||
return
|
||||
|
||||
# 调用新函数来创建并保存摘要表格
|
||||
T1_create_and_save_summary_table(merged_df, selected_dir, dataset_name)
|
||||
|
||||
# 调用新函数来提取和保存所有IoU数据
|
||||
T2_extract_and_save_iou_data(metrics_df, selected_dir, dataset_name)
|
||||
|
||||
# 提取模型族
|
||||
merged_df['Family'] = merged_df['Model'].apply(get_model_family)
|
||||
|
||||
# --- 绘图 ---
|
||||
plt.style.use('seaborn-v0_8-whitegrid')
|
||||
fig, ax = plt.subplots(figsize=(16, 10))
|
||||
|
||||
# 定义颜色和标记
|
||||
families = sorted(merged_df['Family'].unique())
|
||||
palette = sns.color_palette("husl", len(families))
|
||||
markers = ['o', 's', 'X', 'D', '^', 'P', '*', 'v', '<', '>']
|
||||
|
||||
# 循环绘制每个模型族
|
||||
for i, family in enumerate(families):
|
||||
family_df = merged_df[merged_df['Family'] == family].sort_values('Average_FPS')
|
||||
color = palette[i]
|
||||
marker = markers[i % len(markers)]
|
||||
|
||||
# 绘制散点
|
||||
ax.scatter(family_df['Average_FPS'], family_df['mIoU'],
|
||||
color=color, marker=marker, s=150, label=family, zorder=3)
|
||||
|
||||
# 如果族内有多个模型,则用线连接
|
||||
if len(family_df) > 1:
|
||||
ax.plot(family_df['Average_FPS'], family_df['mIoU'],
|
||||
color=color, linestyle='--', linewidth=1.5, zorder=2)
|
||||
|
||||
# 在每个点旁边添加模型全名注释
|
||||
for j, row in family_df.iterrows():
|
||||
ax.text(row['Average_FPS'] * 1.01, row['mIoU'], row['Model'],
|
||||
fontsize=9, verticalalignment='center')
|
||||
|
||||
# 设置图表属性
|
||||
ax.set_title(f'Model Performance vs. Inference Speed ({dataset_name})', fontsize=18, pad=20)
|
||||
ax.set_xlabel('Inference Speed (FPS)', fontsize=14)
|
||||
ax.set_ylabel('Mean IoU (%)', fontsize=14)
|
||||
ax.legend(title='Model Family', bbox_to_anchor=(1.02, 1), loc='upper left', borderaxespad=0.)
|
||||
|
||||
plt.tight_layout(rect=[0, 0, 0.88, 1]) # 调整布局为图例留出空间
|
||||
plt.grid(True, which='both', linestyle='--', linewidth=0.5)
|
||||
|
||||
# 保存并显示图表
|
||||
output_filename_png = f"F1_{dataset_name}_mIoU_vs_FPS.png"
|
||||
save_file_path_png = os.path.join(selected_dir, output_filename_png)
|
||||
plt.savefig(save_file_path_png, dpi=600)
|
||||
output_filename_svg = f"F1_{dataset_name}_mIoU_vs_FPS.svg"
|
||||
save_file_path_svg = os.path.join(selected_dir, output_filename_svg)
|
||||
plt.savefig(save_file_path_svg)
|
||||
print(f"\n图表已成功生成并保存为: {save_file_path_svg} 和 {save_file_path_png}")
|
||||
plt.show()
|
||||
|
||||
def T1_create_and_save_summary_table(merged_df, output_dir, dataset_name):
|
||||
"""
|
||||
根据合并后的数据创建、格式化并保存性能摘要表格。
|
||||
|
||||
Args:
|
||||
merged_df (pd.DataFrame): 包含所有模型指标和性能数据的DataFrame。
|
||||
output_dir (str): 保存CSV文件的目标目录。
|
||||
dataset_name (str): 数据集的名称,用于生成文件名。
|
||||
"""
|
||||
print("正在创建摘要表格...")
|
||||
|
||||
# 检查所需列是否存在
|
||||
required_columns = ['Model', 'mIoU', 'mAcc', 'aAcc', 'Average_FPS', 'FLOPs', 'Params']
|
||||
if not all(col in merged_df.columns for col in required_columns):
|
||||
print("错误: DataFrame中缺少必要的列。请检查CSV文件内容。")
|
||||
return
|
||||
|
||||
# 提取并复制数据,避免修改原始DataFrame
|
||||
summary_df = merged_df[required_columns].copy()
|
||||
|
||||
# 清理和转换数据
|
||||
# 将 '118 G' -> 118.0
|
||||
summary_df['FLOPs'] = summary_df['FLOPs'].astype(str).str.replace(' G', '', regex=False).astype(float)
|
||||
# 将 '13.274M' -> 13.274
|
||||
summary_df['Params'] = summary_df['Params'].astype(str).str.replace('M', '', regex=False).astype(float)
|
||||
|
||||
# 按照用户的要求重命名列
|
||||
summary_df.rename(columns={
|
||||
'Average_FPS': 'FPS',
|
||||
'FLOPs': 'G(GFLOPS)',
|
||||
'Params': 'Para(Params)'
|
||||
}, inplace=True)
|
||||
|
||||
# 按 mIoU 降序排序
|
||||
summary_df = summary_df.sort_values(by='mIoU', ascending=False)
|
||||
|
||||
# 保存表格到CSV文件
|
||||
summary_filename = f"T1_{dataset_name}_performance_summary.csv"
|
||||
summary_save_path = os.path.join(output_dir, summary_filename)
|
||||
|
||||
try:
|
||||
summary_df.to_csv(summary_save_path, index=False, float_format='%.3f')
|
||||
print(f"摘要表格已成功保存到: {summary_save_path}")
|
||||
except Exception as e:
|
||||
print(f"保存摘要表格时出错: {e}")
|
||||
|
||||
def T2_extract_and_save_iou_data(metrics_df, output_dir, dataset_name):
|
||||
"""
|
||||
从 metrics DataFrame 中提取所有 mIoU 和 Class_IoU,并保存到新的CSV文件。
|
||||
|
||||
Args:
|
||||
metrics_df (pd.DataFrame): 包含所有指标的原始DataFrame。
|
||||
output_dir (str): 保存CSV文件的目标目录。
|
||||
dataset_name (str): 数据集的名称,用于生成文件名。
|
||||
"""
|
||||
print("正在提取所有 mIoU 和 Class_IoU 数据...")
|
||||
|
||||
# 检查'Algorithm'列是否存在
|
||||
if 'Algorithm' not in metrics_df.columns:
|
||||
print("错误: 'Algorithm' 列未找到,无法继续。")
|
||||
return
|
||||
|
||||
# 找出所有与IoU相关的列
|
||||
# 包括 'mIoU' 以及所有以 '_IoU' 结尾的列
|
||||
iou_columns = ['Algorithm', 'mIoU'] + [col for col in metrics_df.columns if col.endswith('_IoU') and col != 'mIoU']
|
||||
|
||||
# 移除重复的列名(以防万一)
|
||||
iou_columns = list(dict.fromkeys(iou_columns))
|
||||
|
||||
# 提取数据
|
||||
iou_df = metrics_df[iou_columns].copy()
|
||||
|
||||
# 按 mIoU 降序排序,便于查看
|
||||
iou_df = iou_df.sort_values(by='mIoU', ascending=False)
|
||||
|
||||
# 定义并保存文件
|
||||
iou_filename = f"T2_{dataset_name}_all_iou_summary.csv"
|
||||
iou_save_path = os.path.join(output_dir, iou_filename)
|
||||
|
||||
try:
|
||||
iou_df.to_csv(iou_save_path, index=False, float_format='%.2f')
|
||||
print(f"所有IoU数据已成功保存到: {iou_save_path}")
|
||||
except Exception as e:
|
||||
print(f"保存IoU数据时出错: {e}")
|
||||
|
||||
if __name__ == '__main__':
|
||||
# 设置包含所有数据集结果的根目录
|
||||
results_root_dir = '../BestMode_Predict_Results_DataSet_Public'
|
||||
|
||||
# 启动交互式选择
|
||||
selected_directory, selected_dataset_name = select_dataset(results_root_dir)
|
||||
|
||||
# 如果用户成功选择,则生成图表
|
||||
if selected_directory and selected_dataset_name:
|
||||
plot_performance_speed(selected_directory, selected_dataset_name)
|
||||
@@ -0,0 +1,413 @@
|
||||
import os
|
||||
import glob
|
||||
import logging
|
||||
import argparse
|
||||
import re
|
||||
import csv
|
||||
from typing import Dict, Optional, List
|
||||
from collections import defaultdict
|
||||
|
||||
# --- 新增导入:用于绘图 ---
|
||||
try:
|
||||
import pandas as pd
|
||||
import matplotlib.pyplot as plt
|
||||
MATPLOTLIB_AVAILABLE = True
|
||||
except ImportError:
|
||||
MATPLOTLIB_AVAILABLE = False
|
||||
logging.warning(
|
||||
"未找到 'pandas' 或 'matplotlib' 库。"
|
||||
"脚本将只生成CSV文件,无法自动绘图。"
|
||||
"请运行 'pip install pandas matplotlib' 来安装它们。"
|
||||
)
|
||||
# --- 导入结束 ---
|
||||
|
||||
|
||||
# --- 配置日志记录 ---
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
|
||||
# --- 辅助函数 (与 4_3 版本相同) ---
|
||||
|
||||
def find_all_log_files_sorted(algorithm_dir: str) -> List[str]:
|
||||
"""
|
||||
查找给定算法目录中所有的.log文件,并按从新到旧的顺序排列。
|
||||
"""
|
||||
try:
|
||||
subdirs = [d for d in os.listdir(algorithm_dir) if os.path.isdir(os.path.join(algorithm_dir, d))]
|
||||
except FileNotFoundError:
|
||||
logging.error(f"算法目录不存在: {algorithm_dir}")
|
||||
return []
|
||||
|
||||
if not subdirs:
|
||||
logging.warning(f"在目录 {algorithm_dir} 中未找到任何时间戳子目录。")
|
||||
return []
|
||||
|
||||
sorted_subdirs = sorted(subdirs, reverse=True)
|
||||
log_files = []
|
||||
for subdir_name in sorted_subdirs:
|
||||
subdir_path = os.path.join(algorithm_dir, subdir_name)
|
||||
logs_in_subdir = glob.glob(os.path.join(subdir_path, '*.log'))
|
||||
if logs_in_subdir:
|
||||
log_files.append(logs_in_subdir[0])
|
||||
|
||||
return log_files
|
||||
|
||||
def get_max_epochs(config_path: str) -> Optional[int]:
|
||||
"""
|
||||
从config.py文件中解析train_cfg字典以获取max_epochs的值。
|
||||
"""
|
||||
if not os.path.exists(config_path):
|
||||
logging.error(f"配置文件不存在: {config_path}")
|
||||
return None
|
||||
|
||||
try:
|
||||
with open(config_path, 'r', encoding='utf-8') as f:
|
||||
content = f.read()
|
||||
|
||||
match = re.search(r"train_cfg\s*=\s*dict\(.*?max_epochs\s*=\s*(\d+),.*?\)", content, re.DOTALL)
|
||||
|
||||
if match:
|
||||
max_epochs = int(match.group(1))
|
||||
logging.info(f"从 {os.path.basename(config_path)} 中成功读取 max_epochs: {max_epochs}")
|
||||
return max_epochs
|
||||
else:
|
||||
logging.warning(f"在 {config_path} 中未找到 'max_epochs'。")
|
||||
return None
|
||||
except Exception as e:
|
||||
logging.error(f"解析 {config_path} 时出错: {e}")
|
||||
return None
|
||||
|
||||
def parse_log_file_data(log_path: str, max_epochs: int) -> Dict:
|
||||
"""
|
||||
解析日志文件,提取所有训练损失和所有验证mIoU。
|
||||
"""
|
||||
try:
|
||||
with open(log_path, 'r', encoding='utf-8') as f:
|
||||
content = f.read()
|
||||
except IOError as e:
|
||||
logging.error(f"无法读取日志文件: {log_path}。错误: {e}")
|
||||
return {'training_losses': {}, 'validation_mious': []}
|
||||
|
||||
training_losses_by_epoch = defaultdict(list)
|
||||
validation_mious = []
|
||||
|
||||
# 1. 提取训练损失
|
||||
total_iters_match = re.search(r"Iter\(train\)\s*\[\s*(\d+)\s*/\s*(\d+)\]", content)
|
||||
total_iters = 0
|
||||
if total_iters_match:
|
||||
try:
|
||||
total_iters = int(total_iters_match.group(2))
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
if total_iters > 0:
|
||||
train_loss_pattern = re.compile(
|
||||
r"Iter\(train\)\s*\[\s*(\d+)\s*/\s*(\d+)\]"
|
||||
r"(?:.*?)loss:\s*([\d\.]+)"
|
||||
)
|
||||
for match in re.finditer(train_loss_pattern, content):
|
||||
try:
|
||||
current_iter = int(match.group(1))
|
||||
loss = float(match.group(3))
|
||||
epoch = int((current_iter / total_iters) * max_epochs)
|
||||
training_losses_by_epoch[epoch].append(loss)
|
||||
except (ValueError, ZeroDivisionError) as e:
|
||||
logging.warning(f"解析训练损失时出错: {e}")
|
||||
else:
|
||||
logging.warning(f"在 {os.path.basename(log_path)} 中未找到有效的 'Iter(train)' 行来确定总迭代次数。")
|
||||
|
||||
# 2. 提取验证 mIoU
|
||||
val_summary_pattern = re.compile(
|
||||
r"Iter\(val\) \[\d+/\d+\]\s+aAcc:\s*[\d.]+\s+mIoU:\s*([\d.]+)\s+mAcc:\s*[\d.]+"
|
||||
)
|
||||
for match in re.finditer(val_summary_pattern, content):
|
||||
try:
|
||||
miou = float(match.group(1))
|
||||
validation_mious.append(miou)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
log_name = os.path.basename(log_path)
|
||||
if training_losses_by_epoch:
|
||||
logging.info(f"✅ 从 {log_name} 提取了 {len(training_losses_by_epoch)} 个Epoch的训练损失。")
|
||||
if validation_mious:
|
||||
logging.info(f"✅ 从 {log_name} 提取了 {len(validation_mious)} 次验证的mIoU。")
|
||||
if not training_losses_by_epoch and not validation_mious:
|
||||
logging.warning(f"❌ 在 {log_name} 中未找到训练损失或验证mIoU。")
|
||||
|
||||
return {
|
||||
'training_losses': training_losses_by_epoch,
|
||||
'validation_mious': validation_mious
|
||||
}
|
||||
|
||||
# --- 新增:从 4_4 脚本中合并过来的绘图函数 ---
|
||||
|
||||
def plot_loss_curves(csv_path: str):
|
||||
"""
|
||||
读取_training_loss_summary.csv文件并绘制训练损失曲线。
|
||||
(此版本已根据用户需求修改)
|
||||
|
||||
Args:
|
||||
csv_path (str): 输入的CSV文件路径。
|
||||
"""
|
||||
if not MATPLOTLIB_AVAILABLE:
|
||||
logging.warning("由于缺少 'pandas' 或 'matplotlib',跳过绘图。")
|
||||
return
|
||||
|
||||
if not os.path.exists(csv_path):
|
||||
logging.error(f"[绘图] 文件未找到: {csv_path}")
|
||||
return
|
||||
|
||||
logging.info(f"[绘图] 正在读取数据: {os.path.basename(csv_path)}")
|
||||
try:
|
||||
df = pd.read_csv(csv_path)
|
||||
except Exception as e:
|
||||
logging.error(f"[绘图] 读取CSV时出错: {e}")
|
||||
return
|
||||
|
||||
try:
|
||||
df = df.set_index('Algorithm')
|
||||
except KeyError:
|
||||
logging.error("[绘图] CSV文件中未找到 'Algorithm' 列。")
|
||||
return
|
||||
|
||||
loss_cols = [col for col in df.columns if col.startswith('Epoch_') and col.endswith('_Loss')]
|
||||
|
||||
if not loss_cols:
|
||||
logging.warning("[绘图] 在CSV中未找到任何 'Epoch_X_Loss' 列。")
|
||||
return
|
||||
|
||||
try:
|
||||
epochs = [int(col.split('_')[1]) for col in loss_cols]
|
||||
except (ValueError, IndexError) as e:
|
||||
logging.error(f"[绘图] 解析Epoch列名时出错: {e}。列名格式应为 'Epoch_N_Loss'。")
|
||||
return
|
||||
|
||||
df_losses = df[loss_cols].apply(pd.to_numeric, errors='coerce')
|
||||
|
||||
# --- 开始绘图 ---
|
||||
logging.info("[绘图] 正在生成图表...")
|
||||
|
||||
fig, ax = plt.subplots(figsize=(14, 8))
|
||||
|
||||
for alg_name, row in df_losses.iterrows():
|
||||
# --- 修改点 1 ---
|
||||
# 移除 marker 和 markersize,使用实线 (linestyle='-')
|
||||
ax.plot(epochs, row.values, label=alg_name, linestyle='-')
|
||||
|
||||
# --- 设置图表样式 ---
|
||||
ax.set_xlabel('Epoch', fontsize=12)
|
||||
ax.set_ylabel('Average Training Loss', fontsize=12)
|
||||
ax.set_title(f'Training Loss per Epoch\n(Source: {os.path.basename(csv_path)})', fontsize=14)
|
||||
ax.grid(True, linestyle=':', alpha=0.7)
|
||||
|
||||
# --- 修改点 2 ---
|
||||
# 将Y轴(纵轴)的范围设置为 0 到 5
|
||||
ax.set_ylim(bottom=0, top=5) # TODO TODO
|
||||
|
||||
# 将图例放在图表右侧外部
|
||||
ax.legend(bbox_to_anchor=(1.04, 1), loc='upper left', title="Algorithms")
|
||||
|
||||
fig.tight_layout(rect=[0, 0, 0.85, 1])
|
||||
|
||||
# --- 保存图表 ---
|
||||
output_png_path = os.path.splitext(csv_path)[0] + '.png'
|
||||
try:
|
||||
plt.savefig(output_png_path, dpi=150, bbox_inches='tight')
|
||||
logging.info(f"✅ 图表已成功保存到: {output_png_path}")
|
||||
except IOError as e:
|
||||
logging.error(f"[绘图] 保存图像时出错: {e}")
|
||||
finally:
|
||||
plt.close(fig) # 释放内存
|
||||
|
||||
# --- 主函数 (小幅修改以调用绘图) ---
|
||||
|
||||
def main(args):
|
||||
"""
|
||||
脚本主入口,负责编排整个自动化分析流程。
|
||||
"""
|
||||
input_root = args.input_dir
|
||||
output_root = args.output_dir
|
||||
|
||||
if not os.path.isdir(input_root):
|
||||
logging.error(f"输入目录不存在: {input_root}")
|
||||
return
|
||||
|
||||
# --- 交互式菜单 (保持不变) ---
|
||||
all_dataset_dirs = sorted([
|
||||
d for d in glob.glob(os.path.join(input_root, '*_outputs-MMSeg')) if os.path.isdir(d)
|
||||
])
|
||||
if not all_dataset_dirs:
|
||||
logging.error(f"在输入目录 {input_root} 中未找到任何有效的数据集文件夹。")
|
||||
return
|
||||
dataset_map = {str(i + 1): path for i, path in enumerate(all_dataset_dirs)}
|
||||
print("\n" + "="*50)
|
||||
print("--- 步骤 1: 请选择要处理的数据集 ---")
|
||||
for key, path in dataset_map.items():
|
||||
print(f"{key}: {os.path.basename(path)}")
|
||||
print("="*50)
|
||||
choice1 = input("请输入数据集编号并按回车键: ").strip()
|
||||
|
||||
model_dirs = []
|
||||
selected_dataset_dir = None
|
||||
|
||||
if choice1 in dataset_map:
|
||||
selected_dataset_dir = dataset_map[choice1]
|
||||
logging.info(f"您已选择数据集: [{os.path.basename(selected_dataset_dir)}]")
|
||||
alg_dirs = sorted([
|
||||
d for d in glob.glob(os.path.join(selected_dataset_dir, '*')) if os.path.isdir(d)
|
||||
])
|
||||
if not alg_dirs:
|
||||
logging.warning(f"在 {os.path.basename(selected_dataset_dir)} 中未发现任何算法子文件夹。")
|
||||
return
|
||||
alg_map = {str(i + 1): path for i, path in enumerate(alg_dirs)}
|
||||
print("\n" + "="*50)
|
||||
print("--- 步骤 2: 请选择要处理的算法 ---")
|
||||
print("0: 批量处理当前数据集下的【全部】算法")
|
||||
for key, path in alg_map.items():
|
||||
print(f"{key}: {os.path.basename(path)}")
|
||||
print("="*50)
|
||||
choice2 = input("请输入算法编号 (或输入 '0' 处理全部) 并按回车键: ").strip()
|
||||
if choice2 == '0':
|
||||
model_dirs = alg_dirs
|
||||
logging.info(f"您选择了批量处理全部 {len(model_dirs)} 个算法。")
|
||||
elif choice2 in alg_map:
|
||||
model_dirs = [alg_map[choice2]]
|
||||
logging.info(f"您选择了处理单个算法: {os.path.basename(model_dirs[0])}")
|
||||
else:
|
||||
logging.error("无效的算法选择,程序已退出。")
|
||||
return
|
||||
else:
|
||||
logging.error("无效的数据集选择,程序已退出。")
|
||||
return
|
||||
# --- 交互式菜单结束 ---
|
||||
|
||||
|
||||
# --- 处理逻辑 (保持不变) ---
|
||||
csv_rows = []
|
||||
|
||||
for model_dir in model_dirs:
|
||||
model_name = os.path.basename(model_dir)
|
||||
logging.info(f"\n--- 开始处理算法: {model_name} ---")
|
||||
|
||||
all_logs_sorted = find_all_log_files_sorted(model_dir)
|
||||
|
||||
if not all_logs_sorted:
|
||||
logging.warning(f"跳过算法 {model_name},因为未找到任何日志文件。")
|
||||
continue
|
||||
|
||||
latest_log_path = all_logs_sorted[0]
|
||||
config_path = os.path.join(os.path.dirname(latest_log_path), 'vis_data', 'config.py')
|
||||
max_epochs = get_max_epochs(config_path)
|
||||
if max_epochs is None:
|
||||
logging.error(f"无法为算法 {model_name} 找到 max_epochs,将跳过。")
|
||||
continue
|
||||
|
||||
all_losses_for_model = defaultdict(list)
|
||||
all_mious_for_model = []
|
||||
|
||||
for log_file_path in all_logs_sorted:
|
||||
logging.info(f"正在解析: {os.path.relpath(log_file_path)}")
|
||||
parsed_data = parse_log_file_data(log_file_path, max_epochs)
|
||||
for epoch, losses in parsed_data['training_losses'].items():
|
||||
all_losses_for_model[epoch].extend(losses)
|
||||
if parsed_data['validation_mious']:
|
||||
all_mious_for_model.extend(parsed_data['validation_mious'])
|
||||
|
||||
if not all_losses_for_model and not all_mious_for_model:
|
||||
logging.warning(f"❌❌❌跳过算法 {model_name},因为在其所有日志文件中都未能找到有效的训练或验证数据。❌❌❌")
|
||||
continue
|
||||
|
||||
best_miou = 0.0
|
||||
if all_mious_for_model:
|
||||
try:
|
||||
best_miou = max(all_mious_for_model)
|
||||
logging.info(f"找到了最佳 mIoU: {best_miou:.4f}")
|
||||
except (ValueError, TypeError) as e:
|
||||
logging.error(f"为算法 {model_name} 寻找最佳mIoU时出错: {e}")
|
||||
else:
|
||||
logging.warning(f"算法 {model_name} 没有找到任何mIoU数据。")
|
||||
|
||||
short_model_name = model_name.split('Alg_', 1)[1] if 'Alg_' in model_name else model_name
|
||||
row_data = {
|
||||
'Algorithm': short_model_name
|
||||
}
|
||||
|
||||
sorted_epochs = sorted(all_losses_for_model.keys())
|
||||
if not sorted_epochs:
|
||||
logging.warning(f"算法 {model_name} 没有找到任何训练损失数据。")
|
||||
|
||||
for epoch in sorted_epochs:
|
||||
losses = all_losses_for_model[epoch]
|
||||
avg_loss = sum(losses) / len(losses)
|
||||
row_data[f'Epoch_{epoch}_Loss'] = f"{avg_loss:.4f}"
|
||||
|
||||
row_data['Best_mIoU'] = f"{best_miou:.4f}"
|
||||
csv_rows.append(row_data)
|
||||
# --- 处理逻辑结束 ---
|
||||
|
||||
|
||||
# --- 写入 CSV 文件 (修改后) ---
|
||||
if not csv_rows:
|
||||
logging.info("没有成功获取任何模型的统计数据,不生成 CSV 文件。")
|
||||
return
|
||||
|
||||
# 动态生成所有列名
|
||||
all_fieldnames_set = set()
|
||||
for row in csv_rows:
|
||||
all_fieldnames_set.update(row.keys())
|
||||
|
||||
base_fields = ['Algorithm']
|
||||
miou_field = ['Best_mIoU']
|
||||
loss_fields = [f for f in all_fieldnames_set if f.startswith('Epoch_')]
|
||||
|
||||
try:
|
||||
loss_fields.sort(key=lambda x: int(x.split('_')[1]))
|
||||
except (ValueError, IndexError):
|
||||
logging.error("排序Epoch列名时出错,将按字母顺序排序。")
|
||||
loss_fields.sort()
|
||||
|
||||
final_fieldnames = base_fields + loss_fields + miou_field
|
||||
|
||||
final_output_dir = os.path.join(output_root, os.path.basename(selected_dataset_dir))
|
||||
os.makedirs(final_output_dir, exist_ok=True)
|
||||
dataset_name = os.path.basename(selected_dataset_dir).split('_outputs-MMSeg')[0]
|
||||
output_csv_path = os.path.join(final_output_dir, f'{dataset_name}_training_loss_summary.csv')
|
||||
|
||||
try:
|
||||
with open(output_csv_path, 'w', newline='', encoding='utf-8-sig') as csvfile:
|
||||
writer = csv.DictWriter(csvfile, fieldnames=final_fieldnames, extrasaction='ignore')
|
||||
writer.writeheader()
|
||||
writer.writerows(csv_rows)
|
||||
|
||||
logging.info(f"\n=== CSV文件已成功保存到: {output_csv_path} ===")
|
||||
|
||||
# --- 新增:调用绘图函数 ---
|
||||
# 仅在CSV成功写入后才尝试绘图
|
||||
plot_loss_curves(output_csv_path)
|
||||
# --- 新增结束 ---
|
||||
|
||||
except IOError as e:
|
||||
logging.error(f"无法写入 CSV 文件: {output_csv_path}。错误: {e}")
|
||||
except Exception as e_plot:
|
||||
# 捕获绘图时可能发生的其他错误
|
||||
logging.error(f"在主流程中调用绘图时出错: {e_plot}")
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(
|
||||
description="MMSegmentation 训练损失提取与绘图脚本 (V3-Integrated)"
|
||||
)
|
||||
parser.add_argument(
|
||||
'--input_dir',
|
||||
type=str,
|
||||
default='../Hardisk',
|
||||
help="包含数据集输出文件夹 (例如 '..._outputs-MMSeg') 的根目录。"
|
||||
)
|
||||
parser.add_argument(
|
||||
'--output_dir',
|
||||
type=str,
|
||||
default='../BestMode_Predict_Results_DataSet_Public',
|
||||
help="用于存储所有分析结果 (CSV和PNG) 的根目录。"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
@@ -0,0 +1,28 @@
|
||||
from .Initial_Alg_Gen_Tool import get_var_from_file
|
||||
|
||||
# 生成 _base_ 变量,传入算法配置、数据集配置、schedule配置
|
||||
def generate_base_config(alg_file_name, dataset_file_name, schedule_file_name):
|
||||
if alg_file_name != None:
|
||||
base_config =[
|
||||
f'../_base_/models/{alg_file_name}.py',
|
||||
f'../_base_/datasets/{dataset_file_name}.py', #换成自己定义的数据集
|
||||
f'../_base_/default_runtime.py',
|
||||
f'../_base_/schedules/{schedule_file_name}.py'
|
||||
]
|
||||
else:
|
||||
base_config =[
|
||||
f'../_base_/datasets/{dataset_file_name}.py', #换成自己定义的数据集
|
||||
f'../_base_/default_runtime.py',
|
||||
f'../_base_/schedules/{schedule_file_name}.py'
|
||||
]
|
||||
|
||||
return base_config
|
||||
|
||||
if __name__ == '__main__':
|
||||
# 示例用法
|
||||
alg_file_name = 'ann_r50-d8' # 算法根文件
|
||||
dataset_file_name = 'my_dataset_model' # 数据文件
|
||||
schedule_file_name = 'schedule_4k_check_400' # schedule文件
|
||||
|
||||
_base_ = generate_base_config(alg_file_name=alg_file_name, dataset_file_name=dataset_file_name, schedule_file_name=schedule_file_name)
|
||||
print(_base_)
|
||||
@@ -0,0 +1,19 @@
|
||||
from .Initial_Alg_Gen_Tool import get_var_from_file
|
||||
|
||||
# 单卡 norm_cfg = dict(type='BN')
|
||||
# 多卡 norm_cfg = dict(type='SyncBN')
|
||||
def generate_norm_cfg(GPU_num = 2):
|
||||
GPU_num = int(GPU_num)
|
||||
if GPU_num == 1:
|
||||
return dict(type='BN')
|
||||
elif GPU_num > 1:
|
||||
return dict(type='SyncBN')
|
||||
else:
|
||||
raise ValueError("GPU_num需要为大于等于1的整数数值")
|
||||
|
||||
if __name__ == '__main__':
|
||||
# 示例用法
|
||||
GPU_num = 1
|
||||
|
||||
norm_cfg = generate_norm_cfg(GPU_num=GPU_num)
|
||||
print(norm_cfg)
|
||||
@@ -0,0 +1,23 @@
|
||||
from .Initial_Alg_Gen_Tool import get_var_from_file
|
||||
|
||||
# crop_size 数据预处理是分割大小
|
||||
# crop_size = (512,512)
|
||||
def generate_data_preprocessor(crop_size=None, mean=None, std=None, bgr_to_rgb=False):
|
||||
|
||||
data_preprocessor = {}
|
||||
if crop_size != None:
|
||||
data_preprocessor['size']=crop_size
|
||||
if mean != None:
|
||||
data_preprocessor['mean']=mean
|
||||
if std != None:
|
||||
data_preprocessor['std']=std
|
||||
if bgr_to_rgb != None:
|
||||
data_preprocessor['bgr_to_rgb']=bgr_to_rgb
|
||||
return data_preprocessor
|
||||
|
||||
if __name__ == '__main__':
|
||||
# 示例用法
|
||||
crop_size = (512,512)
|
||||
|
||||
data_preprocessor = generate_data_preprocessor(crop_size = crop_size)
|
||||
print(data_preprocessor)
|
||||
@@ -0,0 +1,119 @@
|
||||
from .Initial_Alg_Gen_Tool import get_var_from_file, format_dict
|
||||
|
||||
# 1. 修改pretrained
|
||||
# pretrained_pth = './My_Local_Model/open_mmlab/resnet50_v1c.pth')
|
||||
def generate_model_pretrained(pretrained_pth=None):
|
||||
pretrained = pretrained_pth
|
||||
if pretrained_pth == None:
|
||||
return None
|
||||
return pretrained
|
||||
|
||||
# 2. 修改backbone
|
||||
# depth=50
|
||||
def generate_model_backbone(depth=None):
|
||||
backbone = {}
|
||||
if depth != None:
|
||||
backbone['depth']=depth
|
||||
return backbone
|
||||
|
||||
# 3. 修改model_data_preprocessor
|
||||
# model_data_preprocessor='data_preprocessor'
|
||||
def generate_model_data_preprocessor(model_data_preprocessor=None):
|
||||
if model_data_preprocessor != None:
|
||||
model_data_preprocessor = model_data_preprocessor
|
||||
return model_data_preprocessor
|
||||
|
||||
# 4. 修改decode_head
|
||||
# num_classes='36' # 分割数
|
||||
# decode_head_loss_decode_dict=dict(type='DiceLoss', use_sigmoid=False, loss_weight=1.0) # 直接传入dict
|
||||
# align_corners=False是否需要角对齐
|
||||
def generate_model_decode_head(num_classes=None, decode_head_loss_decode_dict=None, align_corners=None):
|
||||
decode_head = {}
|
||||
if num_classes != None:
|
||||
decode_head['num_classes']=num_classes
|
||||
if decode_head_loss_decode_dict != None:
|
||||
decode_head['loss_decode']=decode_head_loss_decode_dict
|
||||
if align_corners != None:
|
||||
decode_head['align_corners']=align_corners
|
||||
|
||||
return decode_head
|
||||
|
||||
# 5. 修改auxiliary_head
|
||||
# num_classes='36' # 分割数
|
||||
# auxiliary_head_loss_decode_dict=dict(type='DiceLoss', use_sigmoid=False, loss_weight=0.4) # 直接传入dict
|
||||
# align_corners=False是否需要角对齐
|
||||
def generate_model_auxiliary_head(num_classes=None, auxiliary_head_loss_decode_dict=None, align_corners=None):
|
||||
auxiliary_head = {}
|
||||
if num_classes != None:
|
||||
auxiliary_head['num_classes']=num_classes
|
||||
if auxiliary_head_loss_decode_dict != None:
|
||||
auxiliary_head['loss_decode']=auxiliary_head_loss_decode_dict
|
||||
if align_corners != None:
|
||||
auxiliary_head['align_corners']=align_corners
|
||||
|
||||
return auxiliary_head
|
||||
|
||||
# 6. 修改train_cfg
|
||||
def generate_model_train_cfg():
|
||||
train_cfg = {}
|
||||
|
||||
return train_cfg
|
||||
|
||||
# 6. 修改test_cfg
|
||||
# mode='slide'
|
||||
# crop_size = (767,767)
|
||||
# test_cfg_crop_div_stride = 1.5
|
||||
def generate_model_test_cfg(test_cfg_mode=None, crop_size=None, test_cfg_crop_div_stride=None):
|
||||
test_cfg = {}
|
||||
if test_cfg_mode != None:
|
||||
test_cfg['mode'] = test_cfg_mode
|
||||
if crop_size != None:
|
||||
test_cfg['crop_size'] = crop_size
|
||||
if test_cfg_crop_div_stride != None:
|
||||
test_cfg['stride'] = (int(crop_size[0]/test_cfg_crop_div_stride), int(crop_size[1]/test_cfg_crop_div_stride))
|
||||
|
||||
return test_cfg
|
||||
|
||||
if __name__ == '__main__':
|
||||
# 1. 修改pretrained
|
||||
pretrained_pth = './My_Local_Model/open_mmlab/resnet50_v1c.pth'
|
||||
pretrained = generate_model_pretrained(pretrained_pth = pretrained_pth)
|
||||
|
||||
# 2. 修改backbone
|
||||
depth = 50
|
||||
backbone = generate_model_backbone(depth=depth)
|
||||
|
||||
# 3. 修改data_preprocessor
|
||||
model_data_preprocessor = 'data_preprocessor'
|
||||
model_data_preprocessor = generate_model_data_preprocessor(model_data_preprocessor=model_data_preprocessor)
|
||||
|
||||
# 4.5. 修改decode_head、auxiliary_head
|
||||
num_classes='36' # 分割数
|
||||
decode_head_loss_decode_dict=dict(type='DiceLoss', use_sigmoid=False, loss_weight=1.0) # 直接传入dict
|
||||
align_corners=False # 是否需要角对齐
|
||||
auxiliary_head_loss_decode_dict=dict(type='DiceLoss', use_sigmoid=False, loss_weight=0.4)
|
||||
|
||||
decode_head = generate_model_decode_head(num_classes=num_classes, decode_head_loss_decode_dict=decode_head_loss_decode_dict, align_corners=align_corners)
|
||||
auxiliary_head = generate_model_auxiliary_head(num_classes=num_classes, auxiliary_head_loss_decode_dict=auxiliary_head_loss_decode_dict, align_corners=align_corners)
|
||||
|
||||
# 6. 修改train_cfg
|
||||
train_cfg = generate_model_train_cfg()
|
||||
|
||||
# 6. 修改test_cfg
|
||||
mode='slide'
|
||||
crop_size = (767,767)
|
||||
test_cfg_crop_div_stride = 1.5
|
||||
test_cfg = generate_model_test_cfg(mode=mode, crop_size=crop_size, test_cfg_crop_div_stride=test_cfg_crop_div_stride)
|
||||
|
||||
# 汇总为model
|
||||
model = dict(
|
||||
pretrained = pretrained,
|
||||
backbone = backbone,
|
||||
data_preprocessor = model_data_preprocessor,
|
||||
decode_head = decode_head,
|
||||
auxiliary_head = auxiliary_head,
|
||||
train_cfg = train_cfg,
|
||||
test_cfg = test_cfg,
|
||||
)
|
||||
model_ = format_dict(model)
|
||||
print(model_)
|
||||
@@ -0,0 +1,173 @@
|
||||
# optimizer(优化器设计)TODO
|
||||
# type_of_back_bone = "Vit"
|
||||
def generate_optim_wrapper(type_of_back_bone=None):
|
||||
|
||||
# optim_wrapper 配置-1
|
||||
optim_wrapper_1 = {
|
||||
'type': 'OptimWrapper', # 表示这是一个优化器包装器(OptimWrapper)
|
||||
'_delete_': True, # 通常用于在继承配置时删除旧的优化器配置,替换为新的优化器配置
|
||||
'optimizer': {
|
||||
'type': 'AdamW', # 优化器类型为 AdamW
|
||||
'lr': 0.0001, # 学习率,通常设定为一个较小的值
|
||||
'weight_decay': 0.0005 # 权重衰减系数
|
||||
},
|
||||
'clip_grad': {
|
||||
'max_norm': 1, # 梯度裁剪的最大范数
|
||||
'norm_type': 2 # L2 范数
|
||||
}
|
||||
}
|
||||
# optim_wrapper 配置-2
|
||||
optim_wrapper_2 = {
|
||||
'type': 'OptimWrapper', # 表示这是一个优化器包装器(OptimWrapper)
|
||||
'_delete_': True, # 通常用于在继承配置时删除旧的优化器配置,替换为新的优化器配置
|
||||
'optimizer': {
|
||||
'type': 'SGD', # 优化器类型为 SGD
|
||||
'lr': 0.05, # 学习率
|
||||
'weight_decay': 0.0005, # 权重衰减系数
|
||||
'momentum': 0.9
|
||||
},
|
||||
'clip_grad': {
|
||||
'max_norm': 1, # 梯度裁剪的最大范数
|
||||
'norm_type': 2 # L2 范数
|
||||
}
|
||||
}
|
||||
|
||||
optim_wrapper_list = [optim_wrapper_1, optim_wrapper_2]
|
||||
|
||||
# 打印所有可用的优化器选项
|
||||
while True:
|
||||
print("请选择 optim_wrapper (按 Enter 使用默认):")
|
||||
for i, scheduler in enumerate(optim_wrapper_list, 1):
|
||||
optimizer_type = scheduler['optimizer']['type']
|
||||
learning_rate = scheduler['optimizer']['lr']
|
||||
print(f"{i}. Optimizer: {optimizer_type}, LR: {learning_rate}")
|
||||
|
||||
# 如果是vit网络,则需要额外的paramwise_cfg
|
||||
if type_of_back_bone != None and (type_of_back_bone.lower() == 'vit' or type_of_back_bone.lower() == 'visiontransformer'):
|
||||
custom_keys={
|
||||
'pos_embed': dict(decay_mult=0.), # 位置嵌入(positional embeddings)。decay_mult=0. 意味着对这些嵌入不应用权重衰减
|
||||
'cls_token': dict(decay_mult=0.), # 是在某些模型(如 Transformer 或 BERT)中,用于分类任务的特定 token
|
||||
'norm': dict(decay_mult=0.) # 对归一化层的参数也禁用了权重衰减
|
||||
}
|
||||
optim_wrapper_list[i-1]['paramwise_cfg'] = dict(custom_keys)
|
||||
optim_wrapper_list[i-1]['_delete_'] = True
|
||||
|
||||
if type_of_back_bone != None and (type_of_back_bone.lower() == 'swin'):
|
||||
custom_keys={
|
||||
'absolute_pos_embed': dict(decay_mult=0.),
|
||||
'relative_position_bias_table': dict(decay_mult=0.),
|
||||
'norm': dict(decay_mult=0.)
|
||||
}
|
||||
optim_wrapper_list[i-1]['paramwise_cfg'] = dict(custom_keys)
|
||||
optim_wrapper_list[i-1]['_delete_'] = True
|
||||
|
||||
choice = input(f"输入 1 到 {len(optim_wrapper_list)} 进行选择(默认 1): ")
|
||||
|
||||
# 如果用户按下 Enter 或选择 1,默认返回 optim_wrapper_1
|
||||
if choice == '' or choice == '1':
|
||||
return optim_wrapper_list[0]
|
||||
elif choice.isdigit() and 1 <= int(choice) <= len(optim_wrapper_list):
|
||||
return optim_wrapper_list[int(choice) - 1]
|
||||
else:
|
||||
print(f"无效输入,请输入 1 到 {len(optim_wrapper_list)} 或按 Enter 使用默认值")
|
||||
|
||||
def generate_param_scheduler(train_type, train_time_or_epoch):
|
||||
"""
|
||||
根据训练模式(epoch或iteration)动态生成并选择学习率调度器。
|
||||
|
||||
Args:
|
||||
train_type (str): 训练模式, 'epoch' 或 'iteration'。
|
||||
train_time_or_epoch (int): 训练的总轮数(epochs)或总迭代次数(k)。
|
||||
"""
|
||||
# 1. 根据训练模式确定核心参数
|
||||
if train_type.lower() == 'epoch':
|
||||
by_epoch = True
|
||||
end_value = train_time_or_epoch
|
||||
warmup_end = 10 # 为epoch模式设置一个合理的10个epoch的预热期
|
||||
# 按比例调整MultiStepLR的milestones
|
||||
milestones = [int(end_value * 0.75), int(end_value * 0.9)]
|
||||
else: # 默认为 iteration 模式
|
||||
by_epoch = False
|
||||
end_value = train_time_or_epoch * 1000
|
||||
warmup_end = 1500 # iteration模式使用原有的1500次迭代预热
|
||||
# MultiStepLR的milestones, 同样适配总时长
|
||||
milestones = [int(end_value * 0.75), int(end_value * 0.9)]
|
||||
|
||||
# 2. 使用动态参数定义调度器模板
|
||||
# param_scheduler 配置-1
|
||||
param_scheduler_1 = [
|
||||
dict(
|
||||
type='LinearLR',
|
||||
start_factor=1e-6,
|
||||
by_epoch=by_epoch,
|
||||
begin=0,
|
||||
end=warmup_end),
|
||||
dict(
|
||||
type='PolyLR',
|
||||
power=0.9,
|
||||
begin=warmup_end,
|
||||
end=end_value,
|
||||
eta_min=1e-5,
|
||||
by_epoch=by_epoch)
|
||||
]
|
||||
# param_scheduler 配置-2
|
||||
param_scheduler_2 = [
|
||||
dict(
|
||||
type='LinearLR',
|
||||
start_factor=1e-6,
|
||||
by_epoch=by_epoch,
|
||||
begin=0,
|
||||
end=warmup_end),
|
||||
dict(
|
||||
type='PolyLR',
|
||||
power=1.0,
|
||||
begin=warmup_end,
|
||||
end=end_value,
|
||||
eta_min=0.0,
|
||||
by_epoch=by_epoch)
|
||||
]
|
||||
# param_scheduler 配置-3
|
||||
param_scheduler_3 = [
|
||||
dict(
|
||||
type='LinearLR',
|
||||
start_factor=0.001,
|
||||
by_epoch=by_epoch,
|
||||
begin=0,
|
||||
end=warmup_end),
|
||||
dict(
|
||||
type='MultiStepLR',
|
||||
begin=warmup_end,
|
||||
end=end_value,
|
||||
milestones=milestones,
|
||||
by_epoch=by_epoch)
|
||||
]
|
||||
param_scheduler_list = [param_scheduler_1, param_scheduler_2, param_scheduler_3]
|
||||
|
||||
# 3. 打印并让用户选择 (这部分逻辑不变)
|
||||
while True:
|
||||
print("请选择 param_scheduler (学习率调度器):")
|
||||
for i, schedulers in enumerate(param_scheduler_list, 1):
|
||||
print(f"Scheduler {i}:")
|
||||
for scheduler in schedulers:
|
||||
# 提取字段并处理缺省情况
|
||||
scheduler_type = scheduler.get('type', '/')
|
||||
begin = scheduler.get('begin', '/')
|
||||
end = scheduler.get('end', '/')
|
||||
power = scheduler.get('power', '/')
|
||||
eta_min = scheduler.get('eta_min', '/')
|
||||
milestones_val = scheduler.get('milestones', '/')
|
||||
|
||||
# 根据类型决定显示哪些信息
|
||||
if scheduler_type == 'MultiStepLR':
|
||||
print(f" - {scheduler_type}: begin={begin}, end={end}, milestones={milestones_val}")
|
||||
else:
|
||||
print(f" - {scheduler_type}: begin={begin}, end={end}, power={power}, eta_min={eta_min}")
|
||||
|
||||
choice = input(f"输入 1 到 {len(param_scheduler_list)} 进行选择(默认 1): ").strip()
|
||||
|
||||
if choice == '' or choice == '1':
|
||||
return param_scheduler_list[0]
|
||||
elif choice.isdigit() and 1 <= int(choice) <= len(param_scheduler_list):
|
||||
return param_scheduler_list[int(choice) - 1]
|
||||
else:
|
||||
print(f"无效输入,请输入 1 到 {len(param_scheduler_list)} 或按 Enter 使用默认值\n")
|
||||
@@ -0,0 +1,59 @@
|
||||
# train_dataloader TODO
|
||||
def generate_train_dataloader(batch_size_default=None, num_workers_default=None):
|
||||
batch_size = None
|
||||
if batch_size_default != None:
|
||||
while True:
|
||||
user_input = input(f"请输入 batch size (默认为 {batch_size_default}): ")
|
||||
|
||||
# 如果用户没有输入内容,使用默认值
|
||||
if not user_input.strip():
|
||||
batch_size = batch_size_default
|
||||
break
|
||||
|
||||
# 尝试将输入转换为整数
|
||||
try:
|
||||
batch_size = int(user_input)
|
||||
if batch_size > 0:
|
||||
break # 输入正确,退出循环
|
||||
else:
|
||||
print("Batch size 必须是正整数,请重新输入。")
|
||||
except ValueError:
|
||||
print("输入无效,请输入一个有效的整数。")
|
||||
print(f"将train_dataloader的batch_size设置为{batch_size}")
|
||||
|
||||
num_workers = None
|
||||
if num_workers_default != None:
|
||||
while True:
|
||||
user_input = input(f"请输入 num workers (默认为 {num_workers_default}): ")
|
||||
|
||||
# 如果用户没有输入内容,使用默认值
|
||||
if not user_input.strip():
|
||||
num_workers = num_workers_default
|
||||
break
|
||||
|
||||
# 尝试将输入转换为整数
|
||||
try:
|
||||
num_workers = int(user_input)
|
||||
if num_workers > 0:
|
||||
break # 输入正确,退出循环
|
||||
else:
|
||||
print("Num workers 必须是正整数,请重新输入。")
|
||||
except ValueError:
|
||||
print("输入无效,请输入一个有效的整数。")
|
||||
print(f"将train_dataloader的num_workers设置为{num_workers}")
|
||||
|
||||
# 返回包含 batch_size 的字典
|
||||
train_dataloader = {}
|
||||
if num_workers == None:
|
||||
train_dataloader['batch_size'] = batch_size
|
||||
return train_dataloader, batch_size
|
||||
|
||||
if batch_size == None:
|
||||
train_dataloader['num_workers'] = num_workers
|
||||
return train_dataloader, num_workers
|
||||
|
||||
train_dataloader['batch_size'] = batch_size
|
||||
train_dataloader['num_workers'] = num_workers
|
||||
return train_dataloader, batch_size, num_workers
|
||||
|
||||
|
||||
@@ -0,0 +1,264 @@
|
||||
import ast, subprocess, os
|
||||
|
||||
# 从文件中获取特定变量信息,返回变量直和类型
|
||||
def get_var_from_file(filename, var_name="norm_cfg"):
|
||||
with open(filename, 'r', encoding='utf-8') as file:
|
||||
code = file.read()
|
||||
|
||||
# 解析文件的 AST 树
|
||||
tree = ast.parse(code)
|
||||
|
||||
# 遍历 AST 树,查找指定变量
|
||||
for node in ast.walk(tree):
|
||||
if isinstance(node, ast.Assign):
|
||||
for target in node.targets:
|
||||
if isinstance(target, ast.Name) and target.id == var_name:
|
||||
# 将 AST 节点转换为 Python 对象
|
||||
var_value = ast.literal_eval(node.value)
|
||||
# 返回变量的值和类型
|
||||
return var_value, type(var_value)
|
||||
|
||||
# 如果没有找到指定的变量,返回 None 和 None
|
||||
return None, None
|
||||
|
||||
def get_var_from_py_file(file_path, var_name="auxiliary_head"):
|
||||
# 定义一个字典用于保存文件中的变量
|
||||
context = {}
|
||||
|
||||
# 读取并执行文件
|
||||
with open(file_path, 'r') as file:
|
||||
exec(file.read(), context)
|
||||
|
||||
# 获取 auxiliary_head 变量
|
||||
if var_name in context:
|
||||
return context[var_name]
|
||||
else:
|
||||
raise AttributeError(f"文件中没有找到 {var_name} 变量")
|
||||
|
||||
# 更新list的dict变量
|
||||
def update_list_dict_var(var, var_new):
|
||||
# 判断 var 和 var_new 的长度是否相等
|
||||
if len(var) != len(var_new):
|
||||
raise ValueError(f"var {len(var)} 和 var_new {len(var_new)} 的大小不相等,无法更新")
|
||||
|
||||
# 遍历 var_new,按索引更新 var
|
||||
for i, new_entry in enumerate(var_new):
|
||||
# 将 new_entry 中的键值覆盖 var[i] 中的相同键
|
||||
var[i].update(new_entry)
|
||||
return var
|
||||
|
||||
if __name__ == '__main__':
|
||||
# 示例调用
|
||||
var_value, var_type = get_var_from_file('./configs/_base_/models', 'norm_cfg')
|
||||
if var_value is not None:
|
||||
print(f"变量名: norm_cfg\n值: {var_value}\n类型: {var_type}")
|
||||
else:
|
||||
print("未找到指定变量")
|
||||
|
||||
# 将文件以dict格式输出
|
||||
def format_dict(d, indent_level=1):
|
||||
formatted_items = []
|
||||
indent = ' ' * indent_level # 根据缩进级别生成空格
|
||||
|
||||
for key, value in d.items():
|
||||
if isinstance(value, dict):
|
||||
# 如果值是字典,递归调用 format_dict_as_func 并增加缩进级别
|
||||
formatted_value = format_dict(value, indent_level + 1)
|
||||
formatted_items.append(f"{indent}{key}={formatted_value}")
|
||||
elif isinstance(value, str):
|
||||
# 如果是字符串,格式化时加引号
|
||||
formatted_items.append(f"{indent}{key}='{value}'")
|
||||
else:
|
||||
# 其他类型(如数值等)不加引号
|
||||
formatted_items.append(f"{indent}{key}={value}")
|
||||
|
||||
# 将所有键值对合成为多行的格式,并加上结尾逗号
|
||||
formatted_str = ",\n".join(formatted_items)
|
||||
|
||||
# 返回更美观的 dict() 的格式,保留缩进和换行
|
||||
return f"dict(\n{formatted_str},\n{' ' * (indent_level - 1)})"
|
||||
|
||||
# 所有格式正确输出
|
||||
def format_all_data_old(data, indent_level=0):
|
||||
indent = ' ' * indent_level # 根据缩进级别生成空格
|
||||
if isinstance(data, dict):
|
||||
formatted_items = []
|
||||
for key, value in data.items():
|
||||
formatted_value = format_all_data(value, indent_level + 1)
|
||||
formatted_items.append(f"{indent} {key}={formatted_value}")
|
||||
formatted_str = ",\n".join(formatted_items)
|
||||
return f"dict(\n{formatted_str},\n{indent})"
|
||||
|
||||
elif isinstance(data, list):
|
||||
formatted_items = [f"{indent} {format_all_data(item, indent_level + 1)}" for item in data]
|
||||
formatted_str = ",\n".join(formatted_items)
|
||||
return f"[\n{formatted_str},\n{indent}]"
|
||||
|
||||
elif isinstance(data, tuple):
|
||||
return f"{data}"
|
||||
|
||||
elif isinstance(data, str):
|
||||
# 如果字符串包含单引号,则使用双引号,否则使用单引号
|
||||
if "'" in data:
|
||||
return f'"{data}"'
|
||||
else:
|
||||
return f"'{data}'"
|
||||
|
||||
else:
|
||||
return str(data)
|
||||
|
||||
# 所有格式正确输出,加入键值中带有"."的处理
|
||||
def format_all_data(data, indent_level=0):
|
||||
indent = ' ' * indent_level # 根据缩进级别生成空格
|
||||
dot_key_items = [] # 用于存储带点号的键
|
||||
regular_items = [] # 用于存储普通键
|
||||
|
||||
if isinstance(data, dict):
|
||||
for key, value in data.items():
|
||||
formatted_value = format_all_data(value, indent_level + 1)
|
||||
# 如果键包含点号,收集到 dot_key_items 中
|
||||
if '.' in key:
|
||||
dot_key_items.append(f"('{key}', {formatted_value})")
|
||||
else:
|
||||
regular_items.append(f"{indent} {key}={formatted_value}")
|
||||
|
||||
# 生成带点号键的部分,并放在最前面
|
||||
if dot_key_items:
|
||||
dot_key_str = f"{indent} [{', '.join(dot_key_items)}]"
|
||||
else:
|
||||
dot_key_str = ""
|
||||
|
||||
# 生成普通键的部分
|
||||
regular_str = ",\n".join(regular_items)
|
||||
|
||||
# 返回最终的组合字符串,带点号键放在普通键前面
|
||||
if dot_key_str:
|
||||
return f"dict(\n{dot_key_str},\n{regular_str},\n{indent})"
|
||||
else:
|
||||
return f"dict(\n{regular_str},\n{indent})"
|
||||
|
||||
elif isinstance(data, list):
|
||||
formatted_items = [f"{indent} {format_all_data(item, indent_level + 1)}" for item in data]
|
||||
formatted_str = ",\n".join(formatted_items)
|
||||
return f"[\n{formatted_str},\n{indent}]"
|
||||
|
||||
elif isinstance(data, tuple):
|
||||
return f"{data}"
|
||||
|
||||
elif isinstance(data, str):
|
||||
# 如果字符串包含单引号,则使用双引号,否则使用单引号
|
||||
if "'" in data:
|
||||
return f'"{data}"'
|
||||
else:
|
||||
return f"'{data}'"
|
||||
|
||||
else:
|
||||
return str(data)
|
||||
|
||||
# # 批量将参数内容写入文件 # V1 传统版
|
||||
# def write_config_to_file(output_file, **kwargs):
|
||||
# """
|
||||
# 将传入的任意数量的参数写入指定文件,并格式化输出。
|
||||
|
||||
# :param output_file: 要写入的文件路径
|
||||
# :param kwargs: 任意数量的配置项,格式为 key=value
|
||||
# """
|
||||
# with open(output_file, 'w', encoding='utf-8') as file:
|
||||
# # 遍历 kwargs,将每个 key, value 写入文件
|
||||
# for key, value in kwargs.items():
|
||||
# file.write(f"{key} = {format_all_data(value)}\n\n")
|
||||
|
||||
# # 打印成功信息
|
||||
# print(f"\033[93m{output_file} file generated successfully\033[0m")
|
||||
|
||||
# 批量将参数内容写入文件 # V2 加入训练过程可视化 TODO
|
||||
def format_all_data(value):
|
||||
"""
|
||||
一个辅助函数,用于将 Python 对象格式化为适合写入配置文件的字符串。
|
||||
使用 repr() 可以确保字符串、列表、字典等都保持其 Python 语法格式。
|
||||
"""
|
||||
return repr(value)
|
||||
def write_config_to_file(output_file, **kwargs):
|
||||
"""
|
||||
将传入的任意数量的参数以及一个自动生成的 visualizer 配置写入指定文件。
|
||||
|
||||
:param output_file: 要写入的文件路径 (e.g., 'configs/exp1.py')
|
||||
:param kwargs: 任意数量的配置项,格式为 key=value
|
||||
"""
|
||||
# --- 1. 从 output_file 路径中提取实验名称 ---
|
||||
# 首先获取基本文件名 (e.g., 'exp1.py')
|
||||
base_name = os.path.basename(output_file)
|
||||
# 然后去掉文件扩展名,得到纯净的实验名 (e.g., 'exp1')
|
||||
experiment_name, _ = os.path.splitext(base_name)
|
||||
|
||||
# --- 2. 构建 visualizer 配置字典 ---
|
||||
vis_backends = [
|
||||
dict(type='LocalVisBackend'),
|
||||
dict(type='TensorboardVisBackend'),
|
||||
dict(
|
||||
type='WandbVisBackend',
|
||||
init_kwargs=dict(
|
||||
project='Seg_MMSeg_Test', # 你的 wandb 项目名称
|
||||
name=experiment_name # 使用上面提取的文件名作为实验名
|
||||
)
|
||||
)
|
||||
]
|
||||
visualizer = dict(
|
||||
name='visualizer',
|
||||
type='SegLocalVisualizer',
|
||||
vis_backends=vis_backends
|
||||
)
|
||||
|
||||
# --- 3. 将自动生成的 visualizer 添加到要写入的内容中 ---
|
||||
# 如果 kwargs 中已经有 'visualizer',它将会被新的配置覆盖
|
||||
kwargs['vis_backends'] = vis_backends
|
||||
kwargs['visualizer'] = visualizer
|
||||
|
||||
# --- 4. 将所有配置项写入文件 ---
|
||||
with open(output_file, 'w', encoding='utf-8') as file:
|
||||
# 遍历所有配置项(包括新加入的 visualizer),写入文件
|
||||
for key, value in kwargs.items():
|
||||
# 使用 format_all_data 保证格式正确
|
||||
file.write(f"{key} = {format_all_data(value)}\n\n")
|
||||
|
||||
# 打印成功信息
|
||||
print(f"\033[93mConfiguration saved to '{output_file}' successfully.\033[0m")
|
||||
print(f"\033[96mWandB experiment name will be: '{experiment_name}'\033[0m")
|
||||
|
||||
# 获取系统中所有 GPU 的可用显存信息。
|
||||
def get_gpu_info():
|
||||
"""
|
||||
获取系统中所有 GPU 的可用显存信息。
|
||||
:return: GPU 显存信息的列表,列表中的每一项是一个 (GPU编号, 剩余显存) 元组
|
||||
"""
|
||||
try:
|
||||
# 使用 nvidia-smi 命令获取 GPU 显存信息
|
||||
result = subprocess.run(
|
||||
['nvidia-smi', '--query-gpu=index,memory.free', '--format=csv,noheader,nounits'],
|
||||
stdout=subprocess.PIPE,
|
||||
encoding='utf-8'
|
||||
)
|
||||
|
||||
# 解析结果
|
||||
gpu_info = []
|
||||
lines = result.stdout.strip().split('\n')
|
||||
for line in lines:
|
||||
index, memory_free = line.split(',')
|
||||
gpu_info.append((int(index.strip()), int(memory_free.strip())))
|
||||
|
||||
return gpu_info
|
||||
except FileNotFoundError:
|
||||
print("\033[91mError: nvidia-smi 命令未找到,请确保 NVIDIA 驱动正确安装。\033[0m")
|
||||
return []
|
||||
|
||||
# 批量生成dict
|
||||
def create_dict_by_kwargs(**kwargs):
|
||||
# 批量生成字典,kwargs 会自动收集所有传入的命名参数
|
||||
return {key: value for key, value in kwargs.items() if value is not None}
|
||||
|
||||
if __name__ == '__main__':
|
||||
# 示例字典
|
||||
my_dict = {'pretrained': './My_Local_Model/open_mmlab/resnet50_v1c.pth'}
|
||||
|
||||
# 打印成 dict() 的格式
|
||||
print(format_dict(my_dict))
|
||||
@@ -0,0 +1,236 @@
|
||||
# 选择crop_size大小
|
||||
def select_crop_size(predefined_options=[(512, 512), (512, 1024), (769, 769)]):
|
||||
"""
|
||||
让用户选择裁剪大小。默认选择 (512, 512),用户可以选择其他预定义的裁剪大小,
|
||||
或者输入自定义的大小,且自定义的数字不能为负数。
|
||||
|
||||
:return: 选择的裁剪大小 (width, height)
|
||||
"""
|
||||
# 预定义的裁剪大小选项
|
||||
predefined_options = {str(i+1): option for i, option in enumerate(predefined_options)}
|
||||
|
||||
# 显示可选的裁剪大小
|
||||
print("可用的裁剪大小选项:")
|
||||
for key, value in predefined_options.items():
|
||||
print(f"{key}. {value}", end=" ")
|
||||
print(f"{len(predefined_options)+1}. 自定义大小")
|
||||
|
||||
# 用户选择
|
||||
choice = input("请选择裁剪大小选项 (默认 1): ").strip()
|
||||
|
||||
# 如果用户没有输入,使用默认值
|
||||
if choice == "" or choice == "1":
|
||||
return predefined_options["1"]
|
||||
|
||||
# 如果用户选择了预定义的选项
|
||||
if choice in predefined_options:
|
||||
return predefined_options[choice]
|
||||
|
||||
# 如果用户选择自定义大小
|
||||
if choice == f"{len(predefined_options)+1}":
|
||||
while True:
|
||||
try:
|
||||
width = int(input("请输入裁剪宽度 (正整数): "))
|
||||
height = int(input("请输入裁剪高度 (正整数): "))
|
||||
if width > 0 and height > 0:
|
||||
return (width, height)
|
||||
else:
|
||||
print("宽度和高度必须是正整数。")
|
||||
except ValueError:
|
||||
print("输入无效,请输入有效的正整数。")
|
||||
|
||||
# 如果用户输入无效,返回默认值
|
||||
print("无效选择,使用默认值 (512, 512)")
|
||||
return predefined_options["1"]
|
||||
|
||||
# 大字典,包含所有模型的信息
|
||||
pretrained_models_dict = {
|
||||
"openmmlab/resnet18_v1c": {'pth': './My_Local_Model/open_mmlab/resnet18_v1c.pth', 'depth': 18, 'type':'ResNetV1c'},
|
||||
"openmmlab/resnet50_v1c": {'pth': './My_Local_Model/open_mmlab/resnet50_v1c.pth', 'depth': 50, 'type':'ResNetV1c'},
|
||||
"openmmlab/resnet101_v1c": {'pth': './My_Local_Model/open_mmlab/resnet101_v1c.pth', 'depth': 101, 'type':'ResNetV1c'},
|
||||
"torchvision://resnet18": {'pth': './My_Local_Model/torchvision_012/resnet18.pth', 'depth': 18, 'type':'ResNet'},
|
||||
"torchvision://resnet50": {'pth': './My_Local_Model/torchvision_012/resnet50.pth', 'depth': 50, 'type':'ResNet'},
|
||||
"torchvision://resnet101": {'pth': './My_Local_Model/torchvision_012/resnet101.pth', 'depth': 101, 'type':'ResNet'},
|
||||
"openmmlab/pidnet-s":{'pth': './My_Local_Model/open_mmlab/pidnet-s.pth', 'type':'pidnet', 'size':'small'},
|
||||
"openmmlab/pidnet-m":{'pth': './My_Local_Model/open_mmlab/pidnet-m.pth', 'type':'pidnet', 'size':'medium'},
|
||||
"openmmlab/pidnet-l":{'pth': './My_Local_Model/open_mmlab/pidnet-l.pth', 'type':'pidnet', 'size':'large'},
|
||||
"openmmlab/ddrnet23-s":{'pth': './My_Local_Model/open_mmlab/ddrnet23-s.pth', 'type':'ddrnet', 'size':'small'},
|
||||
"openmmlab/ddrnet23":{'pth': './My_Local_Model/open_mmlab/ddrnet23.pth', 'type':'ddrnet', 'size':'normal'},
|
||||
"openmmlab/stdc1":{'pth': './My_Local_Model/open_mmlab/stdc1.pth', 'type':'stdc', 'size':'V1'},
|
||||
"openmmlab/stdc2":{'pth': './My_Local_Model/open_mmlab/stdc2.pth', 'type':'stdc', 'size':'V2'},
|
||||
"pretrain/vit-b16_p16_224-80ecf9dd.pth":{'pth': './My_Local_Model/pretrain/vit-b16_p16_224-80ecf9dd.pth'},
|
||||
"pretrain/beit_base_patch16_224_pt22k_ft22k.pth":{'pth': './My_Local_Model/pretrain/beit_base_patch16_224_pt22k_ft22k.pth'},
|
||||
"pretrain/beit_large_patch16_224_pt22k_ft22k.pth":{'pth': './My_Local_Model/pretrain/beit_large_patch16_224_pt22k_ft22k.pth'},
|
||||
"pretrain/swin_large-d5bdebaf.pth":{'pth': './My_Local_Model/pretrain/swin_large_patch4_window7_224_22k_20220308-d5bdebaf.pth', 'type':'swin', 'size':'large'},
|
||||
"pretrain/swin_tiny-f41b89d3.pth":{'pth': './My_Local_Model/pretrain/swin_tiny_patch4_window7_224_20220308-f41b89d3.pth', 'type':'swin', 'size':'tiny'},
|
||||
"mae_pretrain_vit_base_mmcls.pth":{'pth': './My_Local_Model/pretrain/mae_pretrain_vit_base_mmcls.pth'},
|
||||
"open-mmlab://msra/hrnetv2_w18":{'pth': './My_Local_Model/open_mmlab/msra/hrnetv2_w18.pth'},
|
||||
"open-mmlab://msra/hrnetv2_w18_small":{'pth': './My_Local_Model/open_mmlab/msra/hrnetv2_w18_small.pth'},
|
||||
'open-mmlab://msra/hrnetv2_w48':{'pth': './My_Local_Model/open_mmlab/msra/hrnetv2_w48.pth'},
|
||||
"pretrain/swin_large-6580f57d.pth":{'pth': './My_Local_Model/pretrain/swin_large_patch4_window12_384_22k_20220412-6580f57d.pth', 'type':'swin', 'size':'large'}, # mask2former
|
||||
"pretrain/swin_base-e5c09f74.pth":{'pth': './My_Local_Model/pretrain/swin_base_patch4_window12_384_22k_20220317-e5c09f74.pth', 'type':'swin', 'size':'base'}, # mask2former
|
||||
"pretrain/swin_small-7ba6d6dd.pth":{'pth': './My_Local_Model/pretrain/swin_small_patch4_window7_224_20220317-7ba6d6dd.pth', 'type':'swin', 'size':'small'}, # mask2former
|
||||
"pretrain/swin_tiny-1cdeb081.pth":{'pth': './My_Local_Model/pretrain/swin_tiny_patch4_window7_224_20220317-1cdeb081.pth', 'type':'swin', 'size':'tiny'}, # mask2former
|
||||
# 这里可以包含更多模型信息...
|
||||
}
|
||||
|
||||
# 选择预训练模型
|
||||
def select_pretrained_model(model_list, need_select_pretrained = False, pretrained_models_dict = pretrained_models_dict):
|
||||
"""
|
||||
让用户从给定的模型列表中选择预训练模型、是否选择预训练(否-默认开启预训练),并返回对应的 pth 路径和 其他 信息。
|
||||
|
||||
:param model_list: 可用的模型名称列表,例如 ['openmmlab/resnet50_v1c', 'openmmlab/resnet101_v1c']
|
||||
:return: (pretrained_pth, depth)
|
||||
"""
|
||||
|
||||
# 过滤传入的模型列表,确保它们在字典中有信息
|
||||
valid_models = {key: pretrained_models_dict[key] for key in model_list if key in pretrained_models_dict}
|
||||
|
||||
if not valid_models:
|
||||
print("错误:提供的模型列表中没有有效的模型信息。")
|
||||
return None, None
|
||||
|
||||
# 显示可用的预训练模型
|
||||
print("可用的预训练模型类型:")
|
||||
for i, (model_name, model_info) in enumerate(valid_models.items(), 1):
|
||||
print(f"{i}. {model_name}")
|
||||
|
||||
# 用户选择
|
||||
choice = input(f"请选择预训练模型编号 (1-{len(valid_models)}, 默认 1): ").strip()
|
||||
|
||||
# 如果用户没有输入,或输入无效,使用默认值
|
||||
if not choice.isdigit() or not (1 <= int(choice) <= len(valid_models)):
|
||||
choice = "1"
|
||||
|
||||
# 获取用户选择的模型信息
|
||||
selected_model_name = list(valid_models.keys())[int(choice) - 1]
|
||||
selected_model_info = valid_models[selected_model_name]
|
||||
|
||||
# TODO 选择特定信息
|
||||
# # 返回模型的 pth 路径和 depth
|
||||
pretrained_pth = selected_model_info['pth']
|
||||
# depth = selected_model_info['depth']
|
||||
|
||||
# print(f" 已选择模型: {selected_model_name} depth: {depth} pth: {pretrained_pth}")
|
||||
print(f" 已选择模型: {selected_model_name} pth: {pretrained_pth}")
|
||||
|
||||
if need_select_pretrained == False:
|
||||
print("默认开启预训练")
|
||||
select_pretrained = True
|
||||
return selected_model_name, select_pretrained, pretrained_pth, selected_model_info
|
||||
else:
|
||||
# 提示用户输入 True 或 False,或者直接按 Enter 默认使用预训练模型
|
||||
choice = input("是否使用预训练模型?输入 Y(使用)或 N(不使用),直接按 Enter 默认使用:True:")
|
||||
while True:
|
||||
# 如果用户没有输入,默认使用预训练模型
|
||||
if choice == '':
|
||||
select_pretrained = True
|
||||
break
|
||||
# 转换输入为布尔值
|
||||
elif choice.lower() == 'y':
|
||||
select_pretrained = True
|
||||
break
|
||||
elif choice.lower() == 'n':
|
||||
select_pretrained = False
|
||||
break
|
||||
else:
|
||||
print("无效输入,请输入 'True' 或 'False',或直接按 Enter 选择默认值")
|
||||
|
||||
return selected_model_name, select_pretrained, pretrained_pth, selected_model_info
|
||||
|
||||
# 大字典,包含所有模型的信息
|
||||
samplers_dict = {
|
||||
"OHEMPixelSampler": dict(type='OHEMPixelSampler', thresh=0.7, min_kept=10000),
|
||||
# 这里可以包含更多采样函数模型信息...
|
||||
}
|
||||
|
||||
# 选择采样函数
|
||||
def select_sampler(sampler_list, samplers_dict = samplers_dict):
|
||||
|
||||
# 提示用户输入 True 或 False,或者直接按 Enter 默认使用预训练模型
|
||||
choice = input("是否使用采样函数?输入 Y(使用)或 N(不使用),直接按 Enter 默认不使用:False:")
|
||||
while True:
|
||||
# 如果用户没有输入,默认使用预训练模型
|
||||
if choice == '':
|
||||
use_sampler = False
|
||||
return None, use_sampler, None
|
||||
# 转换输入为布尔值
|
||||
elif choice.lower() == 'y':
|
||||
use_sampler = True
|
||||
break
|
||||
elif choice.lower() == 'n':
|
||||
use_sampler = False
|
||||
return None, use_sampler, None
|
||||
else:
|
||||
print("无效输入,请输入 'True' 或 'False',或直接按 Enter 选择默认值")
|
||||
|
||||
# 过滤传入的模型列表,确保它们在字典中有信息
|
||||
valid_samplers = {key: samplers_dict[key] for key in sampler_list if key in samplers_dict}
|
||||
|
||||
if not valid_samplers:
|
||||
print("错误:提供的采样函数列表中没有有效的采样函数信息。")
|
||||
return None, None
|
||||
|
||||
# 如果只有一个可用的采样函数,直接选择
|
||||
if len(valid_samplers) != 1:
|
||||
# 显示可用的采样函数
|
||||
print("可用的采样函数:")
|
||||
for i, (sampler_name, sampler_info) in enumerate(valid_samplers.items(), 1):
|
||||
print(f"{i}. {sampler_name}")
|
||||
|
||||
# 用户选择
|
||||
choice = input(f"请选择采样函数编号 (1-{len(valid_samplers)}, 默认 1): ").strip()
|
||||
|
||||
# 如果用户没有输入,或输入无效,使用默认值
|
||||
if not choice.isdigit() or not (1 <= int(choice) <= len(valid_samplers)):
|
||||
choice = "1"
|
||||
else:
|
||||
choice = "1"
|
||||
|
||||
# 获取用户选择的模型信息
|
||||
selected_sampler_name = list(valid_samplers.keys())[int(choice) - 1]
|
||||
selected_sampler_info = valid_samplers[selected_sampler_name]
|
||||
|
||||
print(f" 已选择采样函数: {selected_sampler_name}")
|
||||
|
||||
return selected_sampler_name, use_sampler, selected_sampler_info
|
||||
|
||||
# 选择test_cfg中是否滑动,是否默认选择select_slide
|
||||
def select_test_cfg_slide(crop_size, select_slide=False):
|
||||
"""
|
||||
让用户选择是否使用滑动窗口,并根据选择设置相应的模式和参数。
|
||||
|
||||
:param crop_size: 输入的裁剪大小 (宽, 高)
|
||||
:return: test_cfg_mode, test_cfg_crop_div_stride, crop_size
|
||||
"""
|
||||
if select_slide == False:
|
||||
# 提示用户是否选择滑动窗口模式
|
||||
use_slide = input("是否选择滑动窗口模式?(y/n, 默认 n): ").strip().lower()
|
||||
else:
|
||||
use_slide = 'y' # 使用滑动窗口模式
|
||||
|
||||
# 默认不使用滑动窗口模式
|
||||
if use_slide == 'y':
|
||||
test_cfg_mode = 'slide'
|
||||
# 提示用户输入 test_cfg_crop_div_stride,默认值为 1.5
|
||||
try:
|
||||
test_cfg_crop_div_stride = input("请输入滑动窗口的 stride 除以 crop_size 比例 (默认 1.5): ").strip()
|
||||
test_cfg_crop_div_stride = float(test_cfg_crop_div_stride) if test_cfg_crop_div_stride else 1.5
|
||||
except ValueError:
|
||||
print("输入无效,使用默认比例 1.5")
|
||||
test_cfg_crop_div_stride = 1.5
|
||||
|
||||
# 计算 stride
|
||||
stride = tuple(int(c / test_cfg_crop_div_stride) for c in crop_size)
|
||||
print(f" 已选择滑动窗口模式: {test_cfg_mode}")
|
||||
print(f"crop_size: {crop_size}")
|
||||
print(f"stride: {stride}")
|
||||
else:
|
||||
test_cfg_mode = None # 默认模式
|
||||
test_cfg_crop_div_stride = None
|
||||
stride = None
|
||||
print(f" 关闭滑动窗口模式")
|
||||
|
||||
return test_cfg_mode, test_cfg_crop_div_stride
|
||||
|
||||
|
||||
@@ -0,0 +1,46 @@
|
||||
from PIL import Image
|
||||
import os
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
|
||||
def calculate_pic_std_and_mean(dataset_dir = r'./My_Data/A_Ori'):
|
||||
# 获取所有jpg图像文件
|
||||
image_files = [os.path.join(dataset_dir, filename) for filename in os.listdir(dataset_dir) if filename.lower().endswith(('.jpg', '.png', '.tiff', '.jpeg', '.bmp'))]
|
||||
|
||||
# 初始化用于存储累积的像素值
|
||||
sum_pixels_normalized = np.zeros(3)
|
||||
sum_squared_pixels_normalized = np.zeros(3)
|
||||
num_pixels = 0
|
||||
|
||||
# 使用tqdm创建一个进度条
|
||||
for image_file in tqdm(image_files, desc="Calculating mean and std"):
|
||||
image = Image.open(image_file).convert('RGB') # 确保图像为RGB
|
||||
image = np.array(image) # 原图像像素范围[0, 255]
|
||||
|
||||
# 归一化到[0, 1]范围
|
||||
image_normalized = image / 255.0
|
||||
|
||||
# 累积归一化像素值和归一化像素平方值
|
||||
sum_pixels_normalized += np.sum(image_normalized, axis=(0, 1)) # 按通道累积
|
||||
sum_squared_pixels_normalized += np.sum(image_normalized ** 2, axis=(0, 1)) # 按通道累积像素平方值
|
||||
num_pixels += image.shape[0] * image.shape[1] # 累积总像素数
|
||||
|
||||
# 计算整个数据集的归一化后的均值
|
||||
mean_normalized = sum_pixels_normalized / num_pixels
|
||||
|
||||
# 计算整个数据集的归一化后的标准差
|
||||
variance_normalized = sum_squared_pixels_normalized / num_pixels - mean_normalized ** 2
|
||||
variance_normalized = np.maximum(variance_normalized, 0) # 防止负数
|
||||
std_normalized = np.sqrt(variance_normalized)
|
||||
|
||||
# 反归一化回[0, 255]范围
|
||||
mean = mean_normalized * 255
|
||||
std = std_normalized * 255
|
||||
|
||||
print(f"\033[93m计算得图片均值-Mean: {mean} 计算得图片方差-Std: {std}\033[0m")
|
||||
|
||||
return mean, std
|
||||
|
||||
if __name__ == '__main__':
|
||||
# 示例调用
|
||||
calculate_pic_std_and_mean()
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user