first commit

This commit is contained in:
admin
2026-05-20 15:05:35 +08:00
commit ac09b26253
2048 changed files with 189478 additions and 0 deletions

92
.gitignore vendored Normal file
View File

@@ -0,0 +1,92 @@
# Python caches
__pycache__/
*.py[cod]
*$py.class
.pytest_cache/
.mypy_cache/
.ruff_cache/
.ipynb_checkpoints/
# Conda / virtual environments
.conda/
.venv/
venv/
env/
# Build artifacts
build/
dist/
*.egg-info/
# Logs and temporary files
*.log
.goutputstream-*
*~
*.tmp
*.bak
*.swp
logs_parallel_*/
predict_logs_parallel_*/
yolo_train_logs_parallel_*/
yolo_predict_logs_parallel_*/
wandb/
mlruns/
.cache/
tmp/
temp/
# Large datasets, predictions, training outputs, and backups
DataSet_Public/
DataSet_Public_outputs/
BestMode_Predict_Results_DataSet_Public/
Hardisk/
Nas_BackUp_Seg/
# Keep DataSet_Own preprocessing code/manuals, ignore actual image data
DataSet_Own/*
!DataSet_Own/1. 图片预处理(内含使用手册)/
!DataSet_Own/1. 图片预处理(内含使用手册)/**
DataSet_Own/1. 图片预处理(内含使用手册)/error*.txt
# Local model caches and heavy pretrained weights
Seg_All_In_One_MMSeg/My_Local_Model/
Seg_All_In_One_MMSeg/work_dirs/
Seg_All_In_One_MMSeg/flops_results/
Seg_All_In_One_MMSeg/tests/data/
# Generated analysis figures; keep CSV/SVG summaries
Seg_All_In_One_Analysis/*/*.png
# Demo/output media and generated visual data
Seg_Predict_Own_Video_V2/*.mp4
Seg_Predict_YoloModel/*.mp4
Seg_Predict_YoloModel/output_*/
Seg_Predict_YoloModel/YOLO*/
Seg_All_In_One_YoloModel/Yolo数据集构建/Data/
Seg_All_In_One_YoloModel/Yolo数据集构建/Label/
Seg_All_In_One_YoloModel/Yolo数据集构建/ORI/
Seg_All_In_One_YoloModel/Yolo数据集构建/ORI_GT_label_fold/
Seg_All_In_One_YoloModel/Yolo数据集构建/ORI_pro_label_fold/
Tool-图片堆叠/ori/
Tool-图片堆叠/label/
Tool-图片堆叠/result_*/
Tool-可视化/Data/
Tool-可视化/runs/
Tool-可视化/0_图片Labels生成/save_*_label_fold/
Tool-可视化/0_图片Labels生成/*.png
# Large model/video/archive formats
*.pt
*.pth
*.onnx
*.engine
*.mp4
*.avi
*.mov
*.mkv
*.zip
*.tar
*.tar.gz
*.tgz
*.7z
*.rar

104
Back_Up.sh Normal file
View File

@@ -0,0 +1,104 @@
#!/usr/bin/env bash
# sync_interactive.sh —— 交互式同步脚本 (v3)
# 脚本会自动以其自身所在的位置为根目录,并提供不同的同步备份模式。
set -euo pipefail
# --- 根目录设置 ---
# 获取脚本文件所在的绝对路径,并将其作为所有操作的根目录
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
# --- 配置区域 ---
# 算法源目录 (路径基于脚本位置)
SRC_DIRS=(
"$SCRIPT_DIR/Seg_All_In_One_YoloModel"
"$SCRIPT_DIR/Seg_All_In_One_SegModel"
"$SCRIPT_DIR/Seg_All_In_One_MMSeg"
"$SCRIPT_DIR/Seg_All_In_One_Analysis"
)
# 本地镜像/中转目录 (路径基于脚本位置)
LOCAL_DST_ROOT="$SCRIPT_DIR/Hardisk"
# NAS备份目标目录 (路径基于脚本位置)
NAS_DST_ROOT="$SCRIPT_DIR/Nas_BackUp_Seg"
# --- 用户选择操作 ---
echo "--- 请选择要执行的同步操作 ---"
echo " 1. [更新并备份算法文件] 从源头更新算法文件到Hardisk并立即备份到NAS"
echo " 2. [备份Hardisk] 将整个Hardisk目录的当前内容完全拷贝到NAS【--delete】"
echo " 3. [退出] 不执行任何操作"
echo "------------------------------------------------"
read -p "请输入选项 [1, 2, 或 3]: " choice
case $choice in
1)
echo "--- 您选择了 [1]: 更新并备份算法文件 (源->Hardisk->NAS) ---"
# --- 第 1 步: 从源目录更新文件到 Hardisk ---
echo ""
echo "--> (1/2) 正在从源目录更新文件到 $LOCAL_DST_ROOT..."
for src_path in "${SRC_DIRS[@]}"; do
if [ ! -d "$src_path" ]; then
echo " 警告: 源目录 '$src_path' 不存在,已跳过。"
continue
fi
dst_dir_name=$(basename "$src_path")
dst_path=$(mkdir -p "$LOCAL_DST_ROOT/$dst_dir_name" && realpath "$LOCAL_DST_ROOT/$dst_dir_name")
echo " >>> 正在同步 $src_path -> $dst_path"
rsync -avh --delete "$src_path/" "$dst_path/"
done
echo "--> (1/2) 本地 Hardisk 更新完成。"
# --- 第 2 步: 从 Hardisk 备份到 NAS ---
echo ""
echo "--> (2/2) 正在将更新后的算法文件从 Hardisk 备份到 $NAS_DST_ROOT..."
for dir_full_path in "${SRC_DIRS[@]}"; do
dir_name=$(basename "$dir_full_path")
src_from_hardisk="$LOCAL_DST_ROOT/$dir_name"
dst_to_nas="$NAS_DST_ROOT/$dir_name"
if [ ! -d "$src_from_hardisk" ]; then
echo " 警告: 源目录 '$src_from_hardisk' 在 Hardisk 中不存在,已跳过备份。"
continue
fi
mkdir -p "$dst_to_nas"
dst_path_final=$(realpath "$dst_to_nas")
echo " >>> 正在备份 $src_from_hardisk -> $dst_path_final"
rsync -avh --delete "$src_from_hardisk/" "$dst_path_final/"
done
echo "--> (2/2) 指定的算法文件已成功备份到 NAS"
;;
2)
echo "--- 您选择了 [2]: 仅备份Hardisk (Hardisk->NAS) ---"
src="$LOCAL_DST_ROOT/"
dst="$NAS_DST_ROOT/"
if [ ! -d "$src" ]; then
echo "错误: 源目录 '$src' 不存在,无法继续。"
exit 1
fi
mkdir -p "$dst"
echo ">>> 正在将 $src 的全部内容备份到 $dst"
rsync -avh --delete "$src" "$dst" # 使用增量复制
echo ">>> Hardisk 目录已完全备份到 NAS"
;;
3)
echo "--- 您选择了 [3]: 退出 ---"
echo "操作已取消。"
;;
*)
echo "无效选项 '$choice'。请输入 1, 2, 或 3。"
exit 1
;;
esac
echo ""
echo ">>> 全部任务完成!"

56
Check_Graph_Card.sh Normal file
View File

@@ -0,0 +1,56 @@
#!/usr/bin/env bash
# 检查系统识别到的 GPU 卡数和驱动状态
echo "======== PCIe 在位 =========="
lspci -d 10de: | grep -i vga
echo
echo "======== 驱动认到几卡 ========"
nvidia-smi --list-gpus | wc -l
echo
# nvidia: probe of 0000:04:06.0 failed with error -1
# NVRM: The NVIDIA probe routine failed for 1 device(s).
echo "======== dmesg 关键报错 ======"
dmesg | grep -iE 'nvidia.*fail|nvidia.*error|Xid.*79|GSP.*timeout' | tail -10
# 当前信息记录:
# Ubuntu中识别到的7张卡
# 卡000000000:04:00.0
# 卡100000000:04:02.0
# 卡200000000:04:04.0
# 坏卡-第一次卡300000000:04:06.0
# 卡400000000:05:00.0
# 坏卡-第二次卡500000000:05:02.0
# 卡600000000:05:04.0
# 卡700000000:05:06.0
# Exsi服务器识别到的7张卡
# 卡00000:16:00.0
# 卡10000:38:00.0
# 卡20000:49:00.0
# 坏卡卡槽-第一次卡30000:5a:00.0
# 卡40000:98:00.0
# 坏卡卡槽-第二次卡50000:b8:00.0
# 卡60000:c8:00.0
# 卡70000:d8:00.0
# 解决方案尝试
# V1.
# 1. 关闭图形界面
# sudo systemctl set-default multi-user.target
# sudo systemctl set-default multi-user.target
# 2. 立刻关 GSP 重新加载驱动
# sudo modprobe -r nvidia_drm nvidia_modeset nvidia nvidia_uvm
# sudo modprobe nvidia NVreg_EnableGpuFirmware=0
# sudo nvidia-smi
# 3. 若 8 卡出现 → 就是 GSP 问题,长期生效:
# echo "options nvidia NVreg_EnableGpuFirmware=0" | sudo tee /etc/modprobe.d/nvidia-disable-gsp.conf
# sudo update-initramfs -u
# V2.
# 1. 看是不是 BAR 空间不足
# sudo dmesg | grep -i "BAR 0\|resource 0" | grep 04:06.0 # TODO 变为对应的显卡号
# [ 4.747643] pci 0000:04:06.0: BAR 0 [mem 0xea000000-0xeaffffff]
# 内核已经成功为 04:06.0 分配了 BAR 0大小 16 MiB没有报 “cant allocate” 或 “failed”因此 BAR 空间不足/Above 4G Decoding 问题可以排除
# V3.
# 彻底冷复位
# 宿主机或 云控制台 → 断电 10 秒再上电
#

View File

@@ -0,0 +1,308 @@
#!/bin/bash
usage() {
echo "Usage: $0 -i <ori_image_directory> -l <ori_label_directory> [-h]"
echo "对image图片和label图片进行处理"
echo "-i:原始图片的路径,-l:原始标签的路径,-h:帮助"
}
ori_image_directorys=""
ori_label_directorys=""
while getopts "hl:i:" opt; do
case $opt in
h)
usage
exit 0
;;
i)
ori_image_directorys=$OPTARG
;;
l)
ori_label_directorys=$OPTARG
;;
*)
echo -e '\033[31m!!! Error, Illegal input !!!\033[0m'
usage
exit 1
;;
esac
done
# 判断输入地址是否都为空
if [ -z "$ori_label_directorys" ] && [ -z "$ori_image_directorys" ]; then
echo -e "\033[31m输入地址 -i -l 都为空\033[0m"
usage
exit 1
fi
# 地址转化
ori_image_directory=$(readlink -f "$ori_image_directorys")
ori_label_directory=$(readlink -f "$ori_label_directorys")
if [ -z "$ori_label_directory" ] && [ -z "$ori_image_directory" ]; then
echo "无法解析地址,程序退出"
echo -e "\033[31mori_image_directory\033[0m: $ori_image_directory"
echo -e "\033[31mori_label_directory\033[0m: $ori_label_directorys"
exit 1
fi
if [ ! -d "$ori_label_directory" ] && [ ! -d "$ori_image_directory" ]; then
echo "image、label两目录都不存在程序退出"
echo -e "\033[31mori_image_directory\033[0m: $ori_image_directory"
echo -e "\033[31mori_label_directory\033[0m: $ori_label_directorys"
exit 1
fi
# 获取当前脚本的路径和名称
script_path=$(dirname "$0")
# 将当前目录更改为脚本所在的路径
cd "$script_path"
echo -e "\033[32m______ 1_rename_data.sh _____\033[0m"
while true; do
PS3='Please enter your choice: '
options=("Move all pics in dir to dir" "Delete all space in filename" "Add jpg to no suffix filename" "Replace content in filename" "Extract filename between prefix and suffix" "Add content before filename" "Add content behand filename" "Quit")
echo "一般处理流程1(移动文件)、2(删除所有空格)、3(文件后缀添加.jpg)、4(删除内容)\"-恢复的\"/\"-副本\"、5(取出前缀、后缀中的内容)(方案1:Still/\"\"方案2:\"\"/.Still)、6(添加前缀)、7(添加后缀)、8(关闭)"
select opt in "${options[@]}"
do
case $opt in
### 选项1将所有在目录下的图片移动到目录中 ###
"Move all pics in dir to dir")
# 判断label图片路径是否存在
if [ -d "$ori_label_directory" ]; then
echo "**** Processing ori_label_directory: $ori_label_directory ****"
# 遍历 标签图片 目录中的所有后缀为 .png 或 .PNG 的文件,将其移动到主目录
find "$ori_label_directory" -not -path "*/error/*" \( -iname "*.png" -o -iname "*.PNG" \) -type f -print0 |
while IFS= read -r -d $'\0' file; do
echo cp -n "$file" "$ori_label_directory"
cp -n "$file" "$ori_label_directory"
done
fi
echo ""
# 判断image图片路径是否存在
if [ -d "$ori_image_directory" ]; then
echo "**** Processing ori_image_directory: $ori_image_directory ****"
# 遍历 原图片 目录中的所有后缀为 .png 或 .PNG 的文件,将其移动到主目录
find "$ori_image_directory" -not -path "*/error/*" \( -iname "*.png" -o -iname "*.PNG" \) -type f -print0 |
while IFS= read -r -d $'\0' file; do
echo cp -n "$file" "$ori_image_directory"
cp -n "$file" "$ori_image_directory"
done
fi
echo -e ""
break
;;
### 选项2替换文件名中的空格 ###
"Delete all space in filename")
# 输入待删除的内容
Del_str=" "
Replace_str=""
# 判断image图片路径是否存在
if [ -d "$ori_image_directory" ]; then
echo "**** Processing ori_image_directory: $ori_image_directory ****"
ls "$ori_image_directory" | grep -e "$Del_str" | awk -v ori_image_directory="$ori_image_directory" -v Replace_str="$Replace_str" -F "$Del_str" '{s1=$0; gsub(/ /, ""); print "mv \""ori_image_directory"/"s1"\" \""ori_image_directory"/"$1""Replace_str""$2"\""}'
ls "$ori_image_directory" | grep -e "$Del_str" | awk -v ori_image_directory="$ori_image_directory" -v Replace_str="$Replace_str" -F "$Del_str" '{s1=$0; gsub(/ /, ""); print "mv \""ori_image_directory"/"s1"\" \""ori_image_directory"/"$1""Replace_str""$2"\""}' | bash
echo -e ""
else
echo "**** image图片目录不存在: $ori_image_directory ****"
echo -e ""
fi
# 判断label图片路径是否存在
if [ -d "$ori_label_directory" ]; then
echo "**** Processing ori_label_directory: $ori_label_directory ****"
ls "$ori_label_directory" | grep -e "$Del_str" | awk -v ori_label_directory="$ori_label_directory" -v Replace_str="$Replace_str" -F "$Del_str" '{s1=$0; gsub(/ /, ""); print "mv \""ori_label_directory"/"s1"\" \""ori_label_directory"/"$1""Replace_str""$2"\""}'
ls "$ori_label_directory" | grep -e "$Del_str" | awk -v ori_label_directory="$ori_label_directory" -v Replace_str="$Replace_str" -F "$Del_str" '{s1=$0; gsub(/ /, ""); print "mv \""ori_label_directory"/"s1"\" \""ori_label_directory"/"$1""Replace_str""$2"\""}' | bash
echo -e ""
else
echo "**** label图片目录不存在: $ori_label_directory ****"
echo -e ""
fi
break
;;
### 选项3在无后缀文件后添加.jpg后缀 ###
"Add jpg to no suffix filename")
# 判断image图片路径是否存在
if [ -d "$ori_image_directory" ]; then
echo "**** Processing ori_image_directory: $ori_image_directory ****"
find "$ori_image_directory" -type f ! -name "*.*" | awk -F/ '{print $NF}' | awk -v ori_image_directory="$ori_image_directory" '{print "mv \""ori_image_directory"/"$0"\" \""ori_image_directory"/"$0".jpg\""}'
find "$ori_image_directory" -type f ! -name "*.*" | awk -F/ '{print $NF}' | awk -v ori_image_directory="$ori_image_directory" '{print "mv \""ori_image_directory"/"$0"\" \""ori_image_directory"/"$0".jpg\""}' | bash
echo -e ""
else
echo "**** image图片目录不存在: $ori_image_directory ****"
echo -e ""
fi
# 判断label图片路径是否存在
if [ -d "$ori_label_directory" ]; then
echo "**** Processing ori_label_directory: $ori_label_directory ****"
find "$ori_label_directory" -type f ! -name "*.*" | awk -F/ '{print $NF}' | awk -v ori_label_directory="$ori_label_directory" '{print "mv \""ori_label_directory"/"$0"\" \""ori_label_directory"/"$0".jpg\""}'
find "$ori_label_directory" -type f ! -name "*.*" | awk -F/ '{print $NF}' | awk -v ori_label_directory="$ori_label_directory" '{print "mv \""ori_label_directory"/"$0"\" \""ori_label_directory"/"$0".jpg\""}' | bash
echo -e ""
else
echo "**** label图片目录不存在: $ori_label_directory ****"
echo -e ""
fi
break
;;
### 选项4替换文件名中内容 ###
"Replace content in filename")
# 输入待删除的内容
echo -n "Please input the content to be deleted = "
read -r Del_str
echo -n "Please input the content to be replace(default is None) = "
read -r Replace_str
# 判断image图片路径是否存在
if [ -d "$ori_image_directory" ]; then
echo "**** Processing ori_image_directory: $ori_image_directory ****"
ls "$ori_image_directory" | grep -E "\.(png|jpg|PNG|JPG|BMP|bmp)$" | grep -e "$Del_str" | awk -v ori_image_directory="$ori_image_directory" -v Del_str="$Del_str" -v Replace_str="$Replace_str" -F "$Del_str" '{s1=$0; gsub(Del_str, Replace_str); print "mv \""ori_image_directory"/"s1"\" \""ori_image_directory"/"$0"\""}'
ls "$ori_image_directory" | grep -E "\.(png|jpg|PNG|JPG|BMP|bmp)$" | grep -e "$Del_str" | awk -v ori_image_directory="$ori_image_directory" -v Del_str="$Del_str" -v Replace_str="$Replace_str" -F "$Del_str" '{s1=$0; gsub(Del_str, Replace_str); print "mv \""ori_image_directory"/"s1"\" \""ori_image_directory"/"$0"\""}' | bash
echo -e ""
else
echo "**** image图片目录不存在: $ori_image_directory ****"
echo -e ""
fi
# 判断label图片路径是否存在
if [ -d "$ori_label_directory" ]; then
echo "**** Processing ori_label_directory: $ori_label_directory ****"
ls "$ori_label_directory" | grep -E "\.(png|jpg|PNG|JPG|BMP|bmp)$" | grep -e "$Del_str" | awk -v ori_label_directory="$ori_label_directory" -v Del_str="$Del_str" -v Replace_str="$Replace_str" -F "$Del_str" '{s1=$0; gsub(Del_str, Replace_str); print "mv \""ori_label_directory"/"s1"\" \""ori_label_directory"/"$0"\""}'
ls "$ori_label_directory" | grep -E "\.(png|jpg|PNG|JPG|BMP|bmp)$" | grep -e "$Del_str" | awk -v ori_label_directory="$ori_label_directory" -v Del_str="$Del_str" -v Replace_str="$Replace_str" -F "$Del_str" '{s1=$0; gsub(Del_str, Replace_str); print "mv \""ori_label_directory"/"s1"\" \""ori_label_directory"/"$0"\""}' | bash
echo -e ""
else
echo "**** label图片目录不存在: $ori_label_directory ****"
echo -e ""
fi
break
;;
### 选项5删除文件名前缀、后缀之间的内容 ###
"Extract filename between prefix and suffix")
# 输入待删除的内容
echo -n "Please input the prefix to be deleted = "
read -r prefix
echo -n "Please input the suffix to be deleted = "
read -r suffix
# 判断image图片路径是否存在
if [ -d "$ori_image_directory" ]; then
echo "**** Processing ori_image_directory: $ori_image_directory ****"
# ls -1 以回车显示
ls -1 "$ori_image_directory" | grep -E "\.(png|jpg|PNG|JPG|BMP|bmp)$" | grep -e "$prefix" | grep -e "$suffix" | while read file_name; do
# 提取出新的文件名
if [ -z $prefix ];then
file_name_new=$(echo $file_name | sed "s/\(.*\)$suffix.*\(\.jpg\|\.png\|\.bmp\|\.JPG\|\.PNG\|\.BMP\)$/\1\2/")
else
echo "sed "s/^.*$prefix\(.*\)$suffix.*\(\.jpg\|\.png\|\.bmp\|\.JPG\|\.PNG\|\.BMP\)$/\1\2/""
file_name_new=$(echo $file_name | sed "s/^.*$prefix\(.*\)$suffix.*\(\.jpg\|\.png\|\.bmp\|\.JPG\|\.PNG\|\.BMP\)$/\1\2/")
fi
echo "$file_name -> $file_name_new"
echo "mv "$ori_image_directory/$file_name" "$ori_image_directory/$file_name_new""
mv "$ori_image_directory/$file_name" "$ori_image_directory/$file_name_new"
done
# ls "$ori_image_directory" | grep -e "$Del_str" | awk -v ori_image_directory="$ori_image_directory" -v Del_str="$Del_str" '{ s1=$0; match($0, Del_str); sub(substr($0, 1, RSTART + RLENGTH - 1), ""); print "mv \""ori_image_directory"/"s1"\" \""ori_image_directory"/"$0"\""}'
# ls "$ori_image_directory" | grep -e "$Del_str" | awk -v ori_image_directory="$ori_image_directory" -v Del_str="$Del_str" '{ s1=$0; match($0, Del_str); sub(substr($0, 1, RSTART + RLENGTH - 1), ""); print "mv \""ori_image_directory"/"s1"\" \""ori_image_directory"/"$0"\""}' | bash
echo -e ""
else
echo "**** image图片目录不存在: $ori_image_directory ****"
echo -e ""
fi
# 判断label图片路径是否存在
if [ -d "$ori_label_directory" ]; then
echo "**** Processing ori_label_directory: $ori_label_directory ****"
ls -1 "$ori_label_directory" | grep -E "\.(png|jpg|PNG|JPG|BMP|bmp)$" | grep -e "$prefix" | grep -e "$suffix" | while read file_name; do
# 提取出新的文件名
if [ -z $prefix ];then
file_name_new=$(echo $file_name | sed "s/\(.*\)$suffix.*\(\.jpg\|\.png\|\.bmp\|\.JPG\|\.PNG\|\.BMP\)$/\1\2/")
else
echo "sed "s/^.*$prefix\(.*\)$suffix.*\(\.jpg\|\.png\|\.bmp\|\.JPG\|\.PNG\|\.BMP\)$/\1\2/""
file_name_new=$(echo $file_name | sed "s/^.*$prefix\(.*\)$suffix.*\(\.jpg\|\.png\|\.bmp\|\.JPG\|\.PNG\|\.BMP\)$/\1\2/")
fi
echo "$file_name -> $file_name_new"
echo "mv "$ori_label_directory/$file_name" "$ori_label_directory/$file_name_new""
mv "$ori_label_directory/$file_name" "$ori_label_directory/$file_name_new"
done
# ls "$ori_label_directory" | grep -e "$Del_str" | awk -v ori_label_directory="$ori_label_directory" -v Del_str="$Del_str" '{ s1=$0; match($0, Del_str); sub(substr($0, 1, RSTART + RLENGTH - 1), ""); print "mv \""ori_label_directory"/"s1"\" \""ori_label_directory"/"$0"\""}'
# ls "$ori_label_directory" | grep -e "$Del_str" | awk -v ori_label_directory="$ori_label_directory" -v Del_str="$Del_str" '{ s1=$0; match($0, Del_str); sub(substr($0, 1, RSTART + RLENGTH - 1), ""); print "mv \""ori_label_directory"/"s1"\" \""ori_label_directory"/"$0"\""}' | bash
echo -e ""
else
echo "**** label图片目录不存在: $ori_label_directory ****"
echo -e ""
fi
break
;;
### 选项6在文件名前添加内容 ###
"Add content before filename")
echo -n "Please input the content to be added = "
read Group_str
# 判断image图片路径是否存在
if [ -d "$ori_image_directory" ]; then
echo "**** Processing ori_image_directory: $ori_image_directory ****"
ls "$ori_image_directory" | grep -E "\.(png|jpg|PNG|JPG|BMP|bmp)$" | awk -v ori_image_directory="$ori_image_directory" -v Group_str="$Group_str" '{print "mv \""ori_image_directory"/"$0"\" \""ori_image_directory"/"Group_str""$0"\""}'
ls "$ori_image_directory" | grep -E "\.(png|jpg|PNG|JPG|BMP|bmp)$" | awk -v ori_image_directory="$ori_image_directory" -v Group_str="$Group_str" '{print "mv \""ori_image_directory"/"$0"\" \""ori_image_directory"/"Group_str""$0"\""}' | bash
echo -e ""
else
echo "**** image图片目录不存在: $ori_image_directory ****"
echo -e ""
fi
# 判断label图片路径是否存在
if [ -d "$ori_label_directory" ]; then
echo "**** Processing ori_label_directory: $ori_label_directory ****"
ls "$ori_label_directory" | grep -E "\.(png|jpg|PNG|JPG|BMP|bmp)$" | awk -v ori_label_directory="$ori_label_directory" -v Group_str="$Group_str" '{print "mv \""ori_label_directory"/"$0"\" \""ori_label_directory"/"Group_str""$0"\""}'
ls "$ori_label_directory" | grep -E "\.(png|jpg|PNG|JPG|BMP|bmp)$" | awk -v ori_label_directory="$ori_label_directory" -v Group_str="$Group_str" '{print "mv \""ori_label_directory"/"$0"\" \""ori_label_directory"/"Group_str""$0"\""}' | bash
echo -e ""
else
echo "**** label图片目录不存在: $ori_label_directory ****"
echo -e ""
fi
break
;;
### 选项7在文件名后添加内容 ###
"Add content behand filename")
echo -n "Please input the content to be added = "
read Group_str
# 判断image图片路径是否存在
if [ -d "$ori_image_directory" ]; then
echo "**** Processing ori_image_directory: $ori_image_directory ****"
ls "$ori_image_directory" | grep -E "\.(png|jpg|PNG|JPG|BMP|bmp)$" | awk -v ori_image_directory="$ori_image_directory" -v Group_str="$Group_str" -F . '{print "mv \""ori_image_directory"/"$0"\" \""ori_image_directory"/"substr($0, 1, length($0)-length($NF)-1)""Group_str"."$NF"\""}'
ls "$ori_image_directory" | grep -E "\.(png|jpg|PNG|JPG|BMP|bmp)$" | awk -v ori_image_directory="$ori_image_directory" -v Group_str="$Group_str" -F . '{print "mv \""ori_image_directory"/"$0"\" \""ori_image_directory"/"substr($0, 1, length($0)-length($NF)-1)""Group_str"."$NF"\""}' | bash
echo -e ""
else
echo "**** image图片目录不存在: $ori_image_directory ****"
echo -e ""
fi
# 判断label图片路径是否存在
if [ -d "$ori_label_directory" ]; then
echo "**** Processing ori_label_directory: $ori_label_directory ****"
ls "$ori_label_directory" | grep -E "\.(png|jpg|PNG|JPG|BMP|bmp)$" | awk -v ori_label_directory="$ori_label_directory" -v Group_str="$Group_str" -F . '{print "mv \""ori_label_directory"/"$0"\" \""ori_label_directory"/"substr($0, 1, length($0)-length($NF)-1)""Group_str"."$NF"\""}'
ls "$ori_label_directory" | grep -E "\.(png|jpg|PNG|JPG|BMP|bmp)$" | awk -v ori_label_directory="$ori_label_directory" -v Group_str="$Group_str" -F . '{print "mv \""ori_label_directory"/"$0"\" \""ori_label_directory"/"substr($0, 1, length($0)-length($NF)-1)""Group_str"."$NF"\""}' | bash
echo -e ""
else
echo "**** label图片目录不存在: $ori_label_directory ****"
echo -e ""
fi
break
;;
### 选项8 退出 ###
"Quit")
echo "Exiting..."
exit 0
;;
*)
echo "Invalid option: $REPLY"
echo -e ""
break
;;
esac
done
done

View File

@@ -0,0 +1,73 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*
import os, cv2, sys
def transform(input_path, output_path):
for root, dirs, files in os.walk(input_path):
for name in files:
print("2_1_正在检索", os.path.join(root, name))
# 如果图片是jpg图片
if name.endswith('.jpg') or name.endswith('.JPG') :
# convert一下图片
print("convert \""+os.path.join(root, name)+"\" \""+os.path.join(input_path, name)+"\"")
os.system("convert \""+os.path.join(root, name)+"\" \""+os.path.join(input_path, name)+"\"")
file = os.path.join(root, name)
if os.path.basename(root) == 'error':
break
try:
# 读取图片并且改变存储方式
im = cv2.imread(file)
if output_path:
# 压缩度调为0
cv2.imwrite(os.path.join(output_path, name.replace('jpg', 'png').replace('JPG', 'png')), im, [cv2.IMWRITE_PNG_COMPRESSION, 100, cv2.IMWRITE_PNG_COMPRESSION, 0])
else:
print('transform' + file.replace('jpg', 'png').replace('JPG', 'png'))
os.system("rm \""+file+"\"")
# 压缩度调为0
cv2.imwrite(file.replace('jpg', 'png').replace('JPG', 'png'), im, [cv2.IMWRITE_PNG_COMPRESSION, 100, cv2.IMWRITE_PNG_COMPRESSION, 0])
except:
os.system("echo "+file+" >> error.txt")
# 检查文件夹是否存在
if not os.path.exists(os.path.join(root, 'error')):
# 如果不存在,创建文件夹
os.mkdir(os.path.join(root, 'error'))
os.system("mv "+file+" "+os.path.join(root, 'error'))
# 如果图片是jpg图片
if name.endswith('.bmp') or name.endswith('.BMP') :
# convert一下图片
print("convert "+os.path.join(root, name)+" "+os.path.join(input_path, name))
os.system("convert "+os.path.join(root, name)+" "+os.path.join(input_path, name))
file = os.path.join(root, name)
if os.path.basename(root) == 'error':
break
try:
# 读取图片并且改变存储方式
im = cv2.imread(file)
if output_path:
# 压缩度调为0
cv2.imwrite(os.path.join(output_path, name.replace('bmp', 'png').replace('BMP', 'png')), im, [cv2.IMWRITE_PNG_COMPRESSION, 100, cv2.IMWRITE_PNG_COMPRESSION, 0])
else:
print('transform' + os.path.join(root, name.replace('bmp', 'png').replace('BMP', 'png')))
os.system("rm \""+file+"\"")
# 压缩度调为0
cv2.imwrite(os.path.join(root, name.replace('bmp', 'png').replace('BMP', 'png')), im, [cv2.IMWRITE_PNG_COMPRESSION, 100, cv2.IMWRITE_PNG_COMPRESSION, 0])
except:
os.system("echo \""+file+"\" >> error.txt")
# 检查文件夹是否存在
if not os.path.exists(os.path.join(root, 'error')):
# 如果不存在,创建文件夹
os.mkdir(os.path.join(root, 'error'))
os.system("mv \""+file+"\" \""+os.path.join(root, 'error')+"\"")
if __name__ == '__main__':
input_path = sys.argv[1] # './2.C组/Group_label_C0'
output_path = None
if not os.path.exists(input_path):
print("文件夹不存在!")
else:
print("Start to transform 2_1_Trans_to_png!")
transform(input_path, output_path)
print("Transform end 2_1_Trans_to_png!")
print()

View File

@@ -0,0 +1,103 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*
import cv2, sys
import os
break_up = None
break_up_point = False
# 重新改变图像大小
def Resize(input_path, output_path, interpolation=False, default_height=1920, default_width=1080):
global break_up, break_up_point
# 遍历输入路径
for root, dirs, files in os.walk(input_path):
# 遍历所有文件
for name in files:
file = os.path.join(root, name)
if(name == break_up or break_up == None): # 如果没到达断点或者没有断点
break_up_point = True
if break_up_point == False:
continue
if not name.endswith(('.png', '.jpg', '.PNG', '.JPG')):
continue
if os.path.basename(root) == 'error':
break
print("2_2_正在处理", file)
try:
# print("convert "+os.path.join(root, name)+" "+os.path.join(input_path, name))
print("Processing: ","convert \""+file+"\" \""+file+"\"")
status = os.system("convert \""+file+"\" \""+file+"\"") # 改变文件格式
# 如果返回状态不为0则移动对应图片到Error文件夹中
if status != 0:
os.system("echo \""+file+"\" >> error.txt")
# 检查文件夹是否存在
if not os.path.exists(os.path.join(root, 'error')):
# 如果不存在,创建文件夹
os.mkdir(os.path.join(root, 'error'))
os.system("mv \""+file+"\" \""+os.path.join(root, 'error')+'\"')
print("此文件是问题文件 ","mv \""+file+"\" \""+os.path.join(root, 'error')+'\"')
continue
# 读取图片
im = cv2.imread(file)
height, width, channels = im.shape
# 判断高度和宽度是否符合要求
if height == default_height and width == default_width:
print("符合要求")
continue
else:
# 是否满足最近临要求
if interpolation == False:
im = cv2.resize(im, (default_width, default_height))
else:
im = cv2.resize(im, (default_width, default_height),interpolation = cv2.INTER_NEAREST)
# 输出影像
if output_path:
cv2.imwrite(os.path.join(output_path, name), im, [cv2.IMWRITE_PNG_COMPRESSION, 100, cv2.IMWRITE_PNG_COMPRESSION, 0])
else:
print('Resize' + file)
os.system("rm \""+file+"\"") # TODO
# 压缩度调为0
cv2.imwrite(file, im, [cv2.IMWRITE_PNG_COMPRESSION, 100, cv2.IMWRITE_PNG_COMPRESSION, 0])
except:
os.system("echo \""+file+"\" >> error.txt")
# 检查文件夹是否存在
if not os.path.exists(os.path.join(root, 'error')):
# 如果不存在,创建文件夹
os.mkdir(os.path.join(root, 'error'))
os.system("mv \""+file+"\" \""+os.path.join(root, 'error')+"\"")
if __name__ == '__main__':
input_path = sys.argv[1] # './2.C组/Group_label_C0'
output_path = None
try:
default_width = int(sys.argv[3])
except:
default_width = 1920 # 默认宽度
try:
default_height = int(sys.argv[4])
except:
default_height = 1080 # 默认高度
if default_width == 0 or default_height == 0:
print("发生错误default_width、default_height不应该为0")
sys.exit()
interpolation = sys.argv[2] # 最近临插值
if interpolation == "False":
interpolation = False
elif interpolation == "True":
interpolation = True
else:
print("interpolation must be True or False!")
quit
if not os.path.exists(input_path):
print(input_path)
print("文件夹不存在!")
else:
print("Start to transform_2_2_Resize.py!")
print(input_path, output_path, default_height,default_width)
Resize(input_path, output_path, interpolation=interpolation, default_height=default_height, default_width=default_width)
print("Transform end_2_2_Resize.py!")

View File

@@ -0,0 +1,118 @@
#!/bin/bash
usage() {
echo "Usage: $0 -i <ori_image_directory> -l <ori_label_directory> -w <width_of_pic> -h <height_of_pic> [-help]"
echo "对image图片和label图片进行处理将其转为PNG格式并调整图片的宽和高和格式"
echo "-i:原始图片的路径,-l:原始标签的路径,-w:图片宽度,-h:图片高度,-help帮助"
}
ori_image_directorys=""
ori_label_directorys=""
pic_width=1920
pic_height=1080
while getopts "l:i:h:w:" opt; do
case $opt in
h)
if [[ $OPTARG =~ ^-?[0-9]+$ ]];then
pic_height=$OPTARG
echo pic_height is $pic_height
elif [ $OPTARG == 'elp' ];then
usage
exit 0
else
echo "-h(pic_height)必须为整数"
usage
exit 1
fi
;;
i)
ori_image_directorys=$OPTARG
;;
l)
ori_label_directorys=$OPTARG
;;
w)
if [[ $OPTARG =~ ^-?[0-9]+$ ]];then
pic_width=$OPTARG
echo pic_width is $pic_width
else
echo "-w(pic_height)必须为整数"
usage
exit 1
fi
;;
*)
echo -e '\033[31m!!! Error, Illegal input !!!\033[0m'
usage
exit 1
;;
esac
done
# 判断输入地址是否都为空
if [ -z "$ori_label_directorys" ] && [ -z "$ori_image_directorys" ]; then
echo -e "\033[31m输入地址 -i -l 都为空\033[0m"
usage
exit 1
fi
# 地址转化
ori_image_directory=$(readlink -f "$ori_image_directorys")
ori_label_directory=$(readlink -f "$ori_label_directorys")
if [ -z "$ori_label_directory" ] && [ -z "$ori_image_directory" ]; then
echo "无法解析地址,程序退出"
echo -e "\033[31mori_image_directory\033[0m: $ori_image_directorys"
echo -e "\033[31mori_label_directory\033[0m: $ori_label_directorys"
exit 1
fi
if [ ! -d "$ori_label_directory" ] && [ ! -d "$ori_image_directory" ]; then
echo "image、label两目录都不存在程序退出"
echo -e "\033[31mori_image_directory\033[0m: $ori_image_directorys"
echo -e "\033[31mori_label_directory\033[0m: $ori_label_directorys"
exit 1
fi
echo -e "\033[32m_____ 2_reformate_data.sh _____\033[0m"
# 获取当前脚本的路径和名称
script_path=$(dirname "$0")
# 将当前目录更改为脚本所在的路径
cd "$script_path"
# 激活conda环境
source /home/"$USER"/miniconda/bin/activate Deal_pics
# 判断image图片路径是否存在
if [ -d "$ori_image_directory" ]; then
echo "**** Processing ori_image_directory: $ori_image_directory ****"
echo "1.Trans pics to png"
python 2_1_Trans_to_png.py "$ori_image_directory"
echo -e ""
echo "2.Resize image pics with nearest"
echo -e "\033[35m运行\033[0mpython 2_2_Resize.py "$ori_image_directory" False $pic_width $pic_height "
python 2_2_Resize.py "$ori_image_directory" False $pic_width $pic_height # False 是不使用最近邻插值
echo -e ""
else
echo "**** image图片目录不存在: $ori_image_directory ****"
echo -e ""
fi
# 判断label图片路径是否存在
if [ -d "$ori_label_directory" ]; then
echo "**** Processing ori_label_directory: $ori_label_directory ****"
echo -e "\033[33m__ 1.Trans pics to png __\033[0m"
echo -e "\033[35m运行\033[0mpython 2_1_Trans_to_png.py "$ori_label_directory""
python 2_1_Trans_to_png.py "$ori_label_directory"
echo -e ""
echo -e "\033[33m__ 2.Resize label pics without nearest __\033[0m"
echo -e "\033[35m运行\033[0mpython 2_2_Resize.py "$ori_label_directory" True $pic_width $pic_height"
python 2_2_Resize.py "$ori_label_directory" True $pic_width $pic_height # True 是使用最近邻插值
echo -e ""
else
echo -e "\033[33m**** label图片目录不存在: $ori_image_directory ****\033[0m"
echo -e ""
fi
source /home/"$USER"/miniconda/bin/deactivate

View File

@@ -0,0 +1,186 @@
#!/bin/bash
usage() {
echo "Usage: $0 -i <ori_image_directory> -l <ori_label_directory> [ -p <prefix> -s <suffix> -h]"
echo "对image图片和label图片进行匹配-i、-l均不能为空-p -s默认为空"") "
echo "-i:原始image的路径-l:原始label的路径-p:前缀内容,-s:后缀内容(不用管文件后缀名)-h:帮助"
echo "e.g. 3_pair_ori_label.sh -i ./C组未标注 -l ./C组标注图片 -p Group_C_ -s _label"
}
ori_image_directorys=""
ori_label_directorys=""
prefix=""
suffix=""
while getopts "hl:i:p:s:" opt; do
case $opt in
h)
usage
exit 0
;;
i)
ori_image_directorys=$OPTARG
;;
l)
ori_label_directorys=$OPTARG
;;
p)
prefix=$OPTARG
;;
s)
suffix=$OPTARG
;;
*)
echo "$opt"
echo -e '\033[31m!!! Error, Illegal input !!!\033[0m'
usage
exit 1
;;
esac
done
# 判断输入地址是否为空
if [ -z "$ori_label_directorys" ] || [ -z "$ori_image_directorys" ]; then
echo -e "\033[31m输入地址 -i -l 存在空地址\033[0m"
usage
exit 1
fi
# 地址转化
ori_image_directory=$(readlink -f "$ori_image_directorys")
ori_label_directory=$(readlink -f "$ori_label_directorys")
if [ -z "$ori_label_directory" ] || [ -z "$ori_image_directory" ]; then
echo "无法解析地址,程序退出"
echo -e "\033[31mori_image_directory\033[0m: $ori_image_directorys"
echo -e "\033[31mori_label_directory\033[0m: $ori_label_directorys"
exit 1
fi
if [ ! -d "$ori_label_directory" ] || [ ! -d "$ori_image_directory" ]; then
echo "image、label两目录有一个不存在程序退出"
echo -e "\033[31mori_image_directory\033[0m: $ori_image_directorys"
echo -e "\033[31mori_label_directory\033[0m: $ori_label_directorys"
exit 1
fi
# 获取当前脚本的路径和名称
script_path=$(dirname "$0")
# 将当前目录更改为脚本所在的路径
cd "$script_path"
echo -e "\033[32m_____ 3_pair_ori_label.sh _____\033[0m"
# find -name 中"*"表示通配,与'.*'不同
# 遍历img目录
for file_path in "$ori_image_directory"/*; do
# 判断是否是文件
if [[ -f "$file_path" ]]; then
file_name=$(basename "$file_path")
# 判断文件名是否符合规范
if [[ "$file_name" =~ .*\.(jpg|png|bmp|JPG|PNG|BMP) ]]; then # 判断是否有为图片
# if [[ "$file_name" =~ "$prefix"".*$suffix".*\.(jpg|png|bmp|JPG|PNG|BMP)$ ]]; then # 判断是否有满足要求的文件名
# 抽取文件名(有前缀、后缀的抽取前缀、后缀里面的,没有的返回整个)
if [ -z $prefix ];then
file_name_extract=$(echo $file_name | sed "s/"$prefix"\(.*\)"$suffix".*/\1/" | sed "s/\(.*\)\.\(jpg\|png\|bmp\|JPG\|PNG\|BMP\)$/\1/")
else
file_name_extract=$(echo $file_name | sed "s/".*$prefix"\(.*\)"$suffix".*/\1/" | sed "s/\(.*\)\.\(jpg\|png\|bmp\|JPG\|PNG\|BMP\)$/\1/")
fi
echo "$file_name -> $file_name_extract"
# 从label目录中看是否有此文件sed匹配时$suffix后要有.*
if [ -z $prefix ];then
file_name_other=$(ls $ori_label_directory | grep $file_name_extract | sed "s/"$prefix"\(.*\)"$suffix".*/\1/" | sed "s/\(.*\)\.\(jpg\|png\|bmp\|JPG\|PNG\|BMP\)$/\1/")
else
file_name_other=$(ls $ori_label_directory | grep $file_name_extract | sed "s/".*$prefix"\(.*\)"$suffix".*/\1/" | sed "s/\(.*\)\.\(jpg\|png\|bmp\|JPG\|PNG\|BMP\)$/\1/")
fi
# 如果另一个目录没有此文件的话
if [ -z "$file_name_other" ]; then
echo "$file_name image中内容未在$ori_label_directory搜索到"
# 建立相关存储文件夹
if [ ! -d "$ori_image_directory/Not_pair_pics" ]; then
mkdir -p "$ori_image_directory/Not_pair_pics" # 建立存储文件夹
fi
# 移动相关文件
echo "$file_name" >> "$ori_image_directory/Not_pair_pics/not_pair.txt"
mv "$ori_image_directory/$file_name" "$ori_image_directory/Not_pair_pics"
fi
fi
fi
done
# 遍历label目录
for file_path in "$ori_label_directory"/*; do
# 判断是否是文件
if [[ -f "$file_path" ]]; then
file_name=$(basename "$file_path")
# 判断文件名是否符合规范
if [[ "$file_name" =~ .*\.(jpg|png|bmp|JPG|PNG|BMP) ]]; then # 判断是否有为图片
# if [[ "$file_name" =~ "$prefix"".*$suffix".*\.(jpg|png|bmp|JPG|PNG|BMP)$ ]]; then # 判断是否有满足要求的文件名
# 抽取文件名(有前缀、后缀的抽取前缀、后缀里面的,没有的返回整个)
if [ -z $prefix ];then
file_name_extract=$(echo $file_name | sed "s/"$prefix"\(.*\)"$suffix".*/\1/" | sed "s/\(.*\)\.\(jpg\|png\|bmp\|JPG\|PNG\|BMP\)$/\1/")
else
file_name_extract=$(echo $file_name | sed "s/".*$prefix"\(.*\)"$suffix".*/\1/" | sed "s/\(.*\)\.\(jpg\|png\|bmp\|JPG\|PNG\|BMP\)$/\1/")
fi
echo "$file_name -> $file_name_extract"
# 从image目录中看是否有此文件
if [ -z $prefix ];then
file_name_other=$(ls $ori_image_directory | grep $file_name_extract | sed "s/"$prefix"\(.*\)"$suffix".*/\1/" | sed "s/\(.*\)\.\(jpg\|png\|bmp\|JPG\|PNG\|BMP\)$/\1/")
else
file_name_other=$(ls $ori_image_directory | grep $file_name_extract | sed "s/".*$prefix"\(.*\)"$suffix".*/\1/" | sed "s/\(.*\)\.\(jpg\|png\|bmp\|JPG\|PNG\|BMP\)$/\1/")
fi
# 如果另一个目录没有此文件的话
if [ -z "$file_name_other" ]; then
echo "$file_name label中对应内容未在$ori_image_directory搜索到"
# 建立相关存储文件夹
if [ ! -d "$ori_label_directory/Not_pair_pics" ]; then
mkdir -p "$ori_label_directory/Not_pair_pics" # 建立存储文件夹
fi
# 移动相关文件
mv "$ori_label_directory/$file_name" "$ori_label_directory/Not_pair_pics"
echo "$file_name" >> "$ori_label_directory/Not_pair_pics/not_pair.txt"
fi
fi
fi
done
# 第二次遍历img目录
for file_path in "$ori_image_directory"/*; do
# 判断是否是文件
if [[ -f "$file_path" ]]; then
file_name=$(basename "$file_path")
# 判断文件名是否符合规范
if [[ "$file_name" =~ .*\.(jpg|png|bmp|JPG|PNG|BMP) ]]; then # 判断是否有为图片
# if [[ "$file_name" =~ "$prefix"".*$suffix".*\.(jpg|png|bmp|JPG|PNG|BMP)$ ]]; then # 判断是否有满足要求的文件名
# 抽取文件名(有前缀、后缀的抽取前缀、后缀里面的,没有的返回整个)
if [ -z $prefix ];then
file_name_extract=$(echo $file_name | sed "s/"$prefix"\(.*\)"$suffix".*/\1/" | sed "s/\(.*\)\.\(jpg\|png\|bmp\|JPG\|PNG\|BMP\)$/\1/")
else
file_name_extract=$(echo $file_name | sed "s/".*$prefix"\(.*\)"$suffix".*/\1/" | sed "s/\(.*\)\.\(jpg\|png\|bmp\|JPG\|PNG\|BMP\)$/\1/")
fi
echo "$file_name -> $file_name_extract"
# 从label目录中看是否有此文件sed匹配时$suffix后要有.*
if [ -z $prefix ];then
file_name_other=$(ls $ori_label_directory | grep $file_name_extract | sed "s/"$prefix"\(.*\)"$suffix".*/\1/" | sed "s/\(.*\)\.\(jpg\|png\|bmp\|JPG\|PNG\|BMP\)$/\1/")
else
file_name_other=$(ls $ori_label_directory | grep $file_name_extract | sed "s/".*$prefix"\(.*\)"$suffix".*/\1/" | sed "s/\(.*\)\.\(jpg\|png\|bmp\|JPG\|PNG\|BMP\)$/\1/")
fi
# 如果另一个目录没有此文件的话
if [ -z "$file_name_other" ]; then
echo "$file_name image中内容未在$ori_label_directory搜索到"
# 建立相关存储文件夹
if [ ! -d "$ori_image_directory/Not_pair_pics" ]; then
mkdir -p "$ori_image_directory/Not_pair_pics" # 建立存储文件夹
fi
# 移动相关文件
echo "$file_name" >> "$ori_image_directory/Not_pair_pics/not_pair.txt"
mv "$ori_image_directory/$file_name" "$ori_image_directory/Not_pair_pics"
fi
fi
fi
done

View File

@@ -0,0 +1,445 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*
import os,time,sys,threading, colorsys, argparse
import asyncio, cv2, multiprocessing, random
from PIL import Image
import numpy as np
from Tool_deal_labels import edge_detection, detect_connected_regions, Tool_color_connected_array, fill_white_regions, color_connected_regions
def getFileList(dir,Filelist=[], ext=None, Max_layer=1, layer=0, Donot_Search=['1_边缘检测并膨胀', '2_连通区域检测', '3_分水岭算法填充']):
"""
获取文件夹及其子文件夹中文件列表
输入 dir文件夹根目录
输入 ext: 扩展名
返回: 文件路径列表
"""
newDir = dir
if os.path.isfile(dir):
if ext is None:
Filelist.append(dir)
else:
if ext in dir[-3:]:
Filelist.append(dir)
elif os.path.isdir(dir):
file_name = os.path.basename(dir)
# 判断是否在禁搜名单中
if file_name in Donot_Search:
return Filelist
for s in os.listdir(dir):
newDir=os.path.join(dir,s)
if layer <= Max_layer:
getFileList(newDir, Filelist, ext, Max_layer, layer+1)
return Filelist
class Deal_image():
def __init__(self, Annotate_CLASSES = ('肝脏','胆囊'), Annotate_PALETTE = [[255,91,0],[255,234,0]], src_label_fold = "./Label", save_pro_label_fold = "./LABEL_PNG_new", save_GT_label_fold = "./Label_Generate", GT_channel = 1, pro_append_name="_label", GT_append_name="_gtFine_labelTrainIds", ori_img_folder="./ORI_PNG", res_label_folder="./Result_label", save_merge_pic_folder="./Result_merge", back_gnd_color=0, first_class_color=1, pic_type="png", Max_width = 10000, Label_Max_Search_layer=1000, save_process_pics=False, bg_PALETTE = [0,0,0]):
# 背景最好放在最后
# self.src_CLASSES = ('肝脏','胆囊','分离钳','止血海绵','肝总管','胆总管','吸引器','剪刀','止血纱布','生物夹','无损伤钳','喷洒','胆囊管','胆囊动脉','电凝','标本袋','引流管','纱布','金属钛夹','术中超声','吻合器','乳胶管','推结器','肝带','钳夹','超声刀','脂肪','双极电凝','棉球','血管阻断夹','肿瘤','针','线','韧带','胆囊静脉','背景')
# self.src_PALETTE = np.array([[255,91,0],[255,234,0],[85, 111, 181],[181, 227, 14],[72, 0, 255],[0, 155, 33],[255,0,255],[29, 32, 136],[160, 15, 95],[0,160,233],[52,184,178],[90,120,41],[255,0,0],[177,0,0],[167,24,233],[112,113,150],[0,255,0],[255,255,255],[0,255,255],[138,251,213],[136,162,196],[197,83,181],[202,202,200],[113,102,140],[66,115,82],[240,16,116],[155,132,0],[155,62,0],[146,175,236],[255,172,159],[245,161,0],[134,124,118], [0,157,142], [181,85,105], [42,8,66],[0,0,0]])
# self.src_CLASSES_NUM = np.shape(self.src_CLASSES)[0]
self.bg_PALETTE = bg_PALETTE # 背景颜色 TODO
self.Annotate_CLASSES = Annotate_CLASSES # 待分类的类
self.Annotate_PALETTE = np.array(Annotate_PALETTE) # 每一类的像素直
self.Annotate_CLASSES_NUM = np.shape(Annotate_CLASSES)[0] # 类数量
self.save_process_pics = save_process_pics # 保存中间过程图片
self.src_label_fold = src_label_fold # 原始标签图片 保存位置
self.save_pro_label_fold = save_pro_label_fold # 优化后标签图片 保存位置
self.save_GT_label_fold = save_GT_label_fold # GT标签图片 保存位置
self.ori_img_folder = ori_img_folder # 最原始手术图片 保存位置
self.res_label_folder = res_label_folder # 训练出来的label 保存位置
self.save_merge_pic_folder = save_merge_pic_folder # 融合图像保存位置
self.pro_append_name = pro_append_name # 优化后标签图片后缀
self.GT_append_name = GT_append_name # GT标签图片后缀
self.GT_channel = GT_channel # GT标签图片通道数
self.Max_width = Max_width # 最大图片宽度(匹配时候用)
self.pic_type = pic_type # 图片类型
self.back_gnd_color = back_gnd_color # 背景颜色
self.first_class_color = first_class_color # 第一类上的颜色
self.Label_Max_Search_layer=Label_Max_Search_layer # 文件夹最大搜索深度
try:
self.labellist_src = getFileList(src_label_fold, [], pic_type, self.Label_Max_Search_layer)
print('本次执行检索到ori_label图片 '+str(len(self.labellist_src))+' 张图像')
except:
self.labellist_src = None
print("没有ori_label相关文件")
try:
# print(save_pro_label_fold)
self.labellist_pro = getFileList(save_pro_label_fold, [], pic_type, self.Label_Max_Search_layer)
print('本次执行检索到pro_label图片 '+str(len(self.labellist_pro))+' 张图像')
except:
self.labellist_pro = None
print("没有pro_label相关文件")
try:
self.imglist_src = getFileList(ori_img_folder, [], pic_type, self.Label_Max_Search_layer)
self.reslist_src = getFileList(res_label_folder, [], pic_type, self.Label_Max_Search_layer)
print('本次执行检索到ori原始图片 '+str(len(self.imglist_src))+' 张图像')
print('本次执行检索到训练train_result图片 '+str(len(self.reslist_src))+' 张图像')
except:
self.imglist_src = None
self.reslist_src = None
print("没有train_result和原始图片相关文件")
# 获取单张图片各个通路信息
def get_single_pic_rgb(self, imgpath):
print(imgpath)
image = Image.open(imgpath).convert('RGB') # 转为RGB图片
# 将 RGB 色值分离
image.load()
r, g, b = image.split()
r = np.array(r)
g = np.array(g)
b = np.array(b)
return image, r, g, b
# 将单个pro图片变成GT图片
def Conver_pro_label_pic_2_GT_pic(self, imgpath, imgname):
time_start=time.time() # 记录开始时间
# 获取单张图片各个通路信息
image, r,g,b = self.get_single_pic_rgb(imgpath)
result_gt = np.ones(np.shape(image))*self.back_gnd_color # 初始化填充内容为back_gnd_color
gt_number = self.first_class_color # 第一类上色颜色确定
# PALETTE中排除掉 '背景' [0,0,0]
PALETTE_No_Bg = self.Annotate_PALETTE[~np.all(self.Annotate_PALETTE == self.bg_PALETTE, axis=1)]
# 遍历所有待识别颜色
for [Annotate_PALETTE_r, Annotate_PALETTE_g, Annotate_PALETTE_b] in PALETTE_No_Bg:
# 查找三原色匹配位置
locate_r = np.where( r == Annotate_PALETTE_r )
locate_g = np.where( g == Annotate_PALETTE_g )
locate_b = np.where( b == Annotate_PALETTE_b )
# 查找都匹配位置(交集)
# 将矩阵换一种表示形式
locate_r = np.array(locate_r[0]) * Max_width + np.array(locate_r[1])
locate_g = np.array(locate_g[0]) * Max_width + np.array(locate_g[1])
locate_b = np.array(locate_b[0]) * Max_width + np.array(locate_b[1])
# 用自带函数寻找匹配项
matched = np.intersect1d(np.intersect1d(locate_r, locate_g), locate_b)
matched = np.concatenate(([matched // self.Max_width], [np.mod(matched, self.Max_width)]), 0)
result_gt[matched[0],matched[1], :] = gt_number
gt_number = gt_number + 1
# 输出GT图片
if(int(self.GT_channel) == 1):
result_gt = result_gt[:,:,0]
elif(int(self.GT_channel) == 3):
result_gt = cv2.cvtColor(np.float32(result_gt), cv2.COLOR_RGB2BGR) # rgb颜色互换
else:
print("GT_channel 必须为1或3")
quit
try: # 新建文件夹
os.mkdir(self.save_GT_label_fold)
except:
print("已有"+self.save_GT_label_fold)
if imgname.lower().endswith(('.jpg', '.png')):
save_dir = os.path.join(self.save_GT_label_fold, os.path.basename(imgname).rpartition('.')[0]+self.GT_append_name+'.'+self.pic_type)
else:
save_dir = os.path.join(self.save_GT_label_fold, os.path.basename(imgname)+self.GT_append_name+'.'+self.pic_type)
cv2.imwrite(save_dir, result_gt)
print("GT图片已保存", save_dir)
time_end=time.time() # 输出结束时间
print('time cost',time_end-time_start,'s')
# 将处理好的图片转化为GT图片
def Conver_pro_label_pic_2_GT_pic_all(self):
print("\033[33m**** 进行转换将Pro_label_pic转换为GT_label_pic ****\033[0m")
print("\033[33mPro_label_pic存储位置为\033[0m", self.save_pro_label_fold)
print("\033[33mGT_label_pic生成位置为\033[0m", self.save_GT_label_fold)
try:
# print(save_pro_label_fold)
self.labellist_pro = getFileList(save_pro_label_fold, [], pic_type, self.Label_Max_Search_layer)
print('本次执行检索到pro_label图片 '+str(len(self.labellist_pro))+' 张图像')
except:
self.labellist_pro = None
print("没有pro_label相关文件")
try:
os.mkdir(self.save_GT_label_fold) # 新建存储文件夹
except:
print("已有"+self.save_GT_label_fold)
# 指定最大进程数为 3
max_processes = 20
# 创建Pool对象
pool = multiprocessing.Pool(processes=max_processes)
# 创建并启动进程
args_list1 = []
args_list2 = []
# 遍历整个文件夹
for imgpath in self.labellist_pro:
imgname = os.path.basename(imgpath).rpartition('.')[0].replace(self.pro_append_name,"")
args_list1.append(imgpath)
args_list2.append(imgname)
args_list = zip(args_list1, args_list2)
# 使用进程池并行执行任务
pool.starmap(self.Conver_pro_label_pic_2_GT_pic, args_list)
# 关闭进程池
pool.close()
pool.join()
def Conver_ori_label_pic_2_pro_pic(self, imgpath, imgname):
time_start=time.time() # 记录开始时间
# 获取单张图片各个通路信息
image = cv2.imread(imgpath)
# 1. 边缘检测并膨胀
dilated_image = edge_detection(image)
# 如果需要存储中间态图片
if(self.save_process_pics == True):
if imgname.lower().endswith(('.jpg', '.png')):
save_dir = os.path.join(self.save_pro_label_fold, '1_边缘检测并膨胀', os.path.basename(imgname).rpartition('.')[0]+self.pro_append_name+'_Edge'+'.'+self.pic_type)
else:
save_dir = os.path.join(self.save_pro_label_fold, '1_边缘检测并膨胀', os.path.basename(imgname)+self.pro_append_name+'_Edge'+'.'+self.pic_type)
cv2.imwrite(save_dir, dilated_image)
print("中间态-边缘检测并膨胀 图片已保存", save_dir)
time_end=time.time() # 输出结束时间
print('time cost',time_end-time_start,'s')
# 2. 检测连通区域
filtered_labeled_array, _ = detect_connected_regions(dilated_image)
colored_image_filtered = Tool_color_connected_array(filtered_labeled_array)
# 如果需要存储中间态图片
if(self.save_process_pics == True):
if imgname.lower().endswith(('.jpg', '.png')):
save_dir = os.path.join(self.save_pro_label_fold, '2_连通区域检测', os.path.basename(imgname).rpartition('.')[0]+self.pro_append_name+'_Region'+'.'+self.pic_type)
else:
save_dir = os.path.join(self.save_pro_label_fold, '2_连通区域检测', os.path.basename(imgname)+self.pro_append_name+'_Region'+'.'+self.pic_type)
cv2.imwrite(save_dir, colored_image_filtered)
print("中间态-连通区域检测 图片已保存", save_dir)
time_end=time.time() # 输出结束时间
print('time cost',time_end-time_start,'s')
# 3. 分水岭填充白色区域
filled_labeled_array = fill_white_regions(filtered_labeled_array)
colored_image_filled = Tool_color_connected_array(filled_labeled_array)
# 如果需要存储中间态图片
if(self.save_process_pics == True):
if imgname.lower().endswith(('.jpg', '.png')):
save_dir = os.path.join(self.save_pro_label_fold, '3_分水岭算法填充', os.path.basename(imgname).rpartition('.')[0]+self.pro_append_name+'_FillEdge'+'.'+self.pic_type)
else:
save_dir = os.path.join(self.save_pro_label_fold, '3_分水岭算法填充', os.path.basename(imgname)+self.pro_append_name+'_FillEdge'+'.'+self.pic_type)
cv2.imwrite(save_dir, colored_image_filled)
print("中间态-分水岭算法填充 图片已保存", save_dir)
time_end=time.time() # 输出结束时间
print('time cost',time_end-time_start,'s')
# 4. 对连通区域最终上色
ori_labeled_image = image
result_pro = color_connected_regions(filled_labeled_array, filtered_labeled_array, ori_labeled_image, self.Annotate_PALETTE)
if imgname.lower().endswith(('.jpg', '.png')):
save_dir = os.path.join(self.save_pro_label_fold, os.path.basename(imgname).rpartition('.')[0]+self.pro_append_name+'.'+self.pic_type)
else:
save_dir = os.path.join(self.save_pro_label_fold, os.path.basename(imgname)+self.pro_append_name+'.'+self.pic_type)
print("Pro图片已保存", save_dir)
cv2.imwrite(save_dir, result_pro)
time_end=time.time() # 输出结束时间
print('time cost',time_end-time_start,'s')
# 将原始src图片转化为处理好的pro图片
def Conver_ori_label_pic_2_pro_pic_all(self):
print("\033[33m**** 进行转换将Ori_label_pic转换为Pro_label_pic ****\033[0m")
print("\033[33mOri_label_pic存储位置为\033[0m", self.src_label_fold)
print("\033[33mPro_label_pic生成位置为\033[0m", self.save_pro_label_fold)
# 输出颜色预处理图片
try:
os.mkdir(self.save_pro_label_fold) # 新建存储文件夹
except:
print("已有"+self.save_pro_label_fold)
if(self.save_process_pics == True):
try:
os.mkdir(os.path.join(self.save_pro_label_fold, '1_边缘检测并膨胀')) # 新建存储1_边缘检测并膨胀文件夹
except:
print("已有"+os.path.join(self.save_pro_label_fold, '1_边缘检测并膨胀'))
try:
os.mkdir(os.path.join(self.save_pro_label_fold, '2_连通区域检测')) # 新建存储2_连通区域检测文件夹
except:
print("已有"+os.path.join(self.save_pro_label_fold, '2_连通区域检测'))
try:
os.mkdir(os.path.join(self.save_pro_label_fold, '3_分水岭算法填充')) # 新建存储1_边缘检测并膨胀文件夹
except:
print("已有"+os.path.join(self.save_pro_label_fold, '3_分水岭算法填充'))
# 指定最大进程数为 20多参数函数并行
max_processes = 20
# 创建Pool对象
pool = multiprocessing.Pool(processes=max_processes)
# 创建并启动进程
args_list1 = []
args_list2 = []
# 遍历整个文件夹
for imgpath in self.labellist_src:
if imgpath.lower().endswith(('.jpg', '.png')):
imgname= os.path.basename(imgpath).rpartition('.')[0].replace(self.pro_append_name,"")
else:
imgname= os.path.basename(imgpath).replace(self.pro_append_name,"")
try:
print("Processing: ", imgname, "...")
# self.Conver_ori_label_pic_2_pro_pic(imgpath, imgname)s
# args_list.append({'imgpath': imgpath, 'imgname': imgname})
args_list1.append(imgpath)
args_list2.append(imgname)
except:
os.system("echo "+imgname+" >> error_1.txt")
args_list = zip(args_list1, args_list2)
# 使用进程池并行执行任务
pool.starmap(self.Conver_ori_label_pic_2_pro_pic, args_list) # 使用starmap进行多参数并行
# 关闭进程池
pool.close()
pool.join()
# 图片堆叠
def Merge_ori_pic_and_label_pic(self, res_img_path, res_imgname):
time_start=time.time() # 记录开始时间
# 获取单张图片各个通路信息
ori_img_path = os.path.join(self.ori_img_folder, res_imgname+'.'+self.pic_type)
if not os.path.exists(ori_img_path):
print("****照片不存在:****", ori_img_path)
return -1
ori_image, ori_r, ori_g, ori_b = self.get_single_pic_rgb(ori_img_path)
res_image, res_r, res_g, res_b = self.get_single_pic_rgb(res_img_path)
merge_img = np.array(ori_image) # merge图片初始化默认图片背景为0.0.0
# 遍历所有待识别颜色
for [Annotate_PALETTE_r, Annotate_PALETTE_g, Annotate_PALETTE_b] in self.Annotate_PALETTE:
# 查找三原色匹配位置
locate_r = np.where( res_r == Annotate_PALETTE_r )
locate_g = np.where( res_g == Annotate_PALETTE_g )
locate_b = np.where( res_b == Annotate_PALETTE_b )
# 查找都匹配位置(交集)
# 将矩阵换一种表示形式
locate_r = np.array(locate_r[0]) * self.Max_width + np.array(locate_r[1])
locate_g = np.array(locate_g[0]) * self.Max_width + np.array(locate_g[1])
locate_b = np.array(locate_b[0]) * self.Max_width + np.array(locate_b[1])
# 用自带函数寻找匹配项
matched = np.intersect1d(np.intersect1d(locate_r, locate_g), locate_b)
matched = np.concatenate(([matched // self.Max_width], [np.mod(matched, self.Max_width)]), 0)
merge_img[matched[0],matched[1], 0] = Annotate_PALETTE_r
merge_img[matched[0],matched[1], 1] = Annotate_PALETTE_g
merge_img[matched[0],matched[1], 2] = Annotate_PALETTE_b
# 转成cv2形式
merge_img = cv2.cvtColor(np.float32(merge_img), cv2.COLOR_RGB2BGR)
try: # 新建文件夹
os.mkdir(self.save_merge_pic_folder)
except:
print("已有"+self.save_merge_pic_folder)
if res_imgname.lower().endswith(('.jpg', '.png')):
save_dir = os.path.join(self.save_merge_pic_folder, os.path.basename(res_imgname).rpartition('.')[0]+'.'+self.pic_type)
else:
save_dir = os.path.join(self.save_merge_pic_folder, os.path.basename(res_imgname)+'.'+self.pic_type)
cv2.imwrite(save_dir, merge_img)
print("Merge图片已保存", save_dir)
time_end=time.time() # 输出结束时间
print('time cost',time_end-time_start,'s')
# 将label图片与原图片重合
def Merge_ori_pic_and_label_pic_all(self):
# 遍历整个文件夹
for res_img_path in self.reslist_src:
if res_img_path.lower().endswith(('.jpg', '.png')):
res_imgname = os.path.basename(res_img_path).rpartition('.')[0].replace(self.pro_append_name,"")
else:
res_imgname = os.path.basename(res_img_path).replace(self.pro_append_name,"")
print("Processing: ", res_imgname, "...")
self.Merge_ori_pic_and_label_pic(res_img_path, res_imgname)
if __name__ == "__main__":
Annotate_CLASSES = ('肝脏','胆囊','分离钳','止血海绵','肝总管','胆总管','吸引器','剪刀','止血纱布','生物夹','无损伤钳','喷洒','胆囊管','胆囊动脉','电凝','标本袋','引流管','纱布','金属钛夹','术中超声','吻合器','乳胶管','推结器','肝带','钳夹','超声刀','脂肪','双极电凝','棉球','血管阻断夹','肿瘤','','线','韧带','胆囊静脉','背景') # 待分类的类
Annotate_PALETTE = [[255,91,0],[255,234,0],[85, 111, 181],[181, 227, 14],[72, 0, 255],[0, 155, 33],[255,0,255],[29, 32, 136],[160, 15, 95],[0,160,233],[52,184,178],[90,120,41],[255,0,0],[177,0,0],[167,24,233],[112,113,150],[0,255,0],[255,255,255],[0,255,255],[138,251,213],[136,162,196],[197,83,181],[202,202,200],[113,102,140],[66,115,82],[240,16,116],[155,132,0],[155,62,0],[146,175,236],[255,172,159],[245,161,0],[134,124,118], [0,157,142], [181,85,105], [42,8,66],[0,0,0]] # 每一类的像素直
bg_PALETTE = [0,0,0] # 背景的RGB
# 创建参数解析器
parser = argparse.ArgumentParser(description='Process some files.')
# 添加参数选项
parser.add_argument('-src_fold', dest='src_label_fold', default='', help='source label folder')
parser.add_argument('-save_pro_fold', dest='save_pro_label_fold', default='./save_pro_label_fold', help='processed label folder')
parser.add_argument('-save_GT_fold', dest='save_GT_label_fold', default='./save_GT_label_fold', help='ground truth folder')
parser.add_argument('-fold_search_depth', dest='Label_Max_Search_layer', default='1000', type=int, help='Folder Search Depth')
parser.add_argument('-pro_suffix_name', dest='pro_append_name', default='_label', help='Pro file suffix')
parser.add_argument('-GT_suffix_name', dest='GT_append_name', default='_gtFine_labelTrainIds', help='GT file suffix')
parser.add_argument('-GT_channel', dest='GT_channel', default='1', type=int, help='GT file channel(1 or 3)')
parser.add_argument('-back_gnd_color', dest='back_gnd_color', default='0', type=int, help='Color of "Back ground"(0 or 255)')
parser.add_argument('-first_class_color', dest='first_class_color', default='1', type=int, help='Color of "First Class"')
parser.add_argument('-pic_type', dest='pic_type', default='png', help='type of picture(Do not add ".")')
parser.add_argument('-Max_width', dest='Max_width', default='10000', type=int, help='Max width of picture')
parser.add_argument('-Rebuild_from', dest='Rebuild_from', default='label', help='Source to Rebuild Labels(label/pro)')
parser.add_argument('-Rebuild_to', dest='Rebuild_to', default='GT', help='Destination of Rebuild Labels(pro/GT)')
parser.add_argument('-save_process_pics', dest='save_process_pics', default='False', help='Save the processed pics(e.g.Gray_pics,Color_pics) in generating pro_pics')
# 解析命令行参数
args = parser.parse_args()
src_label_fold = args.src_label_fold
save_pro_label_fold = args.save_pro_label_fold
save_GT_label_fold = args.save_GT_label_fold
Label_Max_Search_layer = args.Label_Max_Search_layer
pro_append_name = args.pro_append_name
GT_append_name = args.GT_append_name
GT_channel = args.GT_channel
back_gnd_color = args.back_gnd_color
first_class_color = args.first_class_color
pic_type = args.pic_type
Max_width = args.Max_width
Rebuild_from = args.Rebuild_from
Rebuild_to = args.Rebuild_to
save_process_pics = args.save_process_pics
try: # 遍历文件深度最小为1
Label_Max_Search_layer=int(Label_Max_Search_layer)
except:
Label_Max_Search_layer=1000
try: # GT标签图片通道数
GT_channel=int(GT_channel)
except:
GT_channel=1
try: # 背景颜色背景选择0或255)
back_gnd_color=int(back_gnd_color)
except:
back_gnd_color=0
try: # 第一类上的颜色(如果背景为0,选择1)
first_class_color=int(first_class_color)
except:
first_class_color=1
try: # 最大图片宽度(匹配时候用)
Max_width=int(Max_width)
except:
Max_width=10000
if(save_process_pics.lower() == 'false'):
save_process_pics = False
elif(save_process_pics.lower() == 'true'):
save_process_pics = True
else:
save_process_pics = False
D = Deal_image(Annotate_CLASSES=Annotate_CLASSES, Annotate_PALETTE=Annotate_PALETTE, src_label_fold=src_label_fold, save_pro_label_fold=save_pro_label_fold, save_GT_label_fold=save_GT_label_fold, GT_channel=GT_channel, pro_append_name=pro_append_name, GT_append_name=GT_append_name, back_gnd_color=back_gnd_color, first_class_color=first_class_color, pic_type=pic_type, Max_width=Max_width, Label_Max_Search_layer=Label_Max_Search_layer, save_process_pics=save_process_pics, bg_PALETTE = bg_PALETTE)
# print(D.src_CLASSES_NUM)
if Rebuild_from == 'label':
# 1.先将所有原始图片转为pro图片
D.Conver_ori_label_pic_2_pro_pic_all()
pass
if Rebuild_to == 'GT':
# 2.再将pro图片转为GT图片
D.Conver_pro_label_pic_2_GT_pic_all()
pass

View File

@@ -0,0 +1,463 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*
import os,time,sys,threading, colorsys, argparse
import asyncio, cv2, multiprocessing
from PIL import Image
import numpy as np
def getFileList(dir,Filelist=[], ext=None, Max_layer=1, layer=0, Donot_Search=['1_颜色预处理', '2_灰度化', '3_边缘化']):
"""
获取文件夹及其子文件夹中文件列表
输入 dir文件夹根目录
输入 ext: 扩展名
返回: 文件路径列表
"""
newDir = dir
if os.path.isfile(dir):
if ext is None:
Filelist.append(dir)
else:
if ext in dir[-3:]:
Filelist.append(dir)
elif os.path.isdir(dir):
file_name = os.path.basename(dir)
# 判断是否在禁搜名单中
if file_name in Donot_Search:
return Filelist
for s in os.listdir(dir):
newDir=os.path.join(dir,s)
if layer <= Max_layer:
getFileList(newDir, Filelist, ext, Max_layer, layer+1)
return Filelist
class Deal_image():
def __init__(self, Annotate_CLASSES = ('肝脏','胆囊'), Annotate_PALETTE = [[255,91,0],[255,234,0]], src_label_fold = "./Label", save_pro_label_fold = "./LABEL_PNG_new", save_GT_label_fold = "./Label_Generate", GT_channel = 1, pro_append_name="_label", GT_append_name="_gtFine_labelTrainIds", ori_img_folder="./ORI_PNG", res_label_folder="./Result_label", save_merge_pic_folder="./Result_merge", back_gnd_color=0, first_class_color=1, pic_type="png", Max_width = 10000, Label_Max_Search_layer=1000, save_process_pics=False):
self.src_CLASSES = ('肝脏','胆囊','分离钳','止血海绵','肝总管','胆总管','吸引器','剪刀','止血纱布','生物夹','无损伤钳','喷洒','胆囊管','胆囊动脉','电凝','标本袋','引流管','纱布','金属钛夹','背景')
self.src_PALETTE = np.array([[255,91,0],[255,234,0],[85, 107, 179],[181, 227, 14],[72, 0, 255],[0, 155, 33],[255,0,255],[29, 32, 136],[160, 15, 95],[0,160,233],[52,184,178],[90,120,41],[255,0,0],[177,0,0],[167,24,233],[112,113,150],[0,255,0],[255,255,255],[0,255,255],[0,0,0]])
self.src_CLASSES_NUM = np.shape(self.src_CLASSES)[0]
self.Annotate_CLASSES = Annotate_CLASSES # 待分类的类
self.Annotate_PALETTE = np.array(Annotate_PALETTE) # 每一类的像素直
self.Annotate_CLASSES_NUM = np.shape(Annotate_CLASSES)[0] # 类数量
self.save_process_pics = save_process_pics # 保存中间过程图片
self.src_label_fold = src_label_fold # 原始标签图片 保存位置
self.save_pro_label_fold = save_pro_label_fold # 优化后标签图片 保存位置
self.save_GT_label_fold = save_GT_label_fold # GT标签图片 保存位置
self.ori_img_folder = ori_img_folder # 最原始手术图片 保存位置
self.res_label_folder = res_label_folder # 训练出来的label 保存位置
self.save_merge_pic_folder = save_merge_pic_folder # 融合图像保存位置
self.pro_append_name = pro_append_name # 优化后标签图片后缀
self.GT_append_name = GT_append_name # GT标签图片后缀
self.GT_channel = GT_channel # GT标签图片通道数
self.Max_width = Max_width # 最大图片宽度(匹配时候用)
self.pic_type = pic_type # 图片类型
self.back_gnd_color = back_gnd_color # 背景颜色
self.first_class_color = first_class_color # 第一类上的颜色
self.Label_Max_Search_layer=Label_Max_Search_layer # 文件夹最大搜索深度
try:
self.labellist_src = getFileList(src_label_fold, [], pic_type, self.Label_Max_Search_layer)
print('本次执行检索到ori_label图片 '+str(len(self.labellist_src))+' 张图像')
except:
self.labellist_src = None
print("没有ori_label相关文件")
try:
# print(save_pro_label_fold)
self.labellist_pro = getFileList(save_pro_label_fold, [], pic_type, self.Label_Max_Search_layer)
print('本次执行检索到pro_label图片 '+str(len(self.labellist_pro))+' 张图像')
except:
self.labellist_pro = None
print("没有pro_label相关文件")
try:
self.imglist_src = getFileList(ori_img_folder, [], pic_type, self.Label_Max_Search_layer)
self.reslist_src = getFileList(res_label_folder, [], pic_type, self.Label_Max_Search_layer)
print('本次执行检索到ori原始图片 '+str(len(self.imglist_src))+' 张图像')
print('本次执行检索到训练train_result图片 '+str(len(self.reslist_src))+' 张图像')
except:
self.imglist_src = None
self.reslist_src = None
print("没有train_result和原始图片相关文件")
# 获取单张图片各个通路信息
def get_single_pic_rgb(self, imgpath):
print(imgpath)
image = Image.open(imgpath).convert('RGB') # 转为RGB图片
# 将 RGB 色值分离
image.load()
r, g, b = image.split()
r = np.array(r)
g = np.array(g)
b = np.array(b)
return image, r, g, b
# 将单个pro图片变成GT图片
def Conver_pro_label_pic_2_GT_pic(self, imgpath, imgname):
time_start=time.time() # 记录开始时间
# 获取单张图片各个通路信息
image, r,g,b = self.get_single_pic_rgb(imgpath)
result_gt = np.ones(np.shape(image))*self.back_gnd_color # 初始化填充内容为back_gnd_color
gt_number = self.first_class_color # 第一类上色颜色确定
# 遍历所有待识别颜色
for [Annotate_PALETTE_r, Annotate_PALETTE_g, Annotate_PALETTE_b] in self.Annotate_PALETTE:
# 查找三原色匹配位置
locate_r = np.where( r == Annotate_PALETTE_r )
locate_g = np.where( g == Annotate_PALETTE_g )
locate_b = np.where( b == Annotate_PALETTE_b )
# 查找都匹配位置(交集)
# 将矩阵换一种表示形式
locate_r = np.array(locate_r[0]) * Max_width + np.array(locate_r[1])
locate_g = np.array(locate_g[0]) * Max_width + np.array(locate_g[1])
locate_b = np.array(locate_b[0]) * Max_width + np.array(locate_b[1])
# 用自带函数寻找匹配项
matched = np.intersect1d(np.intersect1d(locate_r, locate_g), locate_b)
matched = np.concatenate(([matched // self.Max_width], [np.mod(matched, self.Max_width)]), 0)
result_gt[matched[0],matched[1], :] = gt_number
gt_number = gt_number + 1
# 输出GT图片
if(int(self.GT_channel) == 1):
result_gt = result_gt[:,:,0]
elif(int(self.GT_channel) == 3):
result_gt = cv2.cvtColor(np.float32(result_gt), cv2.COLOR_RGB2BGR) # rgb颜色互换
else:
print("GT_channel 必须为1或3")
quit
try: # 新建文件夹
os.mkdir(self.save_GT_label_fold)
except:
print("已有"+self.save_GT_label_fold)
save_dir = os.path.join(self.save_GT_label_fold, os.path.splitext(imgname)[0]+self.GT_append_name+'.'+self.pic_type)
cv2.imwrite(save_dir, result_gt)
print("GT图片已保存", save_dir)
time_end=time.time() # 输出结束时间
print('time cost',time_end-time_start,'s')
# 将处理好的图片转化为GT图片
def Conver_pro_label_pic_2_GT_pic_all(self):
print("\033[33m**** 进行转换将Pro_label_pic转换为GT_label_pic ****\033[0m")
print("\033[33mPro_label_pic存储位置为\033[0m", self.save_pro_label_fold)
print("\033[33mGT_label_pic生成位置为\033[0m", self.save_GT_label_fold)
try:
# print(save_pro_label_fold)
self.labellist_pro = getFileList(save_pro_label_fold, [], pic_type, self.Label_Max_Search_layer)
print('本次执行检索到pro_label图片 '+str(len(self.labellist_pro))+' 张图像')
except:
self.labellist_pro = None
print("没有pro_label相关文件")
try:
os.mkdir(self.save_GT_label_fold) # 新建存储文件夹
except:
print("已有"+self.save_GT_label_fold)
# 指定最大进程数为 3
max_processes = 20
# 创建Pool对象
pool = multiprocessing.Pool(processes=max_processes)
# 创建并启动进程
args_list1 = []
args_list2 = []
# 遍历整个文件夹
for imgpath in self.labellist_pro:
imgname= os.path.splitext(os.path.basename(imgpath))[0].replace(self.pro_append_name,"")
args_list1.append(imgpath)
args_list2.append(imgname)
args_list = zip(args_list1, args_list2)
# 使用进程池并行执行任务
pool.starmap(self.Conver_pro_label_pic_2_GT_pic, args_list)
# 关闭进程池
pool.close()
pool.join()
def Conver_ori_label_pic_2_pro_pic(self, imgpath, imgname):
time_start=time.time() # 记录开始时间
# 获取单张图片各个通路信息
image, r, g, b = self.get_single_pic_rgb(imgpath)
result_pro = np.zeros(np.shape(image)) # pro图片初始化默认图片背景为0.0.0
# 生成距离初始矩阵 width*height*num_of_classes
Dis_mat = np.zeros((np.shape(image)[0], np.shape(image)[1], self.src_CLASSES_NUM))
# 遍历所有类
i = 0 # 顺序指示物
# 遍历寻找最短距离是第几个内容
for palette in self.src_PALETTE:
Dis_mat[:,:,i] = np.abs((r - palette[0])) + np.abs((g - palette[1])) + np.abs((b - palette[2]))
i = i + 1
Min_Dis_mat = np.argmin(Dis_mat, axis = 2)
# 给图片上色
for number_class in range(self.src_CLASSES_NUM):
# 查找图片种类
loc_of_class_pic = np.where(Min_Dis_mat == number_class)
# 否则给图片上对应颜色
result_pro[loc_of_class_pic[0], loc_of_class_pic[1], :] = self.src_PALETTE[number_class]
result_pro = cv2.cvtColor(np.float32(result_pro), cv2.COLOR_RGB2BGR) # rgb颜色互换
# 如果需要存储中间态图片
if(self.save_process_pics == True):
save_dir = os.path.join(self.save_pro_label_fold, '1_颜色预处理', os.path.splitext(imgname)[0]+self.pro_append_name+'_preproc'+'.'+self.pic_type)
cv2.imwrite(save_dir, result_pro)
print("中间态-颜色图片已保存", save_dir)
time_end=time.time() # 输出结束时间
print('time cost',time_end-time_start,'s')
# 边界区域大小
half_region_len=(5-1)//2
pic_height = result_pro.shape[0]
pic_width = result_pro.shape[1]
# 2_灰度化
gray = cv2.cvtColor(result_pro, cv2.COLOR_BGR2GRAY)
gray = (gray*255).astype(np.uint8) # cv2对图片格式要求较为严格
# 如果需要存储中间态图片
if(self.save_process_pics == True):
save_dir = os.path.join(self.save_pro_label_fold, '2_灰度化', os.path.splitext(imgname)[0]+self.pro_append_name+'_gray'+'.'+self.pic_type)
cv2.imwrite(save_dir, gray)
print("中间态-灰度图片已保存", save_dir)
time_end=time.time() # 输出结束时间
print('time cost',time_end-time_start,'s')
# 边缘检测
edges = cv2.Canny(gray, 50, 150)
# 定义膨胀核
kernel = np.ones((3, 3), np.uint8)
# 对Canny输出图像进行膨胀
edges = cv2.dilate(edges, kernel, iterations=1)
# 如果需要存储中间态图片
if(self.save_process_pics == True):
save_dir = os.path.join(self.save_pro_label_fold, '3_边缘化', os.path.splitext(imgname)[0]+self.pro_append_name+'_edge'+'.'+self.pic_type)
cv2.imwrite(save_dir, edges)
print("中间态-边缘图片已保存", save_dir)
time_end=time.time() # 输出结束时间
print('time cost',time_end-time_start,'s')
# 查找边缘
nonzero_idx = np.array(np.nonzero(edges)).T
# nonzero_idx, counts = np.unique(nonzero_idx, axis=0, return_counts=True)
print("边缘个数:",np.shape(nonzero_idx)[0])
for k in range(len(nonzero_idx)):
i = nonzero_idx[k][0]
j = nonzero_idx[k][1]
# 遍历周围3*3的区域
# 找到与当前像素值附近5*5*3(channel)的像素并将矩阵变为25(5*5)*3(channel)
if i <= half_region_len:
i = half_region_len
elif i >= pic_height-half_region_len:
i = pic_height-half_region_len
if j <= half_region_len:
j = half_region_len
elif j >= pic_width-half_region_len:
j = pic_width-half_region_len
region_x = i
region_y = j
region = np.array(result_pro[region_x-half_region_len:region_x+half_region_len, region_y-half_region_len:region_y+half_region_len,:]).reshape(-1, 3)
# 计算唯一列及其出现次数
# print(region_x-half_region_len, region_x+half_region_len, region_y-half_region_len, region_y+half_region_len)
unique_columns, counts = np.unique(region, axis=0, return_counts=True)
# 找出出现次数最多的列的索引
max_count_index = np.argmax(counts)
# 将其赋值为出现次数最多的列
result_pro[i, j, :] = unique_columns[max_count_index]
# result_pro[i, j, :] = [0,0,0]
save_dir = os.path.join(self.save_pro_label_fold, os.path.splitext(imgname)[0]+self.pro_append_name+'.'+self.pic_type)
print("Pro图片已保存", save_dir)
cv2.imwrite(save_dir, result_pro)
time_end=time.time() # 输出结束时间
print('time cost',time_end-time_start,'s')
# 将原始src图片转化为处理好的pro图片
def Conver_ori_label_pic_2_pro_pic_all(self):
print("\033[33m**** 进行转换将Ori_label_pic转换为Pro_label_pic ****\033[0m")
print("\033[33mOri_label_pic存储位置为\033[0m", self.src_label_fold)
print("\033[33mPro_label_pic生成位置为\033[0m", self.save_pro_label_fold)
# 输出颜色预处理图片
try:
os.mkdir(self.save_pro_label_fold) # 新建存储文件夹
except:
print("已有"+self.save_pro_label_fold)
if(self.save_process_pics == True):
try:
os.mkdir(os.path.join(self.save_pro_label_fold, '1_颜色预处理')) # 新建存储1_颜色预处理文件夹
except:
print("已有"+os.path.join(self.save_pro_label_fold, '1_颜色预处理'))
try:
os.mkdir(os.path.join(self.save_pro_label_fold, '2_灰度化')) # 新建存储2_灰度化文件夹
except:
print("已有"+os.path.join(self.save_pro_label_fold, '2_灰度化'))
try:
os.mkdir(os.path.join(self.save_pro_label_fold, '3_边缘化')) # 新建存储1_颜色预处理文件夹
except:
print("已有"+os.path.join(self.save_pro_label_fold, '3_边缘化'))
# 指定最大进程数为 20多参数函数并行
max_processes = 20
# 创建Pool对象
pool = multiprocessing.Pool(processes=max_processes)
# 创建并启动进程
args_list1 = []
args_list2 = []
# 遍历整个文件夹
for imgpath in self.labellist_src:
try:
imgname= os.path.splitext(os.path.basename(imgpath))[0].replace(self.pro_append_name,"")
print("Processing: ", imgname, "...")
# self.Conver_ori_label_pic_2_pro_pic(imgpath, imgname)s
# args_list.append({'imgpath': imgpath, 'imgname': imgname})
args_list1.append(imgpath)
args_list2.append(imgname)
except:
os.system("echo "+imgname+" >> error_1.txt")
args_list = zip(args_list1, args_list2)
# 使用进程池并行执行任务
pool.starmap(self.Conver_ori_label_pic_2_pro_pic, args_list) # 使用starmap进行多参数并行
# 关闭进程池
pool.close()
pool.join()
# 图片堆叠
def Merge_ori_pic_and_label_pic(self, res_img_path, res_imgname):
time_start=time.time() # 记录开始时间
# 获取单张图片各个通路信息
ori_img_path = os.path.join(self.ori_img_folder, res_imgname+'.'+self.pic_type)
if not os.path.exists(ori_img_path):
print("****照片不存在:****", ori_img_path)
return -1
ori_image, ori_r, ori_g, ori_b = self.get_single_pic_rgb(ori_img_path)
res_image, res_r, res_g, res_b = self.get_single_pic_rgb(res_img_path)
merge_img = np.array(ori_image) # merge图片初始化默认图片背景为0.0.0
# 遍历所有待识别颜色
for [Annotate_PALETTE_r, Annotate_PALETTE_g, Annotate_PALETTE_b] in self.Annotate_PALETTE:
# 查找三原色匹配位置
locate_r = np.where( res_r == Annotate_PALETTE_r )
locate_g = np.where( res_g == Annotate_PALETTE_g )
locate_b = np.where( res_b == Annotate_PALETTE_b )
# 查找都匹配位置(交集)
# 将矩阵换一种表示形式
locate_r = np.array(locate_r[0]) * self.Max_width + np.array(locate_r[1])
locate_g = np.array(locate_g[0]) * self.Max_width + np.array(locate_g[1])
locate_b = np.array(locate_b[0]) * self.Max_width + np.array(locate_b[1])
# 用自带函数寻找匹配项
matched = np.intersect1d(np.intersect1d(locate_r, locate_g), locate_b)
matched = np.concatenate(([matched // self.Max_width], [np.mod(matched, self.Max_width)]), 0)
merge_img[matched[0],matched[1], 0] = Annotate_PALETTE_r
merge_img[matched[0],matched[1], 1] = Annotate_PALETTE_g
merge_img[matched[0],matched[1], 2] = Annotate_PALETTE_b
# 转成cv2形式
merge_img = cv2.cvtColor(np.float32(merge_img), cv2.COLOR_RGB2BGR)
try: # 新建文件夹
os.mkdir(self.save_merge_pic_folder)
except:
print("已有"+self.save_merge_pic_folder)
save_dir = os.path.join(self.save_merge_pic_folder, os.path.splitext(res_imgname)[0]+'.'+self.pic_type)
cv2.imwrite(save_dir, merge_img)
print("Merge图片已保存", save_dir)
time_end=time.time() # 输出结束时间
print('time cost',time_end-time_start,'s')
# 将label图片与原图片重合
def Merge_ori_pic_and_label_pic_all(self):
# 遍历整个文件夹
for res_img_path in self.reslist_src:
res_imgname = os.path.splitext(os.path.basename(res_img_path))[0].replace(self.pro_append_name,"")
print("Processing: ", res_imgname, "...")
self.Merge_ori_pic_and_label_pic(res_img_path, res_imgname)
if __name__ == "__main__":
Annotate_CLASSES = ('肝脏','胆囊','分离钳','止血海绵','肝总管','胆总管','吸引器','剪刀','止血纱布','生物夹','无损伤钳','喷洒','胆囊管','胆囊动脉','电凝','标本袋','引流管','纱布','金属钛夹') # 待分类的类
Annotate_PALETTE = [[255,91,0],[255,234,0],[85, 107, 179],[181, 227, 14],[72, 0, 255],[0, 155, 33],[255,0,255],[29, 32, 136],[160, 15, 95],[0,160,233],[52,184,178],[90,120,41],[255,0,0],[177,0,0],[167,24,233],[112,113,150],[0,255,0],[255,255,255],[0,255,255]]# [[255,91,0],[255,234,0]] # 每一类的像素直
# 创建参数解析器
parser = argparse.ArgumentParser(description='Process some files.')
# 添加参数选项
parser.add_argument('-src_fold', dest='src_label_fold', default='', help='source label folder')
parser.add_argument('-save_pro_fold', dest='save_pro_label_fold', default='./save_pro_label_fold', help='processed label folder')
parser.add_argument('-save_GT_fold', dest='save_GT_label_fold', default='./save_GT_label_fold', help='ground truth folder')
parser.add_argument('-fold_search_depth', dest='Label_Max_Search_layer', default='1000', type=int, help='Folder Search Depth')
parser.add_argument('-pro_suffix_name', dest='pro_append_name', default='_label', help='Pro file suffix')
parser.add_argument('-GT_suffix_name', dest='GT_append_name', default='_gtFine_labelTrainIds', help='GT file suffix')
parser.add_argument('-GT_channel', dest='GT_channel', default='1', type=int, help='GT file channel(1 or 3)')
parser.add_argument('-back_gnd_color', dest='back_gnd_color', default='0', type=int, help='Color of "Back ground"(0 or 255)')
parser.add_argument('-first_class_color', dest='first_class_color', default='1', type=int, help='Color of "First Class"')
parser.add_argument('-pic_type', dest='pic_type', default='png', help='type of picture(Do not add ".")')
parser.add_argument('-Max_width', dest='Max_width', default='10000', type=int, help='Max width of picture')
parser.add_argument('-Rebuild_from', dest='Rebuild_from', default='label', help='Source to Rebuild Labels(label/pro)')
parser.add_argument('-Rebuild_to', dest='Rebuild_to', default='GT', help='Destination of Rebuild Labels(pro/GT)')
parser.add_argument('-save_process_pics', dest='save_process_pics', default='False', help='Save the processed pics(e.g.Gray_pics,Color_pics) in generating pro_pics')
# 解析命令行参数
args = parser.parse_args()
src_label_fold = args.src_label_fold
save_pro_label_fold = args.save_pro_label_fold
save_GT_label_fold = args.save_GT_label_fold
Label_Max_Search_layer = args.Label_Max_Search_layer
pro_append_name = args.pro_append_name
GT_append_name = args.GT_append_name
GT_channel = args.GT_channel
back_gnd_color = args.back_gnd_color
first_class_color = args.first_class_color
pic_type = args.pic_type
Max_width = args.Max_width
Rebuild_from = args.Rebuild_from
Rebuild_to = args.Rebuild_to
save_process_pics = args.save_process_pics
try: # 遍历文件深度最小为1
Label_Max_Search_layer=int(Label_Max_Search_layer)
except:
Label_Max_Search_layer=1000
try: # GT标签图片通道数
GT_channel=int(GT_channel)
except:
GT_channel=1
try: # 背景颜色背景选择0或255)
back_gnd_color=int(back_gnd_color)
except:
back_gnd_color=0
try: # 第一类上的颜色(如果背景为0,选择1)
first_class_color=int(first_class_color)
except:
first_class_color=1
try: # 最大图片宽度(匹配时候用)
Max_width=int(Max_width)
except:
Max_width=10000
if(save_process_pics.lower() == 'false'):
save_process_pics = False
elif(save_process_pics.lower() == 'true'):
save_process_pics = True
else:
save_process_pics = False
D = Deal_image(Annotate_CLASSES=Annotate_CLASSES, Annotate_PALETTE=Annotate_PALETTE, src_label_fold=src_label_fold, save_pro_label_fold=save_pro_label_fold, save_GT_label_fold=save_GT_label_fold, GT_channel=GT_channel, pro_append_name=pro_append_name, GT_append_name=GT_append_name, back_gnd_color=back_gnd_color, first_class_color=first_class_color, pic_type=pic_type, Max_width=Max_width, Label_Max_Search_layer=Label_Max_Search_layer, save_process_pics=save_process_pics)
# print(D.src_CLASSES_NUM)
if Rebuild_from == 'label':
# 1.先将所有原始图片转为pro图片
D.Conver_ori_label_pic_2_pro_pic_all()
pass
if Rebuild_to == 'GT':
# 2.再将pro图片转为GT图片
D.Conver_pro_label_pic_2_GT_pic_all()
pass

View File

@@ -0,0 +1,163 @@
#!/bin/bash
usage() {
echo "Usage: $0 l <ori_label_directory> [ -h ]"
echo "对label图片进行处理及转化-l不能为空 "
echo "-l:原始label的路径-h:帮助"
echo "e.g. 4_rebuild_labels.sh -l ./C组标注图片"
echo "接下来一路回车就ok"
}
ori_label_directorys=""
while getopts "hl:" opt; do
case $opt in
h)
usage
exit 0
;;
l)
ori_label_directorys=$OPTARG
;;
*)
echo -e '\033[31m!!! Error, Illegal input !!!\033[0m'
usage
exit 1
;;
esac
done
# 判断输入地址是否为空
if [ -z "$ori_label_directorys" ]; then
echo -e "\033[31m输入地址 -i -l 存在空地址\033[0m"
usage
exit 1
fi
# 地址转化
ori_label_directory=$(readlink -f "$ori_label_directorys")
if [ -z "$ori_label_directory" ]; then
echo -e "\033[31m无法解析地址程序退出\033[0m"
echo -e "\033[31mori_label_directory\033[0m: $ori_label_directorys"
exit 1
fi
if [ ! -d "$ori_label_directory" ]; then
echo -e "\033[31mlabel目录不存在程序退出\033[0m"
echo -e "\033[31mori_label_directory\033[0m: $ori_label_directorys"
exit 1
fi
echo -e "\033[32m_____ 4_rebuild_labels.sh _____\033[0m"
echo -n "请选择label图片搜索深度(默认为1):"
read -r fold_search_depth
if [ -z $fold_search_depth ]; then
fold_search_depth='1'
fi
save_pro_folds=""${ori_label_directory%/}"_pro_label_fold" # 去掉末尾的足/
save_pro_fold=$(readlink -f "$save_pro_folds")
echo -n "请选择初步处理label后label_pro图片存储位置(默认为$save_pro_fold):"
read -r save_pro_folds
if [ -z $save_pro_folds ]; then
save_pro_folds=""${ori_label_directory%/}"_pro_label_fold"
fi
save_pro_fold=$(readlink -f "$save_pro_folds")
echo -n "请选择label_pro图片后缀(默认为\"_label\"):"
read -r pro_suffix_name
if [ -z $pro_suffix_name ]; then
pro_suffix_name='_label'
fi
save_GT_folds=""${ori_label_directory%/}"_GT_label_fold"
save_GT_fold=$(readlink -f "$save_GT_folds")
echo -n "请选择处理label_pro后label_GT图片存储位置(默认为$save_GT_fold):"
read -r save_pro_folds
if [ -z $save_pro_folds ]; then
save_GT_folds=""${ori_label_directory%/}"_GT_label_fold"
fi
save_GT_fold=$(readlink -f "$save_GT_folds")
echo -n "请选择label_GT图片后缀(默认为\"_gtFine_labelTrainIds\"):"
read -r GT_suffix_name
if [ -z $GT_suffix_name ]; then
GT_suffix_name='_gtFine_labelTrainIds'
fi
echo -n "请选择label_GT图片通道数(1或3默认为1):"
read -r GT_channel
if [ -z $GT_channel ]; then
GT_channel='1'
fi
if [[ $GT_channel != '1' && $GT_channel != '3' ]]; then
echo -e "\033[35mGT_channel只能为1或3输入有误将其默认变为1\033[0m"
GT_channel='1'
fi
echo -n "请选择GT图片背景颜色(0或255默认为0黑色)):"
read -r back_gnd_color
if [ -z $back_gnd_color]; then
back_gnd_color='0'
fi
echo -n "请选择GT图片中第一类的颜色(黑色背景(0)下默认为1白色背景(255)下默认为0):"
read -r first_class_color
if [ -z $first_class_color]; then
if [ $back_gnd_color == '255' ]; then
first_class_color='0'
else
first_class_color='1'
fi
fi
echo -n "请选择图片类型(png或jpg默认为png没有\".\")):"
read -r pic_type
if [ -z $pic_type ]; then
pic_type='png'
fi
echo -n "请选择重建起始目录(label或pro默认为label):"
read -r Rebuild_from
if [ -z $Rebuild_from ]; then
Rebuild_from='label'
fi
if [[ $Rebuild_from != 'label' && $Rebuild_from != 'pro' ]]; then
echo -e "\033[35mRebuild_from只能为label或pro输入有误将其默认变为label\033[0m"
Rebuild_from='label'
fi
echo -n "请选择重建最终目标(pro或GT默认为GT):"
read -r Rebuild_to
if [ -z $Rebuild_to ]; then
Rebuild_to='GT'
fi
if [[ $Rebuild_to != 'GT' && $Rebuild_to != 'pro' ]]; then
echo -e "\033[35mRebuild_to只能为GT或pro输入有误将其默认变为GT\033[0m"
Rebuild_to='GT'
fi
echo -n "请选择是否保存pro图片生成中间状态e.g.灰度图等)(false或true默认为false):"
read -r save_process_pics
if [ -z $save_process_pics ]; then
save_process_pics='false'
fi
if [[ $save_process_pics != 'true' && $save_process_pics != 'false' ]]; then
echo -e "\033[35msave_process_pics只能为true或false输入有误将其默认变为false\033[0m"
save_process_pics='false'
fi
# 获取当前脚本的路径和名称
script_path=$(dirname "$0")
# 将当前目录更改为脚本所在的路径
cd "$script_path"
# 激活conda环境
source /home/"$USER"/miniconda/bin/activate Deal_pics
echo -e "\033[35m运行\033[0mpython 4_deal_labels.py -src_fold $ori_label_directory -save_pro_fold $save_pro_fold -save_GT_fold $save_GT_fold -fold_search_depth $fold_search_depth -pro_suffix_name $pro_suffix_name -GT_suffix_name $GT_suffix_name -GT_channel $GT_channel -back_gnd_color $back_gnd_color -first_class_color $first_class_color -pic_type $pic_type -Rebuild_from $Rebuild_from -Rebuild_to $Rebuild_to -save_process_pics $save_process_pics "
echo ""
python 4_deal_labels.py -src_fold $ori_label_directory -save_pro_fold $save_pro_fold -save_GT_fold $save_GT_fold -fold_search_depth $fold_search_depth -pro_suffix_name $pro_suffix_name -GT_suffix_name $GT_suffix_name -GT_channel $GT_channel -back_gnd_color $back_gnd_color -first_class_color $first_class_color -pic_type $pic_type -Rebuild_from $Rebuild_from -Rebuild_to $Rebuild_to -save_process_pics $save_process_pics
echo "4_rebuild_label_pics.sh重构完毕"

View File

@@ -0,0 +1,122 @@
#!/bin/bash
usage() {
echo "Usage: $0 -i <ori_image_directory> -l <ori_label_directory> -r <stack_result_directory> [ -a <alpha> -p <prefix> -s <suffix> -h]"
echo "对image图片和label图片进行匹配-i、-l -r均不能为空-p -s默认为空"" -a默认为\"0.3\") "
echo "-i:原始image的路径-l:原始label的路径-p:前缀内容,-s:后缀内容(不用管文件后缀名)-h:帮助"
echo "e.g. 5_TOOL_stack_pics.sh -i ./C组未标注 -l ./C组标注图片 -r ./C组result_0.3透明度 -a 0.3 -p Group_C_ -s _label"
}
ori_image_directorys=""
ori_label_directorys=""
stack_result_directorys=""
prefix=""
suffix=""
alpha="0.3"
while getopts "hl:i:r:p:s:a:" opt; do
case $opt in
h)
usage
exit 0
;;
i)
ori_image_directorys=$OPTARG
;;
l)
ori_label_directorys=$OPTARG
;;
p)
prefix=$OPTARG
;;
s)
suffix=$OPTARG
;;
r)
stack_result_directorys=$OPTARG
;;
a)
alpha=$OPTARG
;;
*)
echo -e '\033[31m!!! Error, Illegal input !!!\033[0m'
usage
exit 1
;;
esac
done
# 判断输入地址是否为空
if [ -z "$ori_label_directorys" ] || [ -z "$ori_image_directorys" ] || [ -z "$stack_result_directorys" ]; then
echo -e "\033[31m输入地址 -i -l -z 存在空地址\033[0m"
usage
exit 1
fi
# 地址转化
ori_image_directory=$(readlink -f "$ori_image_directorys")
ori_label_directory=$(readlink -f "$ori_label_directorys")
stack_result_directory=$(readlink -f "$stack_result_directorys")
if [ -z "$ori_label_directory" ] || [ -z "$ori_image_directory" ]|| [ -z "$stack_result_directory" ]; then
echo "image、label、result存在无法解析地址程序退出"
echo -e "\033[31mori_image_directory\033[0m: $ori_image_directorys"
echo -e "\033[31mori_label_directory\033[0m: $ori_label_directorys"
echo -e "\033[31mori_label_directory\033[0m: $stack_result_directorys"
exit 1
fi
if [ ! -d "$ori_label_directory" ] || [ ! -d "$ori_image_directory" ]; then
echo "image、label两目录有一个不存在程序退出"
echo -e "\033[31mori_image_directory\033[0m: $ori_image_directory"
echo -e "\033[31mori_label_directory\033[0m: $ori_label_directorys"
exit 1
fi
# 获取当前脚本的路径和名称
script_path=$(dirname "$0")
# 将当前目录更改为脚本所在的路径
cd "$script_path"
# 激活conda环境
source /home/"$USER"/miniconda/bin/activate Deal_pics
echo -e "\033[32m_____ 5_TOOL_stack_pics.sh _____\033[0m"
echo -e "\033[33mimage所在文件夹为$ori_image_directory\nlable所在文件夹为$ori_label_directory\033[0m"
# 遍历label目录
for file_path in "$ori_label_directory"/*; do
# 判断是否是文件
if [[ -f "$file_path" ]]; then
file_name=$(basename "$file_path")
# 判断文件名是否符合规范
if [[ "$file_name" =~ .*\.(jpg|png|bmp|JPG|PNG|BMP) ]]; then # 判断是否有为图片
# if [[ "$file_name" =~ "$prefix".*"$suffix".*\.(jpg|png|bmp|JPG|PNG|BMP)$ ]]; then # 判断是否有满足要求的文件名
# 抽取文件名(有前缀、后缀的抽取前缀、后缀里面的,没有的返回整个)
if [ -z $prefix ];then
file_name_extract=$(echo $file_name | sed "s/"$prefix"\(.*\)"$suffix".*/\1/" | sed "s/\(.*\)\.\(jpg\|png\|bmp\|JPG\|PNG\|BMP\)$/\1/")
else
file_name_extract=$(echo $file_name | sed "s/".*$prefix"\(.*\)"$suffix".*/\1/" | sed "s/\(.*\)\.\(jpg\|png\|bmp\|JPG\|PNG\|BMP\)$/\1/")
fi
# 从label目录中看是否有此文件
file_name_other=$(ls $ori_image_directory | grep $file_name_extract)
file_name_other=$(echo "$(echo "$file_name_other" | sed '/^$/d')" | head -n1) # 提取出文件名
# 如果另一个目录没有此文件的话
if [ -z "$file_name_other" ]; then
echo "$file_name label中对应内容未在$ori_image_directory搜索到"
# 建立相关存储文件夹
if [ ! -d "$ori_label_directory/Not_pair_pics" ]; then
mkdir -p "$ori_label_directory/Not_pair_pics" # 建立存储文件夹
fi
# 移动相关文件
cp "$ori_label_directory/$file_name" "$ori_label_directory/Not_pair_pics"
echo "$file_name" >> "$ori_label_directory/Not_pair_pics/not_pair.txt"
else # 如果另一个目录有此配对文件的话,则运行相关程序
echo "image中的$file_name_other与lable中的$file_name"
mkdir -p "$stack_result_directory"
python 5_stack_picture.py "$ori_image_directory/$file_name_other" "$ori_label_directory/$file_name" "$stack_result_directory" "$alpha"
echo ""
fi
fi
else
echo "$file_path不是文件"
fi
done

View File

@@ -0,0 +1,41 @@
#!/usr/bin/env python3
# -*- coding: UTF-8 -*-
import cv2, os, sys
def Stack_pic(Background_path, Overlay_path, Result_dir, alpha=0.3):
# 读取两张没有alpha通道的图片
img1 = cv2.imread(Background_path) # 底层图片
img2 = cv2.imread(Overlay_path) # 顶层图片
Result_name = os.path.splitext(os.path.basename(Background_path))[0]
# 将img2调整为与img1大小相同
img2 = cv2.resize(img2, (img1.shape[1], img1.shape[0]))
# 将img2的透明度调整为20%
overlay_alpha = alpha
# 将img2叠加到img1上
overlay = cv2.addWeighted(img1, 1 - overlay_alpha, img2, overlay_alpha, 0)
# 保存结果
if not os.path.exists(Result_dir):
os.makedirs(Result_dir)
cv2.imwrite(os.path.join(Result_dir, Result_name+'.png'), overlay)
print("堆叠图片写入地址:", os.path.join(Result_dir, Result_name+'.png'))
if __name__ == '__main__':
Background_path = sys.argv[1] # 背景所在路径
Overlay_path = sys.argv[2] # 上层图片所在路径
Result_dir = sys.argv[3] # 结果所在目录
# 透明度默认为0.3
try:
alpha = float(sys.argv[4])
if(alpha > 1 or alpha < 0):
print("alpha 透明度输入不正确其值应该在0~1之间")
alpha = 0.3
except:
alpha = 0.3
# 进行对叠程序
Stack_pic(Background_path, Overlay_path, Result_dir, alpha)

View File

@@ -0,0 +1,167 @@
#!/bin/bash
usage() {
echo "Usage: $0 -i <ori_image_directory> -l <ori_label_directory> -r <stitch_result_directory> [ -p <prefix> -s <suffix> -h]"
echo "对image图片和label图片进行拼接-i -r不能为空-l为填写项-p -s默认为空"") "
echo "-i:原始image的路径-l:拼接label的路径-p:前缀内容,-s:后缀内容(不用管文件后缀名)-h:帮助"
echo "e.g. 6_TOOL_stitch_pics.sh -i ./C组未标注 -l ./C组标注图片 -r ./C组result_stitch -p Group_C_ -s _label"
}
ori_image_directorys=""
ori_label_directorys=""
stitch_result_directorys=""
prefix=""
suffix=""
while getopts "hl:i:r:p:s:a:" opt; do
case $opt in
h)
usage
exit 0
;;
i)
ori_image_directorys=$OPTARG
;;
l)
ori_label_directorys=$OPTARG
;;
p)
prefix=$OPTARG
;;
s)
suffix=$OPTARG
;;
r)
stitch_result_directorys=$OPTARG
;;
*)
echo -e '\033[31m!!! Error, Illegal input !!!\033[0m'
usage
exit 1
;;
esac
done
# 如果堆叠地址为空,生成默认堆叠地址
if [ -z "$stitch_result_directorys" ]; then
stitch_result_directorys=""$ori_label_directorys"_拼接"
echo -n "请输入堆叠结果存储目录(默认为$stitch_result_directorys)"
read -r temp
if [ ! -z $temp ]; then
stitch_result_directorys=$temp
fi
fi
# 判断输入地址是否为空
if [ -z "$ori_label_directorys" ] || [ -z "$ori_image_directorys" ] || [ -z "$stitch_result_directorys" ]; then
echo -e "\033[31m输入地址 -i -l -z 存在空地址\033[0m"
usage
exit 1
fi
# 地址转化
ori_image_directory=$(readlink -f "$ori_image_directorys")
ori_label_directory=$(readlink -f "$ori_label_directorys")
stitch_result_directory=$(readlink -f "$stitch_result_directorys")
if [ -z "$ori_label_directory" ] || [ -z "$ori_image_directory" ]|| [ -z "$stitch_result_directory" ]; then
echo "image、label、result存在无法解析地址程序退出"
echo -e "\033[31mori_image_directory\033[0m: $ori_image_directorys"
echo -e "\033[31mori_label_directory\033[0m: $ori_label_directorys"
echo -e "\033[31mori_label_directory\033[0m: $stitch_result_directorys"
exit 1
fi
if [ ! -d "$ori_label_directory" ] || [ ! -d "$ori_image_directory" ]; then
echo "image、label两目录有一个不存在程序退出"
echo -e "\033[31mori_image_directory\033[0m: $ori_image_directory"
echo -e "\033[31mori_label_directory\033[0m: $ori_label_directorys"
exit 1
fi
echo -e "\033[32m_____ 6_TOOL_stitch_pics.sh _____\033[0m"
# 图片相对位置
PS3='Please enter your choice:'
options=("Imge ↑ Label ↓" "Label ↑ Imge ↓" "Imge ← Label →" "Label ← Imge →" )
echo "请选择图片与Label的相对位置默认为1.\"Imge↑Label↓\""
select opt in "${options[@]}"
do
case $opt in
# 这里面返回的结果都是img在前label在后的
"Imge ↑ Label ↓")
relative_pos="Img_up_label_down"
break
;;
"Label ↑ Imge ↓")
relative_pos="Img_down_label_up"
break
;;
"Imge ← Label →")
relative_pos="Img_left_label_right"
break
;;
"Label ← Imge →")
relative_pos="Img_right_label_left"
break
;;
*)
echo "Set to default 1.\"Imge↑Label↓\""
relative_pos="Img_up_label_down"
break
;;
esac
done
# 获取当前脚本的路径和名称
script_path=$(dirname "$0")
# 将当前目录更改为脚本所在的路径
cd "$script_path"
# 激活conda环境
source /home/"$USER"/miniconda/bin/activate Deal_pics
echo -e "\033[33mimage所在文件夹为$ori_image_directory\nlable所在文件夹为$ori_label_directory\033[0m"
# 遍历label目录
for file_path in "$ori_label_directory"/*; do
# 判断是否是文件
if [[ -f "$file_path" ]]; then
file_name=$(basename "$file_path")
# 判断文件名是否符合规范
if [[ "$file_name" =~ .*\.(jpg|png|bmp|JPG|PNG|BMP) ]]; then # 判断是否有为图片
# if [[ "$file_name" =~ "$prefix".*"$suffix".*\.(jpg|png|bmp|JPG|PNG|BMP)$ ]]; then # 判断是否有满足要求的文件名
# 抽取文件名(有前缀、后缀的抽取前缀、后缀里面的,没有的返回整个)
if [ -z $prefix ];then
file_name_extract=$(echo $file_name | sed "s/"$prefix"\(.*\)"$suffix".*/\1/" | sed "s/\(.*\)\.\(jpg\|png\|bmp\|JPG\|PNG\|BMP\)$/\1/")
else
file_name_extract=$(echo $file_name | sed "s/".*$prefix"\(.*\)"$suffix".*/\1/" | sed "s/\(.*\)\.\(jpg\|png\|bmp\|JPG\|PNG\|BMP\)$/\1/")
fi
# 从label目录中看是否有此文件
file_name_other=$(ls $ori_image_directory | grep $file_name_extract)
file_name_other=$(echo "$(echo "$file_name_other" | sed '/^$/d')" | head -n1) # 提取出文件名
# 如果另一个目录没有此文件的话
if [ -z "$file_name_other" ]; then
echo "$file_name label中对应内容未在$ori_image_directory搜索到"
# 建立相关存储文件夹
if [ ! -d "$ori_label_directory/Not_pair_pics" ]; then
mkdir -p "$ori_label_directory/Not_pair_pics" # 建立存储文件夹
fi
# 移动相关文件
cp "$ori_label_directory/$file_name" "$ori_label_directory/Not_pair_pics"
echo "$file_name" >> "$ori_label_directory/Not_pair_pics/not_pair.txt"
else # 如果另一个目录有此配对文件的话,则运行相关程序
echo "image中的$file_name_other与lable中的$file_name"
if [ ! -d "$stitch_result_directory" ]; then
echo "创建stitch_result_directory存储文件夹"
echo -e "\033[35运行:\033[0m mkdir -p $stitch_result_directory"
mkdir -p $stitch_result_directory
fi
echo -e "\033[35运行:\033[0mpython 6_stitch_picture.py "$ori_image_directory/$file_name_other" "$ori_label_directory/$file_name" "$stitch_result_directory" "$relative_pos""
python 6_stitch_picture.py "$ori_image_directory/$file_name_other" "$ori_label_directory/$file_name" "$stitch_result_directory" "$relative_pos"
echo ""
fi
fi
else
echo "$file_path不是文件"
fi
done

View File

@@ -0,0 +1,75 @@
#!/usr/bin/env python3
# -*- coding: UTF-8 -*-
import cv2, os, sys, re
def Stitch_pic(Ori_image_path, Ori_label_path, Result_dir, img_pos, label_pos):
# 读取两张没有stitch_pos通道的图片
img1 = cv2.imread(Ori_image_path) # 底层图片
img2 = cv2.imread(Ori_label_path) # 顶层图片
Result_name = os.path.splitext(os.path.basename(Ori_image_path))[0]
# 将img2调整为与img1大小相同
if img1.shape != img2.shape:
img2 = cv2.resize(img2, (img1.shape[1], img1.shape[0]))
# 拼接图片
if img_pos == 'up' and label_pos == 'down':
result = cv2.vconcat([img1, img2])
elif img_pos == 'down' and label_pos == 'up':
result = cv2.vconcat([img2, img1])
elif img_pos == 'left' and label_pos == 'right':
result = cv2.hconcat([img1, img2])
elif img_pos == 'right' and label_pos == 'left':
result = cv2.hconcat([img2, img1])
else:
RED = '\033[91m'
END = '\033[0m'
print(RED + "The input of relative_pos is wrong, img_pos is " + img_pos + " label_pos is " + label_pos + END)
os.exit()
# 保存结果
if not os.path.exists(Result_dir):
os.makedirs(Result_dir)
cv2.imwrite(os.path.join(Result_dir, Result_name+'.png'), result)
print("堆叠图片写入地址:", os.path.join(Result_dir, Result_name+'.png'))
if __name__ == '__main__':
Ori_image_path = sys.argv[1] # 背景所在路径
Ori_label_path = sys.argv[2] # 上层图片所在路径
Result_dir = sys.argv[3] # 结果所在目录
relative_pos = str.lower(sys.argv[4])
match_up = re.search(r'up', relative_pos)
match_down = re.search(r'down', relative_pos)
match_left = re.search(r'up', relative_pos)
match_right = re.search(r'down', relative_pos)
if match_up and match_down:
pos_up = match_up.start()
pos_down = match_down.start()
if pos_down < pos_up :
img_pos = "down"
label_pos = "up"
else:
img_pos = "up"
label_pos = "down"
elif match_left and match_right:
pos_left = match_up.start()
pos_right = match_down.start()
if pos_left < pos_right :
img_pos = "left"
label_pos = "right"
else:
img_pos = "left"
label_pos = "right"
else:
print("Either 'up'/'down' 'left'/'right' is missing or in the text.")
print("Set to default img_pos = up label_pos = down")
img_pos = "up"
label_pos = "down"
# 进行对叠程序
Stitch_pic(Ori_image_path, Ori_label_path, Result_dir, img_pos, label_pos)

View File

@@ -0,0 +1,246 @@
#!/bin/bash
usage() {
echo "Usage: $0 -i <ori_image_directory> -l <ori_label_directory> [-h]"
echo "对image图片和label图片进行统一处理"
echo "-i:原始图片的路径,-l:原始标签的路径,-h:帮助"
}
ori_image_directorys=""
ori_label_directorys=""
stack_result_directorys=""
stitch_result_directorys=""
while getopts "hl:i:" opt; do
case $opt in
h)
usage
exit 0
;;
i)
ori_image_directorys=$OPTARG
;;
l)
ori_label_directorys=$OPTARG
;;
*)
echo '!!! Error, Illegal input !!!'
usage
exit 1
;;
esac
done
# 判断label、image是否都为空
echo $ori_label_directorys $ori_image_directorys
if [ -z "$ori_label_directorys" ] && [ -z "$ori_image_directorys" ]; then
echo -e "\033[31mori_label_directory、ori_image_directory不能都为空\033[0m"
usage
exit 1
fi
# 进行绝对路径转化
ori_image_directory=$(readlink -f "$ori_image_directorys")
ori_label_directory=$(readlink -f "$ori_label_directorys")
if [ ! -d "$ori_label_directory" ] && [ ! -d "$ori_image_directory" ]; then
echo "image、label都不存在程序退出"
echo -e "\033[31mori_image_directory\033[0m: $ori_image_directorys"
echo "$ori_image_directory"
echo -e "\033[31mori_label_directory\033[0m: $ori_label_directorys"
echo "$ori_label_directory"
exit 1
fi
# 激活conda环境
# source /home/"$USER"/miniconda/bin/activate Pat_infos
# 记录脚本所在路径
script_path=$(dirname "$0")
echo -e "\033[32m******* 开始Run运行Seg_data_run.sh批量图片处理程序 *******\033[0m"
# 进行操作选择
while true; do
PS3='Please enter your choice: '
options=("移动图片与重命名" "图片统一大小与类型" "image、label配对检测" "对label进行重建" "TOOL_image、label堆叠" "TOOL_image、label拼接" "Quit")
echo -e "\033[35m____ Seg_data_run选择 ____\033[0m"
echo -e "\033[35mImage所在地址\033[0m$ori_image_directory"
echo -e "\033[35mLabel所在地址\033[0m$ori_label_directory"
select opt in "${options[@]}"
do
case $opt in
### 选项1 ###
"移动图片与重命名")
echo -e "\033[31mRun运行:\033[0mbash $script_path/1_rename_pics.sh -i $ori_image_directory -l $ori_label_directory"
bash $script_path/1_rename_pics.sh -i "$ori_image_directory" -l "$ori_label_directory"
echo ""
break
;;
### 选项2 ###
"图片统一大小与类型")
echo -n "请输入宽度默认1920:"
read -r width
echo -n "请输入高度默认1080:"
read -r height
if [ -z "$width" ]; then
width=1920
fi
if [ -z "$height" ]; then
height=1080
fi
# 如果输入图片路径为空
if [ -z $ori_image_directory ];then
echo -e "\033[31mRun运行:\033[0mbash $script_path/2_reformate_pics.sh -l $ori_label_directory -w $width -h $height "
bash $script_path/2_reformate_pics.sh -l $ori_label_directory -w $width -h $height
elif [ -z $ori_label_directory ];then
echo -e "\033[31mRun运行:\033[0mbash $script_path/2_reformate_pics.sh -l $ori_label_directory -w $width -h $height "
bash $script_path/2_reformate_pics.sh -i $ori_image_directory -w $width -h $height
else
echo -e "\033[31mRun运行:\033[0mbash $script_path/2_reformate_pics.sh -i $ori_image_directory -l $ori_label_directory -w $width -h $height"
bash $script_path/2_reformate_pics.sh -i $ori_image_directory -l $ori_label_directory -w $width -h $height
fi
echo ""
break
;;
### 选项3 ###
"image、label配对检测")
while [ ! -d "$ori_image_directory" ]; do
echo -e "\033[31mImage地址为:$ori_image_directorys,其不存在\033[0m"
echo -n "请输入image所在地址:"
read -r ori_image_directorys
ori_image_directory=$(readlink -f "$ori_image_directorys")
done
while [ ! -d "$ori_label_directory" ]; do
echo -e "\033[31mLabel地址为:$ori_label_directorys,其不存在\033[0m"
echo -n "请输入label所在地址:"
read -r ori_label_directorys
ori_label_directory=$(readlink -f "$ori_label_directorys")
done
echo -n "请输入图片前缀文本(默认为\"\":"
read -r prefix # 禁止转译
echo -n "请输入图片后缀文本(非.png类后缀名默认为\"\":"
read -r suffix
if [ -z $preffix ];then
preffix=""
fi
if [ -z $suffix ];then
suffix=""
fi
echo -e "\033[31mRun运行:\033[0mbash $script_path/3_pair_ori_label.sh -i $ori_image_directory -l $ori_label_directory -p $prefix -s $suffix"
bash $script_path/3_pair_ori_label.sh -i $ori_image_directory -l $ori_label_directory -p "$prefix" -s "$suffix"
echo ""
break
;;
### 选项4 ###
"对label进行重建")
# 判断Label目录是否存在
while [ -z "$ori_label_directory" ]; do
echo -e "\033[31mLabel地址为:$ori_label_directorys,其存在异常\033[0m"
echo -n "请输入堆叠结果存储目录地址:"
read -r ori_label_directorys
ori_label_directory=$(readlink -f "$ori_label_directorys")
done
echo -e "\033[31mRun运行:\033[0mbash $script_path/4_rebuild_labels.sh -l $ori_label_directory"
bash $script_path/4_rebuild_labels.sh -l $ori_label_directory
echo ""
break
;;
### 选项5 ###
"TOOL_image、label堆叠")
while [ ! -d "$ori_image_directory" ]; do
echo -e "\033[31mImage地址为:$ori_image_directorys,其不存在\033[0m"
echo -n "请输入image所在地址:"
read -r ori_image_directorys
ori_image_directory=$(readlink -f "$ori_image_directorys")
done
while [ ! -d "$ori_label_directory" ]; do
echo -e "\033[31mLabel地址为:$ori_label_directorys,其不存在\033[0m"
echo -n "请输入label所在地址:"
read -r ori_label_directorys
ori_label_directory=$(readlink -f "$ori_label_directorys")
done
echo -n "请输入堆叠图片透明程度0~1,0为最透明默认为0.3"
read -r alpha
echo -n "请输入图片前缀文本(默认为\"\":"
read -r prefix
echo -n "请输入图片后缀文本(非.png类后缀名默认为\"\":"
read -r suffix
if [ -z "$alpha" ]; then
alpha="0.3"
fi
stack_result_directory=""$ori_label_directory"_堆叠_"$alpha"_透明度"
echo -n "请输入堆叠结果存储目录(默认为$stack_result_directory)"
read -r stack_result_directorys
# 判断堆叠目录非为空则进行转换
if [ ! -z "$stack_result_directorys" ]; then
stack_result_directory=$(readlink -f "$stack_result_directorys")
fi
# 判断Label目录是否存在
while [ -z "$ori_label_directory" ]; do
echo -e "\033[31mLabel地址为:$ori_label_directorys,其存在异常\033[0m"
echo -n "请输入堆叠结果存储目录地址:"
read -r ori_label_directorys
ori_label_directory=$(readlink -f "$ori_label_directorys")
done
echo -e "\033[31mRun运行:\033[0mbash $script_path/5_TOOL_stack_pics.sh -i $ori_image_directory -l $ori_label_directory -p $prefix -s $suffix -a $alpha -r $stack_result_directory"
bash $script_path/5_TOOL_stack_pics.sh -i "$ori_image_directory" -l "$ori_label_directory" -p "$prefix" -s "$suffix" -a "$alpha" -r "$stack_result_directory"
echo ""
break
;;
### 选项6 图片拼接 ###
"TOOL_image、label拼接")
while [ ! -d "$ori_image_directory" ]; do
echo -e "\033[31mImage地址为:$ori_image_directorys,其不存在\033[0m"
echo -n "请输入image所在地址:"
read -r ori_image_directorys
ori_image_directory=$(readlink -f "$ori_image_directorys")
done
echo "1.label目前为\"$ori_label_directory\""
echo -n "是否调整label所在文件夹(默认为不调整,有输入视为调整):"
read -r temp
while [ ! -z "$temp" ]; do
read -r ori_label_directorys
ori_label_directorys=$(readlink -f "$ori_label_directorys")
done
while [ ! -d "$ori_label_directory" ]; do
echo -e "\033[31mLabel地址为:$ori_label_directorys,其不存在\033[0m"
echo -n "请输入label所在地址:"
read -r ori_label_directorys
ori_label_directory=$(readlink -f "$ori_label_directorys")
done
stitch_result_directory=""$ori_label_directory"_拼接"
echo -n "2.请输入堆叠结果存储目录(默认为$stitch_result_directory)"
read -r stitch_result_directorys
# 判断堆叠目录非为空则进行转换
if [ ! -z "$stitch_result_directorys" ]; then
stitch_result_directory=$(readlink -f "$stitch_result_directorys")
fi
echo -n "3.请输入图片前缀文本(默认为\"\":"
read -r prefix
echo -n "4.请输入图片后缀文本(非.png类后缀名默认为\"\":"
read -r suffix
echo -e "\033[31mRun运行:\033[0mbash $script_path/6_TOOL_stitch_pics.sh -i $ori_image_directory -l $ori_label_directory -p "$prefix" -s "$suffix" -r "$stitch_result_directory""
bash $script_path/6_TOOL_stitch_pics.sh -i "$ori_image_directory" -l "$ori_label_directory" -p "$prefix" -s "$suffix" -r "$stitch_result_directory"
exit 0
;;
### 选项7 退出 ###
"Quit")
echo -e "\033[35mSeg_data_run.sh Exiting...\033[0m"
exit 0
;;
*)
echo "Invalid option: $REPLY"
echo -e ""
break
;;
esac
done
done

View File

@@ -0,0 +1,14 @@
# 创建环境
conda create -n Deal_pics python=3.8
sudo apt install imagemagick
# 安装包
conda activate Deal_pics
pip install --upgrade pip
pip install opencv-python
# 解决BUG
# 命令python cv2.Canny(image, 50, 150)
# 错误cv2.error: OpenCV(4.5.5) /io/opencv/modules/imgproc/src/canny.cpp:829: error: (-215:Assertion failed) _src.depth() == CV_8U in function 'Canny'
# 错误原因Canny函数只对0~255起作用对0~1不起作用
# 解决方案image = (image*255).astype(np.uint8)

View File

@@ -0,0 +1,28 @@
# 相关程序所在位置:./1_Preprocess_pics已加入路径
1.对于图片批量重命名(允许只输入-i或-l一个参数
1_rename_pics.sh -i <ori_image_directory> -l <ori_label_directory> [-h]
2.对于图片大小批量修改(允许只输入-i或-l一个参数
2_reformate_pics.sh -i <ori_image_directory> -l <ori_label_directory> [ -w <width_of_pic> -h <height_of_pic> -help]
默认:-w=1920 -h=1080
3.对于原始图片、标签图片进行配对(必须输入-i和-l两个参数文件可带前缀或后缀
3_pair_ori_label.sh -i <ori_image_directory> -l <ori_label_directory> [ -p <prefix> -s <suffix> -h]
默认:-p="" -s=""
4.对于标签图片进行进一步处理
※ 如果Label有变请修改 4_deal_labels.py main 中的 Annotate_CLASSES、Annotate_PALETTE
4_rebuild_labels.sh -l <ori_label_directory> [ -h ]
5.TOOL - 对于image图像、label图片进行堆叠透明度可调文件可带前缀或后缀
5_TOOL_stack_pics.sh -i <ori_image_directory> -l <ori_label_directory> -r <stack_result_directory> [ -a <alpha> -p <prefix> -s <suffix> -h]
默认:-a=0.3 -p="" -s=""
6.TOOL - 对于image图像、label图像进行左右放置文件可带前缀或后缀
6_TOOL_stitch_pics.sh -i <ori_image_directory> -l <ori_label_directory> -r <stitch_result_directory> [ -p <prefix> -s <suffix> -h]
============ 用法 ============
1. 将原始Label与现有label进行匹配查看差异
bash 6_TOOL_stitch_pics.sh -i ./A_Label -l ./A_Label_pro_label_fold -r ./A_Label_Compare -s _label

264
README.md Normal file
View File

@@ -0,0 +1,264 @@
# Seg 图像分割项目使用说明
本项目是一个多路线图像分割实验与推理工程,包含自研 `segmentation_models_pytorch` 训练流程、YOLO 分割流程、MMSegmentation 流程,以及数据预处理、视频抽帧、图片叠加和结果分析工具。
## 1. 项目结构
| 路径 | 用途 |
| --- | --- |
| `Seg_All_In_One_SegModel/` | 基于 `segmentation_models_pytorch` 的语义分割训练、推理、参数量/FLOPs/FPS 统计和输出完整性检查。 |
| `Seg_All_In_One_YoloModel/` | 基于 Ultralytics YOLO 的实例/语义分割训练、推理、热力图可视化和横向对比。 |
| `Seg_All_In_One_MMSeg/` | 基于 OpenMMLab MMSegmentation 的模型配置、训练和结果汇总流程。 |
| `Seg_All_In_One_Analysis/` | 汇总不同模型的指标、FLOPs、FPS生成表格和 mIoU/FPS 图。 |
| `DataSet_Own/1. 图片预处理(内含使用手册)/` | 自有数据的重命名、尺寸统一、图像/标签配对、标签重建、叠加和拼接检查。 |
| `Seg_Predict_Own_Video_V2/` | 将视频按固定间隔抽帧,转换为 `DataSet_Public/<dataset>/images/val` 格式。 |
| `Tool-图片堆叠/` | 快速检查原图和标签是否匹配,并生成透明叠加图。 |
| `Tool-可视化/` | YOLO 标签生成、热图、FPS 等可视化辅助脚本。 |
| `Back_Up.sh` | 将算法目录同步到 `Hardisk/``Nas_BackUp_Seg/` 的本地/NAS 备份脚本。 |
大体量目录如 `DataSet_Public/``DataSet_Own/` 中的数据、`BestMode_Predict_Results_DataSet_Public/``Hardisk/``Nas_BackUp_Seg/`、模型权重和训练产物不进入 Git。
## 2. Conda 环境
推荐使用独立环境 `seg_smp`。本机已有可运行的 `SMP` 环境时,最快方式是克隆它:
```bash
conda create --name seg_smp --clone SMP -y
conda activate seg_smp
python -V
python -c "import torch; print(torch.cuda.is_available(), torch.__version__)"
```
从零安装时可参考:
```bash
conda create -n seg_smp python=3.9 -y
conda activate seg_smp
pip install torch==2.8.0+cu129 torchvision==0.23.0+cu129 --index-url https://download.pytorch.org/whl/cu129
pip install -r requirements-seg_smp.txt
cd Seg_All_In_One_MMSeg
pip install -v -e .
cd ..
```
如果使用批量脚本,默认会激活 `seg_smp`。如需临时使用旧环境:
```bash
SEG_CONDA_ENV=SMP bash yolo_train.sh
```
环境验证:
```bash
conda run -n seg_smp python -c "import torch, segmentation_models_pytorch, ultralytics, mmcv, mmengine, mmseg, cv2, albumentations; print('ok')"
```
## 3. 数据约定
SegModel 默认读取:
```text
DataSet_Public/<dataset>/
images/train
images/val
labels_GT/train
labels_GT/val
```
YOLO 默认读取:
```text
DataSet_Public/<dataset>/
images/train
images/val
labels/train
labels/val
```
切换数据集时:
- SegModel修改 `Seg_All_In_One_SegModel/config.py` 中的 `DATA_DIR``OUTPUTS_DIR``PREDICT_BEST_MODEL_DIR`、类别和图像尺寸。
- YOLO修改 `Seg_All_In_One_YoloModel/dataset.yaml` 中的 `path``train``val``test``names`;必要时修改 `yolo_config.py` 的训练参数。
- MMSeg`Seg_All_In_One_MMSeg/※使用手册/※2025_9_23_MMSeg使用手册` 生成数据集和算法配置。
## 4. SegModel 使用方式
```bash
cd Seg_All_In_One_SegModel
conda activate seg_smp
# 单模型训练
CUDA_VISIBLE_DEVICES=0 python train.py -a Unet
# 批量训练
bash train.sh
# 单模型推理
CUDA_VISIBLE_DEVICES=0 python 1_predict.py -a Unet
# 批量推理
bash predict.sh
# 参数量、FLOPs、FPS
CUDA_VISIBLE_DEVICES=0 python 2_predict_params_and_FLOPs_V2.py
# 检查预测 raw mask 是否齐全
python 1_predict_raw_masks_check.py
```
可选模型包括:
```text
Unet, UnetPlusPlus, FPN, PSPNet, DeepLabV3, DeepLabV3Plus,
Linknet, MAnet, PAN, UPerNet, Segformer, DPT
```
训练结果先写入 `DataSet_Public_outputs/<dataset>_outputs-SegModel/`,脚本结束后会移动到 `Hardisk/`;推理结果写入 `BestMode_Predict_Results_DataSet_Public/<dataset>_outputs-SegModel/`
## 5. YOLO 使用方式
```bash
cd Seg_All_In_One_YoloModel
conda activate seg_smp
# 检查当前 dataset.yaml 解析出的路径
python yolo_config.py
# 单模型训练
CUDA_VISIBLE_DEVICES=0 python yolo_train.py --model "YOLOv8n-seg"
# 批量训练
bash yolo_train.sh
# 复制最佳权重到预测目录
bash ./Tool_Yolo_Copy_Best_Model.sh --pt_name "best.pt"
# 单模型推理
CUDA_VISIBLE_DEVICES=0 python yolo_predict_V2.py --model "YOLOv8n-seg" --conf 0.2 --pt_name "best.pt"
# 批量推理
bash yolo_predict.sh --conf 0.2 --pt_name "best.pt"
# 批量热图
bash yolo_predict.sh --heatmap_method "All" --pt_name "best.pt"
# 横向对比
python yolo_predict_V2_compare_all.py --pt_name "all"
# 检查预测 raw mask 是否齐全
python yolo_predict_raw_masks_check.py --pt_name "best.pt"
```
常用模型包括:
```text
YOLOv8n-seg, YOLOv8s-seg, YOLOv8m-seg, YOLOv8l-seg, YOLOv8x-seg,
YOLOv9c-seg, YOLOv9e-seg,
YOLO11n-seg, YOLO11s-seg, YOLO11m-seg, YOLO11l-seg, YOLO11x-seg,
YOLO12-seg
```
## 6. MMSeg 使用方式
```bash
cd Seg_All_In_One_MMSeg
conda activate seg_smp
# 首次或新增模块后注册工程
pip install -v -e .
# 下载/保存必要预训练权重
python My_All_In_One/0_Initial_Save_All_Model_locally.py
# 生成数据配置
python My_All_In_One/1_Initial_Data_All_data_from_1_Data_Parameter-V2.py
# 生成算法配置
python My_All_In_One/2_Initial_Alg_All_data_from_2_Alg_Program-V2.py
# 参数量、FLOPs、FPS
CUDA_VISIBLE_DEVICES=0 python My_All_In_One/4_1_predict_params_FLOPs_FPS_V2.py
# 指标汇总
CUDA_VISIBLE_DEVICES=0 python My_All_In_One/4_2_predict_matrics_from_log_V2.py
# 生成预测图和表格
CUDA_VISIBLE_DEVICES=0 python My_All_In_One/4_3_predict_draw_pictures_and_tabels.py
# 提取 loss 和 best mIoU
CUDA_VISIBLE_DEVICES=0 python My_All_In_One/4_4_extract_loss_and_best_miou.py
```
MMSeg 对 `mmcv/mmengine/mmsegmentation` 版本较敏感;若遇到 `mmcv` CUDA 算子或版本错误,优先参考 MMSeg 使用手册中的安装记录。
## 7. 数据预处理与视频抽帧
自有图片预处理:
```bash
cd "DataSet_Own/1. 图片预处理(内含使用手册)"
bash 1_rename_pics.sh -i <ori_image_directory> -l <ori_label_directory>
bash 2_reformate_pics.sh -i <ori_image_directory> -l <ori_label_directory> -w 1920 -h 1080
bash 3_pair_ori_label.sh -i <ori_image_directory> -l <ori_label_directory>
bash 4_rebuild_labels.sh -l <ori_label_directory>
bash 5_TOOL_stack_pics.sh -i <ori_image_directory> -l <ori_label_directory> -r <stack_result_directory> -a 0.3
bash 6_TOOL_stitch_pics.sh -i <ori_image_directory> -l <ori_label_directory> -r <stitch_result_directory>
```
视频抽帧:
```bash
cd Seg_Predict_Own_Video_V2
python 1_Save_Frame_V2.py \
--video ./LC_Video_1.mp4 \
--resize "1920x1080" \
--output_dir "../DataSet_Public/5_Predict_Video" \
--interval 0.5
```
输出路径为:
```text
DataSet_Public/5_Predict_Video/<video_name>/images/val
```
图片叠加检查:
```bash
cd Tool-图片堆叠
python 1_check_picture_pair.py -i ./ori -l ./label
bash 2_TOOL_stack_pics.sh -i ./ori -l ./label -r ./result_0.3透明度 -a 0.3 -s _label
```
## 8. 结果分析
```bash
cd Seg_All_In_One_Analysis
conda activate seg_smp
python 1_Analysis_All.py \
--input_dir ../BestMode_Predict_Results_DataSet_Public \
--output_dir ./
```
脚本会交互选择数据集,合并 SegModel/MMSeg 的指标和速度数据,并生成 CSV、PNG、SVG 等分析结果。
## 9. 备份与 Git
本仓库只提交程序和轻量配置。以下内容不进入 Git
- 数据集:`DataSet_Public/`、除预处理脚本外的 `DataSet_Own/`
- 训练和预测产物:`DataSet_Public_outputs/``BestMode_Predict_Results_DataSet_Public/``Hardisk/``Nas_BackUp_Seg/`
- 大权重和视频:`*.pt``*.pth``*.onnx``*.mp4`
- Python/构建缓存:`__pycache__/``.pytest_cache/``build/``dist/`
常用 Git 检查:
```bash
git status --short
git ls-files | grep -E '\\.(pt|pth|mp4|zip)$|DataSet_Public|Hardisk|BestMode' || true
```

View File

@@ -0,0 +1,298 @@
import os
import glob
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import re
import argparse
from collections import defaultdict
def get_model_family(model_name):
"""
根据模型名称提取模型族。
例如: 'my_bisenetv1_r50' -> 'my_bisenetv1'
'my_fast_scnn' -> 'my_fast_scnn'
"""
# 使用正则表达式匹配,将 _rXX 或 _dXX 等后缀去掉
match = re.match(r'^(.*?)_r\d+$', model_name)
if match:
return match.group(1)
return model_name
def select_dataset(results_dir):
"""
扫描目录,对数据集进行分组,让用户交互式选择一个数据集进行合并分析。
"""
print("正在扫描可用的数据集...")
try:
# 查找所有匹配后缀的目录
all_dirs = glob.glob(os.path.join(results_dir, '*_outputs-MMSeg')) + \
glob.glob(os.path.join(results_dir, '*_outputs-SegModel'))
if not all_dirs:
print(f"'{results_dir}' 中未找到任何数据集目录 (以 '_outputs-MMSeg''_outputs-SegModel' 结尾)。")
return None, None
# --- 新增逻辑:按基本数据集名称对目录进行分组 ---
datasets_map = defaultdict(list)
for dir_path in all_dirs:
if os.path.isdir(dir_path):
# 提取基本名称,例如 '1_CholecSeg8k-13Type-1920x1080'
base_name = re.sub(r'_outputs-(MMSeg|SegModel)$', '', os.path.basename(dir_path))
datasets_map[base_name].append(dir_path)
sorted_dataset_names = sorted(datasets_map.keys())
except Exception as e:
print(f"扫描目录 '{results_dir}' 时出错: {e}")
return None, None
print("\n请选择要合并分析的数据集:")
for i, name in enumerate(sorted_dataset_names):
# 显示每个数据集包含的源文件夹数量
source_count = len(datasets_map[name])
print(f" [{i+1}] {name} ({source_count}个源)")
while True:
try:
choice = input(f"\n请输入选项编号 (1-{len(sorted_dataset_names)}): ")
choice_idx = int(choice) - 1
if 0 <= choice_idx < len(sorted_dataset_names):
selected_name = sorted_dataset_names[choice_idx]
selected_dirs = datasets_map[selected_name] # 获取与所选数据集关联的所有目录
return selected_dirs, selected_name
else:
print("无效的选项,请输入列表中的编号。")
except (ValueError, IndexError):
print("无效的输入,请输入一个数字编号。")
except (KeyboardInterrupt, EOFError):
print("\n操作已取消。")
return None, None
def F1_plot_performance_speed(selected_dirs, dataset_name, output_base_dir):
"""
根据选定的数据集目录列表,加载并合并数据、生成图表和表格,并保存到指定的输出目录。
Args:
selected_dirs (list): 用户选择的原始数据所在的所有目录的列表。
dataset_name (str): 从目录名中提取的数据集名称。
output_base_dir (str): 保存所有输出文件的根目录。
"""
print(f"\n正在为数据集 '{dataset_name}' 合并数据并生成图表...")
# 在指定的输出根目录下,为当前数据集创建一个专属的输出文件夹
dataset_output_dir = os.path.join(output_base_dir, dataset_name)
os.makedirs(dataset_output_dir, exist_ok=True)
print(f"所有输出文件将被保存到: {dataset_output_dir}")
# --- 修改逻辑:从多个目录加载并合并数据 ---
all_metrics = []
all_fps = []
print("正在读取以下来源的数据:")
for selected_dir in selected_dirs:
print(f" - {os.path.basename(selected_dir)}")
metrics_file = os.path.join(selected_dir, f"{dataset_name}_metrics_summary_wide.csv")
fps_file = os.path.join(selected_dir, f"{dataset_name}_flops_params_fps_summary.csv")
# 检查文件是否存在
if not os.path.exists(metrics_file) or not os.path.exists(fps_file):
print(f" -> 警告: 在目录 '{os.path.basename(selected_dir)}' 中缺少数据文件,已跳过。")
continue
try:
metrics_df_part = pd.read_csv(metrics_file)
all_metrics.append(metrics_df_part)
fps_df_part = pd.read_csv(fps_file)
all_fps.append(fps_df_part)
except Exception as e:
print(f" -> 错误: 读取CSV文件时出错: {e}")
continue
if not all_metrics or not all_fps:
print("\n错误: 未能从任何有效的源目录中加载数据,无法继续生成报告。")
return
# 合并来自所有源的数据
metrics_df = pd.concat(all_metrics, ignore_index=True)
fps_df = pd.concat(all_fps, ignore_index=True)
print("\n数据合并完成。")
# 对合并后的数据进行去重处理
if 'Epoch' in metrics_df.columns:
metrics_df = metrics_df.sort_values('Epoch', ascending=False).drop_duplicates('Algorithm')
else:
metrics_df = metrics_df.drop_duplicates('Algorithm')
fps_df = fps_df.drop_duplicates('Model')
# 合并两个DataFrame
merged_df = pd.merge(metrics_df, fps_df, left_on='Algorithm', right_on='Model', how='inner')
if merged_df.empty:
print("错误: 数据合并失败。请检查 'Algorithm''Model' 列中的模型名称是否完全匹配。")
print(f" - 指标文件中的模型: {metrics_df['Algorithm'].unique()}")
print(f" - 性能文件中的模型: {fps_df['Model'].unique()}")
return
# 调用函数创建并保存摘要表格到新的输出目录
T1_create_and_save_summary_table(merged_df, dataset_output_dir, dataset_name)
# 调用函数来提取和保存所有IoU数据到新的输出目录
T2_extract_and_save_iou_data(metrics_df, dataset_output_dir, dataset_name)
# 提取模型族
merged_df['Family'] = merged_df['Model'].apply(get_model_family)
# --- 绘图 ---
plt.style.use('seaborn-v0_8-whitegrid')
fig, ax = plt.subplots(figsize=(16, 10))
# 定义颜色和标记
families = sorted(merged_df['Family'].unique())
palette = sns.color_palette("husl", len(families))
markers = ['o', 's', 'X', 'D', '^', 'P', '*', 'v', '<', '>']
# 循环绘制每个模型族
for i, family in enumerate(families):
family_df = merged_df[merged_df['Family'] == family].sort_values('Average_FPS')
color = palette[i]
marker = markers[i % len(markers)]
# 绘制散点
ax.scatter(family_df['Average_FPS'], family_df['mIoU'],
color=color, marker=marker, s=150, label=family, zorder=3)
# 如果族内有多个模型,则用线连接
if len(family_df) > 1:
ax.plot(family_df['Average_FPS'], family_df['mIoU'],
color=color, linestyle='--', linewidth=1.5, zorder=2)
# 在每个点旁边添加模型全名注释
for j, row in family_df.iterrows():
ax.text(row['Average_FPS'] * 1.01, row['mIoU'], row['Model'],
fontsize=9, verticalalignment='center')
# 设置图表属性
ax.set_title(f'Model Performance vs. Inference Speed ({dataset_name})', fontsize=18, pad=20)
ax.set_xlabel('Inference Speed (FPS)', fontsize=14)
ax.set_ylabel('Mean IoU (%)', fontsize=14)
ax.legend(title='Model Family', bbox_to_anchor=(1.02, 1), loc='upper left', borderaxespad=0.)
plt.tight_layout(rect=[0, 0, 0.88, 1]) # 调整布局为图例留出空间
plt.grid(True, which='both', linestyle='--', linewidth=0.5)
# 保存图表到新的输出目录
output_filename_png = f"F1_{dataset_name}_mIoU_vs_FPS.png"
save_file_path_png = os.path.join(dataset_output_dir, output_filename_png)
plt.savefig(save_file_path_png, dpi=600)
output_filename_svg = f"F1_{dataset_name}_mIoU_vs_FPS.svg"
save_file_path_svg = os.path.join(dataset_output_dir, output_filename_svg)
plt.savefig(save_file_path_svg)
print(f"\n图表已成功生成并保存为: {save_file_path_svg}{save_file_path_png}")
plt.close(fig) # 关闭图形,避免在循环中使用时重复显示
def T1_create_and_save_summary_table(merged_df, output_dir, dataset_name):
"""
根据合并后的数据创建、格式化并保存性能摘要表格。
"""
print("正在创建摘要表格...")
# 检查所需列是否存在
required_columns = ['Model', 'mIoU', 'mAcc', 'aAcc', 'Average_FPS', 'FLOPs', 'Params']
if not all(col in merged_df.columns for col in required_columns):
print("错误: DataFrame中缺少必要的列。请检查CSV文件内容。")
print(f" - 需要的列: {required_columns}")
print(f" - 实际的列: {merged_df.columns.tolist()}")
return
# 提取并复制数据避免修改原始DataFrame
summary_df = merged_df[required_columns].copy()
# 清理和转换数据
summary_df['FLOPs'] = summary_df['FLOPs'].astype(str).str.replace(r'\s*G', '', regex=True).astype(float)
summary_df['Params'] = summary_df['Params'].astype(str).str.replace(r'\s*M', '', regex=True).astype(float)
# 按照用户的要求重命名列
summary_df.rename(columns={
'Average_FPS': 'FPS',
'FLOPs': 'FLOPs(G)',
'Params': 'Params(M)'
}, inplace=True)
# 按 mIoU 降序排序
summary_df = summary_df.sort_values(by='mIoU', ascending=False)
# 保存表格到CSV文件
summary_filename = f"T1_{dataset_name}_performance_summary.csv"
summary_save_path = os.path.join(output_dir, summary_filename)
try:
summary_df.to_csv(summary_save_path, index=False, float_format='%.3f')
print(f"摘要表格已成功保存到: {summary_save_path}")
except Exception as e:
print(f"保存摘要表格时出错: {e}")
def T2_extract_and_save_iou_data(metrics_df, output_dir, dataset_name):
"""
从 metrics DataFrame 中提取所有 mIoU 和 Class_IoU并保存到新的CSV文件。
"""
print("正在提取所有 mIoU 和 Class_IoU 数据...")
# 检查'Algorithm'列是否存在
if 'Algorithm' not in metrics_df.columns:
print("错误: 'Algorithm' 列未找到,无法继续。")
return
# 找出所有与IoU相关的列
iou_columns = ['Algorithm', 'mIoU'] + [col for col in metrics_df.columns if col.endswith('_IoU') and col != 'mIoU']
# 移除重复的列名(以防万一)
iou_columns = list(dict.fromkeys(iou_columns))
# 提取数据
iou_df = metrics_df[iou_columns].copy()
# 按 mIoU 降序排序,便于查看
if 'mIoU' in iou_df.columns:
iou_df = iou_df.sort_values(by='mIoU', ascending=False)
# 定义并保存文件
iou_filename = f"T2_{dataset_name}_all_iou_summary.csv"
iou_save_path = os.path.join(output_dir, iou_filename)
try:
iou_df.to_csv(iou_save_path, index=False, float_format='%.2f')
print(f"所有IoU数据已成功保存到: {iou_save_path}")
except Exception as e:
print(f"保存IoU数据时出错: {e}")
if __name__ == '__main__':
# --- 设置命令行参数解析 ---
parser = argparse.ArgumentParser(description="从模型评估结果生成性能与速度对比图和摘要表。")
parser.add_argument(
'--input_dir',
type=str,
default='../BestMode_Predict_Results_DataSet_Public',
help="包含所有数据集结果的根目录 (例如 '..._outputs-MMSeg''..._outputs-SegModel' 的父目录)。"
)
parser.add_argument(
'--output_dir',
type=str,
default='./',
help="用于存储所有生成的图表和表格的根目录。"
)
args = parser.parse_args()
# 确保输出目录存在
os.makedirs(args.output_dir, exist_ok=True)
# 启动交互式选择
selected_directories, selected_dataset_name = select_dataset(args.input_dir)
# 如果用户成功选择,则生成图表和表格
if selected_directories and selected_dataset_name:
F1_plot_performance_speed(selected_directories, selected_dataset_name, args.output_dir)

File diff suppressed because it is too large Load Diff

After

Width:  |  Height:  |  Size: 86 KiB

View File

@@ -0,0 +1,21 @@
Model,mIoU,mAcc,aAcc,FPS,FLOPs(G),Params(M)
UnetPlusPlus,96.860,95.430,99.750,11.940,590.910,26.080
UPerNet,96.670,95.380,99.740,17.250,574.480,29.600
MAnet,96.630,94.960,99.740,23.480,271.820,31.790
Unet,96.590,95.520,99.730,26.290,253.380,24.440
DeepLabV3Plus,96.500,94.990,99.730,33.210,252.410,22.440
Linknet,96.460,94.550,99.720,32.820,161.800,21.770
Segformer,96.450,94.880,99.720,21.020,209.450,21.880
DeepLabV3,96.420,94.730,99.720,13.860,871.240,26.010
FPN,96.410,94.740,99.720,34.920,219.570,23.160
PAN,96.370,94.480,99.720,37.630,238.120,21.480
DPT,96.310,94.900,99.710,1.900,1696.580,137.810
PSPNet,96.010,94.610,99.690,79.510,76.810,21.490
my_fastfcn_r50,89.740,94.210,97.830,10.620,1032.000,66.346
my_icnet_r50,88.840,93.150,97.780,58.690,122.000,47.527
my_icnet_r18,85.760,92.400,96.600,101.260,73.869,24.873
my_bisenetv1_r50,82.640,89.980,95.690,13.630,784.000,56.867
my_bisenetv1_r18,82.610,89.220,94.890,66.760,118.000,13.274
my_bisenetv2,74.610,82.580,92.090,68.050,97.578,3.353
my_fast_scnn,69.290,76.970,93.650,179.900,7.426,1.400
my_en_bisenetv2,30.950,44.500,67.960,66.090,62.729,2.776
1 Model mIoU mAcc aAcc FPS FLOPs(G) Params(M)
2 UnetPlusPlus 96.860 95.430 99.750 11.940 590.910 26.080
3 UPerNet 96.670 95.380 99.740 17.250 574.480 29.600
4 MAnet 96.630 94.960 99.740 23.480 271.820 31.790
5 Unet 96.590 95.520 99.730 26.290 253.380 24.440
6 DeepLabV3Plus 96.500 94.990 99.730 33.210 252.410 22.440
7 Linknet 96.460 94.550 99.720 32.820 161.800 21.770
8 Segformer 96.450 94.880 99.720 21.020 209.450 21.880
9 DeepLabV3 96.420 94.730 99.720 13.860 871.240 26.010
10 FPN 96.410 94.740 99.720 34.920 219.570 23.160
11 PAN 96.370 94.480 99.720 37.630 238.120 21.480
12 DPT 96.310 94.900 99.710 1.900 1696.580 137.810
13 PSPNet 96.010 94.610 99.690 79.510 76.810 21.490
14 my_fastfcn_r50 89.740 94.210 97.830 10.620 1032.000 66.346
15 my_icnet_r50 88.840 93.150 97.780 58.690 122.000 47.527
16 my_icnet_r18 85.760 92.400 96.600 101.260 73.869 24.873
17 my_bisenetv1_r50 82.640 89.980 95.690 13.630 784.000 56.867
18 my_bisenetv1_r18 82.610 89.220 94.890 66.760 118.000 13.274
19 my_bisenetv2 74.610 82.580 92.090 68.050 97.578 3.353
20 my_fast_scnn 69.290 76.970 93.650 179.900 7.426 1.400
21 my_en_bisenetv2 30.950 44.500 67.960 66.090 62.729 2.776

View File

@@ -0,0 +1,21 @@
Algorithm,mIoU,10_IoU,11_IoU,12_IoU,1_IoU,2_IoU,3_IoU,4_IoU,5_IoU,6_IoU,7_IoU,8_IoU,9_IoU,背景_IoU
UnetPlusPlus,96.86,84.08,71.33,99.41,96.18,96.27,94.45,91.76,90.90,92.24,93.28,79.37,98.01,97.99
UPerNet,96.67,83.18,71.32,99.32,95.95,96.10,94.12,90.87,90.75,92.12,92.90,78.68,97.92,97.85
MAnet,96.63,83.84,69.46,99.34,95.94,96.05,94.00,91.09,90.14,91.92,92.85,78.50,98.05,97.82
Unet,96.59,83.11,71.34,99.34,95.88,96.01,93.90,90.62,90.34,91.67,92.45,79.29,97.73,97.81
DeepLabV3Plus,96.50,83.07,71.81,99.28,95.86,95.85,93.77,89.92,90.31,91.77,92.17,78.45,97.69,97.73
Linknet,96.46,81.22,71.54,99.33,95.69,95.83,93.64,89.90,90.38,91.54,92.29,78.48,97.57,97.71
Segformer,96.45,81.30,67.75,99.27,95.74,95.85,93.56,89.94,89.98,91.87,92.42,78.16,97.80,97.73
DeepLabV3,96.42,82.35,67.37,99.26,95.76,95.78,93.51,88.96,90.20,91.91,92.48,76.75,97.86,97.75
FPN,96.41,82.43,69.98,99.24,95.71,95.74,93.61,89.88,90.22,91.74,92.23,77.50,97.76,97.69
PAN,96.37,81.33,71.47,99.24,95.65,95.77,93.51,89.62,90.08,91.60,92.43,77.88,97.81,97.61
DPT,96.31,83.17,70.59,99.27,95.49,95.64,93.48,89.71,90.34,91.36,92.53,78.20,97.63,97.53
PSPNet,96.01,81.24,68.49,99.13,95.33,95.30,92.75,87.61,89.36,91.01,91.76,76.03,97.51,97.45
my_fastfcn_r50,89.74,83.23,64.19,97.57,95.43,95.86,94.40,91.26,89.54,90.49,93.22,76.89,97.98,96.56
my_icnet_r50,88.84,80.89,61.34,97.85,94.99,95.16,93.92,89.80,89.45,89.54,91.77,75.20,97.76,97.26
my_icnet_r18,85.76,79.53,60.25,94.73,93.08,94.07,92.81,84.45,87.61,83.91,91.24,73.36,84.31,95.45
my_bisenetv1_r50,82.64,80.56,67.31,97.75,91.70,91.28,85.29,83.91,84.91,73.37,80.36,71.34,74.45,92.07
my_bisenetv1_r18,82.61,74.28,59.56,94.47,88.80,91.20,88.70,87.56,84.82,78.50,86.37,68.96,80.99,89.71
my_bisenetv2,74.61,71.97,0.00,88.12,89.12,85.45,84.42,77.65,77.16,75.13,87.12,65.45,86.14,82.19
my_fast_scnn,69.29,0.00,0.00,92.37,86.24,88.04,82.33,78.77,76.54,80.84,84.43,63.59,76.35,91.32
my_en_bisenetv2,30.95,0.00,0.00,79.41,56.32,44.01,28.35,27.32,47.21,20.52,26.05,27.07,0.00,46.09
1 Algorithm mIoU 10_IoU 11_IoU 12_IoU 1_IoU 2_IoU 3_IoU 4_IoU 5_IoU 6_IoU 7_IoU 8_IoU 9_IoU 背景_IoU
2 UnetPlusPlus 96.86 84.08 71.33 99.41 96.18 96.27 94.45 91.76 90.90 92.24 93.28 79.37 98.01 97.99
3 UPerNet 96.67 83.18 71.32 99.32 95.95 96.10 94.12 90.87 90.75 92.12 92.90 78.68 97.92 97.85
4 MAnet 96.63 83.84 69.46 99.34 95.94 96.05 94.00 91.09 90.14 91.92 92.85 78.50 98.05 97.82
5 Unet 96.59 83.11 71.34 99.34 95.88 96.01 93.90 90.62 90.34 91.67 92.45 79.29 97.73 97.81
6 DeepLabV3Plus 96.50 83.07 71.81 99.28 95.86 95.85 93.77 89.92 90.31 91.77 92.17 78.45 97.69 97.73
7 Linknet 96.46 81.22 71.54 99.33 95.69 95.83 93.64 89.90 90.38 91.54 92.29 78.48 97.57 97.71
8 Segformer 96.45 81.30 67.75 99.27 95.74 95.85 93.56 89.94 89.98 91.87 92.42 78.16 97.80 97.73
9 DeepLabV3 96.42 82.35 67.37 99.26 95.76 95.78 93.51 88.96 90.20 91.91 92.48 76.75 97.86 97.75
10 FPN 96.41 82.43 69.98 99.24 95.71 95.74 93.61 89.88 90.22 91.74 92.23 77.50 97.76 97.69
11 PAN 96.37 81.33 71.47 99.24 95.65 95.77 93.51 89.62 90.08 91.60 92.43 77.88 97.81 97.61
12 DPT 96.31 83.17 70.59 99.27 95.49 95.64 93.48 89.71 90.34 91.36 92.53 78.20 97.63 97.53
13 PSPNet 96.01 81.24 68.49 99.13 95.33 95.30 92.75 87.61 89.36 91.01 91.76 76.03 97.51 97.45
14 my_fastfcn_r50 89.74 83.23 64.19 97.57 95.43 95.86 94.40 91.26 89.54 90.49 93.22 76.89 97.98 96.56
15 my_icnet_r50 88.84 80.89 61.34 97.85 94.99 95.16 93.92 89.80 89.45 89.54 91.77 75.20 97.76 97.26
16 my_icnet_r18 85.76 79.53 60.25 94.73 93.08 94.07 92.81 84.45 87.61 83.91 91.24 73.36 84.31 95.45
17 my_bisenetv1_r50 82.64 80.56 67.31 97.75 91.70 91.28 85.29 83.91 84.91 73.37 80.36 71.34 74.45 92.07
18 my_bisenetv1_r18 82.61 74.28 59.56 94.47 88.80 91.20 88.70 87.56 84.82 78.50 86.37 68.96 80.99 89.71
19 my_bisenetv2 74.61 71.97 0.00 88.12 89.12 85.45 84.42 77.65 77.16 75.13 87.12 65.45 86.14 82.19
20 my_fast_scnn 69.29 0.00 0.00 92.37 86.24 88.04 82.33 78.77 76.54 80.84 84.43 63.59 76.35 91.32
21 my_en_bisenetv2 30.95 0.00 0.00 79.41 56.32 44.01 28.35 27.32 47.21 20.52 26.05 27.07 0.00 46.09

File diff suppressed because it is too large Load Diff

After

Width:  |  Height:  |  Size: 84 KiB

View File

@@ -0,0 +1,21 @@
Model,mIoU,mAcc,aAcc,FPS,FLOPs(G),Params(M)
DeepLabV3,80.740,83.990,99.520,14.030,871.240,26.010
PSPNet,79.980,83.730,99.500,79.650,76.810,21.490
UPerNet,79.960,85.130,99.500,17.440,574.480,29.600
PAN,79.730,83.990,99.500,38.020,238.120,21.480
DeepLabV3Plus,79.610,85.070,99.480,33.420,252.410,22.440
Segformer,79.250,83.200,99.480,21.050,209.450,21.880
FPN,78.990,83.980,99.470,35.060,219.570,23.160
MAnet,77.380,82.040,99.420,23.610,271.820,31.790
UnetPlusPlus,77.250,81.010,99.440,12.080,590.910,26.080
Unet,76.160,83.160,99.380,26.410,253.380,24.440
Linknet,75.510,81.050,99.380,33.040,161.800,21.770
my_fastfcn_r50,71.040,79.630,92.280,10.610,1032.000,66.346
my_icnet_r50,70.900,78.660,94.020,59.150,122.000,47.526
my_icnet_r18,64.370,76.040,91.130,102.830,73.857,24.873
DPT,58.120,62.610,98.840,1.910,1696.580,137.810
my_bisenetv1_r50,49.540,70.640,85.890,13.670,784.000,56.864
my_bisenetv1_r18,43.630,51.500,86.400,67.190,118.000,13.273
my_fast_scnn,35.470,53.070,78.230,178.010,7.426,1.400
my_bisenetv2,30.770,46.870,67.040,68.880,97.479,3.350
my_en_bisenetv2,21.060,28.780,81.280,66.830,62.629,2.773
1 Model mIoU mAcc aAcc FPS FLOPs(G) Params(M)
2 DeepLabV3 80.740 83.990 99.520 14.030 871.240 26.010
3 PSPNet 79.980 83.730 99.500 79.650 76.810 21.490
4 UPerNet 79.960 85.130 99.500 17.440 574.480 29.600
5 PAN 79.730 83.990 99.500 38.020 238.120 21.480
6 DeepLabV3Plus 79.610 85.070 99.480 33.420 252.410 22.440
7 Segformer 79.250 83.200 99.480 21.050 209.450 21.880
8 FPN 78.990 83.980 99.470 35.060 219.570 23.160
9 MAnet 77.380 82.040 99.420 23.610 271.820 31.790
10 UnetPlusPlus 77.250 81.010 99.440 12.080 590.910 26.080
11 Unet 76.160 83.160 99.380 26.410 253.380 24.440
12 Linknet 75.510 81.050 99.380 33.040 161.800 21.770
13 my_fastfcn_r50 71.040 79.630 92.280 10.610 1032.000 66.346
14 my_icnet_r50 70.900 78.660 94.020 59.150 122.000 47.526
15 my_icnet_r18 64.370 76.040 91.130 102.830 73.857 24.873
16 DPT 58.120 62.610 98.840 1.910 1696.580 137.810
17 my_bisenetv1_r50 49.540 70.640 85.890 13.670 784.000 56.864
18 my_bisenetv1_r18 43.630 51.500 86.400 67.190 118.000 13.273
19 my_fast_scnn 35.470 53.070 78.230 178.010 7.426 1.400
20 my_bisenetv2 30.770 46.870 67.040 68.880 97.479 3.350
21 my_en_bisenetv2 21.060 28.780 81.280 66.830 62.629 2.773

View File

@@ -0,0 +1,21 @@
Algorithm,mIoU,1_IoU,2_IoU,3_IoU,4_IoU,5_IoU,6_IoU,7_IoU,8_IoU,9_IoU,背景_IoU
DeepLabV3,80.74,68.94,77.29,80.52,86.86,67.33,78.49,40.91,85.00,80.77,95.97
PSPNet,79.98,66.02,78.51,79.55,86.79,65.31,81.79,43.71,88.19,78.72,95.59
UPerNet,79.96,68.17,79.22,79.58,88.29,66.22,76.92,48.90,87.01,78.32,95.69
PAN,79.73,68.17,79.87,80.10,87.75,67.79,80.61,45.17,85.84,77.10,95.55
DeepLabV3Plus,79.61,67.65,80.67,79.04,86.41,67.82,78.38,45.17,84.93,78.48,95.51
Segformer,79.25,70.26,80.48,79.32,86.77,64.76,77.48,40.30,86.94,76.90,95.67
FPN,78.99,64.32,76.67,77.73,85.13,66.86,80.62,41.37,86.36,78.77,95.61
MAnet,77.38,68.36,75.96,76.54,85.39,64.29,75.99,42.33,80.07,76.40,95.13
UnetPlusPlus,77.25,68.41,80.79,78.11,88.39,61.29,75.66,43.16,78.42,73.51,95.01
Unet,76.16,65.81,75.72,77.40,86.54,64.59,78.09,41.00,86.14,71.37,94.50
Linknet,75.51,67.53,72.66,77.15,85.43,62.20,66.99,42.97,80.32,72.87,94.71
my_fastfcn_r50,71.04,72.94,75.20,83.46,86.34,75.83,79.99,15.77,77.42,52.27,91.19
my_icnet_r50,70.90,70.59,79.88,82.49,88.10,75.12,84.08,0.00,76.83,58.64,93.31
my_icnet_r18,64.37,68.27,66.15,76.69,80.02,69.07,77.24,0.00,67.25,48.86,90.15
DPT,58.12,35.86,57.90,52.58,69.75,25.48,51.48,11.85,64.40,61.94,90.98
my_bisenetv1_r50,49.54,41.45,52.43,51.70,70.38,31.42,58.09,4.18,58.32,42.20,85.24
my_bisenetv1_r18,43.63,40.48,33.63,57.69,68.59,18.73,41.95,5.47,46.50,37.73,85.56
my_fast_scnn,35.47,17.13,31.73,32.59,67.72,13.07,44.63,0.00,39.41,30.97,77.47
my_bisenetv2,30.77,13.45,36.67,16.43,60.36,13.90,36.03,0.00,37.48,28.07,65.35
my_en_bisenetv2,21.06,1.90,15.66,34.99,36.61,5.25,18.82,0.00,14.46,0.37,82.57
1 Algorithm mIoU 1_IoU 2_IoU 3_IoU 4_IoU 5_IoU 6_IoU 7_IoU 8_IoU 9_IoU 背景_IoU
2 DeepLabV3 80.74 68.94 77.29 80.52 86.86 67.33 78.49 40.91 85.00 80.77 95.97
3 PSPNet 79.98 66.02 78.51 79.55 86.79 65.31 81.79 43.71 88.19 78.72 95.59
4 UPerNet 79.96 68.17 79.22 79.58 88.29 66.22 76.92 48.90 87.01 78.32 95.69
5 PAN 79.73 68.17 79.87 80.10 87.75 67.79 80.61 45.17 85.84 77.10 95.55
6 DeepLabV3Plus 79.61 67.65 80.67 79.04 86.41 67.82 78.38 45.17 84.93 78.48 95.51
7 Segformer 79.25 70.26 80.48 79.32 86.77 64.76 77.48 40.30 86.94 76.90 95.67
8 FPN 78.99 64.32 76.67 77.73 85.13 66.86 80.62 41.37 86.36 78.77 95.61
9 MAnet 77.38 68.36 75.96 76.54 85.39 64.29 75.99 42.33 80.07 76.40 95.13
10 UnetPlusPlus 77.25 68.41 80.79 78.11 88.39 61.29 75.66 43.16 78.42 73.51 95.01
11 Unet 76.16 65.81 75.72 77.40 86.54 64.59 78.09 41.00 86.14 71.37 94.50
12 Linknet 75.51 67.53 72.66 77.15 85.43 62.20 66.99 42.97 80.32 72.87 94.71
13 my_fastfcn_r50 71.04 72.94 75.20 83.46 86.34 75.83 79.99 15.77 77.42 52.27 91.19
14 my_icnet_r50 70.90 70.59 79.88 82.49 88.10 75.12 84.08 0.00 76.83 58.64 93.31
15 my_icnet_r18 64.37 68.27 66.15 76.69 80.02 69.07 77.24 0.00 67.25 48.86 90.15
16 DPT 58.12 35.86 57.90 52.58 69.75 25.48 51.48 11.85 64.40 61.94 90.98
17 my_bisenetv1_r50 49.54 41.45 52.43 51.70 70.38 31.42 58.09 4.18 58.32 42.20 85.24
18 my_bisenetv1_r18 43.63 40.48 33.63 57.69 68.59 18.73 41.95 5.47 46.50 37.73 85.56
19 my_fast_scnn 35.47 17.13 31.73 32.59 67.72 13.07 44.63 0.00 39.41 30.97 77.47
20 my_bisenetv2 30.77 13.45 36.67 16.43 60.36 13.90 36.03 0.00 37.48 28.07 65.35
21 my_en_bisenetv2 21.06 1.90 15.66 34.99 36.61 5.25 18.82 0.00 14.46 0.37 82.57

File diff suppressed because it is too large Load Diff

After

Width:  |  Height:  |  Size: 81 KiB

View File

@@ -0,0 +1,21 @@
Model,mIoU,mAcc,aAcc,FPS,FLOPs(G),Params(M)
FPN,77.150,91.110,99.500,204.850,27.550,23.160
DeepLabV3,77.110,92.370,99.490,94.790,109.330,26.010
PAN,76.600,90.580,99.480,234.070,29.880,21.480
UPerNet,75.930,90.910,99.450,107.590,72.100,29.600
UnetPlusPlus,75.800,90.720,99.460,89.430,74.150,26.080
Segformer,75.080,88.260,99.450,151.560,26.280,21.880
PSPNet,74.850,86.440,99.450,573.660,9.640,21.490
Unet,73.860,89.190,99.410,173.520,31.800,24.440
DeepLabV3Plus,73.830,86.410,99.420,208.780,31.680,22.440
Linknet,73.790,87.770,99.410,197.430,20.300,21.770
MAnet,73.630,89.900,99.400,152.330,33.850,31.790
my_fastfcn_r50,61.430,90.120,97.480,71.100,130.000,66.346
DPT,61.420,82.190,99.070,30.180,212.980,137.810
my_bisenetv1_r50,59.590,84.700,95.770,88.970,98.945,56.862
my_icnet_r50,57.840,80.930,94.950,179.660,15.428,47.526
my_icnet_r18,57.350,88.250,96.590,268.050,9.360,24.873
my_bisenetv1_r18,56.730,84.610,96.770,310.400,14.827,13.273
my_bisenetv2,45.950,73.530,94.010,223.740,12.311,3.348
my_fast_scnn,41.590,67.120,92.870,314.130,0.936,1.400
my_en_bisenetv2,26.470,47.770,88.930,167.350,7.907,2.771
1 Model mIoU mAcc aAcc FPS FLOPs(G) Params(M)
2 FPN 77.150 91.110 99.500 204.850 27.550 23.160
3 DeepLabV3 77.110 92.370 99.490 94.790 109.330 26.010
4 PAN 76.600 90.580 99.480 234.070 29.880 21.480
5 UPerNet 75.930 90.910 99.450 107.590 72.100 29.600
6 UnetPlusPlus 75.800 90.720 99.460 89.430 74.150 26.080
7 Segformer 75.080 88.260 99.450 151.560 26.280 21.880
8 PSPNet 74.850 86.440 99.450 573.660 9.640 21.490
9 Unet 73.860 89.190 99.410 173.520 31.800 24.440
10 DeepLabV3Plus 73.830 86.410 99.420 208.780 31.680 22.440
11 Linknet 73.790 87.770 99.410 197.430 20.300 21.770
12 MAnet 73.630 89.900 99.400 152.330 33.850 31.790
13 my_fastfcn_r50 61.430 90.120 97.480 71.100 130.000 66.346
14 DPT 61.420 82.190 99.070 30.180 212.980 137.810
15 my_bisenetv1_r50 59.590 84.700 95.770 88.970 98.945 56.862
16 my_icnet_r50 57.840 80.930 94.950 179.660 15.428 47.526
17 my_icnet_r18 57.350 88.250 96.590 268.050 9.360 24.873
18 my_bisenetv1_r18 56.730 84.610 96.770 310.400 14.827 13.273
19 my_bisenetv2 45.950 73.530 94.010 223.740 12.311 3.348
20 my_fast_scnn 41.590 67.120 92.870 314.130 0.936 1.400
21 my_en_bisenetv2 26.470 47.770 88.930 167.350 7.907 2.771

View File

@@ -0,0 +1,21 @@
Algorithm,mIoU,1_IoU,2_IoU,3_IoU,4_IoU,6_IoU,背景_IoU,5_IoU,7_IoU
FPN,77.15,68.36,56.28,89.32,75.57,90.63,97.53,0.00,0.00
DeepLabV3,77.11,70.40,56.63,88.94,75.66,91.14,97.59,0.00,0.00
PAN,76.60,67.15,58.83,88.96,66.32,93.62,97.63,0.00,0.00
UPerNet,75.93,70.51,55.94,87.54,66.97,93.12,97.42,0.00,0.00
UnetPlusPlus,75.80,71.10,53.32,89.23,62.62,92.12,97.61,0.00,0.00
Segformer,75.08,69.28,51.50,89.63,60.44,89.39,97.46,0.00,0.00
PSPNet,74.85,69.13,59.87,89.43,41.93,88.80,97.14,0.00,0.00
Unet,73.86,70.85,47.93,88.16,65.93,81.26,97.40,0.00,0.00
DeepLabV3Plus,73.83,67.50,55.09,88.72,43.60,89.56,97.30,0.00,0.00
Linknet,73.79,70.46,51.28,88.14,61.32,74.38,97.41,0.00,0.00
MAnet,73.63,69.37,50.95,86.99,64.69,88.41,97.27,0.00,0.00
my_fastfcn_r50,61.43,75.72,61.00,89.72,74.27,92.82,97.88,,
DPT,61.42,53.96,44.36,74.11,42.46,72.76,95.73,0.00,0.00
my_bisenetv1_r50,59.59,55.03,44.88,81.05,56.63,82.73,96.78,,
my_icnet_r50,57.84,55.85,40.23,81.41,59.23,72.62,95.51,,
my_icnet_r18,57.35,67.17,51.83,88.81,66.89,87.06,97.04,,
my_bisenetv1_r18,56.73,65.55,56.03,89.65,54.60,90.81,97.18,,
my_bisenetv2,45.95,35.33,43.64,75.24,32.28,86.09,95.00,,
my_fast_scnn,41.59,33.67,19.74,60.73,38.78,84.75,95.02,,
my_en_bisenetv2,26.47,21.72,3.36,51.58,1.01,42.74,91.37,,
1 Algorithm mIoU 1_IoU 2_IoU 3_IoU 4_IoU 6_IoU 背景_IoU 5_IoU 7_IoU
2 FPN 77.15 68.36 56.28 89.32 75.57 90.63 97.53 0.00 0.00
3 DeepLabV3 77.11 70.40 56.63 88.94 75.66 91.14 97.59 0.00 0.00
4 PAN 76.60 67.15 58.83 88.96 66.32 93.62 97.63 0.00 0.00
5 UPerNet 75.93 70.51 55.94 87.54 66.97 93.12 97.42 0.00 0.00
6 UnetPlusPlus 75.80 71.10 53.32 89.23 62.62 92.12 97.61 0.00 0.00
7 Segformer 75.08 69.28 51.50 89.63 60.44 89.39 97.46 0.00 0.00
8 PSPNet 74.85 69.13 59.87 89.43 41.93 88.80 97.14 0.00 0.00
9 Unet 73.86 70.85 47.93 88.16 65.93 81.26 97.40 0.00 0.00
10 DeepLabV3Plus 73.83 67.50 55.09 88.72 43.60 89.56 97.30 0.00 0.00
11 Linknet 73.79 70.46 51.28 88.14 61.32 74.38 97.41 0.00 0.00
12 MAnet 73.63 69.37 50.95 86.99 64.69 88.41 97.27 0.00 0.00
13 my_fastfcn_r50 61.43 75.72 61.00 89.72 74.27 92.82 97.88
14 DPT 61.42 53.96 44.36 74.11 42.46 72.76 95.73 0.00 0.00
15 my_bisenetv1_r50 59.59 55.03 44.88 81.05 56.63 82.73 96.78
16 my_icnet_r50 57.84 55.85 40.23 81.41 59.23 72.62 95.51
17 my_icnet_r18 57.35 67.17 51.83 88.81 66.89 87.06 97.04
18 my_bisenetv1_r18 56.73 65.55 56.03 89.65 54.60 90.81 97.18
19 my_bisenetv2 45.95 35.33 43.64 75.24 32.28 86.09 95.00
20 my_fast_scnn 41.59 33.67 19.74 60.73 38.78 84.75 95.02
21 my_en_bisenetv2 26.47 21.72 3.36 51.58 1.01 42.74 91.37

File diff suppressed because it is too large Load Diff

After

Width:  |  Height:  |  Size: 84 KiB

View File

@@ -0,0 +1,21 @@
Model,mIoU,mAcc,aAcc,FPS,FLOPs(G),Params(M)
MAnet,93.900,84.990,99.210,151.600,33.850,31.790
Segformer,93.280,82.550,99.130,151.930,26.280,21.880
UnetPlusPlus,92.970,82.760,99.090,90.260,74.150,26.080
FPN,92.670,82.440,99.050,207.350,27.550,23.160
DeepLabV3,92.550,82.140,99.030,92.870,109.330,26.010
Unet,92.530,79.230,99.030,177.100,31.800,24.440
PAN,92.480,81.380,99.020,232.430,29.880,21.480
UPerNet,92.180,80.020,98.980,105.860,72.100,29.600
Linknet,92.060,79.170,98.970,199.680,20.300,21.770
PSPNet,91.940,76.940,98.950,578.550,9.640,21.490
DeepLabV3Plus,91.500,78.630,98.890,213.500,31.680,22.440
DPT,87.840,72.480,98.380,30.860,212.980,137.810
my_fastfcn_r50,55.340,84.060,96.630,71.420,130.000,66.346
my_icnet_r50,50.400,78.440,95.210,202.090,15.428,47.526
my_bisenetv1_r50,49.620,78.030,96.150,88.850,98.945,56.862
my_icnet_r18,47.540,76.700,94.100,275.600,9.360,24.873
my_bisenetv1_r18,45.020,67.190,95.580,346.950,14.827,13.273
my_bisenetv2,38.850,65.830,93.150,243.230,12.311,3.348
my_fast_scnn,36.200,61.550,92.870,381.410,0.936,1.400
my_en_bisenetv2,21.760,41.090,86.700,203.200,7.907,2.771
1 Model mIoU mAcc aAcc FPS FLOPs(G) Params(M)
2 MAnet 93.900 84.990 99.210 151.600 33.850 31.790
3 Segformer 93.280 82.550 99.130 151.930 26.280 21.880
4 UnetPlusPlus 92.970 82.760 99.090 90.260 74.150 26.080
5 FPN 92.670 82.440 99.050 207.350 27.550 23.160
6 DeepLabV3 92.550 82.140 99.030 92.870 109.330 26.010
7 Unet 92.530 79.230 99.030 177.100 31.800 24.440
8 PAN 92.480 81.380 99.020 232.430 29.880 21.480
9 UPerNet 92.180 80.020 98.980 105.860 72.100 29.600
10 Linknet 92.060 79.170 98.970 199.680 20.300 21.770
11 PSPNet 91.940 76.940 98.950 578.550 9.640 21.490
12 DeepLabV3Plus 91.500 78.630 98.890 213.500 31.680 22.440
13 DPT 87.840 72.480 98.380 30.860 212.980 137.810
14 my_fastfcn_r50 55.340 84.060 96.630 71.420 130.000 66.346
15 my_icnet_r50 50.400 78.440 95.210 202.090 15.428 47.526
16 my_bisenetv1_r50 49.620 78.030 96.150 88.850 98.945 56.862
17 my_icnet_r18 47.540 76.700 94.100 275.600 9.360 24.873
18 my_bisenetv1_r18 45.020 67.190 95.580 346.950 14.827 13.273
19 my_bisenetv2 38.850 65.830 93.150 243.230 12.311 3.348
20 my_fast_scnn 36.200 61.550 92.870 381.410 0.936 1.400
21 my_en_bisenetv2 21.760 41.090 86.700 203.200 7.907 2.771

View File

@@ -0,0 +1,21 @@
Algorithm,mIoU,1_IoU,2_IoU,3_IoU,4_IoU,6_IoU,背景_IoU,5_IoU,7_IoU
MAnet,93.90,70.53,46.86,88.84,50.14,73.06,97.91,0.00,0.00
Segformer,93.28,68.10,50.06,85.56,36.53,75.76,97.56,0.00,0.00
UnetPlusPlus,92.97,64.84,47.57,81.19,59.57,62.17,97.61,0.00,0.00
FPN,92.67,63.45,39.01,89.55,41.62,55.78,97.70,0.00,0.00
DeepLabV3,92.55,68.69,38.16,83.63,42.43,56.29,97.59,0.00,0.00
Unet,92.53,64.81,44.44,81.98,34.55,60.02,97.52,0.00,0.00
PAN,92.48,64.05,37.24,81.88,47.61,63.51,97.47,0.00,0.00
UPerNet,92.18,68.07,37.63,83.54,39.83,44.67,97.68,0.00,0.00
Linknet,92.06,57.25,42.14,86.36,31.33,61.26,97.65,0.00,0.00
PSPNet,91.94,62.48,37.34,82.50,19.62,61.92,97.44,0.00,0.00
DeepLabV3Plus,91.50,62.77,36.12,76.23,40.89,55.95,97.32,0.00,0.00
DPT,87.84,52.03,29.80,67.47,24.96,44.50,94.84,0.00,0.00
my_fastfcn_r50,55.34,71.20,55.69,85.87,60.47,72.04,97.47,,
my_icnet_r50,50.40,61.35,42.53,78.21,62.38,62.30,96.39,,
my_bisenetv1_r50,49.62,63.38,40.38,82.82,47.85,64.80,97.78,,
my_icnet_r18,47.54,48.47,27.63,83.13,55.34,70.40,95.37,,
my_bisenetv1_r18,45.02,60.13,35.97,81.80,30.29,54.89,97.10,,
my_bisenetv2,38.85,47.72,28.29,73.87,24.03,41.84,95.06,,
my_fast_scnn,36.20,36.31,19.69,59.68,31.48,46.85,95.56,,
my_en_bisenetv2,21.76,19.78,7.06,38.84,0.06,18.16,90.16,,
1 Algorithm mIoU 1_IoU 2_IoU 3_IoU 4_IoU 6_IoU 背景_IoU 5_IoU 7_IoU
2 MAnet 93.90 70.53 46.86 88.84 50.14 73.06 97.91 0.00 0.00
3 Segformer 93.28 68.10 50.06 85.56 36.53 75.76 97.56 0.00 0.00
4 UnetPlusPlus 92.97 64.84 47.57 81.19 59.57 62.17 97.61 0.00 0.00
5 FPN 92.67 63.45 39.01 89.55 41.62 55.78 97.70 0.00 0.00
6 DeepLabV3 92.55 68.69 38.16 83.63 42.43 56.29 97.59 0.00 0.00
7 Unet 92.53 64.81 44.44 81.98 34.55 60.02 97.52 0.00 0.00
8 PAN 92.48 64.05 37.24 81.88 47.61 63.51 97.47 0.00 0.00
9 UPerNet 92.18 68.07 37.63 83.54 39.83 44.67 97.68 0.00 0.00
10 Linknet 92.06 57.25 42.14 86.36 31.33 61.26 97.65 0.00 0.00
11 PSPNet 91.94 62.48 37.34 82.50 19.62 61.92 97.44 0.00 0.00
12 DeepLabV3Plus 91.50 62.77 36.12 76.23 40.89 55.95 97.32 0.00 0.00
13 DPT 87.84 52.03 29.80 67.47 24.96 44.50 94.84 0.00 0.00
14 my_fastfcn_r50 55.34 71.20 55.69 85.87 60.47 72.04 97.47
15 my_icnet_r50 50.40 61.35 42.53 78.21 62.38 62.30 96.39
16 my_bisenetv1_r50 49.62 63.38 40.38 82.82 47.85 64.80 97.78
17 my_icnet_r18 47.54 48.47 27.63 83.13 55.34 70.40 95.37
18 my_bisenetv1_r18 45.02 60.13 35.97 81.80 30.29 54.89 97.10
19 my_bisenetv2 38.85 47.72 28.29 73.87 24.03 41.84 95.06
20 my_fast_scnn 36.20 36.31 19.69 59.68 31.48 46.85 95.56
21 my_en_bisenetv2 21.76 19.78 7.06 38.84 0.06 18.16 90.16

File diff suppressed because it is too large Load Diff

After

Width:  |  Height:  |  Size: 84 KiB

View File

@@ -0,0 +1,21 @@
Model,mIoU,mAcc,aAcc,FPS,FLOPs(G),Params(M)
MAnet,92.710,58.340,99.310,152.290,33.850,31.790
my_fastfcn_r50,37.810,51.890,96.480,71.040,130.000,66.346
my_icnet_r50,35.930,54.750,95.850,193.830,15.430,47.526
my_icnet_r18,34.840,47.480,95.570,286.580,9.362,24.873
my_bisenetv1_r50,33.600,48.180,95.950,88.460,98.957,56.865
PAN,32.230,52.260,99.570,238.750,29.880,21.480
FPN,31.480,50.620,99.500,208.170,27.550,23.160
UPerNet,30.810,56.260,99.540,108.960,72.100,29.600
PSPNet,30.460,48.110,99.580,586.440,9.640,21.490
DeepLabV3,30.440,45.870,99.550,96.470,109.330,26.010
DeepLabV3Plus,30.390,53.560,99.520,218.940,31.680,22.440
Segformer,30.280,50.430,99.550,153.490,26.280,21.880
my_bisenetv1_r18,29.660,36.130,96.230,316.660,14.830,13.273
Unet,29.560,48.490,99.500,177.730,31.800,24.440
UnetPlusPlus,29.020,46.550,99.530,91.560,74.150,26.080
Linknet,27.440,45.720,99.520,202.960,20.300,21.770
my_bisenetv2,26.790,48.100,94.290,220.480,12.323,3.351
my_fast_scnn,24.240,38.810,94.450,318.670,0.936,1.400
DPT,12.740,27.010,99.490,30.930,212.980,137.810
my_en_bisenetv2,12.710,29.950,85.020,202.950,7.919,2.774
1 Model mIoU mAcc aAcc FPS FLOPs(G) Params(M)
2 MAnet 92.710 58.340 99.310 152.290 33.850 31.790
3 my_fastfcn_r50 37.810 51.890 96.480 71.040 130.000 66.346
4 my_icnet_r50 35.930 54.750 95.850 193.830 15.430 47.526
5 my_icnet_r18 34.840 47.480 95.570 286.580 9.362 24.873
6 my_bisenetv1_r50 33.600 48.180 95.950 88.460 98.957 56.865
7 PAN 32.230 52.260 99.570 238.750 29.880 21.480
8 FPN 31.480 50.620 99.500 208.170 27.550 23.160
9 UPerNet 30.810 56.260 99.540 108.960 72.100 29.600
10 PSPNet 30.460 48.110 99.580 586.440 9.640 21.490
11 DeepLabV3 30.440 45.870 99.550 96.470 109.330 26.010
12 DeepLabV3Plus 30.390 53.560 99.520 218.940 31.680 22.440
13 Segformer 30.280 50.430 99.550 153.490 26.280 21.880
14 my_bisenetv1_r18 29.660 36.130 96.230 316.660 14.830 13.273
15 Unet 29.560 48.490 99.500 177.730 31.800 24.440
16 UnetPlusPlus 29.020 46.550 99.530 91.560 74.150 26.080
17 Linknet 27.440 45.720 99.520 202.960 20.300 21.770
18 my_bisenetv2 26.790 48.100 94.290 220.480 12.323 3.351
19 my_fast_scnn 24.240 38.810 94.450 318.670 0.936 1.400
20 DPT 12.740 27.010 99.490 30.930 212.980 137.810
21 my_en_bisenetv2 12.710 29.950 85.020 202.950 7.919 2.774

View File

@@ -0,0 +1,21 @@
Algorithm,mIoU,10_IoU,1_IoU,2_IoU,4_IoU,5_IoU,6_IoU,7_IoU,8_IoU,9_IoU,背景_IoU,3_IoU
MAnet,92.71,35.32,26.82,27.63,69.93,17.24,1.35,68.73,39.18,18.98,96.27,0.00
my_fastfcn_r50,37.81,38.43,20.81,34.67,71.51,20.69,0.00,63.65,45.78,23.86,96.53,
my_icnet_r50,35.93,37.37,25.03,26.86,63.67,21.21,1.01,64.49,41.90,17.85,95.88,
my_icnet_r18,34.84,37.50,29.12,26.10,67.33,17.58,0.17,64.91,38.94,5.98,95.65,
my_bisenetv1_r50,33.60,33.75,23.22,33.49,57.73,7.82,0.54,61.04,39.39,16.52,96.05,
PAN,32.23,24.44,27.98,17.51,55.65,14.65,1.28,52.96,33.81,14.59,95.93,0.00
FPN,31.48,26.46,27.38,16.61,54.84,13.75,0.10,52.75,30.84,14.13,95.44,0.00
UPerNet,30.81,25.82,23.85,18.94,48.88,19.74,3.77,55.93,28.24,13.85,95.60,0.00
PSPNet,30.46,22.59,24.88,13.77,52.96,10.33,2.15,57.02,32.44,11.20,95.99,0.00
DeepLabV3,30.44,20.32,23.59,9.83,53.98,12.70,0.00,54.00,33.47,12.89,95.82,0.00
DeepLabV3Plus,30.39,21.50,25.78,16.46,55.69,14.54,1.83,49.53,31.38,9.97,95.64,0.00
Segformer,30.28,29.15,22.06,13.37,52.94,11.50,0.52,58.03,40.69,16.09,95.76,0.00
my_bisenetv1_r18,29.66,31.09,13.25,24.55,57.82,7.06,0.00,41.10,38.13,16.98,96.25,
Unet,29.56,24.81,26.57,8.55,47.35,16.95,0.00,41.73,32.39,14.65,95.35,0.00
UnetPlusPlus,29.02,15.46,25.18,9.08,49.10,17.60,0.33,56.31,31.64,12.25,95.60,0.00
Linknet,27.44,20.81,22.23,14.15,48.85,15.84,0.01,54.10,28.56,10.97,95.59,0.00
my_bisenetv2,26.79,15.92,20.58,16.00,53.72,10.36,0.42,37.73,38.00,7.52,94.43,
my_fast_scnn,24.24,0.45,20.35,10.26,57.55,0.00,0.47,33.80,24.96,0.00,94.61,
DPT,12.74,0.00,2.63,0.00,30.18,0.00,0.22,13.80,12.37,0.00,95.15,0.00
my_en_bisenetv2,12.71,0.00,12.19,0.00,35.31,0.00,0.14,1.78,4.69,0.00,85.72,
1 Algorithm mIoU 10_IoU 1_IoU 2_IoU 4_IoU 5_IoU 6_IoU 7_IoU 8_IoU 9_IoU 背景_IoU 3_IoU
2 MAnet 92.71 35.32 26.82 27.63 69.93 17.24 1.35 68.73 39.18 18.98 96.27 0.00
3 my_fastfcn_r50 37.81 38.43 20.81 34.67 71.51 20.69 0.00 63.65 45.78 23.86 96.53
4 my_icnet_r50 35.93 37.37 25.03 26.86 63.67 21.21 1.01 64.49 41.90 17.85 95.88
5 my_icnet_r18 34.84 37.50 29.12 26.10 67.33 17.58 0.17 64.91 38.94 5.98 95.65
6 my_bisenetv1_r50 33.60 33.75 23.22 33.49 57.73 7.82 0.54 61.04 39.39 16.52 96.05
7 PAN 32.23 24.44 27.98 17.51 55.65 14.65 1.28 52.96 33.81 14.59 95.93 0.00
8 FPN 31.48 26.46 27.38 16.61 54.84 13.75 0.10 52.75 30.84 14.13 95.44 0.00
9 UPerNet 30.81 25.82 23.85 18.94 48.88 19.74 3.77 55.93 28.24 13.85 95.60 0.00
10 PSPNet 30.46 22.59 24.88 13.77 52.96 10.33 2.15 57.02 32.44 11.20 95.99 0.00
11 DeepLabV3 30.44 20.32 23.59 9.83 53.98 12.70 0.00 54.00 33.47 12.89 95.82 0.00
12 DeepLabV3Plus 30.39 21.50 25.78 16.46 55.69 14.54 1.83 49.53 31.38 9.97 95.64 0.00
13 Segformer 30.28 29.15 22.06 13.37 52.94 11.50 0.52 58.03 40.69 16.09 95.76 0.00
14 my_bisenetv1_r18 29.66 31.09 13.25 24.55 57.82 7.06 0.00 41.10 38.13 16.98 96.25
15 Unet 29.56 24.81 26.57 8.55 47.35 16.95 0.00 41.73 32.39 14.65 95.35 0.00
16 UnetPlusPlus 29.02 15.46 25.18 9.08 49.10 17.60 0.33 56.31 31.64 12.25 95.60 0.00
17 Linknet 27.44 20.81 22.23 14.15 48.85 15.84 0.01 54.10 28.56 10.97 95.59 0.00
18 my_bisenetv2 26.79 15.92 20.58 16.00 53.72 10.36 0.42 37.73 38.00 7.52 94.43
19 my_fast_scnn 24.24 0.45 20.35 10.26 57.55 0.00 0.47 33.80 24.96 0.00 94.61
20 DPT 12.74 0.00 2.63 0.00 30.18 0.00 0.22 13.80 12.37 0.00 95.15 0.00
21 my_en_bisenetv2 12.71 0.00 12.19 0.00 35.31 0.00 0.14 1.78 4.69 0.00 85.72

View File

@@ -0,0 +1,8 @@
cff-version: 1.2.0
message: "If you use this software, please cite it as below."
authors:
- name: "MMSegmentation Contributors"
title: "OpenMMLab Semantic Segmentation Toolbox and Benchmark"
date-released: 2020-07-10
url: "https://github.com/open-mmlab/mmsegmentation"
license: Apache-2.0

View File

@@ -0,0 +1,203 @@
Copyright 2020 The MMSegmentation Authors. All rights reserved.
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright 2020 The MMSegmentation Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

View File

@@ -0,0 +1,5 @@
include requirements/*.txt
include mmseg/.mim/model-index.yml
include mmseg/utils/bpe_simple_vocab_16e6.txt.gz
recursive-include mmseg/.mim/configs *.py *.yaml
recursive-include mmseg/.mim/tools *.py *.sh

View File

@@ -0,0 +1,297 @@
import os, requests, hashlib
from tqdm import tqdm
### 链接获取网址https://github.com/open-mmlab/mmcv/blob/master/mmcv/model_zoo/[deprecated.json | mmcls.json | open_mmlab.json | torchvision_0.12.json] ###
# open_mmlab JSON 数据
open_mmlab_model_urls = {
"vgg16_caffe": "https://download.openmmlab.com/pretrain/third_party/vgg16_caffe-292e1171.pth",
"detectron/resnet50_caffe": "https://download.openmmlab.com/pretrain/third_party/resnet50_caffe-788b5fa3.pth",
"detectron2/resnet50_caffe": "https://download.openmmlab.com/pretrain/third_party/resnet50_msra-5891d200.pth",
"detectron/resnet101_caffe": "https://download.openmmlab.com/pretrain/third_party/resnet101_caffe-3ad79236.pth",
"detectron2/resnet101_caffe": "https://download.openmmlab.com/pretrain/third_party/resnet101_msra-6cc46731.pth",
"detectron2/resnext101_32x8d": "https://download.openmmlab.com/pretrain/third_party/resnext101_32x8d-1516f1aa.pth",
"resnext50_32x4d": "https://download.openmmlab.com/pretrain/third_party/resnext50-32x4d-0ab1a123.pth",
"resnext101_32x4d": "https://download.openmmlab.com/pretrain/third_party/resnext101_32x4d-a5af3160.pth",
"resnext101_64x4d": "https://download.openmmlab.com/pretrain/third_party/resnext101_64x4d-ee2c6f71.pth",
"contrib/resnet50_gn": "https://download.openmmlab.com/pretrain/third_party/resnet50_gn_thangvubk-ad1730dd.pth",
"detectron/resnet50_gn": "https://download.openmmlab.com/pretrain/third_party/resnet50_gn-9186a21c.pth",
"detectron/resnet101_gn": "https://download.openmmlab.com/pretrain/third_party/resnet101_gn-cac0ab98.pth",
"jhu/resnet50_gn_ws": "https://download.openmmlab.com/pretrain/third_party/resnet50_gn_ws-15beedd8.pth",
"jhu/resnet101_gn_ws": "https://download.openmmlab.com/pretrain/third_party/resnet101_gn_ws-3e3c308c.pth",
"jhu/resnext50_32x4d_gn_ws": "https://download.openmmlab.com/pretrain/third_party/resnext50_32x4d_gn_ws-0d87ac85.pth",
"jhu/resnext101_32x4d_gn_ws": "https://download.openmmlab.com/pretrain/third_party/resnext101_32x4d_gn_ws-34ac1a9e.pth",
"jhu/resnext50_32x4d_gn": "https://download.openmmlab.com/pretrain/third_party/resnext50_32x4d_gn-c7e8b754.pth",
"jhu/resnext101_32x4d_gn": "https://download.openmmlab.com/pretrain/third_party/resnext101_32x4d_gn-ac3bb84e.pth",
"msra/hrnetv2_w18_small": "https://download.openmmlab.com/pretrain/third_party/hrnetv2_w18_small-b5a04e21.pth",
"msra/hrnetv2_w18": "https://download.openmmlab.com/pretrain/third_party/hrnetv2_w18-00eb2006.pth",
"msra/hrnetv2_w32": "https://download.openmmlab.com/pretrain/third_party/hrnetv2_w32-dc9eeb4f.pth",
"msra/hrnetv2_w40": "https://download.openmmlab.com/pretrain/third_party/hrnetv2_w40-ed0b031c.pth",
"msra/hrnetv2_w48": "https://download.openmmlab.com/pretrain/third_party/hrnetv2_w48-d2186c55.pth",
"bninception_caffe": "https://download.openmmlab.com/pretrain/third_party/bn_inception_caffe-ed2e8665.pth",
"kin400/i3d_r50_f32s2_k400": "https://download.openmmlab.com/pretrain/third_party/i3d_r50_f32s2_k400-2c57e077.pth",
"kin400/nl3d_r50_f32s2_k400": "https://download.openmmlab.com/pretrain/third_party/nl3d_r50_f32s2_k400-fa7e7caa.pth",
"res2net101_v1d_26w_4s": "https://download.openmmlab.com/pretrain/third_party/res2net101_v1d_26w_4s_mmdetv2-f0a600f9.pth",
"regnetx_400mf": "https://download.openmmlab.com/pretrain/third_party/regnetx_400mf-a5b10d96.pth",
"regnetx_800mf": "https://download.openmmlab.com/pretrain/third_party/regnetx_800mf-1f4be4c7.pth",
"regnetx_1.6gf": "https://download.openmmlab.com/pretrain/third_party/regnetx_1.6gf-5791c176.pth",
"regnetx_3.2gf": "https://download.openmmlab.com/pretrain/third_party/regnetx_3.2gf-c2599b0f.pth",
"regnetx_4.0gf": "https://download.openmmlab.com/pretrain/third_party/regnetx_4.0gf-a88f671e.pth",
"regnetx_6.4gf": "https://download.openmmlab.com/pretrain/third_party/regnetx_6.4gf-006af45d.pth",
"regnetx_8.0gf": "https://download.openmmlab.com/pretrain/third_party/regnetx_8.0gf-3c68abe7.pth",
"regnetx_12gf": "https://download.openmmlab.com/pretrain/third_party/regnetx_12gf-4c2a3350.pth",
"resnet18_v1c": "https://download.openmmlab.com/pretrain/third_party/resnet18_v1c-b5776b93.pth",
"resnet50_v1c": "https://download.openmmlab.com/pretrain/third_party/resnet50_v1c-2cccc1ad.pth",
"resnet101_v1c": "https://download.openmmlab.com/pretrain/third_party/resnet101_v1c-e67eebb6.pth",
"mmedit/vgg16": "https://download.openmmlab.com/mmediting/third_party/vgg_state_dict.pth",
"mmedit/res34_en_nomixup": "https://download.openmmlab.com/mmediting/third_party/model_best_resnet34_En_nomixup.pth",
"mmedit/mobilenet_v2": "https://download.openmmlab.com/mmediting/third_party/mobilenet_v2.pth",
"contrib/mobilenet_v3_large": "https://download.openmmlab.com/pretrain/third_party/mobilenet_v3_large-bc2c3fd3.pth",
"contrib/mobilenet_v3_small": "https://download.openmmlab.com/pretrain/third_party/mobilenet_v3_small-47085aa1.pth",
"resnest50": "https://download.openmmlab.com/pretrain/third_party/resnest50_d2-7497a55b.pth",
"resnest101": "https://download.openmmlab.com/pretrain/third_party/resnest101_d2-f3b931b2.pth",
"resnest200": "https://download.openmmlab.com/pretrain/third_party/resnest200_d2-ca88e41f.pth",
"darknet53": "https://download.openmmlab.com/pretrain/third_party/darknet53-a628ea1b.pth",
"mmdet/mobilenet_v2": "https://download.openmmlab.com/mmdetection/v2.0/third_party/mobilenet_v2_batch256_imagenet-ff34753d.pth",
"pidnet-s": "https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/pidnet/pidnet-s_imagenet1k_20230306-715e6273.pth",
"pidnet-m": "https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/pidnet/pidnet-m_imagenet1k_20230306-39893c52.pth",
"pidnet-l": "https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/pidnet/pidnet-l_imagenet1k_20230306-67889109.pth",
"ddrnet23-s": "https://download.openmmlab.com/mmsegmentation/v0.5/ddrnet/pretrain/ddrnet23s-in1kpre_3rdparty-1ccac5b1.pth",
"ddrnet23": "https://download.openmmlab.com/mmsegmentation/v0.5/ddrnet/pretrain/ddrnet23-in1kpre_3rdparty-9ca29f62.pth",
"stdc1": "https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/stdc/stdc1_20220308-5368626c.pth",
"stdc2": "https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/stdc/stdc2_20220308-7dbd9127.pth"
}
# deprecated_model_urls = {{
# "resnet50_caffe": "detectron/resnet50_caffe",
# "resnet50_caffe_bgr": "detectron2/resnet50_caffe_bgr",
# "resnet101_caffe": "detectron/resnet101_caffe",
# "resnet101_caffe_bgr": "detectron2/resnet101_caffe_bgr"
# }}
mmcls_model_urls = {
"vgg11": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg11_batch256_imagenet_20210208-4271cd6c.pth",
"vgg13": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg13_batch256_imagenet_20210208-4d1d6080.pth",
"vgg16": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg16_batch256_imagenet_20210208-db26f1a5.pth",
"vgg19": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg19_batch256_imagenet_20210208-e6920e4a.pth",
"vgg11_bn": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg11_bn_batch256_imagenet_20210207-f244902c.pth",
"vgg13_bn": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg13_bn_batch256_imagenet_20210207-1a8b7864.pth",
"vgg16_bn": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg16_bn_batch256_imagenet_20210208-7e55cd29.pth",
"vgg19_bn": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg19_bn_batch256_imagenet_20210208-da620c4f.pth",
"resnet18": "https://download.openmmlab.com/mmclassification/v0/resnet/resnet18_8xb32_in1k_20210831-fbbb1da6.pth",
"resnet34": "https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth",
"resnet50": "https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth",
"resnet101": "https://download.openmmlab.com/mmclassification/v0/resnet/resnet101_8xb32_in1k_20210831-539c63f8.pth",
"resnet152": "https://download.openmmlab.com/mmclassification/v0/resnet/resnet152_8xb32_in1k_20210901-4d7582fa.pth",
"resnet50_v1d": "https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1d50_b32x8_imagenet_20210531-db14775a.pth",
"resnet101_v1d": "https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1d101_b32x8_imagenet_20210531-6e13bcd3.pth",
"resnet152_v1d": "https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1d152_b32x8_imagenet_20210531-278cf22a.pth",
"resnext50_32x4d": "https://download.openmmlab.com/mmclassification/v0/resnext/resnext50_32x4d_b32x8_imagenet_20210429-56066e27.pth",
"resnext101_32x4d": "https://download.openmmlab.com/mmclassification/v0/resnext/resnext101_32x4d_b32x8_imagenet_20210506-e0fa3dd5.pth",
"resnext101_32x8d": "https://download.openmmlab.com/mmclassification/v0/resnext/resnext101_32x8d_b32x8_imagenet_20210506-23a247d5.pth",
"resnext152_32x4d": "https://download.openmmlab.com/mmclassification/v0/resnext/resnext152_32x4d_b32x8_imagenet_20210524-927787be.pth",
"se-resnet50": "https://download.openmmlab.com/mmclassification/v0/se-resnet/se-resnet50_batch256_imagenet_20200804-ae206104.pth",
"se-resnet101": "https://download.openmmlab.com/mmclassification/v0/se-resnet/se-resnet101_batch256_imagenet_20200804-ba5b51d4.pth",
"resnest50": "https://download.openmmlab.com/mmclassification/v0/resnest/resnest50_imagenet_converted-1ebf0afe.pth",
"resnest101": "https://download.openmmlab.com/mmclassification/v0/resnest/resnest101_imagenet_converted-032caa52.pth",
"resnest200": "https://download.openmmlab.com/mmclassification/v0/resnest/resnest200_imagenet_converted-581a60f2.pth",
"resnest269": "https://download.openmmlab.com/mmclassification/v0/resnest/resnest269_imagenet_converted-59930960.pth",
"shufflenet_v1": "https://download.openmmlab.com/mmclassification/v0/shufflenet_v1/shufflenet_v1_batch1024_imagenet_20200804-5d6cec73.pth",
"shufflenet_v2": "https://download.openmmlab.com/mmclassification/v0/shufflenet_v2/shufflenet_v2_batch1024_imagenet_20200812-5bf4721e.pth",
"mobilenet_v2": "https://download.openmmlab.com/mmclassification/v0/mobilenet_v2/mobilenet_v2_batch256_imagenet_20200708-3b2dc3af.pth",
"mobilenet_v3_small": "https://download.openmmlab.com/mmclassification/v0/mobilenet_v3/convert/mobilenet_v3_small-8427ecf0.pth",
"mobilenet_v3_large": "https://download.openmmlab.com/mmclassification/v0/mobilenet_v3/convert/mobilenet_v3_large-3ea3c186.pth",
"repvgg_A0": "https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-A0_3rdparty_4xb64-coslr-120e_in1k_20210909-883ab98c.pth",
"repvgg_A1": "https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-A1_3rdparty_4xb64-coslr-120e_in1k_20210909-24003a24.pth",
"repvgg_A2": "https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-A2_3rdparty_4xb64-coslr-120e_in1k_20210909-97d7695a.pth",
"repvgg_B0": "https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B0_3rdparty_4xb64-coslr-120e_in1k_20210909-446375f4.pth",
"repvgg_B1": "https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B1_3rdparty_4xb64-coslr-120e_in1k_20210909-750cdf67.pth",
"repvgg_B1g2": "https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B1g2_3rdparty_4xb64-coslr-120e_in1k_20210909-344f6422.pth",
"repvgg_B1g4": "https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B1g4_3rdparty_4xb64-coslr-120e_in1k_20210909-d4c1a642.pth",
"repvgg_B2": "https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B2_3rdparty_4xb64-coslr-120e_in1k_20210909-bd6b937c.pth",
"repvgg_B2g4": "https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B2g4_3rdparty_4xb64-autoaug-lbs-mixup-coslr-200e_in1k_20210909-7b7955f0.pth",
"repvgg_B3": "https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B3_3rdparty_4xb64-autoaug-lbs-mixup-coslr-200e_in1k_20210909-dda968bf.pth",
"repvgg_B3g4": "https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B3g4_3rdparty_4xb64-autoaug-lbs-mixup-coslr-200e_in1k_20210909-4e54846a.pth",
"repvgg_D2se": "https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-D2se_3rdparty_4xb64-autoaug-lbs-mixup-coslr-200e_in1k_20210909-cf3139b7.pth",
"res2net101_w26": "https://download.openmmlab.com/mmclassification/v0/res2net/res2net101-w26-s4_3rdparty_8xb32_in1k_20210927-870b6c36.pth",
"res2net50_w14": "https://download.openmmlab.com/mmclassification/v0/res2net/res2net50-w14-s8_3rdparty_8xb32_in1k_20210927-bc967bf1.pth",
"res2net50_w26": "https://download.openmmlab.com/mmclassification/v0/res2net/res2net50-w26-s8_3rdparty_8xb32_in1k_20210927-f547a94b.pth",
"swin_tiny": "https://download.openmmlab.com/mmclassification/v0/swin-transformer/swin_tiny_224_b16x64_300e_imagenet_20210616_090925-66df6be6.pth",
"swin_small": "https://download.openmmlab.com/mmclassification/v0/swin-transformer/swin_small_224_b16x64_300e_imagenet_20210615_110219-7f9d988b.pth",
"swin_base": "https://download.openmmlab.com/mmclassification/v0/swin-transformer/convert/swin_base_patch4_window7_224_22kto1k-f967f799.pth",
"swin_large": "https://download.openmmlab.com/mmclassification/v0/swin-transformer/convert/swin_large_patch4_window7_224_22kto1k-5f0996db.pth",
"t2t_vit_t_14": "https://download.openmmlab.com/mmclassification/v0/t2t-vit/t2t-vit-t-14_3rdparty_8xb64_in1k_20210928-b7c09b62.pth",
"t2t_vit_t_19": "https://download.openmmlab.com/mmclassification/v0/t2t-vit/t2t-vit-t-19_3rdparty_8xb64_in1k_20210928-7f1478d5.pth",
"t2t_vit_t_24": "https://download.openmmlab.com/mmclassification/v0/t2t-vit/t2t-vit-t-24_3rdparty_8xb64_in1k_20210928-fe95a61b.pth",
"tnt_small": "https://download.openmmlab.com/mmclassification/v0/tnt/tnt-small-p16_3rdparty_in1k_20210903-c56ee7df.pth",
"vit_base_p16": "https://download.openmmlab.com/mmclassification/v0/vit/finetune/vit-base-p16_in21k-pre-3rdparty_ft-64xb64_in1k-384_20210928-98e8652b.pth",
"vit_base_p32": "https://download.openmmlab.com/mmclassification/v0/vit/finetune/vit-base-p32_in21k-pre-3rdparty_ft-64xb64_in1k-384_20210928-9cea8599.pth",
"vit_large_p16": "https://download.openmmlab.com/mmclassification/v0/vit/finetune/vit-large-p16_in21k-pre-3rdparty_ft-64xb64_in1k-384_20210928-b20ba619.pth"
}
torchvision_012_model_urls = {
"alexnet": "https://download.pytorch.org/models/alexnet-owt-7be5be79.pth",
"densenet121": "https://download.pytorch.org/models/densenet121-a639ec97.pth",
"densenet169": "https://download.pytorch.org/models/densenet169-b2777c0a.pth",
"densenet201": "https://download.pytorch.org/models/densenet201-c1103571.pth",
"densenet161": "https://download.pytorch.org/models/densenet161-8d451a50.pth",
"efficientnet_b0": "https://download.pytorch.org/models/efficientnet_b0_rwightman-3dd342df.pth",
"efficientnet_b1": "https://download.pytorch.org/models/efficientnet_b1_rwightman-533bc792.pth",
"efficientnet_b2": "https://download.pytorch.org/models/efficientnet_b2_rwightman-bcdf34b7.pth",
"efficientnet_b3": "https://download.pytorch.org/models/efficientnet_b3_rwightman-cf984f9c.pth",
"efficientnet_b4": "https://download.pytorch.org/models/efficientnet_b4_rwightman-7eb33cd5.pth",
"efficientnet_b5": "https://download.pytorch.org/models/efficientnet_b5_lukemelas-b6417697.pth",
"efficientnet_b6": "https://download.pytorch.org/models/efficientnet_b6_lukemelas-c76e70fd.pth",
"efficientnet_b7": "https://download.pytorch.org/models/efficientnet_b7_lukemelas-dcc49843.pth",
"googlenet": "https://download.pytorch.org/models/googlenet-1378be20.pth",
"inception_v3_google": "https://download.pytorch.org/models/inception_v3_google-0cc3c7bd.pth",
"mobilenet_v2": "https://download.pytorch.org/models/mobilenet_v2-b0353104.pth",
"mobilenet_v3_large": "https://download.pytorch.org/models/mobilenet_v3_large-8738ca79.pth",
"mobilenet_v3_small": "https://download.pytorch.org/models/mobilenet_v3_small-047dcff4.pth",
"regnet_y_400mf": "https://download.pytorch.org/models/regnet_y_400mf-c65dace8.pth",
"regnet_y_800mf": "https://download.pytorch.org/models/regnet_y_800mf-1b27b58c.pth",
"regnet_y_1_6gf": "https://download.pytorch.org/models/regnet_y_1_6gf-b11a554e.pth",
"regnet_y_3_2gf": "https://download.pytorch.org/models/regnet_y_3_2gf-b5a9779c.pth",
"regnet_y_8gf": "https://download.pytorch.org/models/regnet_y_8gf-d0d0e4a8.pth",
"regnet_y_16gf": "https://download.pytorch.org/models/regnet_y_16gf-9e6ed7dd.pth",
"regnet_y_32gf": "https://download.pytorch.org/models/regnet_y_32gf-4dee3f7a.pth",
"regnet_x_400mf": "https://download.pytorch.org/models/regnet_x_400mf-adf1edd5.pth",
"regnet_x_800mf": "https://download.pytorch.org/models/regnet_x_800mf-ad17e45c.pth",
"regnet_x_1_6gf": "https://download.pytorch.org/models/regnet_x_1_6gf-e3633e7f.pth",
"regnet_x_3_2gf": "https://download.pytorch.org/models/regnet_x_3_2gf-f342aeae.pth",
"regnet_x_8gf": "https://download.pytorch.org/models/regnet_x_8gf-03ceed89.pth",
"regnet_x_16gf": "https://download.pytorch.org/models/regnet_x_16gf-2007eb11.pth",
"regnet_x_32gf": "https://download.pytorch.org/models/regnet_x_32gf-9d47f8d0.pth",
"resnet18": "https://download.pytorch.org/models/resnet18-f37072fd.pth",
"resnet34": "https://download.pytorch.org/models/resnet34-b627a593.pth",
"resnet50": "https://download.pytorch.org/models/resnet50-0676ba61.pth",
"resnet101": "https://download.pytorch.org/models/resnet101-63fe2227.pth",
"resnet152": "https://download.pytorch.org/models/resnet152-394f9c45.pth",
"resnext50_32x4d": "https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth",
"resnext101_32x8d": "https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth",
"wide_resnet50_2": "https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth",
"wide_resnet101_2": "https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth",
"shufflenetv2_x0.5": "https://download.pytorch.org/models/shufflenetv2_x0.5-f707e7126e.pth",
"shufflenetv2_x1.0": "https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth",
"shufflenetv2_x1.5": None,
"shufflenetv2_x2.0": None,
"squeezenet1_0": "https://download.pytorch.org/models/squeezenet1_0-b66bff10.pth",
"squeezenet1_1": "https://download.pytorch.org/models/squeezenet1_1-b8a52dc0.pth",
"vgg11": "https://download.pytorch.org/models/vgg11-8a719046.pth",
"vgg13": "https://download.pytorch.org/models/vgg13-19584684.pth",
"vgg16": "https://download.pytorch.org/models/vgg16-397923af.pth",
"vgg19": "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth",
"vgg11_bn": "https://download.pytorch.org/models/vgg11_bn-6002323d.pth",
"vgg13_bn": "https://download.pytorch.org/models/vgg13_bn-abd245e5.pth",
"vgg16_bn": "https://download.pytorch.org/models/vgg16_bn-6c64b313.pth",
"vgg19_bn": "https://download.pytorch.org/models/vgg19_bn-c79401a0.pth"
}
def calculate_file_hash(file_path, hash_algorithm='md5'):
"""计算文件的哈希值,默认使用 MD5"""
hash_func = hashlib.new(hash_algorithm)
with open(file_path, 'rb') as f:
while chunk := f.read(8192):
hash_func.update(chunk)
return hash_func.hexdigest()
def download_file(url, output_path):
"""下载并保存文件,显示下载进度条"""
response = requests.get(url, stream=True)
if response.status_code == 200:
# 获取文件的总大小,以便确定进度条的总长度
total_size = int(response.headers.get('Content-Length', 0))
# 初始化 tqdm 进度条
with tqdm(total=total_size, unit='B', unit_scale=True, desc=output_path, ncols=100) as pbar:
with open(output_path, 'wb') as f:
for chunk in response.iter_content(chunk_size=1024):
if chunk:
f.write(chunk)
pbar.update(len(chunk)) # 更新进度条
print(f"Downloaded {output_path}")
else:
print(f"Failed to download {url}")
def file_exists_and_same(url, output_path):
"""检查文件是否已存在并且相同"""
if not os.path.exists(output_path):
return False
# 计算远程文件的大小
response = requests.head(url)
remote_file_size = int(response.headers.get('Content-Length', 0))
# 获取本地文件的大小
local_file_size = os.path.getsize(output_path)
# 如果文件大小不同,返回 False
if local_file_size != remote_file_size:
return False
else:
return True
# # 如果大小相同,比较文件哈希值
# remote_hash = requests.get(url + ".md5").text.strip() if url.endswith(".pth") else None
# local_hash = calculate_file_hash(output_path) if remote_hash else None
# print(remote_hash, local_hash)
# # 返回 True 如果哈希值相同
# return local_hash == remote_hash if remote_hash else False
def download_all_models(model_urls, output_dir):
"""遍历JSON并下载文件"""
for model_name, url in model_urls.items():
# 如果url为空则继续
if url == None:
print(" ", end='')
print(f"\033[91m{model_name}后URL为空跳过下载\033[0m")
continue
# 创建保存路径
file_name = f"{model_name}.pth" # 将文件名按 key 进行命名
output_path = os.path.join(output_dir, file_name)
# 获取 output_path 中的文件夹路径,并确保该路径存在
output_folder = os.path.dirname(output_path)
os.makedirs(output_folder, exist_ok=True) # 如果文件夹不存在则创建
# 检查文件是否已存在并且相同
if file_exists_and_same(url, output_path):
print(" ", end='')
print(f"\033[93mFile {output_path} already exists and the size is same, skipping download.\033[0m")
else:
print(" ", end='')
print(f"正在下载 {model_name} : {output_path} 中...")
# 下载文件
download_file(url, output_path)
if __name__ == '__main__':
### 1.下载openmmlab数据集 ###
# 创建存储下载内容的文件夹
output_open_mmlab_dir = './My_Local_Model/open_mmlab'
os.makedirs(output_open_mmlab_dir, exist_ok=True)
# 执行下载
print(f"\033[32m下载open_mmlab数据中...\033[0m")
download_all_models(open_mmlab_model_urls, output_open_mmlab_dir)
# ### 2.下载deprecated数据集 ###
# # 创建存储下载内容的文件夹
# output_deprecated_dir = './My_Local_Model/deprecated'
# os.makedirs(output_deprecated_dir, exist_ok=True)
# # 执行下载
# download_all_models(deprecated_model_urls, output_deprecated_dir)
### 3.下载mmcls数据集 ###
# 创建存储下载内容的文件夹
output_mmcls_dir = './My_Local_Model/mmcls'
os.makedirs(output_mmcls_dir, exist_ok=True)
# 执行下载
download_all_models(mmcls_model_urls, output_mmcls_dir)
### 4.下载torchvision_012数据集 ###
# 创建存储下载内容的文件夹
output_torchvision_012_dir = './My_Local_Model/torchvision_012'
os.makedirs(output_torchvision_012_dir, exist_ok=True)
# 执行下载
download_all_models(torchvision_012_model_urls, output_torchvision_012_dir)

View File

@@ -0,0 +1,628 @@
{
"publicdataset_cholecseg8k": {
"train_imgs_num": 6464,
"classes": [
"背景",
"1",
"2",
"3",
"4",
"5",
"6",
"7",
"8",
"9",
"10",
"11",
"12"
],
"palette": [
[
0,
0,
0
],
[
255,
91,
0
],
[
255,
234,
0
],
[
85,
111,
181
],
[
181,
227,
14
],
[
72,
0,
255
],
[
0,
155,
33
],
[
255,
0,
255
],
[
29,
32,
136
],
[
160,
15,
95
],
[
0,
160,
233
],
[
52,
184,
178
],
[
90,
120,
41
]
],
"palette_num": 13,
"mean": [
85.65740418979115,
53.99282220050495,
46.074045888534535
],
"std": [
72.24589167201978,
56.76979155397199,
49.056637115061775
],
"imgs_num": 6464
},
"my_dataset_model": {
"train_imgs_num": 631,
"classes": [
"背景",
"肝脏",
"胆囊",
"分离钳",
"止血海绵",
"肝总管",
"胆总管",
"吸引器",
"剪刀",
"止血纱布",
"生物夹",
"无损伤钳",
"喷洒",
"胆囊管",
"胆囊动脉",
"电凝",
"标本袋",
"引流管",
"纱布",
"金属钛夹",
"术中超声",
"吻合器",
"乳胶管",
"推结器",
"肝带",
"钳夹",
"超声刀",
"脂肪",
"双极电凝",
"棉球",
"血管阻断夹",
"肿瘤",
"针",
"线",
"韧带",
"胆囊静脉"
],
"palette": [
[
0,
0,
0
],
[
255,
91,
0
],
[
255,
234,
0
],
[
85,
111,
181
],
[
181,
227,
14
],
[
72,
0,
255
],
[
0,
155,
33
],
[
255,
0,
255
],
[
29,
32,
136
],
[
160,
15,
95
],
[
0,
160,
233
],
[
52,
184,
178
],
[
90,
120,
41
],
[
255,
0,
0
],
[
177,
0,
0
],
[
167,
24,
233
],
[
112,
113,
150
],
[
0,
255,
0
],
[
255,
255,
255
],
[
0,
255,
255
],
[
138,
251,
213
],
[
136,
162,
196
],
[
197,
83,
181
],
[
202,
202,
200
],
[
113,
102,
140
],
[
66,
115,
82
],
[
240,
16,
116
],
[
155,
132,
0
],
[
155,
62,
0
],
[
146,
175,
236
],
[
255,
172,
159
],
[
245,
161,
0
],
[
134,
124,
118
],
[
0,
157,
142
],
[
181,
85,
105
],
[
42,
8,
66
]
],
"palette_num": 36,
"mean": [
94.94709810464319,
61.729422339499315,
75.93763705236911
],
"std": [
44.00550608113231,
42.695956669847746,
44.99354156225513
],
"imgs_num": 2000
},
"publicdataset_autolaparo": {
"train_imgs_num": 1440,
"classes": [
"背景",
"1",
"2",
"3",
"4",
"5",
"6",
"7",
"8",
"9"
],
"palette": [
[
0,
0,
0
],
[
255,
91,
0
],
[
255,
234,
0
],
[
85,
111,
181
],
[
181,
227,
14
],
[
72,
0,
255
],
[
0,
155,
33
],
[
255,
0,
255
],
[
29,
32,
136
],
[
160,
15,
95
]
],
"palette_num": 10,
"mean": [
123.62464353460942,
85.34836259209033,
82.31539425671558
],
"std": [
47.172211618459315,
47.08256715323592,
48.135121265163605
]
},
"publicdataset_endovis_2017": {
"train_imgs_num": 1800,
"classes": [
"背景",
"1",
"2",
"3",
"4",
"5",
"6",
"7"
],
"palette": [
[
0,
0,
0
],
[
255,
91,
0
],
[
255,
234,
0
],
[
85,
111,
181
],
[
181,
227,
14
],
[
72,
0,
255
],
[
0,
155,
33
],
[
255,
0,
255
]
],
"palette_num": 8,
"mean": [
122.21429912990676,
77.0821859677977,
87.03836664626716
],
"std": [
50.53335800365262,
42.895340354037465,
47.739426483390446
]
},
"publicdataset_dresden": {
"train_imgs_num": 17363,
"classes": [
"背景",
"1",
"2",
"3",
"4",
"5",
"6",
"7",
"8",
"9",
"10"
],
"palette": [
[
0,
0,
0
],
[
255,
91,
0
],
[
255,
234,
0
],
[
85,
111,
181
],
[
181,
227,
14
],
[
72,
0,
255
],
[
0,
155,
33
],
[
255,
0,
255
],
[
29,
32,
136
],
[
160,
15,
95
],
[
0,
160,
233
]
],
"palette_num": 11,
"mean": [
103.172638338208,
61.44762740851152,
51.407770213021976
],
"std": [
75.77031253622098,
54.63616729031377,
49.45572239497569
]
},
"publicdataset_endovis_2018": {
"train_imgs_num": 1800,
"classes": [
"背景",
"1",
"2",
"3",
"4",
"5",
"6",
"7"
],
"palette": [
[
0,
0,
0
],
[
255,
91,
0
],
[
255,
234,
0
],
[
85,
111,
181
],
[
181,
227,
14
],
[
72,
0,
255
],
[
0,
155,
33
],
[
255,
0,
255
]
],
"palette_num": 8,
"mean": [
122.21429912990676,
77.0821859677977,
87.03836664626716
],
"std": [
50.53335800365262,
42.895340354037465,
47.739426483390446
]
}
}

View File

@@ -0,0 +1,50 @@
{
"____一、comment_dataset_info": "定义多个数据集文件名和对应类名",
"dataset_info": {
"dataset_file_name": "my_dataset_model",
"dataset_class_name": "MyDataset_model",
"data_root": "/home/wkmgc/Desktop/Seg/Seg_All_In_One_MMSeg/My_Data",
"img_scale_width": 1920,
"img_scale_height": 1080,
"____#####comment_paths#####": "训练、验证、测试集所在文件夹",
"paths": {
"train_img_path": "A_Ori",
"train_seg_map_path": "A_Label_GT_label_fold",
"val_img_path": "A_Ori",
"val_seg_map_path": "A_Label_GT_label_fold",
"test_img_path": "A_Ori",
"test_seg_map_path": "A_Label_GT_label_fold"
}
},
"____二、comment_label_info": "定义Label图片相关参数",
"label_info": {
"classes": [
"背景", "肝脏", "胆囊", "分离钳", "止血海绵", "肝总管", "胆总管", "吸引器", "剪刀", "止血纱布", "生物夹", "无损伤钳", "喷洒",
"胆囊管", "胆囊动脉", "电凝", "标本袋", "引流管", "纱布", "金属钛夹", "术中超声", "吻合器", "乳胶管", "推结器",
"肝带", "钳夹", "超声刀", "脂肪", "双极电凝", "棉球", "血管阻断夹", "肿瘤", "针", "线", "韧带", "胆囊静脉"
],
"palette": [
[0, 0, 0], [255, 91, 0], [255, 234, 0], [85, 111, 181], [181, 227, 14], [72, 0, 255], [0, 155, 33], [255, 0, 255],
[29, 32, 136], [160, 15, 95], [0, 160, 233], [52, 184, 178], [90, 120, 41], [255, 0, 0], [177, 0, 0],
[167, 24, 233], [112, 113, 150], [0, 255, 0], [255, 255, 255], [0, 255, 255], [138, 251, 213], [136, 162, 196],
[197, 83, 181], [202, 202, 200], [113, 102, 140], [66, 115, 82], [240, 16, 116], [155, 132, 0], [155, 62, 0],
[146, 175, 236], [255, 172, 159], [245, 161, 0], [134, 124, 118], [0, 157, 142], [181, 85, 105], [42, 8, 66]
],
"____#####comment#####": "一般不太会变的参数",
"img_suffix": ".png",
"seg_map_suffix": "_gtFine_labelTrainIds.png",
"____#####comment_reduce_zero_label_1#####": "在第 0 类为无意义黑边时使用reduce_zero_label = True将其和待分类内容分开",
"____#####comment_reduce_zero_label_2#####": "在第 0 类为 background 类别的数据集上如果您最终是需要将背景和您的其余类别分开时是不需要使用reduce_zero_label的 【reduce_zero_label = False】",
"reduce_zero_label": false
},
"____三、comment_training_info": "定义训练相关参数",
"training_info": {
"crop_size_width": 256,
"crop_size_height": 256,
"train_batch_size": 16,
"train_num_workers": 4,
"val_and_test_batch_size": 1,
"val_and_test_num_workers": 4
}
}

View File

@@ -0,0 +1,42 @@
{
"____一、comment_dataset_info": "定义多个数据集文件名和对应类名",
"dataset_info": {
"dataset_file_name": "publicdataset_autolaparo",
"dataset_class_name": "PublicDataSet_AutoLaparo",
"data_root": "/home/wkmgc/Desktop/Seg/DataSet_Public/2_AutoLaparo-10Type-1920x1080",
"img_scale_width": 1920,
"img_scale_height": 1080,
"____#####comment_paths#####": "训练、验证、测试集所在文件夹",
"paths": {
"train_img_path": "images/train",
"train_seg_map_path": "labels_GT/train",
"val_img_path": "images/val",
"val_seg_map_path": "labels_GT/val",
"test_img_path": "images/val",
"test_seg_map_path": "labels_GT/val"
}
},
"____二、comment_label_info": "定义Label图片相关参数",
"label_info": {
"classes": ["背景", "1", "2", "3", "4", "5", "6", "7", "8", "9"],
"palette": [
[0, 0, 0], [255, 91, 0], [255, 234, 0], [85, 111, 181], [181, 227, 14], [72, 0, 255], [0, 155, 33], [255, 0, 255], [29, 32, 136], [160, 15, 95]
],
"____#####comment#####": "一般不太会变的参数",
"img_suffix": ".png",
"seg_map_suffix": ".png",
"____#####comment_reduce_zero_label_1#####": "在第 0 类为无意义黑边时使用reduce_zero_label = True将其和待分类内容分开",
"____#####comment_reduce_zero_label_2#####": "在第 0 类为 background 类别的数据集上如果您最终是需要将背景和您的其余类别分开时是不需要使用reduce_zero_label的 【reduce_zero_label = False】",
"reduce_zero_label": false
},
"____三、comment_training_info": "定义训练相关参数",
"training_info": {
"crop_size_width": 256,
"crop_size_height": 256,
"train_batch_size": 16,
"train_num_workers": 4,
"val_and_test_batch_size": 1,
"val_and_test_num_workers": 4
}
}

View File

@@ -0,0 +1,42 @@
{
"____一、comment_dataset_info": "定义多个数据集文件名和对应类名",
"dataset_info": {
"dataset_file_name": "publicdataset_cholecseg8k",
"dataset_class_name": "PublicDataSet_CholecSeg8k",
"data_root": "/home/wkmgc/Desktop/Seg/DataSet_Public/1_CholecSeg8k-13Type-1920x1080",
"img_scale_width": 1920,
"img_scale_height": 1080,
"____#####comment_paths#####": "训练、验证、测试集所在文件夹",
"paths": {
"train_img_path": "images/train",
"train_seg_map_path": "labels_GT/train",
"val_img_path": "images/val",
"val_seg_map_path": "labels_GT/val",
"test_img_path": "images/val",
"test_seg_map_path": "labels_GT/val"
}
},
"____二、comment_label_info": "定义Label图片相关参数",
"label_info": {
"classes": ["背景", "1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11", "12"],
"palette": [
[0, 0, 0], [255, 91, 0], [255, 234, 0], [85, 111, 181], [181, 227, 14], [72, 0, 255], [0, 155, 33], [255, 0, 255], [29, 32, 136], [160, 15, 95], [0, 160, 233], [52, 184, 178], [90, 120, 41]
],
"____#####comment#####": "一般不太会变的参数",
"img_suffix": ".png",
"seg_map_suffix": ".png",
"____#####comment_reduce_zero_label_1#####": "在第 0 类为无意义黑边时使用reduce_zero_label = True将其和待分类内容分开",
"____#####comment_reduce_zero_label_2#####": "在第 0 类为 background 类别的数据集上如果您最终是需要将背景和您的其余类别分开时是不需要使用reduce_zero_label的 【reduce_zero_label = False】",
"reduce_zero_label": false
},
"____三、comment_training_info": "定义训练相关参数",
"training_info": {
"crop_size_width": 256,
"crop_size_height": 256,
"train_batch_size": 16,
"train_num_workers": 4,
"val_and_test_batch_size": 1,
"val_and_test_num_workers": 4
}
}

View File

@@ -0,0 +1,42 @@
{
"____一、comment_dataset_info": "定义多个数据集文件名和对应类名",
"dataset_info": {
"dataset_file_name": "publicdataset_dresden",
"dataset_class_name": "PublicDataSet_Dresden",
"data_root": "/home/wkmgc/Desktop/Seg/DataSet_Public/4_Dresden-11Type-512x512",
"img_scale_width": 512,
"img_scale_height": 512,
"____#####comment_paths#####": "训练、验证、测试集所在文件夹",
"paths": {
"train_img_path": "images/train",
"train_seg_map_path": "labels_GT/train",
"val_img_path": "images/val",
"val_seg_map_path": "labels_GT/val",
"test_img_path": "images/test",
"test_seg_map_path": "labels_GT/test"
}
},
"____二、comment_label_info": "定义Label图片相关参数",
"label_info": {
"classes": ["背景", "1", "2", "3", "4", "5", "6", "7", "8", "9", "10"],
"palette": [
[0, 0, 0], [255, 91, 0], [255, 234, 0], [85, 111, 181], [181, 227, 14], [72, 0, 255], [0, 155, 33], [255, 0, 255], [29, 32, 136], [160, 15, 95], [0, 160, 233]
],
"____#####comment#####": "一般不太会变的参数",
"img_suffix": ".png",
"seg_map_suffix": ".png",
"____#####comment_reduce_zero_label_1#####": "在第 0 类为无意义黑边时使用reduce_zero_label = True将其和待分类内容分开",
"____#####comment_reduce_zero_label_2#####": "在第 0 类为 background 类别的数据集上如果您最终是需要将背景和您的其余类别分开时是不需要使用reduce_zero_label的 【reduce_zero_label = False】",
"reduce_zero_label": false
},
"____三、comment_training_info": "定义训练相关参数",
"training_info": {
"crop_size_width": 256,
"crop_size_height": 256,
"train_batch_size": 16,
"train_num_workers": 4,
"val_and_test_batch_size": 1,
"val_and_test_num_workers": 4
}
}

View File

@@ -0,0 +1,42 @@
{
"____一、comment_dataset_info": "定义多个数据集文件名和对应类名",
"dataset_info": {
"dataset_file_name": "publicdataset_endovis_2017",
"dataset_class_name": "PublicDataSet_Endovis_2017",
"data_root": "/home/wkmgc/Desktop/Seg/DataSet_Public/3_1_Endovis_2017-8Type-512x512",
"img_scale_width": 512,
"img_scale_height": 512,
"____#####comment_paths#####": "训练、验证、测试集所在文件夹",
"paths": {
"train_img_path": "images/train",
"train_seg_map_path": "labels_GT/train",
"val_img_path": "images/val",
"val_seg_map_path": "labels_GT/val",
"test_img_path": "images/val",
"test_seg_map_path": "labels_GT/val"
}
},
"____二、comment_label_info": "定义Label图片相关参数",
"label_info": {
"classes": ["背景", "1", "2", "3", "4", "5", "6", "7"],
"palette": [
[0, 0, 0], [255, 91, 0], [255, 234, 0], [85, 111, 181], [181, 227, 14], [72, 0, 255], [0, 155, 33], [255, 0, 255]
],
"____#####comment#####": "一般不太会变的参数",
"img_suffix": ".bmp",
"seg_map_suffix": ".bmp",
"____#####comment_reduce_zero_label_1#####": "在第 0 类为无意义黑边时使用reduce_zero_label = True将其和待分类内容分开",
"____#####comment_reduce_zero_label_2#####": "在第 0 类为 background 类别的数据集上如果您最终是需要将背景和您的其余类别分开时是不需要使用reduce_zero_label的 【reduce_zero_label = False】",
"reduce_zero_label": false
},
"____三、comment_training_info": "定义训练相关参数",
"training_info": {
"crop_size_width": 256,
"crop_size_height": 256,
"train_batch_size": 16,
"train_num_workers": 4,
"val_and_test_batch_size": 1,
"val_and_test_num_workers": 4
}
}

View File

@@ -0,0 +1,42 @@
{
"____一、comment_dataset_info": "定义多个数据集文件名和对应类名",
"dataset_info": {
"dataset_file_name": "publicdataset_endovis_2018",
"dataset_class_name": "PublicDataSet_Endovis_2018",
"data_root": "/home/wkmgc/Desktop/Seg/DataSet_Public/3_2_Endovis_2018-8Type-512x512",
"img_scale_width": 512,
"img_scale_height": 512,
"____#####comment_paths#####": "训练、验证、测试集所在文件夹",
"paths": {
"train_img_path": "images/train",
"train_seg_map_path": "labels_GT/train",
"val_img_path": "images/val",
"val_seg_map_path": "labels_GT/val",
"test_img_path": "images/val",
"test_seg_map_path": "labels_GT/val"
}
},
"____二、comment_label_info": "定义Label图片相关参数",
"label_info": {
"classes": ["背景", "1", "2", "3", "4", "5", "6", "7"],
"palette": [
[0, 0, 0], [255, 91, 0], [255, 234, 0], [85, 111, 181], [181, 227, 14], [72, 0, 255], [0, 155, 33], [255, 0, 255]
],
"____#####comment#####": "一般不太会变的参数",
"img_suffix": ".bmp",
"seg_map_suffix": ".bmp",
"____#####comment_reduce_zero_label_1#####": "在第 0 类为无意义黑边时使用reduce_zero_label = True将其和待分类内容分开",
"____#####comment_reduce_zero_label_2#####": "在第 0 类为 background 类别的数据集上如果您最终是需要将背景和您的其余类别分开时是不需要使用reduce_zero_label的 【reduce_zero_label = False】",
"reduce_zero_label": false
},
"____三、comment_training_info": "定义训练相关参数",
"training_info": {
"crop_size_width": 256,
"crop_size_height": 256,
"train_batch_size": 16,
"train_num_workers": 4,
"val_and_test_batch_size": 1,
"val_and_test_num_workers": 4
}
}

View File

@@ -0,0 +1,61 @@
#!/usr/bin/env python3
# -*- coding: UTF-8 -*-
import os
from Initial_Data_Program.Initial_Data_Gen_configs_base_datasets_my_dataset import generate_configs_base_datasets_my_dataset_file
from Initial_Data_Program.Initial_Data_Gen_mmseg_utils_class_names import generate_mmseg_utils_class_names_file
from Initial_Data_Program.Initial_Data_Gen_mmseg_datasets_my_dataset import generate_mmseg_datasets_my_dataset_file
from Initial_Data_Program.Initial_Data_Gen_mmseg_datasets_init_ import generate_mmseg_datasets_init_file
if __name__ == '__main__':
########### 1.1.定义各数据集相关参数 ###########
# 可以定义多个数据集文件名 和 对应类名
dataset_file_names = ["my_dataset_model"] # =['my_dataset', 'my_dataset_2'] # =["my_dataset"]
dataset_class_names = ["MyDataset_model"] # =['MyDataset', 'MyDataset2'] # =["MyDataset"]
dataset_file_name='my_dataset_model' # 数据集 文件名.py TODO
dataset_class_name='MyDataset_model' # 数据集 类名称 TODO
data_root='/home/audience/Desktop/Seg_data/Data' # 数据根目录
img_scale=(1920, 1080) # 图片大小
# 训练、验证、测试集所在文件夹
train_img_path='A_Ori'
train_seg_map_path='A_Label_GT_label_fold'
val_img_path='A_Ori'
val_seg_map_path='A_Label_GT_label_fold'
test_img_path='A_Ori'
test_seg_map_path='A_Label_GT_label_fold'
########### 1.2.定义Label图片相关参数 ###########
classes = ['肝脏','胆囊','分离钳','止血海绵','肝总管','胆总管','吸引器','剪刀','止血纱布','生物夹','无损伤钳','喷洒','胆囊管','胆囊动脉','电凝','标本袋','引流管','纱布','金属钛夹','术中超声','吻合器','乳胶管','推结器','肝带','钳夹','超声刀','脂肪','双极电凝','棉球','血管阻断夹','肿瘤','','线','韧带','胆囊静脉']
palette = [[255,91,0],[255,234,0],[85, 111, 181],[181, 227, 14],[72, 0, 255],[0, 155, 33],[255,0,255],[29, 32, 136],[160, 15, 95],[0,160,233],[52,184,178],[90,120,41],[255,0,0],[177,0,0],[167,24,233],[112,113,150],[0,255,0],[255,255,255],[0,255,255],[138,251,213],[136,162,196],[197,83,181],[202,202,200],[113,102,140],[66,115,82],[240,16,116],[155,132,0],[155,62,0],[146,175,236],[255,172,159],[245,161,0],[134,124,118],[0,157,142],[181,85,105],[42,8,66]]
# 这里的classes一定是经过“Initial_Gen_mmseg_datasets_my_dataset.py”处理的
classes_all = [
['背景','肝脏','胆囊','分离钳','止血海绵','肝总管','胆总管','吸引器','剪刀','止血纱布','生物夹','无损伤钳','喷洒','胆囊管','胆囊动脉','电凝','标本袋','引流管','纱布','金属钛夹','术中超声','吻合器','乳胶管','推结器','肝带','钳夹','超声刀','脂肪','双极电凝','棉球','血管阻断夹','肿瘤','','线','韧带','胆囊静脉'],
]
palette_all = [
[[0,0,0],[255,91,0],[255,234,0],[85, 111, 181],[181, 227, 14],[72, 0, 255],[0, 155, 33],[255,0,255],[29, 32, 136],[160, 15, 95],[0,160,233],[52,184,178],[90,120,41],[255,0,0],[177,0,0],[167,24,233],[112,113,150],[0,255,0],[255,255,255],[0,255,255],[138,251,213],[136,162,196],[197,83,181],[202,202,200],[113,102,140],[66,115,82],[240,16,116],[155,132,0],[155,62,0],[146,175,236],[255,172,159],[245,161,0],[134,124,118],[0,157,142],[181,85,105],[42,8,66]],
]
# 一般不太会变的参数
img_suffix = ".png"
seg_map_suffix = "_gtFine_labelTrainIds.png"
reduce_zero_label = False # 在第 0 类为 background 类别的数据集上如果您最终是需要将背景和您的其余类别分开时是不需要使用reduce_zero_label的 【reduce_zero_label = False】
########### 1.3.定义训练相关参数 ###########
# 一般不太会变的参数
crop_size=(512, 512) # 分割大小
train_batch_size=4 # 训练batch
train_num_workers=4 # 训练并行运行数量
val_and_test_batch_size=1 # 验证集和测试集batch
val_and_test_num_workers=4 # 验证集和测试集并行运行数量
########### 2.文件存储位置 ###########
output_configs_base_datasets_my_dataset=f'./configs/_base_/datasets/{dataset_file_name}.py'
output_mmseg_datasets_dataset_file_name = os.path.join(f'./mmseg/datasets/{dataset_file_name}.py')
output_mmseg_datasets_init = os.path.join('./mmseg/datasets/__init__.py')
output_mmseg_utils_class_names = f'./mmseg/utils/class_names.py'
########### 3.运行程序生成配置文件 ###########
success = generate_configs_base_datasets_my_dataset_file(output_file=output_configs_base_datasets_my_dataset, dataset_class_name=dataset_class_name , data_root=data_root, img_scale=img_scale, crop_size=crop_size, train_batch_size=train_batch_size, train_num_workers=train_num_workers, val_and_test_batch_size=val_and_test_batch_size, val_and_test_num_workers=val_and_test_num_workers, train_img_path=train_img_path, train_seg_map_path=train_seg_map_path, val_img_path=val_img_path, val_seg_map_path=val_seg_map_path, test_img_path=test_img_path, test_seg_map_path=test_seg_map_path)
success, classes, palette = generate_mmseg_datasets_my_dataset_file(output_file=output_mmseg_datasets_dataset_file_name, dataset_class_name=dataset_class_name, classes=classes, palette=palette, img_suffix=img_suffix, seg_map_suffix=seg_map_suffix, reduce_zero_label=reduce_zero_label)
# 需要用到上一步的classes和palette
success = generate_mmseg_datasets_init_file(output_file=output_mmseg_datasets_init, dataset_file_names=dataset_file_names, dataset_class_names=dataset_class_names)
success = generate_mmseg_utils_class_names_file(output_file=output_mmseg_utils_class_names, dataset_file_names=dataset_file_names, classes_all=classes_all, palette_all=palette_all)

View File

@@ -0,0 +1,209 @@
#!/usr/bin/env python3
# -*- coding: UTF-8 -*-
import os, json
from Initial_Data_Program.Initial_Data_Calculate_std_and_mean import calculate_pic_std_and_mean
from Initial_Data_Program.Initial_Data_Gen_configs_base_datasets_my_dataset import generate_configs_base_datasets_my_dataset_file
from Initial_Data_Program.Initial_Data_Gen_mmseg_utils_class_names import generate_mmseg_utils_class_names_file
from Initial_Data_Program.Initial_Data_Gen_mmseg_datasets_my_dataset import generate_mmseg_datasets_my_dataset_file
from Initial_Data_Program.Initial_Data_Gen_mmseg_datasets_init_ import generate_mmseg_datasets_init_file
def load_json_files(directory, not_check_list=None):
"""
读取指定文件夹下的所有 JSON 文件,排除指定的文件。
:param directory: 要读取的目录
:param not_check_list: 要排除的文件名列表,不包含路径。默认为空列表
:return: 包含所有有效 JSON 数据的列表,每个元素为一个字典
"""
if not_check_list is None:
not_check_list = [''] # 默认排除文件
# 获取所有 .json 文件,并排除在 not_check_list 中的文件
json_files = [f for f in os.listdir(directory) if f.endswith('.json') and f not in not_check_list]
data_list = []
# 遍历每个 JSON 文件
for json_file in json_files:
json_path = os.path.join(directory, json_file)
try:
# 打开并加载 JSON 文件
with open(json_path, 'r', encoding='utf-8') as f:
data = json.load(f)
# 将文件名(去掉 .json添加到数据中
data['file_name_json'] = json_file.rstrip(".json")
data_list.append(data)
except json.JSONDecodeError:
print(f"\033[91mError decoding JSON file: {json_path}\033[0m")
except Exception as e:
print(f"\033[91mError reading file {json_path}: {str(e)}\033[0m")
return data_list
def process_json_data(json_data):
"""处理每个 JSON 数据,生成相应参数"""
# 提取 dataset 信息
dataset_file_name = json_data["dataset_info"]["dataset_file_name"]
dataset_class_name = json_data["dataset_info"]["dataset_class_name"]
data_root = json_data["dataset_info"]["data_root"]
# 转换 img_scale 为元组 (img_scale_width, img_scale_height)
img_scale = (json_data["dataset_info"]["img_scale_width"], json_data["dataset_info"]["img_scale_height"])
# 提取其他必要信息
train_img_path = json_data["dataset_info"]["paths"]["train_img_path"]
train_seg_map_path = json_data["dataset_info"]["paths"]["train_seg_map_path"]
val_img_path = json_data["dataset_info"]["paths"]["val_img_path"]
val_seg_map_path = json_data["dataset_info"]["paths"]["val_seg_map_path"]
test_img_path = json_data["dataset_info"]["paths"]["test_img_path"]
test_seg_map_path = json_data["dataset_info"]["paths"]["test_seg_map_path"]
# 提取 label 相关信息
classes = json_data["label_info"]["classes"]
palette = json_data["label_info"]["palette"]
img_suffix = json_data["label_info"]["img_suffix"]
seg_map_suffix = json_data["label_info"]["seg_map_suffix"]
reduce_zero_label = json_data["label_info"]["reduce_zero_label"]
# 提取训练相关参数
# 转换 crop_size 为元组 (crop_size_width, crop_size_height)
crop_size = (json_data["training_info"]["crop_size_width"], json_data["training_info"]["crop_size_height"])
train_batch_size = json_data["training_info"]["train_batch_size"]
train_num_workers = json_data["training_info"]["train_num_workers"]
val_and_test_batch_size = json_data["training_info"]["val_and_test_batch_size"]
val_and_test_num_workers = json_data["training_info"]["val_and_test_num_workers"]
return (dataset_file_name, dataset_class_name, data_root, img_scale, train_img_path, train_seg_map_path, val_img_path, val_seg_map_path, test_img_path, test_seg_map_path,
classes, palette, img_suffix, seg_map_suffix, reduce_zero_label, crop_size, train_batch_size, train_num_workers, val_and_test_batch_size, val_and_test_num_workers,)
def save_all_record_to_json(output_file, dataset_file_names, classes_all, palette_all, palette_num_all, mean_all, std_all):
"""构建一个 JSON 文件,用于存储每个数据集的信息"""
# 构建一个字典,用于存储每个数据集的信息
data_record = {}
# 假设所有列表长度相同,遍历每个 dataset_file_name
for i in range(len(dataset_file_names)):
data_record[dataset_file_names[i]] = {
"classes": classes_all[i],
"palette": palette_all[i],
"palette_num": palette_num_all[i],
"mean": list(mean_all[i]),
"std": list(std_all[i])
}
# 将字典写入 JSON 文件
with open(output_file, 'w', encoding='utf-8') as json_file:
json.dump(data_record, json_file, ensure_ascii=False, indent=6)
print(f"\033[93m相关数据汇总到 {output_file} successfully!\033[0m")
return True
if __name__ == '__main__':
########### 1.超参数-定义训练文件夹路径 ###########
train_parameter_dir = './My_All_In_One/1_Data_Parameter'
all_data_record_json = "All_Data_Record.json" # 记录所有数据
########### 2.1.遍历所有配置文件,并生成对应数据集 ###########
# 定义保存信息的列表
dataset_file_names = []
dataset_class_names = []
classes_all = []
palette_all = []
palette_num_all = []
mean_all = []
std_all = []
# 2.1.1. 从 ./1_Data_Parameter 文件夹读取所有 JSON 文件
json_data_list = load_json_files(train_parameter_dir, not_check_list = [all_data_record_json])
# 2.1.2. 遍历每个 JSON 数据文件并生成对应配置文件
for json_data in json_data_list:
# A. 输出当前处理文件
print(f"\033[32m正在处理{json_data['file_name_json']}.json文件\033[0m")
# B. 处理 JSON 数据并提取参数
(dataset_file_name, dataset_class_name, data_root, img_scale, train_img_path, train_seg_map_path, val_img_path, val_seg_map_path, test_img_path, test_seg_map_path,
classes, palette, img_suffix, seg_map_suffix, reduce_zero_label, crop_size, train_batch_size, train_num_workers, val_and_test_batch_size, val_and_test_num_workers,
) = process_json_data(json_data)
# 保存文件名和类别名
dataset_file_names.append(dataset_file_name)
dataset_class_names.append(dataset_class_name)
# C. 文件存储位置
output_configs_base_datasets_my_dataset = f'./configs/_base_/datasets/{dataset_file_name}.py'
output_mmseg_datasets_dataset_file_name = os.path.join(f'./mmseg/datasets/{dataset_file_name}.py')
output_mmseg_datasets_init = os.path.join('./mmseg/datasets/__init__.py')
output_mmseg_utils_class_names = f'./mmseg/utils/class_names.py'
# D. 运行程序生成配置文件
# 生成 ./configs/_base_/datasets/{dataset_file_name}.py
print(" ",end='')
success = generate_configs_base_datasets_my_dataset_file(
output_file=output_configs_base_datasets_my_dataset,
dataset_class_name=dataset_class_name,
data_root=data_root,
img_scale=img_scale,
crop_size=crop_size,
train_batch_size=train_batch_size,
train_num_workers=train_num_workers,
val_and_test_batch_size=val_and_test_batch_size,
val_and_test_num_workers=val_and_test_num_workers,
train_img_path=train_img_path,
train_seg_map_path=train_seg_map_path,
val_img_path=val_img_path,
val_seg_map_path=val_seg_map_path,
test_img_path=test_img_path,
test_seg_map_path=test_seg_map_path
)
# 生成 ./mmseg/datasets/{dataset_file_name}.py
print(" ",end='')
success, classes, palette = generate_mmseg_datasets_my_dataset_file(
output_file=output_mmseg_datasets_dataset_file_name,
dataset_class_name=dataset_class_name,
classes=classes,
palette=palette,
img_suffix=img_suffix,
seg_map_suffix=seg_map_suffix,
reduce_zero_label=reduce_zero_label
)
mean, std = calculate_pic_std_and_mean(
dataset_dir = os.path.join(data_root, train_img_path)
)
# 保存标注类名和颜色
classes_all.append(classes)
palette_all.append(palette)
palette_num_all.append(len(classes))
mean_all.append(mean)
std_all.append(std)
########### 2.2.汇总所有信息运行生成 init 和 class_names 文件 ###########
# 生成 ./mmseg/datasets/__init__.py
print(" ",end='')
success = generate_mmseg_datasets_init_file(
output_file=output_mmseg_datasets_init,
dataset_file_names=dataset_file_names,
dataset_class_names=dataset_class_names
)
# 生成 ./mmseg/utils/class_names.py
print(" ",end='')
success = generate_mmseg_utils_class_names_file(
output_file=output_mmseg_utils_class_names,
dataset_file_names=dataset_file_names,
classes_all=classes_all,
palette_all=palette_all
)
########### 2.2.汇总dataset_file_names、classes_all、palette_all、palette_num_all所有信息到My_All_In_One/1_Data_Parameter/All_Data_Record.json文件 ###########
output_all_data_record = os.path.join(train_parameter_dir, all_data_record_json)
# 调用函数保存数据
success = save_all_record_to_json(output_file=output_all_data_record, dataset_file_names=dataset_file_names, classes_all=classes_all, palette_all=palette_all, palette_num_all=palette_num_all, mean_all=mean_all, std_all=std_all)

View File

@@ -0,0 +1,212 @@
#!/usr/bin/env python3
# -*- coding: UTF-8 -*-
import os, json
from Initial_Data_Program.Initial_Data_Calculate_std_and_mean import calculate_pic_std_and_mean
from Initial_Data_Program.Initial_Data_Gen_configs_base_datasets_my_dataset import generate_configs_base_datasets_my_dataset_file
from Initial_Data_Program.Initial_Data_Gen_mmseg_utils_class_names import generate_mmseg_utils_class_names_file
from Initial_Data_Program.Initial_Data_Gen_mmseg_datasets_my_dataset import generate_mmseg_datasets_my_dataset_file
from Initial_Data_Program.Initial_Data_Gen_mmseg_datasets_init_ import generate_mmseg_datasets_init_file
def load_json_files(directory, not_check_list=None):
"""
读取指定文件夹下的所有 JSON 文件,排除指定的文件。
:param directory: 要读取的目录
:param not_check_list: 要排除的文件名列表,不包含路径。默认为空列表
:return: 包含所有有效 JSON 数据的列表,每个元素为一个字典
"""
if not_check_list is None:
not_check_list = [''] # 默认排除文件
# 获取所有 .json 文件,并排除在 not_check_list 中的文件
json_files = [f for f in os.listdir(directory) if f.endswith('.json') and f not in not_check_list]
data_list = []
# 遍历每个 JSON 文件
for json_file in json_files:
json_path = os.path.join(directory, json_file)
try:
# 打开并加载 JSON 文件
with open(json_path, 'r', encoding='utf-8') as f:
data = json.load(f)
# 将文件名(去掉 .json添加到数据中
data['file_name_json'] = json_file.rstrip(".json")
data_list.append(data)
except json.JSONDecodeError:
print(f"\033[91mError decoding JSON file: {json_path}\033[0m")
except Exception as e:
print(f"\033[91mError reading file {json_path}: {str(e)}\033[0m")
return data_list
def process_json_data(json_data):
"""处理每个 JSON 数据,生成相应参数"""
# 提取 dataset 信息
dataset_file_name = json_data["dataset_info"]["dataset_file_name"]
dataset_class_name = json_data["dataset_info"]["dataset_class_name"]
data_root = json_data["dataset_info"]["data_root"]
# 转换 img_scale 为元组 (img_scale_width, img_scale_height)
img_scale = (json_data["dataset_info"]["img_scale_width"], json_data["dataset_info"]["img_scale_height"])
# 提取其他必要信息
train_img_path = json_data["dataset_info"]["paths"]["train_img_path"]
train_seg_map_path = json_data["dataset_info"]["paths"]["train_seg_map_path"]
val_img_path = json_data["dataset_info"]["paths"]["val_img_path"]
val_seg_map_path = json_data["dataset_info"]["paths"]["val_seg_map_path"]
test_img_path = json_data["dataset_info"]["paths"]["test_img_path"]
test_seg_map_path = json_data["dataset_info"]["paths"]["test_seg_map_path"]
# 提取 label 相关信息
classes = json_data["label_info"]["classes"]
palette = json_data["label_info"]["palette"]
img_suffix = json_data["label_info"]["img_suffix"]
seg_map_suffix = json_data["label_info"]["seg_map_suffix"]
reduce_zero_label = json_data["label_info"]["reduce_zero_label"]
# 提取训练相关参数
# 转换 crop_size 为元组 (crop_size_width, crop_size_height)
crop_size = (json_data["training_info"]["crop_size_width"], json_data["training_info"]["crop_size_height"])
train_batch_size = json_data["training_info"]["train_batch_size"]
train_num_workers = json_data["training_info"]["train_num_workers"]
val_and_test_batch_size = json_data["training_info"]["val_and_test_batch_size"]
val_and_test_num_workers = json_data["training_info"]["val_and_test_num_workers"]
return (dataset_file_name, dataset_class_name, data_root, img_scale, train_img_path, train_seg_map_path, val_img_path, val_seg_map_path, test_img_path, test_seg_map_path,
classes, palette, img_suffix, seg_map_suffix, reduce_zero_label, crop_size, train_batch_size, train_num_workers, val_and_test_batch_size, val_and_test_num_workers,)
def save_all_record_to_json(output_file, dataset_file_names, classes_all, palette_all, palette_num_all, mean_all, std_all, train_imgs_num):
"""构建一个 JSON 文件,用于存储每个数据集的信息"""
# 构建一个字典,用于存储每个数据集的信息
data_record = {}
# 假设所有列表长度相同,遍历每个 dataset_file_name
for i in range(len(dataset_file_names)):
data_record[dataset_file_names[i]] = {
"classes": classes_all[i],
"palette": palette_all[i],
"palette_num": palette_num_all[i],
"mean": list(mean_all[i]),
"std": list(std_all[i]),
"train_imgs_num": train_imgs_num[i]
}
# 将字典写入 JSON 文件
with open(output_file, 'w', encoding='utf-8') as json_file:
json.dump(data_record, json_file, ensure_ascii=False, indent=6)
print(f"\033[93m相关数据汇总到 {output_file} successfully!\033[0m")
return True
if __name__ == '__main__':
########### 1.超参数-定义训练文件夹路径 ###########
train_parameter_dir = './My_All_In_One/1_Data_Parameter'
all_data_record_json = "All_Data_Record.json" # 记录所有数据
########### 2.1.遍历所有配置文件,并生成对应数据集 ###########
# 定义保存信息的列表
dataset_file_names = []
dataset_class_names = []
classes_all = []
palette_all = []
palette_num_all = []
mean_all = []
std_all = []
train_imgs_num = []
# 2.1.1. 从 ./1_Data_Parameter 文件夹读取所有 JSON 文件
json_data_list = load_json_files(train_parameter_dir, not_check_list = [all_data_record_json])
# 2.1.2. 遍历每个 JSON 数据文件并生成对应配置文件
for json_data in json_data_list:
# A. 输出当前处理文件
print(f"\033[32m正在处理{json_data['file_name_json']}.json文件\033[0m")
# B. 处理 JSON 数据并提取参数
(dataset_file_name, dataset_class_name, data_root, img_scale, train_img_path, train_seg_map_path, val_img_path, val_seg_map_path, test_img_path, test_seg_map_path,
classes, palette, img_suffix, seg_map_suffix, reduce_zero_label, crop_size, train_batch_size, train_num_workers, val_and_test_batch_size, val_and_test_num_workers,
) = process_json_data(json_data)
# 保存文件名和类别名
dataset_file_names.append(dataset_file_name)
dataset_class_names.append(dataset_class_name)
# C. 文件存储位置
output_configs_base_datasets_my_dataset = f'./configs/_base_/datasets/{dataset_file_name}.py'
output_mmseg_datasets_dataset_file_name = os.path.join(f'./mmseg/datasets/{dataset_file_name}.py')
output_mmseg_datasets_init = os.path.join('./mmseg/datasets/__init__.py')
output_mmseg_utils_class_names = f'./mmseg/utils/class_names.py'
# D. 运行程序生成配置文件
# 生成 ./configs/_base_/datasets/{dataset_file_name}.py
print(" ",end='')
success = generate_configs_base_datasets_my_dataset_file(
output_file=output_configs_base_datasets_my_dataset,
dataset_class_name=dataset_class_name,
data_root=data_root,
img_scale=img_scale,
crop_size=crop_size,
train_batch_size=train_batch_size,
train_num_workers=train_num_workers,
val_and_test_batch_size=val_and_test_batch_size,
val_and_test_num_workers=val_and_test_num_workers,
train_img_path=train_img_path,
train_seg_map_path=train_seg_map_path,
val_img_path=val_img_path,
val_seg_map_path=val_seg_map_path,
test_img_path=test_img_path,
test_seg_map_path=test_seg_map_path
)
# 生成 ./mmseg/datasets/{dataset_file_name}.py
print(" ",end='')
success, classes, palette = generate_mmseg_datasets_my_dataset_file(
output_file=output_mmseg_datasets_dataset_file_name,
dataset_class_name=dataset_class_name,
classes=classes,
palette=palette,
img_suffix=img_suffix,
seg_map_suffix=seg_map_suffix,
reduce_zero_label=reduce_zero_label
)
mean, std = calculate_pic_std_and_mean(
dataset_dir = os.path.join(data_root, train_img_path)
)
# 保存标注类名和颜色
classes_all.append(classes)
palette_all.append(palette)
palette_num_all.append(len(classes))
mean_all.append(mean)
std_all.append(std)
train_imgs_num.append(len(os.listdir(os.path.join(data_root, train_img_path))))
########### 2.2.汇总所有信息运行生成 init 和 class_names 文件 ###########
# 生成 ./mmseg/datasets/__init__.py
print(" ",end='')
success = generate_mmseg_datasets_init_file(
output_file=output_mmseg_datasets_init,
dataset_file_names=dataset_file_names,
dataset_class_names=dataset_class_names
)
# 生成 ./mmseg/utils/class_names.py
print(" ",end='')
success = generate_mmseg_utils_class_names_file(
output_file=output_mmseg_utils_class_names,
dataset_file_names=dataset_file_names,
classes_all=classes_all,
palette_all=palette_all
)
########### 2.2.汇总dataset_file_names、classes_all、palette_all、palette_num_all所有信息到My_All_In_One/1_Data_Parameter/All_Data_Record.json文件 ###########
output_all_data_record = os.path.join(train_parameter_dir, all_data_record_json)
# 调用函数保存数据
success = save_all_record_to_json(output_file=output_all_data_record, dataset_file_names=dataset_file_names, classes_all=classes_all, palette_all=palette_all, palette_num_all=palette_num_all, mean_all=mean_all, std_all=std_all, train_imgs_num=train_imgs_num)

View File

@@ -0,0 +1,101 @@
import os, sys, argparse, json
# 添加上级目录到 sys.path
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../')))
from Initial_Alg_Program.Initial_Alg_Select_Tool import select_crop_size, select_pretrained_model, select_test_cfg_slide
from Initial_Alg_Program.Initial_Alg_Gen_Tool import format_all_data, write_config_to_file
from Initial_Alg_Program.Initial_Alg_Gen_0_base_ import generate_base_config
from Initial_Alg_Program.Initial_Alg_Gen_1_norm_cfg import generate_norm_cfg
from Initial_Alg_Program.Initial_Alg_Gen_2_data_preprocessor import generate_data_preprocessor
from Initial_Alg_Program.Initial_Alg_Gen_3_model import generate_model_pretrained, generate_model_backbone, generate_model_data_preprocessor, generate_model_decode_head, generate_model_auxiliary_head, generate_model_train_cfg, generate_model_test_cfg
from Initial_Alg_Program.Initial_Alg_Gen_4_optimizer import generate_optim_wrapper, generate_param_scheduler
if __name__ == '__main__':
########### 1.超参数-算法默认参数 ###########
alg_file_name = 'ann_r50-d8'
########### 2.参数解析 ###########
parser = argparse.ArgumentParser(description="Run algorithm with specified parameters.")
# 添加命令行参数
parser.add_argument('--alg_name', type=str, required=True, help="Algorithm name (e.g., 'ann')")
parser.add_argument('--dataset_file_name', type=str, required=True, help="Dataset file name (e.g., 'my_dataset_model')")
parser.add_argument('--mean', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
parser.add_argument('--std', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
parser.add_argument('--dataset_num_classes', type=int, required=True, help="Dataset class num (e.g., 36)")
parser.add_argument('--GPU_num', type=int, required=True, help="Num of GPU to train (e.g., 1)")
parser.add_argument('--schedule_file_name', type=str, required=True, help="Schedule file name (e.g., 'schedule_4k_check_400')")
parser.add_argument('--train_type', type=str, required=True, help="Schedule train type (e.g., Epoch or Iteration)")
parser.add_argument('--train_time_or_epoch_k', type=int, required=True, help="Schedule train time k (e.g., 4)")
# 解析命令行参数
args = parser.parse_args()
# 变量赋值
alg_name = args.alg_name
dataset_file_name = args.dataset_file_name
mean = json.loads(args.mean) # 使用json.loads转换成list
std = json.loads(args.std) # 使用json.loads转换成list
num_classes = args.dataset_num_classes
schedule_file_name = args.schedule_file_name
train_type = args.train_type
train_time_or_epoch_k = args.train_time_or_epoch_k
GPU_num = args.GPU_num
########### 3. 设定参数 ###########
# 3.1. 自选参数
crop_size = select_crop_size(predefined_options=[(512, 512), (512, 1024), (769, 769)]) # 选择切割大小
model_list = ['openmmlab/resnet50_v1c', 'openmmlab/resnet101_v1c']
selected_model_name, select_pretrained, pretrained_pth, selected_model_info = select_pretrained_model(model_list=model_list)
depth = selected_model_info['depth']
test_cfg_mode, test_cfg_crop_div_stride = select_test_cfg_slide(crop_size)
# 3.2. 自动生成参数
# decode、auxiliary损失下载方式
decode_head_loss_decode_dict=dict(type='DiceLoss', use_sigmoid=False, loss_weight=1.0) # decode损失函数 # TODO
auxiliary_head_loss_decode_dict=dict(type='DiceLoss', use_sigmoid=False, loss_weight=0.4) # DiceLoss损失函数 # TODO
# 如果使用slide模式则开启align_corners
if test_cfg_mode == "slide":
align_corners = True
test_slide = 'testslide' # 算法名进行标记
else:
align_corners = False
test_slide = 'no_testslide' # 算法名进行标记
########### 3.生成各个部分 ###########
_base_ = generate_base_config(alg_file_name=alg_file_name, dataset_file_name=dataset_file_name, schedule_file_name=schedule_file_name)
norm_cfg = generate_norm_cfg(GPU_num=GPU_num)
data_preprocessor = generate_data_preprocessor(crop_size = crop_size, mean = mean, std = std)
pretrained = generate_model_pretrained(pretrained_pth = pretrained_pth)
backbone = generate_model_backbone(depth=depth)
model_data_preprocessor = data_preprocessor # 修改data_preprocessor
model_data_preprocessor = generate_model_data_preprocessor(model_data_preprocessor=model_data_preprocessor)
decode_head = generate_model_decode_head(num_classes=num_classes, decode_head_loss_decode_dict=decode_head_loss_decode_dict, align_corners=align_corners)
auxiliary_head = generate_model_auxiliary_head(num_classes=num_classes, auxiliary_head_loss_decode_dict=auxiliary_head_loss_decode_dict, align_corners=align_corners)
# 综合model
model = dict(
pretrained = pretrained,
backbone = backbone,
data_preprocessor = model_data_preprocessor,
decode_head = decode_head,
auxiliary_head = auxiliary_head,
)
test_cfg = generate_model_test_cfg(test_cfg_mode=test_cfg_mode, crop_size=crop_size, test_cfg_crop_div_stride=test_cfg_crop_div_stride)
optim_wrapper = generate_optim_wrapper()
param_scheduler = generate_param_scheduler(train_type, train_time_or_epoch_k)
########### 4.程序写入 ###########
# 算法名称解析ann【Alg】 _ r50【pretrained模型深度】 - d8【_base_/models】 _ 4【GPU数量】 x b2【Batch Size大小】 - 40k【schedule】 _ cityscapes【数据集】 - 512x1024【crop_size】.py
alg_file_name = f"{alg_name}_r{depth}_{'Pre' if select_pretrained else 'NoPre'}_g{GPU_num}-{schedule_file_name.lstrip('schedule_')}_{dataset_file_name}-{crop_size[0]}x{crop_size[1]}-{test_slide}.py"
output_configs_alg_my_alg = os.path.join(f'./configs/ann/', alg_file_name)
write_config_to_file(output_configs_alg_my_alg, _base_=_base_, crop_size=crop_size, data_preprocessor=data_preprocessor, model=model, test_cfg=test_cfg, optim_wrapper=optim_wrapper, param_scheduler=param_scheduler)
# 将信息写入文件
with open("_temp_.txt", "w") as file:
json.dump({'alg_infos':{'alg_file_name':alg_file_name,'alg_file_pth':output_configs_alg_my_alg}}, file)

View File

@@ -0,0 +1,102 @@
import os, sys, argparse, json
# 添加上级目录到 sys.path
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../')))
from Initial_Alg_Program.Initial_Alg_Select_Tool import select_crop_size, select_pretrained_model, select_test_cfg_slide
from Initial_Alg_Program.Initial_Alg_Gen_Tool import format_all_data, write_config_to_file
from Initial_Alg_Program.Initial_Alg_Gen_0_base_ import generate_base_config
from Initial_Alg_Program.Initial_Alg_Gen_1_norm_cfg import generate_norm_cfg
from Initial_Alg_Program.Initial_Alg_Gen_2_data_preprocessor import generate_data_preprocessor
from Initial_Alg_Program.Initial_Alg_Gen_3_model import generate_model_pretrained, generate_model_backbone, generate_model_data_preprocessor, generate_model_decode_head, generate_model_auxiliary_head, generate_model_train_cfg, generate_model_test_cfg
from Initial_Alg_Program.Initial_Alg_Gen_4_optimizer import generate_optim_wrapper, generate_param_scheduler
if __name__ == '__main__':
########### 1.超参数-算法默认参数 ###########
alg_file_name = 'apcnet_r50-d8'
########### 2.参数解析 ###########
parser = argparse.ArgumentParser(description="Run algorithm with specified parameters.")
# 添加命令行参数
parser.add_argument('--alg_name', type=str, required=True, help="Algorithm name (e.g., 'ann')")
parser.add_argument('--dataset_file_name', type=str, required=True, help="Dataset file name (e.g., 'my_dataset_model')")
parser.add_argument('--mean', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
parser.add_argument('--std', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
parser.add_argument('--dataset_num_classes', type=int, required=True, help="Dataset class num (e.g., 36)")
parser.add_argument('--GPU_num', type=int, required=True, help="Num of GPU to train (e.g., 1)")
parser.add_argument('--schedule_file_name', type=str, required=True, help="Schedule file name (e.g., 'schedule_4k_check_400')")
parser.add_argument('--train_type', type=str, required=True, help="Schedule train type (e.g., Epoch or Iteration)")
parser.add_argument('--train_time_or_epoch_k', type=int, required=True, help="Schedule train time k (e.g., 4)")
# 解析命令行参数
args = parser.parse_args()
# 变量赋值
alg_name = args.alg_name
dataset_file_name = args.dataset_file_name
mean = json.loads(args.mean) # 使用json.loads转换成list
std = json.loads(args.std) # 使用json.loads转换成list
num_classes = args.dataset_num_classes
schedule_file_name = args.schedule_file_name
train_type = args.train_type
train_time_or_epoch_k = args.train_time_or_epoch_k
GPU_num = args.GPU_num
########### 3. 设定参数 ###########
# 3.1. 自选参数
crop_size = select_crop_size(predefined_options=[(512, 512), (512, 1024), (769, 769)]) # 选择切割大小
model_list = ['openmmlab/resnet50_v1c', 'openmmlab/resnet101_v1c']
selected_model_name, select_pretrained, pretrained_pth, selected_model_info = select_pretrained_model(model_list=model_list)
depth = selected_model_info['depth']
test_cfg_mode, test_cfg_crop_div_stride = select_test_cfg_slide(crop_size)
# 3.2. 自动生成参数
# decode、auxiliary损失下载方式
decode_head_loss_decode_dict=dict(type='DiceLoss', use_sigmoid=False, loss_weight=1.0) # decode损失函数 # TODO
auxiliary_head_loss_decode_dict=dict(type='DiceLoss', use_sigmoid=False, loss_weight=0.4) # DiceLoss损失函数 # TODO
# 如果使用slide模式则开启align_corners
if test_cfg_mode == "slide":
align_corners = True
test_slide = 'testslide' # 算法名进行标记
else:
align_corners = False
test_slide = 'no_testslide' # 算法名进行标记
########### 3.生成各个部分 ###########
_base_ = generate_base_config(alg_file_name=alg_file_name, dataset_file_name=dataset_file_name, schedule_file_name=schedule_file_name)
norm_cfg = generate_norm_cfg(GPU_num=GPU_num)
data_preprocessor = generate_data_preprocessor(crop_size = crop_size, mean = mean, std = std)
pretrained = generate_model_pretrained(pretrained_pth = pretrained_pth)
backbone = generate_model_backbone(depth=depth)
model_data_preprocessor = data_preprocessor # 修改data_preprocessor
model_data_preprocessor = generate_model_data_preprocessor(model_data_preprocessor=model_data_preprocessor)
decode_head = generate_model_decode_head(num_classes=num_classes, decode_head_loss_decode_dict=decode_head_loss_decode_dict, align_corners=align_corners)
auxiliary_head = generate_model_auxiliary_head(num_classes=num_classes, auxiliary_head_loss_decode_dict=auxiliary_head_loss_decode_dict, align_corners=align_corners)
# 综合model
model = dict(
pretrained = pretrained,
backbone = backbone,
data_preprocessor = model_data_preprocessor,
decode_head = decode_head,
auxiliary_head = auxiliary_head,
)
test_cfg = generate_model_test_cfg(test_cfg_mode=test_cfg_mode, crop_size=crop_size, test_cfg_crop_div_stride=test_cfg_crop_div_stride)
optim_wrapper = generate_optim_wrapper()
param_scheduler = generate_param_scheduler(train_type, train_time_or_epoch_k)
########### 4.程序写入 ###########
# 算法名称解析apcnet【Alg】 _ r50【pretrained模型深度】 - d8【_base_/models】 _ 4【GPU数量】 x b2【Batch Size大小】 - 40k【schedule】 _ cityscapes【数据集】 - 512x1024【crop_size】.py
alg_file_name = f"{alg_name}_r{depth}_{'Pre' if select_pretrained else 'NoPre'}_g{GPU_num}-{schedule_file_name.lstrip('schedule_')}_{dataset_file_name}-{crop_size[0]}x{crop_size[1]}-{test_slide}.py"
output_configs_alg_my_alg = os.path.join(f'./configs/apcnet/', alg_file_name)
write_config_to_file(output_configs_alg_my_alg, _base_=_base_, norm_cfg=norm_cfg, crop_size=crop_size, data_preprocessor=data_preprocessor, model=model, test_cfg=test_cfg, optim_wrapper=optim_wrapper, param_scheduler=param_scheduler)
# 将信息写入文件
with open("_temp_.txt", "w") as file:
json.dump({'alg_infos':{'alg_file_name':alg_file_name,'alg_file_pth':output_configs_alg_my_alg}}, file)

View File

@@ -0,0 +1,176 @@
import os, sys, argparse, json
# 添加上级目录到 sys.path
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../')))
from Initial_Alg_Program.Initial_Alg_Select_Tool import select_crop_size, select_pretrained_model
from Initial_Alg_Program.Initial_Alg_Gen_Tool import format_all_data, write_config_to_file, create_dict_by_kwargs, get_var_from_file, update_list_dict_var, get_var_from_py_file
from Initial_Alg_Program.Initial_Alg_Gen_0_base_ import generate_base_config
from Initial_Alg_Program.Initial_Alg_Gen_1_norm_cfg import generate_norm_cfg
from Initial_Alg_Program.Initial_Alg_Gen_2_data_preprocessor import generate_data_preprocessor
from Initial_Alg_Program.Initial_Alg_Gen_3_model import generate_model_pretrained, generate_model_backbone, generate_model_data_preprocessor, generate_model_decode_head, generate_model_auxiliary_head, generate_model_train_cfg, generate_model_test_cfg
from Initial_Alg_Program.Initial_Alg_Gen_4_optimizer import generate_optim_wrapper, generate_param_scheduler
from Initial_Alg_Program.Initial_Alg_Gen_5_train_dataloader import generate_train_dataloader
# 交互式选择 decode_head 的函数
def select_decode_head(decode_head_choose):
print("可用的 decode head 选项:")
for i, key in enumerate(decode_head_choose.keys()):
print(f"{i + 1}. {key}")
while True:
try:
# 提示用户输入选择
choice = int(input("请选择需要的 decode head输入编号"))
# 检查输入是否在有效范围内
if 1 <= choice <= len(decode_head_choose):
selected_key = list(decode_head_choose.keys())[choice - 1]
print(f"你选择了: {selected_key}")
return decode_head_choose[selected_key]
else:
print("输入的编号不正确,请重新输入。")
except ValueError:
print("输入无效,请输入有效的编号。")
if __name__ == '__main__':
########### 1.超参数-算法默认参数 ###########
alg_file_name = 'upernet_beit'
alg_file_pth = f"configs/_base_/models/{alg_file_name}.py"
########### 2.参数解析 ###########
parser = argparse.ArgumentParser(description="Run algorithm with specified parameters.")
# 添加命令行参数
parser.add_argument('--alg_name', type=str, required=True, help="Algorithm name (e.g., 'ann')")
parser.add_argument('--dataset_file_name', type=str, required=True, help="Dataset file name (e.g., 'my_dataset_model')")
parser.add_argument('--mean', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
parser.add_argument('--std', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
parser.add_argument('--dataset_num_classes', type=int, required=True, help="Dataset class num (e.g., 36)")
parser.add_argument('--GPU_num', type=int, required=True, help="Num of GPU to train (e.g., 1)")
parser.add_argument('--schedule_file_name', type=str, required=True, help="Schedule file name (e.g., 'schedule_4k_check_400')")
parser.add_argument('--train_type', type=str, required=True, help="Schedule train type (e.g., Epoch or Iteration)")
parser.add_argument('--train_time_or_epoch_k', type=int, required=True, help="Schedule train time k (e.g., 4)")
# 解析命令行参数
args = parser.parse_args()
# 变量赋值
alg_name = args.alg_name
dataset_file_name = args.dataset_file_name
mean = json.loads(args.mean) # 使用json.loads转换成list
std = json.loads(args.std) # 使用json.loads转换成list
num_classes = args.dataset_num_classes
schedule_file_name = args.schedule_file_name
train_type = args.train_type
train_time_or_epoch_k = args.train_time_or_epoch_k
GPU_num = args.GPU_num
########### 3.1. 生成_base_部分 ###########
_base_ = generate_base_config(alg_file_name=alg_file_name, dataset_file_name=dataset_file_name, schedule_file_name=schedule_file_name)
########### 3.1. 生成norm_cfg部分 ###########
norm_cfg = generate_norm_cfg(GPU_num=GPU_num)
########### 3.2. 生成data_preprocessor部分 ###########
crop_size = select_crop_size(predefined_options=[(640, 640)]) # 选择切割大小
data_preprocessor = generate_data_preprocessor(crop_size = crop_size, mean = mean, std = std)
########### 3.3. 生成model部分 ###########
model_data_preprocessor = data_preprocessor # 修改data_preprocessor
# 获取backbone模型、是否需要预训练
model_list = ['pretrain/beit_base_patch16_224_pt22k_ft22k.pth', 'pretrain/beit_large_patch16_224_pt22k_ft22k.pth']
selected_model_name, select_pretrained, pretrained_pth, selected_model_info = select_pretrained_model(model_list=model_list) # 需要选择是否用预训练模型
backbone = create_dict_by_kwargs(type='BEiT', img_size=crop_size,)
decode_head_loss_decode=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0) # Way Ori
decode_head_loss_decode=dict(type='DiceLoss', use_sigmoid=False, loss_weight=1.0) # Way 1
auxiliary_head_loss_decode=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4) # Way Ori
auxiliary_head_loss_decode=dict(type='DiceLoss', use_sigmoid=False, loss_weight=0.4) # Way 1
if selected_model_name == 'pretrain/beit_base_patch16_224_pt22k_ft22k.pth':
backbone_new=dict(
# patch_size=16,
# in_channels=3,
embed_dims=768,
num_layers=12,
num_heads=12,
# mlp_ratio=4,
out_indices=(3, 5, 7, 11),
# qv_bias=True,
# attn_drop_rate=0.0,
drop_path_rate=0.1,
# norm_cfg=dict(type='LN', eps=1e-6),
# act_cfg=dict(type='GELU'),
# norm_eval=False,
init_values=0.1)
neck = dict(type='Feature2Pyramid', embed_dim=768, rescales=[4, 2, 1, 0.5])
decode_head=dict(in_channels=[768, 768, 768, 768], num_classes=num_classes, channels=768, norm_cfg=norm_cfg, loss_decode=decode_head_loss_decode)
auxiliary_head=dict(norm_cfg=norm_cfg, in_channels=768, num_classes=num_classes, loss_decode = auxiliary_head_loss_decode)
optim_wrapper = dict(
_delete_=True,
type='AmpOptimWrapper', # type='OptimWrapper', # TODO Ori
optimizer=dict(
type='AdamW', lr=3e-5, betas=(0.9, 0.999), weight_decay=0.05),
constructor='LayerDecayOptimizerConstructor',
paramwise_cfg=dict(num_layers=12, layer_decay_rate=0.9))
model_size = 'base'
train_dataloader, batch_size = generate_train_dataloader(2)
elif selected_model_name == 'pretrain/beit_large_patch16_224_pt22k_ft22k.pth':
backbone_new=dict(
embed_dims=1024,
num_layers=24,
num_heads=16,
# mlp_ratio=4,
# qv_bias=True,
init_values=1e-6,
drop_path_rate=0.2,
out_indices=[7, 11, 15, 23])
neck=dict(type='Feature2Pyramid', embed_dim=1024, rescales=[4, 2, 1, 0.5])
decode_head=dict(in_channels=[1024, 1024, 1024, 1024], num_classes=num_classes, channels=1024, norm_cfg=norm_cfg, loss_decode=decode_head_loss_decode)
auxiliary_head=dict(norm_cfg=norm_cfg, in_channels=1024, num_classes=num_classes, loss_decode = auxiliary_head_loss_decode)
optim_wrapper = dict(
_delete_=True,
type='AmpOptimWrapper',
optimizer=dict(
type='AdamW', lr=2e-5, betas=(0.9, 0.999), weight_decay=0.05),
constructor='LayerDecayOptimizerConstructor',
paramwise_cfg=dict(num_layers=24, layer_decay_rate=0.95),
accumulative_counts=2)
model_size = 'large'
train_dataloader, batch_size = generate_train_dataloader(1)
# 更新backbone
backbone.update(backbone_new)
model = dict(
backbone = backbone,
data_preprocessor = model_data_preprocessor,
pretrained = pretrained_pth,
neck = neck,
decode_head = decode_head,
auxiliary_head = auxiliary_head,
)
########### 3.4. 生成optim、param_scheduler部分 ###########
optim_wrapper = optim_wrapper
param_scheduler = generate_param_scheduler(train_type, train_time_or_epoch_k)
########### 4.程序写入 ###########
# 算法名称解析beit【Alg】 _ r50【pretrained模型深度】 - d8【_base_/models】 _ 4【GPU数量】 x b2【Batch Size大小】 - 40k【schedule】 _ cityscapes【数据集】 - 512x1024【crop_size】.py
alg_file_name = f"{alg_name}-{model_size}_b{batch_size}_g{GPU_num}-{schedule_file_name.lstrip('schedule_')}_{dataset_file_name}-{crop_size[0]}x{crop_size[1]}.py"
output_configs_alg_my_alg = os.path.join(f'./configs/beit/', alg_file_name)
write_config_to_file(output_configs_alg_my_alg, _base_=_base_, norm_cfg=norm_cfg, crop_size=crop_size, data_preprocessor=data_preprocessor, model=model, optim_wrapper=optim_wrapper, param_scheduler=param_scheduler, train_dataloader = train_dataloader)
# 将信息写入文件
with open("_temp_.txt", "w") as file:
json.dump({'alg_infos':{'alg_file_name':alg_file_name,'alg_file_pth':output_configs_alg_my_alg}}, file)

View File

@@ -0,0 +1,133 @@
import os, sys, argparse, json
# 添加上级目录到 sys.path
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../')))
from Initial_Alg_Program.Initial_Alg_Select_Tool import select_crop_size, select_pretrained_model
from Initial_Alg_Program.Initial_Alg_Gen_Tool import format_all_data, write_config_to_file, create_dict_by_kwargs, get_var_from_file, update_list_dict_var, get_var_from_py_file
from Initial_Alg_Program.Initial_Alg_Gen_0_base_ import generate_base_config
from Initial_Alg_Program.Initial_Alg_Gen_1_norm_cfg import generate_norm_cfg
from Initial_Alg_Program.Initial_Alg_Gen_2_data_preprocessor import generate_data_preprocessor
from Initial_Alg_Program.Initial_Alg_Gen_3_model import generate_model_pretrained, generate_model_backbone, generate_model_data_preprocessor, generate_model_decode_head, generate_model_auxiliary_head, generate_model_train_cfg, generate_model_test_cfg
from Initial_Alg_Program.Initial_Alg_Gen_4_optimizer import generate_optim_wrapper, generate_param_scheduler
if __name__ == '__main__':
########### 1.超参数-算法默认参数 ###########
alg_file_name = 'bisenetv1_r18-d32'
alg_file_pth = f"configs/_base_/models/{alg_file_name}.py"
########### 2.参数解析 ###########
parser = argparse.ArgumentParser(description="Run algorithm with specified parameters.")
# 添加命令行参数
parser.add_argument('--alg_name', type=str, required=True, help="Algorithm name (e.g., 'ann')")
parser.add_argument('--dataset_file_name', type=str, required=True, help="Dataset file name (e.g., 'my_dataset_model')")
parser.add_argument('--mean', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
parser.add_argument('--std', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
parser.add_argument('--dataset_num_classes', type=int, required=True, help="Dataset class num (e.g., 36)")
parser.add_argument('--GPU_num', type=int, required=True, help="Num of GPU to train (e.g., 1)")
parser.add_argument('--schedule_file_name', type=str, required=True, help="Schedule file name (e.g., 'schedule_4k_check_400')")
parser.add_argument('--train_type', type=str, required=True, help="Schedule train type (e.g., Epoch or Iteration)")
parser.add_argument('--train_time_or_epoch_k', type=int, required=True, help="Schedule train time k (e.g., 4)")
# 解析命令行参数
args = parser.parse_args()
# 变量赋值
alg_name = args.alg_name
dataset_file_name = args.dataset_file_name
mean = json.loads(args.mean) # 使用json.loads转换成list
std = json.loads(args.std) # 使用json.loads转换成list
num_classes = args.dataset_num_classes
schedule_file_name = args.schedule_file_name
train_type = args.train_type
train_time_or_epoch_k = args.train_time_or_epoch_k
GPU_num = args.GPU_num
########### 3.1. 生成_base_部分 ###########
_base_ = generate_base_config(alg_file_name=alg_file_name, dataset_file_name=dataset_file_name, schedule_file_name=schedule_file_name)
########### 3.1. 生成norm_cfg部分 ###########
norm_cfg = generate_norm_cfg(GPU_num=GPU_num)
########### 3.2. 生成data_preprocessor部分 ###########
crop_size = select_crop_size(predefined_options=[(512, 512), (1024, 1024)]) # 选择切割大小
data_preprocessor = generate_data_preprocessor(crop_size = crop_size, mean = mean, std = std)
########### 3.3. 生成model部分 ###########
model_data_preprocessor = data_preprocessor # 修改data_preprocessor
# 获取backbone模型、是否需要预训练
model_list = ['openmmlab/resnet18_v1c', 'openmmlab/resnet50_v1c', 'openmmlab/resnet101_v1c']
selected_model_name, select_pretrained, pretrained_pth, selected_model_info = select_pretrained_model(model_list=model_list, need_select_pretrained = True) # 需要选择是否用预训练模型
depth = selected_model_info['depth']
if select_pretrained == True:
pretrained_txt = 'Pre'
else:
pretrained_txt = 'NoPre'
# 模型信息
if selected_model_name == 'openmmlab/resnet18_v1c' :
backbone_context_channels = (128, 256, 512)
backbone_spatial_channels = (64, 64, 64, 128)
backbone_out_channels = 256
decode_head_in_channels = 256
decode_head_channels = 256
auxiliary_head_in_channels = 128
auxiliary_head_channels = 64
elif selected_model_name == 'openmmlab/resnet50_v1c' or selected_model_name == 'openmmlab/resnet101_v1c' :
backbone_context_channels = (512, 1024, 2048)
backbone_spatial_channels = (256, 256, 256, 512)
backbone_out_channels = 1024
decode_head_in_channels = 1024
decode_head_channels = 1024
auxiliary_head_in_channels = 512
auxiliary_head_channels = 256
# 需要选择预训练模型
if select_pretrained == True:
backbone_backbone_cfg = create_dict_by_kwargs(type='ResNet', norm_cfg=norm_cfg, depth=depth, init_cfg=dict(type='Pretrained', checkpoint=pretrained_pth))
else:
backbone_backbone_cfg = create_dict_by_kwargs(type='ResNet', norm_cfg=norm_cfg, depth=depth)
backbone = create_dict_by_kwargs(context_channels=backbone_context_channels, norm_cfg=norm_cfg, spatial_channels=backbone_spatial_channels, out_channels=backbone_out_channels, backbone_cfg=backbone_backbone_cfg)
# decode、auxiliary损失下载方式 TODO
decode_head_loss_decode_dict=dict(type='DiceLoss', use_sigmoid=False, loss_weight=1.0) # decode损失函数 # TODO
decode_head = create_dict_by_kwargs(type='FCNHead',in_channels=decode_head_in_channels, channels=decode_head_channels, num_classes=num_classes)
# auxiliary损失下载方式 TODO # 可更改
# auxiliary_head_loss_decode_dict = dict(type='DiceLoss', use_sigmoid=False, loss_weight=0.4) # DiceLoss损失函数 # TODO # 可更改
auxiliary_head_loss_decode_dict = dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)
auxiliary_head = [create_dict_by_kwargs(type='FCNHead', norm_cfg=norm_cfg, in_channels=auxiliary_head_in_channels, channels=auxiliary_head_channels, num_classes=num_classes, align_corners=False, loss_decode=auxiliary_head_loss_decode_dict),
create_dict_by_kwargs(type='FCNHead', norm_cfg=norm_cfg, in_channels=auxiliary_head_in_channels, channels=auxiliary_head_channels, num_classes=num_classes, align_corners=False, loss_decode=auxiliary_head_loss_decode_dict)]
# 获取原始auxiliary_head
auxiliary_head = get_var_from_py_file(alg_file_pth, var_name="model")["auxiliary_head"]
# 新的auxiliary_head
auxiliary_head_new = [create_dict_by_kwargs(type='FCNHead', norm_cfg=norm_cfg, in_channels=auxiliary_head_in_channels, channels=auxiliary_head_channels, num_classes=num_classes, align_corners=False, loss_decode=auxiliary_head_loss_decode_dict),
create_dict_by_kwargs(type='FCNHead', norm_cfg=norm_cfg, in_channels=auxiliary_head_in_channels, channels=auxiliary_head_channels, num_classes=num_classes, align_corners=False, loss_decode=auxiliary_head_loss_decode_dict)]
# 更新auxiliary_head
auxiliary_head = update_list_dict_var(auxiliary_head, auxiliary_head_new)
# 综合model
model = dict(
backbone = backbone,
data_preprocessor = model_data_preprocessor,
decode_head = decode_head,
auxiliary_head = auxiliary_head,
)
########### 3.4. 生成optim、param_scheduler部分 ###########
optim_wrapper = generate_optim_wrapper()
param_scheduler = generate_param_scheduler(train_type, train_time_or_epoch_k)
########### 4.程序写入 ###########
# 算法名称解析bisenetv1【Alg】 _ r50【pretrained模型深度】 - d8【_base_/models】 _ 4【GPU数量】 x b2【Batch Size大小】 - 40k【schedule】 _ cityscapes【数据集】 - 512x1024【crop_size】.py
alg_file_name = f"{alg_name}_r{depth}_{'Pre' if select_pretrained else 'NoPre'}_g{GPU_num}-{schedule_file_name.lstrip('schedule_')}_{dataset_file_name}-{crop_size[0]}x{crop_size[1]}.py"
output_configs_alg_my_alg = os.path.join(f'./configs/bisenetv1/', alg_file_name)
write_config_to_file(output_configs_alg_my_alg, _base_=_base_, norm_cfg=norm_cfg, crop_size=crop_size, data_preprocessor=data_preprocessor, model=model, optim_wrapper=optim_wrapper, param_scheduler=param_scheduler)
# 将信息写入文件
with open("_temp_.txt", "w") as file:
json.dump({'alg_infos':{'alg_file_name':alg_file_name,'alg_file_pth':output_configs_alg_my_alg}}, file)

View File

@@ -0,0 +1,111 @@
import os, sys, argparse, json
# 添加上级目录到 sys.path
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../')))
from Initial_Alg_Program.Initial_Alg_Select_Tool import select_crop_size, select_pretrained_model, select_sampler
from Initial_Alg_Program.Initial_Alg_Gen_Tool import format_all_data, write_config_to_file, create_dict_by_kwargs, get_var_from_file, update_list_dict_var, get_var_from_py_file
from Initial_Alg_Program.Initial_Alg_Gen_0_base_ import generate_base_config
from Initial_Alg_Program.Initial_Alg_Gen_1_norm_cfg import generate_norm_cfg
from Initial_Alg_Program.Initial_Alg_Gen_2_data_preprocessor import generate_data_preprocessor
from Initial_Alg_Program.Initial_Alg_Gen_3_model import generate_model_pretrained, generate_model_backbone, generate_model_data_preprocessor, generate_model_decode_head, generate_model_auxiliary_head, generate_model_train_cfg, generate_model_test_cfg
from Initial_Alg_Program.Initial_Alg_Gen_4_optimizer import generate_optim_wrapper, generate_param_scheduler
if __name__ == '__main__':
########### 1.超参数-算法默认参数 ###########
alg_file_name = 'bisenetv2'
alg_file_pth = f"configs/_base_/models/{alg_file_name}.py"
########### 2.参数解析 ###########
parser = argparse.ArgumentParser(description="Run algorithm with specified parameters.")
# 添加命令行参数
parser.add_argument('--alg_name', type=str, required=True, help="Algorithm name (e.g., 'ann')")
parser.add_argument('--dataset_file_name', type=str, required=True, help="Dataset file name (e.g., 'my_dataset_model')")
parser.add_argument('--mean', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
parser.add_argument('--std', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
parser.add_argument('--dataset_num_classes', type=int, required=True, help="Dataset class num (e.g., 36)")
parser.add_argument('--GPU_num', type=int, required=True, help="Num of GPU to train (e.g., 1)")
parser.add_argument('--schedule_file_name', type=str, required=True, help="Schedule file name (e.g., 'schedule_4k_check_400')")
parser.add_argument('--train_type', type=str, required=True, help="Schedule train type (e.g., Epoch or Iteration)")
parser.add_argument('--train_time_or_epoch_k', type=int, required=True, help="Schedule train time k (e.g., 4)")
# 解析命令行参数
args = parser.parse_args()
# 变量赋值
alg_name = args.alg_name
dataset_file_name = args.dataset_file_name
mean = json.loads(args.mean) # 使用json.loads转换成list
std = json.loads(args.std) # 使用json.loads转换成list
num_classes = args.dataset_num_classes
schedule_file_name = args.schedule_file_name
train_type = args.train_type
train_time_or_epoch_k = args.train_time_or_epoch_k
GPU_num = args.GPU_num
########### 3.1. 生成_base_部分 ###########
_base_ = generate_base_config(alg_file_name=alg_file_name, dataset_file_name=dataset_file_name, schedule_file_name=schedule_file_name)
########### 3.1. 生成norm_cfg部分 ###########
norm_cfg = generate_norm_cfg(GPU_num=GPU_num)
########### 3.2. 生成data_preprocessor部分 ###########
crop_size = select_crop_size(predefined_options=[(512, 512)]) # 选择切割大小
data_preprocessor = generate_data_preprocessor(crop_size = crop_size, mean = mean, std = std)
########### 3.3. 生成model部分 ###########
# 3.3.1. 预处理data_preprocessor
model_data_preprocessor = data_preprocessor # 修改data_preprocessor
# 3.3.2. 采样函数sampler
selected_sampler_name, use_sampler, selected_sampler_info = select_sampler(sampler_list=['OHEMPixelSampler'])
if use_sampler == True:
decode_head_sampler=selected_sampler_info
auxiliary_head_sampler=selected_sampler_info
use_sampler_txt = 'ohempSampler' # 标记
else:
decode_head_sampler=None
auxiliary_head_sampler=None
use_sampler_txt = ''
# 3.3.3. 解码器decode_head
# decode、auxiliary损失下载方式 TODO
decode_head_loss_decode_dict=dict(type='DiceLoss', use_sigmoid=False, loss_weight=1.0) # decode损失函数 # TODO
# decode_head_loss_decode_dict=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)
decode_head = create_dict_by_kwargs(type='FCNHead', loss_decode=decode_head_loss_decode_dict, sampler=decode_head_sampler, num_classes=num_classes)
# 3.3.3. 辅助部分auxiliary_head
# auxiliary损失下载方式 TODO # 可更改
auxiliary_head_loss_decode_dict = dict(type='DiceLoss', use_sigmoid=False, loss_weight=0.4, reduction="none") # DiceLoss损失函数 # TODO # 可更改
# auxiliary_head_loss_decode_dict = dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)
# 获取原始auxiliary_head
auxiliary_head = get_var_from_py_file(alg_file_pth, var_name="model")["auxiliary_head"]
# 新的auxiliary_head
auxiliary_head_new = [create_dict_by_kwargs(type='FCNHead', sampler=auxiliary_head_sampler, norm_cfg=norm_cfg, align_corners=False, loss_decode=auxiliary_head_loss_decode_dict, num_classes=num_classes,),
create_dict_by_kwargs(type='FCNHead', sampler=auxiliary_head_sampler, norm_cfg=norm_cfg, align_corners=False, loss_decode=auxiliary_head_loss_decode_dict, num_classes=num_classes,),
create_dict_by_kwargs(type='FCNHead', sampler=auxiliary_head_sampler, norm_cfg=norm_cfg, align_corners=False, loss_decode=auxiliary_head_loss_decode_dict, num_classes=num_classes,),
create_dict_by_kwargs(type='FCNHead', sampler=auxiliary_head_sampler, norm_cfg=norm_cfg, align_corners=False, loss_decode=auxiliary_head_loss_decode_dict, num_classes=num_classes,)]
# 更新auxiliary_head
auxiliary_head = update_list_dict_var(auxiliary_head, auxiliary_head_new)
# 3.3.4. 综合model
model = dict(
data_preprocessor = model_data_preprocessor,
decode_head = decode_head,
auxiliary_head = auxiliary_head,
)
########### 3.4. 生成optim、param_scheduler部分 ###########
optim_wrapper = generate_optim_wrapper()
param_scheduler = generate_param_scheduler(train_type, train_time_or_epoch_k)
########### 4.程序写入 ###########
# 4.1. 算法名称解析bisenetv2【Alg】 _ r50【pretrained模型深度】 - d8【_base_/models】 _ 4【GPU数量】 x b2【Batch Size大小】 - 40k【schedule】 _ cityscapes【数据集】 - 512x1024【crop_size】.py
alg_file_name = f"{alg_name}_fcn_{use_sampler_txt+'_' if use_sampler_txt != '' else ''}g{GPU_num}-{schedule_file_name.lstrip('schedule_')}_{dataset_file_name}-{crop_size[0]}x{crop_size[1]}.py"
output_configs_alg_my_alg = os.path.join(f'./configs/bisenetv2/', alg_file_name)
write_config_to_file(output_configs_alg_my_alg, _base_=_base_, norm_cfg=norm_cfg, crop_size=crop_size, data_preprocessor=data_preprocessor, model=model, optim_wrapper=optim_wrapper, param_scheduler=param_scheduler)
# 4.2. 将信息临时写入文件
with open("_temp_.txt", "w") as file:
json.dump({'alg_infos':{'alg_file_name':alg_file_name,'alg_file_pth':output_configs_alg_my_alg}}, file)

View File

@@ -0,0 +1,102 @@
import os, sys, argparse, json
# 添加上级目录到 sys.path
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../')))
from Initial_Alg_Program.Initial_Alg_Select_Tool import select_crop_size, select_pretrained_model, select_test_cfg_slide
from Initial_Alg_Program.Initial_Alg_Gen_Tool import format_all_data, write_config_to_file
from Initial_Alg_Program.Initial_Alg_Gen_0_base_ import generate_base_config
from Initial_Alg_Program.Initial_Alg_Gen_1_norm_cfg import generate_norm_cfg
from Initial_Alg_Program.Initial_Alg_Gen_2_data_preprocessor import generate_data_preprocessor
from Initial_Alg_Program.Initial_Alg_Gen_3_model import generate_model_pretrained, generate_model_backbone, generate_model_data_preprocessor, generate_model_decode_head, generate_model_auxiliary_head, generate_model_train_cfg, generate_model_test_cfg
from Initial_Alg_Program.Initial_Alg_Gen_4_optimizer import generate_optim_wrapper, generate_param_scheduler
if __name__ == '__main__':
########### 1.超参数-算法默认参数 ###########
alg_file_name = 'ccnet_r50-d8'
########### 2.参数解析 ###########
parser = argparse.ArgumentParser(description="Run algorithm with specified parameters.")
# 添加命令行参数
parser.add_argument('--alg_name', type=str, required=True, help="Algorithm name (e.g., 'ann')")
parser.add_argument('--dataset_file_name', type=str, required=True, help="Dataset file name (e.g., 'my_dataset_model')")
parser.add_argument('--mean', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
parser.add_argument('--std', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
parser.add_argument('--dataset_num_classes', type=int, required=True, help="Dataset class num (e.g., 36)")
parser.add_argument('--GPU_num', type=int, required=True, help="Num of GPU to train (e.g., 1)")
parser.add_argument('--schedule_file_name', type=str, required=True, help="Schedule file name (e.g., 'schedule_4k_check_400')")
parser.add_argument('--train_type', type=str, required=True, help="Schedule train type (e.g., Epoch or Iteration)")
parser.add_argument('--train_time_or_epoch_k', type=int, required=True, help="Schedule train time k (e.g., 4)")
# 解析命令行参数
args = parser.parse_args()
# 变量赋值
alg_name = args.alg_name
dataset_file_name = args.dataset_file_name
mean = json.loads(args.mean) # 使用json.loads转换成list
std = json.loads(args.std) # 使用json.loads转换成list
num_classes = args.dataset_num_classes
schedule_file_name = args.schedule_file_name
train_type = args.train_type
train_time_or_epoch_k = args.train_time_or_epoch_k
GPU_num = args.GPU_num
########### 3. 设定参数 ###########
# 3.1. 自选参数
crop_size = select_crop_size(predefined_options=[(512, 512), (512, 1024), (769, 769)]) # 选择切割大小
model_list = ['openmmlab/resnet50_v1c', 'openmmlab/resnet101_v1c']
selected_model_name, select_pretrained, pretrained_pth, selected_model_info = select_pretrained_model(model_list=model_list)
depth = selected_model_info['depth']
test_cfg_mode, test_cfg_crop_div_stride = select_test_cfg_slide(crop_size)
# 3.2. 自动生成参数
# decode、auxiliary损失下载方式
decode_head_loss_decode_dict=dict(type='DiceLoss', use_sigmoid=False, loss_weight=1.0) # decode损失函数 # TODO
auxiliary_head_loss_decode_dict=dict(type='DiceLoss', use_sigmoid=False, loss_weight=0.4) # DiceLoss损失函数 # TODO
# 如果使用slide模式则开启align_corners
if test_cfg_mode == "slide":
align_corners = True
test_slide = 'testslide' # 算法名进行标记
else:
align_corners = False
test_slide = 'no_testslide' # 算法名进行标记
########### 3.生成各个部分 ###########
_base_ = generate_base_config(alg_file_name=alg_file_name, dataset_file_name=dataset_file_name, schedule_file_name=schedule_file_name)
norm_cfg = generate_norm_cfg(GPU_num=GPU_num)
data_preprocessor = generate_data_preprocessor(crop_size = crop_size, mean = mean, std = std)
pretrained = generate_model_pretrained(pretrained_pth = pretrained_pth)
backbone = generate_model_backbone(depth=depth)
model_data_preprocessor = data_preprocessor # 修改data_preprocessor
model_data_preprocessor = generate_model_data_preprocessor(model_data_preprocessor=model_data_preprocessor)
decode_head = generate_model_decode_head(num_classes=num_classes, decode_head_loss_decode_dict=decode_head_loss_decode_dict, align_corners=align_corners)
auxiliary_head = generate_model_auxiliary_head(num_classes=num_classes, auxiliary_head_loss_decode_dict=auxiliary_head_loss_decode_dict, align_corners=align_corners)
# 综合model
model = dict(
pretrained = pretrained,
backbone = backbone,
data_preprocessor = model_data_preprocessor,
decode_head = decode_head,
auxiliary_head = auxiliary_head,
)
test_cfg = generate_model_test_cfg(test_cfg_mode=test_cfg_mode, crop_size=crop_size, test_cfg_crop_div_stride=test_cfg_crop_div_stride)
optim_wrapper = generate_optim_wrapper()
param_scheduler = generate_param_scheduler(train_type, train_time_or_epoch_k)
########### 4.程序写入 ###########
# 算法名称解析ccnet【Alg】 _ r50【pretrained模型深度】 - d8【_base_/models】 _ 4【GPU数量】 x b2【Batch Size大小】 - 40k【schedule】 _ cityscapes【数据集】 - 512x1024【crop_size】.py
alg_file_name = f"{alg_name}_r{depth}_{'Pre' if select_pretrained else 'NoPre'}_g{GPU_num}-{schedule_file_name.lstrip('schedule_')}_{dataset_file_name}-{crop_size[0]}x{crop_size[1]}-{test_slide}.py"
output_configs_alg_my_alg = os.path.join(f'./configs/ccnet/', alg_file_name)
write_config_to_file(output_configs_alg_my_alg, _base_=_base_, norm_cfg=norm_cfg, crop_size=crop_size, data_preprocessor=data_preprocessor, model=model, test_cfg=test_cfg, optim_wrapper=optim_wrapper, param_scheduler=param_scheduler)
# 将信息写入文件
with open("_temp_.txt", "w") as file:
json.dump({'alg_infos':{'alg_file_name':alg_file_name,'alg_file_pth':output_configs_alg_my_alg}}, file)

View File

@@ -0,0 +1,95 @@
import os, sys, argparse, json
# 添加上级目录到 sys.path
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../')))
from Initial_Alg_Program.Initial_Alg_Select_Tool import select_crop_size, select_pretrained_model, select_test_cfg_slide
from Initial_Alg_Program.Initial_Alg_Gen_Tool import format_all_data, write_config_to_file, get_var_from_py_file
from Initial_Alg_Program.Initial_Alg_Gen_0_base_ import generate_base_config
from Initial_Alg_Program.Initial_Alg_Gen_1_norm_cfg import generate_norm_cfg
from Initial_Alg_Program.Initial_Alg_Gen_2_data_preprocessor import generate_data_preprocessor
from Initial_Alg_Program.Initial_Alg_Gen_3_model import generate_model_pretrained, generate_model_backbone, generate_model_data_preprocessor, generate_model_decode_head, generate_model_auxiliary_head, generate_model_train_cfg, generate_model_test_cfg
from Initial_Alg_Program.Initial_Alg_Gen_4_optimizer import generate_optim_wrapper, generate_param_scheduler
if __name__ == '__main__':
########### 1.超参数-算法默认参数 ###########
alg_file_name = 'cgnet'
########### 2.参数解析 ###########
parser = argparse.ArgumentParser(description="Run algorithm with specified parameters.")
# 添加命令行参数
parser.add_argument('--alg_name', type=str, required=True, help="Algorithm name (e.g., 'ann')")
parser.add_argument('--dataset_file_name', type=str, required=True, help="Dataset file name (e.g., 'my_dataset_model')")
parser.add_argument('--mean', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
parser.add_argument('--std', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
parser.add_argument('--dataset_num_classes', type=int, required=True, help="Dataset class num (e.g., 36)")
parser.add_argument('--GPU_num', type=int, required=True, help="Num of GPU to train (e.g., 1)")
parser.add_argument('--schedule_file_name', type=str, required=True, help="Schedule file name (e.g., 'schedule_4k_check_400')")
parser.add_argument('--train_type', type=str, required=True, help="Schedule train type (e.g., Epoch or Iteration)")
parser.add_argument('--train_time_or_epoch_k', type=int, required=True, help="Schedule train time k (e.g., 4)")
# 解析命令行参数
args = parser.parse_args()
# 变量赋值
alg_name = args.alg_name
dataset_file_name = args.dataset_file_name
mean = json.loads(args.mean) # 使用json.loads转换成list
std = json.loads(args.std) # 使用json.loads转换成list
num_classes = args.dataset_num_classes
schedule_file_name = args.schedule_file_name
train_type = args.train_type
train_time_or_epoch_k = args.train_time_or_epoch_k
GPU_num = args.GPU_num
########### 3. 设定参数 ###########
# 3.1. 自选参数
crop_size = select_crop_size(predefined_options=[(680, 680), (512, 1024)]) # 选择切割大小
# 3.2. 自动生成参数
# decode、auxiliary损失下载方式
decode_head = get_var_from_py_file(os.path.join('./configs/_base_/models', alg_file_name+'.py'), 'model')['decode_head']
# Way Ori: loss_decode
decode_head_loss_decode_dict = dict(
type='CrossEntropyLoss',
use_sigmoid=False,
loss_weight=1.0,
class_weight=[
2.5959933, 6.7415504, 3.5354059, 9.8663225, 9.690899, 9.369352,
10.289121, 9.953208, 4.3097677, 9.490387, 7.674431, 9.396905,
10.347791, 6.3927646, 10.226669, 10.241062, 10.280587,
10.396974, 10.055647
]
)
# Way 1: loss_decode
decode_head_loss_decode_dict = dict(_delete_=True, type='DiceLoss', use_sigmoid=False, loss_weight=1.0) # decode损失函数 # TODO
decode_head['_delete_'] = True
decode_head['loss_decode'] = decode_head_loss_decode_dict
########### 3.生成各个部分 ###########
_base_ = generate_base_config(alg_file_name=alg_file_name, dataset_file_name=dataset_file_name, schedule_file_name=schedule_file_name)
norm_cfg = generate_norm_cfg(GPU_num=GPU_num)
data_preprocessor = generate_data_preprocessor(crop_size = crop_size, mean = mean, std = std)
model_data_preprocessor = data_preprocessor # 修改data_preprocessor
model_data_preprocessor = generate_model_data_preprocessor(model_data_preprocessor=model_data_preprocessor)
# 综合model
model = dict(data_preprocessor = model_data_preprocessor, decode_head=decode_head)
# test_cfg = generate_model_test_cfg(test_cfg_mode=test_cfg_mode, crop_size=crop_size, test_cfg_crop_div_stride=test_cfg_crop_div_stride)
optim_wrapper = generate_optim_wrapper()
param_scheduler = generate_param_scheduler(train_type, train_time_or_epoch_k)
########### 4.程序写入 ###########
# 算法名称解析cgnet【Alg】 _ r50【pretrained模型深度】 - d8【_base_/models】 _ 4【GPU数量】 x b2【Batch Size大小】 - 40k【schedule】 _ cityscapes【数据集】 - 512x1024【crop_size】.py
alg_file_name = f"{alg_name}_fcn_g{GPU_num}-{schedule_file_name.lstrip('schedule_')}_{dataset_file_name}-{crop_size[0]}x{crop_size[1]}.py"
output_configs_alg_my_alg = os.path.join(f'./configs/cgnet/', alg_file_name)
write_config_to_file(output_configs_alg_my_alg, _base_=_base_, norm_cfg=norm_cfg, crop_size=crop_size, data_preprocessor=data_preprocessor, model=model, optim_wrapper=optim_wrapper, param_scheduler=param_scheduler)
# 将信息写入文件
with open("_temp_.txt", "w") as file:
json.dump({'alg_infos':{'alg_file_name':alg_file_name,'alg_file_pth':output_configs_alg_my_alg}}, file)

View File

@@ -0,0 +1,102 @@
import os, sys, argparse, json
# 添加上级目录到 sys.path
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../')))
from Initial_Alg_Program.Initial_Alg_Select_Tool import select_crop_size, select_pretrained_model, select_test_cfg_slide
from Initial_Alg_Program.Initial_Alg_Gen_Tool import format_all_data, write_config_to_file
from Initial_Alg_Program.Initial_Alg_Gen_0_base_ import generate_base_config
from Initial_Alg_Program.Initial_Alg_Gen_1_norm_cfg import generate_norm_cfg
from Initial_Alg_Program.Initial_Alg_Gen_2_data_preprocessor import generate_data_preprocessor
from Initial_Alg_Program.Initial_Alg_Gen_3_model import generate_model_pretrained, generate_model_backbone, generate_model_data_preprocessor, generate_model_decode_head, generate_model_auxiliary_head, generate_model_train_cfg, generate_model_test_cfg
from Initial_Alg_Program.Initial_Alg_Gen_4_optimizer import generate_optim_wrapper, generate_param_scheduler
if __name__ == '__main__':
########### 1.超参数-算法默认参数 ###########
alg_file_name = 'danet_r50-d8'
########### 2.参数解析 ###########
parser = argparse.ArgumentParser(description="Run algorithm with specified parameters.")
# 添加命令行参数
parser.add_argument('--alg_name', type=str, required=True, help="Algorithm name (e.g., 'ann')")
parser.add_argument('--dataset_file_name', type=str, required=True, help="Dataset file name (e.g., 'my_dataset_model')")
parser.add_argument('--mean', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
parser.add_argument('--std', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
parser.add_argument('--dataset_num_classes', type=int, required=True, help="Dataset class num (e.g., 36)")
parser.add_argument('--GPU_num', type=int, required=True, help="Num of GPU to train (e.g., 1)")
parser.add_argument('--schedule_file_name', type=str, required=True, help="Schedule file name (e.g., 'schedule_4k_check_400')")
parser.add_argument('--train_type', type=str, required=True, help="Schedule train type (e.g., Epoch or Iteration)")
parser.add_argument('--train_time_or_epoch_k', type=int, required=True, help="Schedule train time k (e.g., 4)")
# 解析命令行参数
args = parser.parse_args()
# 变量赋值
alg_name = args.alg_name
dataset_file_name = args.dataset_file_name
mean = json.loads(args.mean) # 使用json.loads转换成list
std = json.loads(args.std) # 使用json.loads转换成list
num_classes = args.dataset_num_classes
schedule_file_name = args.schedule_file_name
train_type = args.train_type
train_time_or_epoch_k = args.train_time_or_epoch_k
GPU_num = args.GPU_num
########### 3. 设定参数 ###########
# 3.1. 自选参数
crop_size = select_crop_size(predefined_options=[(512, 512), (512, 1024), (769, 769)]) # 选择切割大小
model_list = ['openmmlab/resnet50_v1c', 'openmmlab/resnet101_v1c']
selected_model_name, select_pretrained, pretrained_pth, selected_model_info = select_pretrained_model(model_list=model_list)
depth = selected_model_info['depth']
test_cfg_mode, test_cfg_crop_div_stride = select_test_cfg_slide(crop_size)
# 3.2. 自动生成参数
# decode、auxiliary损失下载方式
decode_head_loss_decode_dict=dict(type='DiceLoss', use_sigmoid=False, loss_weight=1.0) # decode损失函数 # TODO
auxiliary_head_loss_decode_dict=dict(type='DiceLoss', use_sigmoid=False, loss_weight=0.4) # DiceLoss损失函数 # TODO
# 如果使用slide模式则开启align_corners
if test_cfg_mode == "slide":
align_corners = True
test_slide = 'testslide' # 算法名进行标记
else:
align_corners = False
test_slide = 'no_testslide' # 算法名进行标记
########### 3.生成各个部分 ###########
_base_ = generate_base_config(alg_file_name=alg_file_name, dataset_file_name=dataset_file_name, schedule_file_name=schedule_file_name)
norm_cfg = generate_norm_cfg(GPU_num=GPU_num)
data_preprocessor = generate_data_preprocessor(crop_size = crop_size, mean = mean, std = std)
pretrained = generate_model_pretrained(pretrained_pth = pretrained_pth)
backbone = generate_model_backbone(depth=depth)
model_data_preprocessor = data_preprocessor # 修改data_preprocessor
model_data_preprocessor = generate_model_data_preprocessor(model_data_preprocessor=model_data_preprocessor)
decode_head = generate_model_decode_head(num_classes=num_classes, decode_head_loss_decode_dict=decode_head_loss_decode_dict, align_corners=align_corners)
auxiliary_head = generate_model_auxiliary_head(num_classes=num_classes, auxiliary_head_loss_decode_dict=auxiliary_head_loss_decode_dict, align_corners=align_corners)
# 综合model
model = dict(
pretrained = pretrained,
backbone = backbone,
data_preprocessor = model_data_preprocessor,
decode_head = decode_head,
auxiliary_head = auxiliary_head,
)
test_cfg = generate_model_test_cfg(test_cfg_mode=test_cfg_mode, crop_size=crop_size, test_cfg_crop_div_stride=test_cfg_crop_div_stride)
optim_wrapper = generate_optim_wrapper()
param_scheduler = generate_param_scheduler(train_type, train_time_or_epoch_k)
########### 4.程序写入 ###########
# 算法名称解析danet【Alg】 _ r50【pretrained模型深度】 - d8【_base_/models】 _ 4【GPU数量】 x b2【Batch Size大小】 - 40k【schedule】 _ cityscapes【数据集】 - 512x1024【crop_size】.py
alg_file_name = f"{alg_name}_r{depth}_{'Pre' if select_pretrained else 'NoPre'}_g{GPU_num}-{schedule_file_name.lstrip('schedule_')}_{dataset_file_name}-{crop_size[0]}x{crop_size[1]}-{test_slide}.py"
output_configs_alg_my_alg = os.path.join(f'./configs/danet/', alg_file_name)
write_config_to_file(output_configs_alg_my_alg, _base_=_base_, norm_cfg=norm_cfg, crop_size=crop_size, data_preprocessor=data_preprocessor, model=model, test_cfg=test_cfg, optim_wrapper=optim_wrapper, param_scheduler=param_scheduler)
# 将信息写入文件
with open("_temp_.txt", "w") as file:
json.dump({'alg_infos':{'alg_file_name':alg_file_name,'alg_file_pth':output_configs_alg_my_alg}}, file)

View File

@@ -0,0 +1,95 @@
import os, sys, argparse, json
import importlib.util
# 添加上级目录到 sys.path
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../')))
from Initial_Alg_Program.Initial_Alg_Select_Tool import select_crop_size, select_pretrained_model, select_sampler
from Initial_Alg_Program.Initial_Alg_Gen_Tool import format_all_data, write_config_to_file, create_dict_by_kwargs, get_var_from_file, update_list_dict_var, get_var_from_py_file
from Initial_Alg_Program.Initial_Alg_Gen_0_base_ import generate_base_config
from Initial_Alg_Program.Initial_Alg_Gen_1_norm_cfg import generate_norm_cfg
from Initial_Alg_Program.Initial_Alg_Gen_2_data_preprocessor import generate_data_preprocessor
from Initial_Alg_Program.Initial_Alg_Gen_3_model import generate_model_pretrained, generate_model_backbone, generate_model_data_preprocessor, generate_model_decode_head, generate_model_auxiliary_head, generate_model_train_cfg, generate_model_test_cfg
from Initial_Alg_Program.Initial_Alg_Gen_4_optimizer import generate_optim_wrapper, generate_param_scheduler
if __name__ == '__main__':
########### 1.超参数-算法默认参数 ###########
alg_file_name = 'ddrnet'
alg_file_pth = f"configs/_base_/models/{alg_file_name}.py"
########### 2.参数解析 ###########
parser = argparse.ArgumentParser(description="Run algorithm with specified parameters.")
# 添加命令行参数
parser.add_argument('--alg_name', type=str, required=True, help="Algorithm name (e.g., 'ann')")
parser.add_argument('--dataset_file_name', type=str, required=True, help="Dataset file name (e.g., 'my_dataset_model')")
parser.add_argument('--mean', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
parser.add_argument('--std', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
parser.add_argument('--dataset_num_classes', type=int, required=True, help="Dataset class num (e.g., 36)")
parser.add_argument('--GPU_num', type=int, required=True, help="Num of GPU to train (e.g., 1)")
parser.add_argument('--schedule_file_name', type=str, required=True, help="Schedule file name (e.g., 'schedule_4k_check_400')")
parser.add_argument('--train_type', type=str, required=True, help="Schedule train type (e.g., Epoch or Iteration)")
parser.add_argument('--train_time_or_epoch_k', type=int, required=True, help="Schedule train time k (e.g., 4)")
# 解析命令行参数
args = parser.parse_args()
# 变量赋值
alg_name = args.alg_name
dataset_file_name = args.dataset_file_name
mean = json.loads(args.mean) # 使用json.loads转换成list
std = json.loads(args.std) # 使用json.loads转换成list
num_classes = args.dataset_num_classes
schedule_file_name = args.schedule_file_name
train_type = args.train_type
train_time_or_epoch_k = args.train_time_or_epoch_k
GPU_num = args.GPU_num
########### 3.1. 生成_base_部分 ###########
_base_ = generate_base_config(alg_file_name=alg_file_name, dataset_file_name=dataset_file_name, schedule_file_name=schedule_file_name)
########### 3.1. 生成norm_cfg部分 ###########
norm_cfg = generate_norm_cfg(GPU_num=GPU_num)
########### 3.2. 生成data_preprocessor部分 ###########
crop_size = select_crop_size(predefined_options=[(512, 512)]) # 选择切割大小
data_preprocessor = generate_data_preprocessor(crop_size = crop_size, mean = mean, std = std)
########### 3.3. 生成model部分 ###########
# 3.3.1. 预处理data_preprocessor
model_data_preprocessor = data_preprocessor # 修改data_preprocessor
# # 3.3.2. 骨架backbone、解码器decode_head
model_list = ["openmmlab/ddrnet23-s", "openmmlab/ddrnet23"]
selected_model_name, select_pretrained, pretrained_pth, selected_model_info = select_pretrained_model(model_list=model_list)
if selected_model_name == "openmmlab/ddrnet23-s":
backbone = dict(channels=32, init_cfg=dict(type='Pretrained', checkpoint=pretrained_pth))
decode_head = dict(in_channels=32 * 4, channels=64, num_classes=num_classes)
model_size = 'small'
elif selected_model_name == "openmmlab/ddrnet23":
backbone = dict(channels=64, init_cfg=dict(type='Pretrained', checkpoint=pretrained_pth))
decode_head = dict(in_channels=64 * 4, channels=128, num_classes=num_classes)
model_size = 'normal'
else:
quit("Error: 未知的模型名称")
# 3.3.4. 综合model
model = dict(
data_preprocessor = model_data_preprocessor,
backbone = backbone,
decode_head = decode_head,
)
########### 3.4. 生成optim、param_scheduler部分 ###########
optim_wrapper = generate_optim_wrapper()
param_scheduler = generate_param_scheduler(train_type, train_time_or_epoch_k)
########### 4.程序写入 ###########
# 4.1. 算法名称解析bisenetv2【Alg】 _ r50【pretrained模型深度】 - d8【_base_/models】 _ 4【GPU数量】 x b2【Batch Size大小】 - 40k【schedule】 _ cityscapes【数据集】 - 512x1024【crop_size】.py
alg_file_name = f"{alg_name}_{model_size}_g{GPU_num}-{schedule_file_name.lstrip('schedule_')}_{dataset_file_name}-{crop_size[0]}x{crop_size[1]}.py"
output_configs_alg_my_alg = os.path.join(f'./configs/ddrnet/', alg_file_name)
write_config_to_file(output_configs_alg_my_alg, _base_=_base_, norm_cfg=norm_cfg, crop_size=crop_size, data_preprocessor=data_preprocessor, model=model, optim_wrapper=optim_wrapper, param_scheduler=param_scheduler)
# 4.2. 将信息临时写入文件
with open("_temp_.txt", "w") as file:
json.dump({'alg_infos':{'alg_file_name':alg_file_name,'alg_file_pth':output_configs_alg_my_alg}}, file)

View File

@@ -0,0 +1,135 @@
import os, sys, argparse, json
# 添加上级目录到 sys.path
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../')))
from Initial_Alg_Program.Initial_Alg_Select_Tool import select_crop_size, select_pretrained_model, select_test_cfg_slide, select_sampler
from Initial_Alg_Program.Initial_Alg_Gen_Tool import format_all_data, write_config_to_file, create_dict_by_kwargs
from Initial_Alg_Program.Initial_Alg_Gen_0_base_ import generate_base_config
from Initial_Alg_Program.Initial_Alg_Gen_1_norm_cfg import generate_norm_cfg
from Initial_Alg_Program.Initial_Alg_Gen_2_data_preprocessor import generate_data_preprocessor
from Initial_Alg_Program.Initial_Alg_Gen_3_model import generate_model_pretrained, generate_model_backbone, generate_model_data_preprocessor, generate_model_decode_head, generate_model_auxiliary_head, generate_model_train_cfg, generate_model_test_cfg
from Initial_Alg_Program.Initial_Alg_Gen_4_optimizer import generate_optim_wrapper, generate_param_scheduler
if __name__ == '__main__':
########### 1.超参数-算法默认参数 ###########
alg_file_name = 'deeplabv3_r50-d8'
########### 2.参数解析 ###########
parser = argparse.ArgumentParser(description="Run algorithm with specified parameters.")
# 添加命令行参数
parser.add_argument('--alg_name', type=str, required=True, help="Algorithm name (e.g., 'ann')")
parser.add_argument('--dataset_file_name', type=str, required=True, help="Dataset file name (e.g., 'my_dataset_model')")
parser.add_argument('--mean', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
parser.add_argument('--std', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
parser.add_argument('--dataset_num_classes', type=int, required=True, help="Dataset class num (e.g., 36)")
parser.add_argument('--GPU_num', type=int, required=True, help="Num of GPU to train (e.g., 1)")
parser.add_argument('--schedule_file_name', type=str, required=True, help="Schedule file name (e.g., 'schedule_4k_check_400')")
parser.add_argument('--train_type', type=str, required=True, help="Schedule train type (e.g., Epoch or Iteration)")
parser.add_argument('--train_time_or_epoch_k', type=int, required=True, help="Schedule train time k (e.g., 4)")
# 解析命令行参数
args = parser.parse_args()
# 变量赋值
alg_name = args.alg_name
dataset_file_name = args.dataset_file_name
mean = json.loads(args.mean) # 使用json.loads转换成list
std = json.loads(args.std) # 使用json.loads转换成list
num_classes = args.dataset_num_classes
schedule_file_name = args.schedule_file_name
train_type = args.train_type
train_time_or_epoch_k = args.train_time_or_epoch_k
GPU_num = args.GPU_num
########### 3. 设定参数 ###########
# 3.1. 自选参数
crop_size = select_crop_size(predefined_options=[(512, 512), (512, 1024), (769, 769),(1280,1280)]) # 选择切割大小
model_list = ['openmmlab/resnet18_v1c', 'openmmlab/resnet50_v1c', 'openmmlab/resnet101_v1c', 'torchvision://resnet18', 'torchvision://resnet50', 'torchvision://resnet101', ]
selected_model_name, select_pretrained, pretrained_pth, selected_model_info = select_pretrained_model(model_list=model_list)
depth = selected_model_info['depth']
type_model = selected_model_info['type']
test_cfg_mode, test_cfg_crop_div_stride = select_test_cfg_slide(crop_size)
# 3.2. 自动生成参数
# decode、auxiliary损失下载方式
decode_head_loss_decode_dict=dict(type='DiceLoss', use_sigmoid=False, loss_weight=1.0) # decode损失函数 # TODO
auxiliary_head_loss_decode_dict=dict(type='DiceLoss', use_sigmoid=False, loss_weight=0.4) # DiceLoss损失函数 # TODO
# 如果使用slide模式则开启align_corners
if test_cfg_mode == "slide":
align_corners = True
test_slide = 'testslide' # 算法名进行标记
else:
align_corners = False
test_slide = 'no_testslide' # 算法名进行标记
# 3.3.2. 采样函数sampler
selected_sampler_name, use_sampler, selected_sampler_info = select_sampler(sampler_list=['OHEMPixelSampler'])
if use_sampler == True:
decode_head_sampler=selected_sampler_info
auxiliary_head_sampler=selected_sampler_info
use_sampler_txt = 'ohempSampler' # 标记
backbone_dilations=(1, 1, 1, 2)
backbone_strides=(1, 2, 2, 1)
backbone_multi_grid=(1, 2, 4)
decode_head_dilations=(1, 6, 12, 18)
else:
decode_head_sampler=None
auxiliary_head_sampler=None
use_sampler_txt = ''
backbone_dilations=None # (1, 1, 2, 4),
backbone_strides=None # (1, 2, 1, 1)
backbone_multi_grid=None # 没有
decode_head_dilations=None # (1, 12, 24, 36),
########### 3.生成各个部分 ###########
_base_ = generate_base_config(alg_file_name=alg_file_name, dataset_file_name=dataset_file_name, schedule_file_name=schedule_file_name)
norm_cfg = generate_norm_cfg(GPU_num=GPU_num)
data_preprocessor = generate_data_preprocessor(crop_size = crop_size, mean = mean, std = std)
pretrained = generate_model_pretrained(pretrained_pth = pretrained_pth)
backbone = create_dict_by_kwargs(depth=depth, type=type_model, dilations=backbone_dilations, strides=backbone_strides, multi_grid=backbone_multi_grid) # generate_model_backbone(depth=depth,)
model_data_preprocessor = data_preprocessor # 修改data_preprocessor
model_data_preprocessor = generate_model_data_preprocessor(model_data_preprocessor=model_data_preprocessor)
if depth == 18:
decode_head = create_dict_by_kwargs(sampler=decode_head_sampler, dilations=decode_head_dilations, in_channels=512, channels=128, num_classes=num_classes, loss_decode=decode_head_loss_decode_dict, align_corners=align_corners, norm_cfg=norm_cfg,)
elif depth == 101 or depth == 50:
decode_head = create_dict_by_kwargs(sampler=decode_head_sampler, dilations=decode_head_dilations, in_channels=2048, channels=512, num_classes=num_classes, loss_decode=decode_head_loss_decode_dict, align_corners=align_corners, norm_cfg=norm_cfg,)
if depth == 18:
auxiliary_head = create_dict_by_kwargs(in_channels=256, channels=64, num_classes=num_classes, loss_decode=decode_head_loss_decode_dict, align_corners=align_corners, norm_cfg=norm_cfg,)
elif depth == 101 or depth == 50:
auxiliary_head = create_dict_by_kwargs(in_channels=1024, channels=256, num_classes=num_classes, loss_decode=decode_head_loss_decode_dict, align_corners=align_corners, norm_cfg=norm_cfg,)
# 综合model
model = dict(
pretrained = pretrained,
backbone = backbone,
data_preprocessor = model_data_preprocessor,
decode_head = decode_head,
auxiliary_head = auxiliary_head,
)
test_cfg = generate_model_test_cfg(test_cfg_mode=test_cfg_mode, crop_size=crop_size, test_cfg_crop_div_stride=test_cfg_crop_div_stride)
optim_wrapper = generate_optim_wrapper()
param_scheduler = generate_param_scheduler(train_type, train_time_or_epoch_k)
########### 4.程序写入 ###########
# 算法名称解析deeplabv3【Alg】 _ r50【pretrained模型深度】 - d8【_base_/models】 _ 4【GPU数量】 x b2【Batch Size大小】 - 40k【schedule】 _ cityscapes【数据集】 - 512x1024【crop_size】.py
alg_file_name = f"{alg_name}_r{depth}_{'Pre' if select_pretrained else 'NoPre'}_{use_sampler_txt+'_' if use_sampler_txt != '' else ''}g{GPU_num}-{schedule_file_name.lstrip('schedule_')}_{dataset_file_name}-{crop_size[0]}x{crop_size[1]}-{test_slide}.py"
output_configs_alg_my_alg = os.path.join(f'./configs/deeplabv3/', alg_file_name)
write_config_to_file(output_configs_alg_my_alg, _base_=_base_, norm_cfg=norm_cfg, crop_size=crop_size, data_preprocessor=data_preprocessor, model=model, test_cfg=test_cfg, optim_wrapper=optim_wrapper, param_scheduler=param_scheduler)
# 将信息写入文件
with open("_temp_.txt", "w") as file:
json.dump({'alg_infos':{'alg_file_name':alg_file_name,'alg_file_pth':output_configs_alg_my_alg}}, file)

View File

@@ -0,0 +1,135 @@
import os, sys, argparse, json
# 添加上级目录到 sys.path
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../')))
from Initial_Alg_Program.Initial_Alg_Select_Tool import select_crop_size, select_pretrained_model, select_test_cfg_slide, select_sampler
from Initial_Alg_Program.Initial_Alg_Gen_Tool import format_all_data, write_config_to_file, create_dict_by_kwargs
from Initial_Alg_Program.Initial_Alg_Gen_0_base_ import generate_base_config
from Initial_Alg_Program.Initial_Alg_Gen_1_norm_cfg import generate_norm_cfg
from Initial_Alg_Program.Initial_Alg_Gen_2_data_preprocessor import generate_data_preprocessor
from Initial_Alg_Program.Initial_Alg_Gen_3_model import generate_model_pretrained, generate_model_backbone, generate_model_data_preprocessor, generate_model_decode_head, generate_model_auxiliary_head, generate_model_train_cfg, generate_model_test_cfg
from Initial_Alg_Program.Initial_Alg_Gen_4_optimizer import generate_optim_wrapper, generate_param_scheduler
if __name__ == '__main__':
########### 1.超参数-算法默认参数 ###########
alg_file_name = 'deeplabv3plus_r50-d8'
########### 2.参数解析 ###########
parser = argparse.ArgumentParser(description="Run algorithm with specified parameters.")
# 添加命令行参数
parser.add_argument('--alg_name', type=str, required=True, help="Algorithm name (e.g., 'ann')")
parser.add_argument('--dataset_file_name', type=str, required=True, help="Dataset file name (e.g., 'my_dataset_model')")
parser.add_argument('--mean', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
parser.add_argument('--std', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
parser.add_argument('--dataset_num_classes', type=int, required=True, help="Dataset class num (e.g., 36)")
parser.add_argument('--GPU_num', type=int, required=True, help="Num of GPU to train (e.g., 1)")
parser.add_argument('--schedule_file_name', type=str, required=True, help="Schedule file name (e.g., 'schedule_4k_check_400')")
parser.add_argument('--train_type', type=str, required=True, help="Schedule train type (e.g., Epoch or Iteration)")
parser.add_argument('--train_time_or_epoch_k', type=int, required=True, help="Schedule train time k (e.g., 4)")
# 解析命令行参数
args = parser.parse_args()
# 变量赋值
alg_name = args.alg_name
dataset_file_name = args.dataset_file_name
mean = json.loads(args.mean) # 使用json.loads转换成list
std = json.loads(args.std) # 使用json.loads转换成list
num_classes = args.dataset_num_classes
schedule_file_name = args.schedule_file_name
train_type = args.train_type
train_time_or_epoch_k = args.train_time_or_epoch_k
GPU_num = args.GPU_num
########### 3. 设定参数 ###########
# 3.1. 自选参数
crop_size = select_crop_size(predefined_options=[(512, 512), (512, 1024), (769, 769),(1280,1280)]) # 选择切割大小
model_list = ['openmmlab/resnet18_v1c', 'openmmlab/resnet50_v1c', 'openmmlab/resnet101_v1c', 'torchvision://resnet18', 'torchvision://resnet50', 'torchvision://resnet101', ]
selected_model_name, select_pretrained, pretrained_pth, selected_model_info = select_pretrained_model(model_list=model_list)
depth = selected_model_info['depth']
type_model = selected_model_info['type']
test_cfg_mode, test_cfg_crop_div_stride = select_test_cfg_slide(crop_size)
# 3.2. 自动生成参数
# decode、auxiliary损失下载方式
decode_head_loss_decode_dict=dict(type='DiceLoss', use_sigmoid=False, loss_weight=1.0) # decode损失函数 # TODO
auxiliary_head_loss_decode_dict=dict(type='DiceLoss', use_sigmoid=False, loss_weight=0.4) # DiceLoss损失函数 # TODO
# 如果使用slide模式则开启align_corners
if test_cfg_mode == "slide":
align_corners = True
test_slide = 'testslide' # 算法名进行标记
else:
align_corners = False
test_slide = 'no_testslide' # 算法名进行标记
# 3.3.2. 采样函数sampler
selected_sampler_name, use_sampler, selected_sampler_info = select_sampler(sampler_list=['OHEMPixelSampler'])
if use_sampler == True:
decode_head_sampler=selected_sampler_info
auxiliary_head_sampler=selected_sampler_info
use_sampler_txt = 'ohempSampler' # 标记
backbone_dilations=(1, 1, 1, 2)
backbone_strides=(1, 2, 2, 1)
backbone_multi_grid=(1, 2, 4)
decode_head_dilations=(1, 6, 12, 18)
else:
decode_head_sampler=None
auxiliary_head_sampler=None
use_sampler_txt = ''
backbone_dilations=None # (1, 1, 2, 4),
backbone_strides=None # (1, 2, 1, 1)
backbone_multi_grid=None # 没有
decode_head_dilations=None # (1, 12, 24, 36),
########### 3.生成各个部分 ###########
_base_ = generate_base_config(alg_file_name=alg_file_name, dataset_file_name=dataset_file_name, schedule_file_name=schedule_file_name)
norm_cfg = generate_norm_cfg(GPU_num=GPU_num)
data_preprocessor = generate_data_preprocessor(crop_size = crop_size, mean = mean, std = std)
pretrained = generate_model_pretrained(pretrained_pth = pretrained_pth)
backbone = create_dict_by_kwargs(depth=depth, type=type_model, dilations=backbone_dilations, strides=backbone_strides, multi_grid=backbone_multi_grid) # generate_model_backbone(depth=depth,)
model_data_preprocessor = data_preprocessor # 修改data_preprocessor
model_data_preprocessor = generate_model_data_preprocessor(model_data_preprocessor=model_data_preprocessor)
if depth == 18:
decode_head = create_dict_by_kwargs(sampler=decode_head_sampler, dilations=decode_head_dilations, c1_in_channels=64, c1_channels=12, in_channels=512, channels=128, num_classes=num_classes, loss_decode=decode_head_loss_decode_dict, align_corners=align_corners, norm_cfg=norm_cfg,)
elif depth == 101 or depth == 50:
decode_head = create_dict_by_kwargs(sampler=decode_head_sampler, dilations=decode_head_dilations, c1_in_channels=256, c1_channels=48, in_channels=2048, channels=512, num_classes=num_classes, loss_decode=decode_head_loss_decode_dict, align_corners=align_corners, norm_cfg=norm_cfg,)
if depth == 18:
auxiliary_head = create_dict_by_kwargs(in_channels=256, channels=64, num_classes=num_classes, loss_decode=decode_head_loss_decode_dict, align_corners=align_corners, norm_cfg=norm_cfg,)
elif depth == 101 or depth == 50:
auxiliary_head = create_dict_by_kwargs(in_channels=1024, channels=256, num_classes=num_classes, loss_decode=decode_head_loss_decode_dict, align_corners=align_corners, norm_cfg=norm_cfg,)
# 综合model
model = dict(
pretrained = pretrained,
backbone = backbone,
data_preprocessor = model_data_preprocessor,
decode_head = decode_head,
auxiliary_head = auxiliary_head,
)
test_cfg = generate_model_test_cfg(test_cfg_mode=test_cfg_mode, crop_size=crop_size, test_cfg_crop_div_stride=test_cfg_crop_div_stride)
optim_wrapper = generate_optim_wrapper()
param_scheduler = generate_param_scheduler(train_type, train_time_or_epoch_k)
########### 4.程序写入 ###########
# 算法名称解析deeplabv3plus【Alg】 _ r50【pretrained模型深度】 - d8【_base_/models】 _ 4【GPU数量】 x b2【Batch Size大小】 - 40k【schedule】 _ cityscapes【数据集】 - 512x1024【crop_size】.py
alg_file_name = f"{alg_name}_r{depth}_{'Pre' if select_pretrained else 'NoPre'}_{use_sampler_txt+'_' if use_sampler_txt != '' else ''}g{GPU_num}-{schedule_file_name.lstrip('schedule_')}_{dataset_file_name}-{crop_size[0]}x{crop_size[1]}-{test_slide}.py"
output_configs_alg_my_alg = os.path.join(f'./configs/deeplabv3plus/', alg_file_name)
write_config_to_file(output_configs_alg_my_alg, _base_=_base_, norm_cfg=norm_cfg, crop_size=crop_size, data_preprocessor=data_preprocessor, model=model, test_cfg=test_cfg, optim_wrapper=optim_wrapper, param_scheduler=param_scheduler)
# 将信息写入文件
with open("_temp_.txt", "w") as file:
json.dump({'alg_infos':{'alg_file_name':alg_file_name,'alg_file_pth':output_configs_alg_my_alg}}, file)

View File

@@ -0,0 +1,102 @@
import os, sys, argparse, json
# 添加上级目录到 sys.path
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../')))
from Initial_Alg_Program.Initial_Alg_Select_Tool import select_crop_size, select_pretrained_model, select_test_cfg_slide
from Initial_Alg_Program.Initial_Alg_Gen_Tool import format_all_data, write_config_to_file
from Initial_Alg_Program.Initial_Alg_Gen_0_base_ import generate_base_config
from Initial_Alg_Program.Initial_Alg_Gen_1_norm_cfg import generate_norm_cfg
from Initial_Alg_Program.Initial_Alg_Gen_2_data_preprocessor import generate_data_preprocessor
from Initial_Alg_Program.Initial_Alg_Gen_3_model import generate_model_pretrained, generate_model_backbone, generate_model_data_preprocessor, generate_model_decode_head, generate_model_auxiliary_head, generate_model_train_cfg, generate_model_test_cfg
from Initial_Alg_Program.Initial_Alg_Gen_4_optimizer import generate_optim_wrapper, generate_param_scheduler
if __name__ == '__main__':
########### 1.超参数-算法默认参数 ###########
alg_file_name = 'dnl_r50-d8'
########### 2.参数解析 ###########
parser = argparse.ArgumentParser(description="Run algorithm with specified parameters.")
# 添加命令行参数
parser.add_argument('--alg_name', type=str, required=True, help="Algorithm name (e.g., 'ann')")
parser.add_argument('--dataset_file_name', type=str, required=True, help="Dataset file name (e.g., 'my_dataset_model')")
parser.add_argument('--mean', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
parser.add_argument('--std', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
parser.add_argument('--dataset_num_classes', type=int, required=True, help="Dataset class num (e.g., 36)")
parser.add_argument('--GPU_num', type=int, required=True, help="Num of GPU to train (e.g., 1)")
parser.add_argument('--schedule_file_name', type=str, required=True, help="Schedule file name (e.g., 'schedule_4k_check_400')")
parser.add_argument('--train_type', type=str, required=True, help="Schedule train type (e.g., Epoch or Iteration)")
parser.add_argument('--train_time_or_epoch_k', type=int, required=True, help="Schedule train time k (e.g., 4)")
# 解析命令行参数
args = parser.parse_args()
# 变量赋值
alg_name = args.alg_name
dataset_file_name = args.dataset_file_name
mean = json.loads(args.mean) # 使用json.loads转换成list
std = json.loads(args.std) # 使用json.loads转换成list
num_classes = args.dataset_num_classes
schedule_file_name = args.schedule_file_name
train_type = args.train_type
train_time_or_epoch_k = args.train_time_or_epoch_k
GPU_num = args.GPU_num
########### 3. 设定参数 ###########
# 3.1. 自选参数
crop_size = select_crop_size(predefined_options=[(512, 512), (512, 1024), (769, 769)]) # 选择切割大小
model_list = ['openmmlab/resnet50_v1c', 'openmmlab/resnet101_v1c']
selected_model_name, select_pretrained, pretrained_pth, selected_model_info = select_pretrained_model(model_list=model_list)
depth = selected_model_info['depth']
test_cfg_mode, test_cfg_crop_div_stride = select_test_cfg_slide(crop_size)
# 3.2. 自动生成参数
# decode、auxiliary损失下载方式
decode_head_loss_decode_dict=dict(type='DiceLoss', use_sigmoid=False, loss_weight=1.0) # decode损失函数 # TODO
auxiliary_head_loss_decode_dict=dict(type='DiceLoss', use_sigmoid=False, loss_weight=0.4) # DiceLoss损失函数 # TODO
# 如果使用slide模式则开启align_corners
if test_cfg_mode == "slide":
align_corners = True
test_slide = 'testslide' # 算法名进行标记
else:
align_corners = False
test_slide = 'no_testslide' # 算法名进行标记
########### 3.生成各个部分 ###########
_base_ = generate_base_config(alg_file_name=alg_file_name, dataset_file_name=dataset_file_name, schedule_file_name=schedule_file_name)
norm_cfg = generate_norm_cfg(GPU_num=GPU_num)
data_preprocessor = generate_data_preprocessor(crop_size = crop_size, mean = mean, std = std)
pretrained = generate_model_pretrained(pretrained_pth = pretrained_pth)
backbone = generate_model_backbone(depth=depth)
model_data_preprocessor = data_preprocessor # 修改data_preprocessor
model_data_preprocessor = generate_model_data_preprocessor(model_data_preprocessor=model_data_preprocessor)
decode_head = generate_model_decode_head(num_classes=num_classes, decode_head_loss_decode_dict=decode_head_loss_decode_dict, align_corners=align_corners)
auxiliary_head = generate_model_auxiliary_head(num_classes=num_classes, auxiliary_head_loss_decode_dict=auxiliary_head_loss_decode_dict, align_corners=align_corners)
# 综合model
model = dict(
pretrained = pretrained,
backbone = backbone,
data_preprocessor = model_data_preprocessor,
decode_head = decode_head,
auxiliary_head = auxiliary_head,
)
test_cfg = generate_model_test_cfg(test_cfg_mode=test_cfg_mode, crop_size=crop_size, test_cfg_crop_div_stride=test_cfg_crop_div_stride)
optim_wrapper = generate_optim_wrapper()
param_scheduler = generate_param_scheduler(train_type, train_time_or_epoch_k)
########### 4.程序写入 ###########
# 算法名称解析dnlnet【Alg】 _ r50【pretrained模型深度】 - d8【_base_/models】 _ 4【GPU数量】 x b2【Batch Size大小】 - 40k【schedule】 _ cityscapes【数据集】 - 512x1024【crop_size】.py
alg_file_name = f"{alg_name}_r{depth}_{'Pre' if select_pretrained else 'NoPre'}_g{GPU_num}-{schedule_file_name.lstrip('schedule_')}_{dataset_file_name}-{crop_size[0]}x{crop_size[1]}-{test_slide}.py"
output_configs_alg_my_alg = os.path.join(f'./configs/dnlnet/', alg_file_name)
write_config_to_file(output_configs_alg_my_alg, _base_=_base_, norm_cfg=norm_cfg, crop_size=crop_size, data_preprocessor=data_preprocessor, model=model, test_cfg=test_cfg, optim_wrapper=optim_wrapper, param_scheduler=param_scheduler)
# 将信息写入文件
with open("_temp_.txt", "w") as file:
json.dump({'alg_infos':{'alg_file_name':alg_file_name,'alg_file_pth':output_configs_alg_my_alg}}, file)

View File

@@ -0,0 +1,100 @@
import os, sys, argparse, json
# 添加上级目录到 sys.path
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../')))
from Initial_Alg_Program.Initial_Alg_Select_Tool import select_crop_size, select_pretrained_model, select_sampler
from Initial_Alg_Program.Initial_Alg_Gen_Tool import format_all_data, write_config_to_file, create_dict_by_kwargs, get_var_from_file, update_list_dict_var, get_var_from_py_file
from Initial_Alg_Program.Initial_Alg_Gen_0_base_ import generate_base_config
from Initial_Alg_Program.Initial_Alg_Gen_1_norm_cfg import generate_norm_cfg
from Initial_Alg_Program.Initial_Alg_Gen_2_data_preprocessor import generate_data_preprocessor
from Initial_Alg_Program.Initial_Alg_Gen_3_model import generate_model_pretrained, generate_model_backbone, generate_model_data_preprocessor, generate_model_decode_head, generate_model_auxiliary_head, generate_model_train_cfg, generate_model_test_cfg
from Initial_Alg_Program.Initial_Alg_Gen_4_optimizer import generate_optim_wrapper, generate_param_scheduler
from Initial_Alg_Program.Initial_Alg_Gen_5_train_dataloader import generate_train_dataloader
if __name__ == '__main__':
########### 1.超参数-算法默认参数 ###########
alg_file_name = 'dpt_vit-b16'
alg_file_pth = f"configs/_base_/models/{alg_file_name}.py"
########### 2.参数解析 ###########
parser = argparse.ArgumentParser(description="Run algorithm with specified parameters.")
# 添加命令行参数
parser.add_argument('--alg_name', type=str, required=True, help="Algorithm name (e.g., 'ann')")
parser.add_argument('--dataset_file_name', type=str, required=True, help="Dataset file name (e.g., 'my_dataset_model')")
parser.add_argument('--mean', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
parser.add_argument('--std', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
parser.add_argument('--dataset_num_classes', type=int, required=True, help="Dataset class num (e.g., 36)")
parser.add_argument('--GPU_num', type=int, required=True, help="Num of GPU to train (e.g., 1)")
parser.add_argument('--schedule_file_name', type=str, required=True, help="Schedule file name (e.g., 'schedule_4k_check_400')")
parser.add_argument('--train_type', type=str, required=True, help="Schedule train type (e.g., Epoch or Iteration)")
parser.add_argument('--train_time_or_epoch_k', type=int, required=True, help="Schedule train time k (e.g., 4)")
# 解析命令行参数
args = parser.parse_args()
# 变量赋值
alg_name = args.alg_name
dataset_file_name = args.dataset_file_name
mean = json.loads(args.mean) # 使用json.loads转换成list
std = json.loads(args.std) # 使用json.loads转换成list
num_classes = args.dataset_num_classes
schedule_file_name = args.schedule_file_name
train_type = args.train_type
train_time_or_epoch_k = args.train_time_or_epoch_k
GPU_num = args.GPU_num
########### 3.1. 生成_base_部分 ###########
_base_ = generate_base_config(alg_file_name=alg_file_name, dataset_file_name=dataset_file_name, schedule_file_name=schedule_file_name)
########### 3.1. 生成norm_cfg部分 ###########
norm_cfg = generate_norm_cfg(GPU_num=GPU_num)
########### 3.2. 生成data_preprocessor部分 ###########
crop_size = select_crop_size(predefined_options=[(1024, 1024)]) # 选择切割大小
data_preprocessor = generate_data_preprocessor(crop_size = crop_size, mean = mean, std = std)
########### 3.3. 生成model部分 ###########
# 3.3.1. 预处理data_preprocessor
model_data_preprocessor = data_preprocessor # 修改data_preprocessor
model_list = ['pretrain/vit-b16_p16_224-80ecf9dd.pth']
selected_model_name, select_pretrained, pretrained_pth, selected_model_info = select_pretrained_model(model_list=model_list)
pretrained = generate_model_pretrained(pretrained_pth = pretrained_pth)
model_data_preprocessor = data_preprocessor
# 3.3.2. 解码器decode_head
# decode、auxiliary损失下载方式 TODO
decode_head_loss_decode_dict=dict(type='DiceLoss', use_sigmoid=False, loss_weight=1.0) # decode损失函数 # TODO
# decode_head_loss_decode_dict=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0) # 默认 TODO
decode_head = create_dict_by_kwargs(loss_decode=decode_head_loss_decode_dict, num_classes=num_classes)
# 3.3.3. 辅助部分auxiliary_head [空]
auxiliary_head = None
# 3.3.4. 综合model
model = dict(
data_preprocessor = model_data_preprocessor,
pretrained = pretrained,
decode_head = decode_head,
# auxiliary_head = auxiliary_head,
)
########### 3.4. 生成optim、param_scheduler部分 ###########
optim_wrapper = generate_optim_wrapper(type_of_back_bone = "Vit")
param_scheduler = generate_param_scheduler(train_type, train_time_or_epoch_k)
########### 3.5. 生成train_dataloader部分[1个batch_size就要20G] ###########
train_dataloader, batch_size = generate_train_dataloader(batch_size_default=1)
########### 4.程序写入 ###########
# 4.1. 算法名称解析dpt_vit【Alg】 _ r50【pretrained模型深度】 - d8【_base_/models】 _ 4【GPU数量】 x b2【Batch Size大小】 - 40k【schedule】 _ cityscapes【数据集】 - 512x1024【crop_size】.py
alg_file_name = f"{alg_name}_g{GPU_num}_b{batch_size}-{schedule_file_name.lstrip('schedule_')}_{dataset_file_name}-{crop_size[0]}x{crop_size[1]}.py"
output_configs_alg_my_alg = os.path.join(f'./configs/dpt/', alg_file_name)
write_config_to_file(output_configs_alg_my_alg, _base_=_base_, norm_cfg=norm_cfg, crop_size=crop_size, data_preprocessor=data_preprocessor, model=model, optim_wrapper=optim_wrapper, param_scheduler=param_scheduler, train_dataloader=train_dataloader)
# 4.2. 将信息临时写入文件
with open("_temp_.txt", "w") as file:
json.dump({'alg_infos':{'alg_file_name':alg_file_name,'alg_file_pth':output_configs_alg_my_alg}}, file)

View File

@@ -0,0 +1,106 @@
import os, sys, argparse, json
# 添加上级目录到 sys.path
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../')))
from Initial_Alg_Program.Initial_Alg_Select_Tool import select_crop_size, select_pretrained_model, select_test_cfg_slide
from Initial_Alg_Program.Initial_Alg_Gen_Tool import format_all_data, write_config_to_file
from Initial_Alg_Program.Initial_Alg_Gen_0_base_ import generate_base_config
from Initial_Alg_Program.Initial_Alg_Gen_1_norm_cfg import generate_norm_cfg
from Initial_Alg_Program.Initial_Alg_Gen_2_data_preprocessor import generate_data_preprocessor
from Initial_Alg_Program.Initial_Alg_Gen_3_model import generate_model_pretrained, generate_model_backbone, generate_model_data_preprocessor, generate_model_decode_head, generate_model_auxiliary_head, generate_model_train_cfg, generate_model_test_cfg
from Initial_Alg_Program.Initial_Alg_Gen_4_optimizer import generate_optim_wrapper, generate_param_scheduler
if __name__ == '__main__':
########### 1.超参数-算法默认参数 ###########
alg_file_name = 'emanet_r50-d8'
########### 2.参数解析 ###########
parser = argparse.ArgumentParser(description="Run algorithm with specified parameters.")
# 添加命令行参数
parser.add_argument('--alg_name', type=str, required=True, help="Algorithm name (e.g., 'ann')")
parser.add_argument('--dataset_file_name', type=str, required=True, help="Dataset file name (e.g., 'my_dataset_model')")
parser.add_argument('--mean', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
parser.add_argument('--std', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
parser.add_argument('--dataset_num_classes', type=int, required=True, help="Dataset class num (e.g., 36)")
parser.add_argument('--GPU_num', type=int, required=True, help="Num of GPU to train (e.g., 1)")
parser.add_argument('--schedule_file_name', type=str, required=True, help="Schedule file name (e.g., 'schedule_4k_check_400')")
parser.add_argument('--train_type', type=str, required=True, help="Schedule train type (e.g., Epoch or Iteration)")
parser.add_argument('--train_time_or_epoch_k', type=int, required=True, help="Schedule train time k (e.g., 4)")
# 解析命令行参数
args = parser.parse_args()
# 变量赋值
alg_name = args.alg_name
dataset_file_name = args.dataset_file_name
mean = json.loads(args.mean) # 使用json.loads转换成list
std = json.loads(args.std) # 使用json.loads转换成list
num_classes = args.dataset_num_classes
schedule_file_name = args.schedule_file_name
train_type = args.train_type
train_time_or_epoch_k = args.train_time_or_epoch_k
GPU_num = args.GPU_num
########### 3. 设定参数 ###########
# 3.1. 自选参数
crop_size = select_crop_size(predefined_options=[(512, 512), (512, 1024), (769, 769)]) # 选择切割大小
model_list = ['openmmlab/resnet50_v1c', 'openmmlab/resnet101_v1c']
selected_model_name, select_pretrained, pretrained_pth, selected_model_info = select_pretrained_model(model_list=model_list)
depth = selected_model_info['depth']
test_cfg_mode, test_cfg_crop_div_stride = select_test_cfg_slide(crop_size)
# test_cfg_mode = None
# test_cfg_crop_div_stride = crop_size
# 3.2. 自动生成参数
# decode、auxiliary损失下载方式
decode_head_loss_decode_dict=dict(type='DiceLoss', use_sigmoid=False, loss_weight=1.0) # decode损失函数 # TODO
# decode_head_loss_decode_dict=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0) # decode损失函数 # TODO
auxiliary_head_loss_decode_dict=dict(type='DiceLoss', use_sigmoid=False, loss_weight=0.4) # DiceLoss损失函数 # TODO
# auxiliary_head_loss_decode_dict=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4) # DiceLoss损失函数 # TODO
# 如果使用slide模式则开启align_corners
if test_cfg_mode == "slide":
align_corners = True
test_slide = 'testslide' # 算法名进行标记
else:
align_corners = False
test_slide = 'no_testslide' # 算法名进行标记
########### 3.生成各个部分 ###########
_base_ = generate_base_config(alg_file_name=alg_file_name, dataset_file_name=dataset_file_name, schedule_file_name=schedule_file_name)
norm_cfg = generate_norm_cfg(GPU_num=GPU_num)
data_preprocessor = generate_data_preprocessor(crop_size = crop_size, mean = mean, std = std)
pretrained = generate_model_pretrained(pretrained_pth = pretrained_pth)
backbone = generate_model_backbone(depth=depth)
model_data_preprocessor = data_preprocessor # 修改data_preprocessor
model_data_preprocessor = generate_model_data_preprocessor(model_data_preprocessor=model_data_preprocessor)
decode_head = generate_model_decode_head(num_classes=num_classes, decode_head_loss_decode_dict=decode_head_loss_decode_dict, align_corners=align_corners)
auxiliary_head = generate_model_auxiliary_head(num_classes=num_classes, auxiliary_head_loss_decode_dict=auxiliary_head_loss_decode_dict, align_corners=align_corners)
# 综合model
model = dict(
pretrained = pretrained,
backbone = backbone,
data_preprocessor = model_data_preprocessor,
decode_head = decode_head,
auxiliary_head = auxiliary_head,
)
test_cfg = generate_model_test_cfg(test_cfg_mode=test_cfg_mode, crop_size=crop_size, test_cfg_crop_div_stride=test_cfg_crop_div_stride)
optim_wrapper = generate_optim_wrapper()
param_scheduler = generate_param_scheduler(train_type, train_time_or_epoch_k)
########### 4.程序写入 ###########
# 算法名称解析emanet【Alg】 _ r50【pretrained模型深度】 - d8【_base_/models】 _ 4【GPU数量】 x b2【Batch Size大小】 - 40k【schedule】 _ cityscapes【数据集】 - 512x1024【crop_size】.py
alg_file_name = f"{alg_name}_r{depth}_{'Pre' if select_pretrained else 'NoPre'}_g{GPU_num}-{schedule_file_name.lstrip('schedule_')}_{dataset_file_name}-{crop_size[0]}x{crop_size[1]}-{test_slide}.py"
output_configs_alg_my_alg = os.path.join(f'./configs/emanet/', alg_file_name)
write_config_to_file(output_configs_alg_my_alg, _base_=_base_, norm_cfg=norm_cfg, crop_size=crop_size, data_preprocessor=data_preprocessor, model=model, test_cfg=test_cfg, optim_wrapper=optim_wrapper, param_scheduler=param_scheduler)
# 将信息写入文件
with open("_temp_.txt", "w") as file:
json.dump({'alg_infos':{'alg_file_name':alg_file_name,'alg_file_pth':output_configs_alg_my_alg}}, file)

View File

@@ -0,0 +1,106 @@
import os, sys, argparse, json
# 添加上级目录到 sys.path
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../')))
from Initial_Alg_Program.Initial_Alg_Select_Tool import select_crop_size, select_pretrained_model, select_test_cfg_slide
from Initial_Alg_Program.Initial_Alg_Gen_Tool import format_all_data, write_config_to_file
from Initial_Alg_Program.Initial_Alg_Gen_0_base_ import generate_base_config
from Initial_Alg_Program.Initial_Alg_Gen_1_norm_cfg import generate_norm_cfg
from Initial_Alg_Program.Initial_Alg_Gen_2_data_preprocessor import generate_data_preprocessor
from Initial_Alg_Program.Initial_Alg_Gen_3_model import generate_model_pretrained, generate_model_backbone, generate_model_data_preprocessor, generate_model_decode_head, generate_model_auxiliary_head, generate_model_train_cfg, generate_model_test_cfg
from Initial_Alg_Program.Initial_Alg_Gen_4_optimizer import generate_optim_wrapper, generate_param_scheduler
if __name__ == '__main__':
########### 1.超参数-算法默认参数 ###########
alg_file_name = 'encnet_r50-d8'
########### 2.参数解析 ###########
parser = argparse.ArgumentParser(description="Run algorithm with specified parameters.")
# 添加命令行参数
parser.add_argument('--alg_name', type=str, required=True, help="Algorithm name (e.g., 'ann')")
parser.add_argument('--dataset_file_name', type=str, required=True, help="Dataset file name (e.g., 'my_dataset_model')")
parser.add_argument('--mean', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
parser.add_argument('--std', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
parser.add_argument('--dataset_num_classes', type=int, required=True, help="Dataset class num (e.g., 36)")
parser.add_argument('--GPU_num', type=int, required=True, help="Num of GPU to train (e.g., 1)")
parser.add_argument('--schedule_file_name', type=str, required=True, help="Schedule file name (e.g., 'schedule_4k_check_400')")
parser.add_argument('--train_type', type=str, required=True, help="Schedule train type (e.g., Epoch or Iteration)")
parser.add_argument('--train_time_or_epoch_k', type=int, required=True, help="Schedule train time k (e.g., 4)")
# 解析命令行参数
args = parser.parse_args()
# 变量赋值
alg_name = args.alg_name
dataset_file_name = args.dataset_file_name
mean = json.loads(args.mean) # 使用json.loads转换成list
std = json.loads(args.std) # 使用json.loads转换成list
num_classes = args.dataset_num_classes
schedule_file_name = args.schedule_file_name
train_type = args.train_type
train_time_or_epoch_k = args.train_time_or_epoch_k
GPU_num = args.GPU_num
########### 3. 设定参数 ###########
# 3.1. 自选参数
crop_size = select_crop_size(predefined_options=[(512, 512), (512, 1024), (769, 769)]) # 选择切割大小
model_list = ['openmmlab/resnet50_v1c', 'openmmlab/resnet101_v1c']
selected_model_name, select_pretrained, pretrained_pth, selected_model_info = select_pretrained_model(model_list=model_list)
depth = selected_model_info['depth']
test_cfg_mode, test_cfg_crop_div_stride = select_test_cfg_slide(crop_size)
# test_cfg_mode = None
# test_cfg_crop_div_stride = crop_size
# 3.2. 自动生成参数
# decode、auxiliary损失下载方式
decode_head_loss_decode_dict=dict(type='DiceLoss', use_sigmoid=False, loss_weight=1.0) # decode损失函数 # TODO
# decode_head_loss_decode_dict=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0) # decode损失函数 # TODO
auxiliary_head_loss_decode_dict=dict(type='DiceLoss', use_sigmoid=False, loss_weight=0.4) # DiceLoss损失函数 # TODO
# auxiliary_head_loss_decode_dict=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4) # DiceLoss损失函数 # TODO
# 如果使用slide模式则开启align_corners
if test_cfg_mode == "slide":
align_corners = True
test_slide = 'testslide' # 算法名进行标记
else:
align_corners = False
test_slide = 'no_testslide' # 算法名进行标记
########### 3.生成各个部分 ###########
_base_ = generate_base_config(alg_file_name=alg_file_name, dataset_file_name=dataset_file_name, schedule_file_name=schedule_file_name)
norm_cfg = generate_norm_cfg(GPU_num=GPU_num)
data_preprocessor = generate_data_preprocessor(crop_size = crop_size, mean = mean, std = std)
pretrained = generate_model_pretrained(pretrained_pth = pretrained_pth)
backbone = generate_model_backbone(depth=depth)
model_data_preprocessor = data_preprocessor # 修改data_preprocessor
model_data_preprocessor = generate_model_data_preprocessor(model_data_preprocessor=model_data_preprocessor)
decode_head = generate_model_decode_head(num_classes=num_classes, decode_head_loss_decode_dict=decode_head_loss_decode_dict, align_corners=align_corners)
auxiliary_head = generate_model_auxiliary_head(num_classes=num_classes, auxiliary_head_loss_decode_dict=auxiliary_head_loss_decode_dict, align_corners=align_corners)
# 综合model
model = dict(
pretrained = pretrained,
backbone = backbone,
data_preprocessor = model_data_preprocessor,
decode_head = decode_head,
auxiliary_head = auxiliary_head,
)
test_cfg = generate_model_test_cfg(test_cfg_mode=test_cfg_mode, crop_size=crop_size, test_cfg_crop_div_stride=test_cfg_crop_div_stride)
optim_wrapper = generate_optim_wrapper()
param_scheduler = generate_param_scheduler(train_type, train_time_or_epoch_k)
########### 4.程序写入 ###########
# 算法名称解析encnet【Alg】 _ r50【pretrained模型深度】 - d8【_base_/models】 _ 4【GPU数量】 x b2【Batch Size大小】 - 40k【schedule】 _ cityscapes【数据集】 - 512x1024【crop_size】.py
alg_file_name = f"{alg_name}_r{depth}_{'Pre' if select_pretrained else 'NoPre'}_g{GPU_num}-{schedule_file_name.lstrip('schedule_')}_{dataset_file_name}-{crop_size[0]}x{crop_size[1]}-{test_slide}.py"
output_configs_alg_my_alg = os.path.join(f'./configs/encnet/', alg_file_name)
write_config_to_file(output_configs_alg_my_alg, _base_=_base_, norm_cfg=norm_cfg, crop_size=crop_size, data_preprocessor=data_preprocessor, model=model, test_cfg=test_cfg, optim_wrapper=optim_wrapper, param_scheduler=param_scheduler)
# 将信息写入文件
with open("_temp_.txt", "w") as file:
json.dump({'alg_infos':{'alg_file_name':alg_file_name,'alg_file_pth':output_configs_alg_my_alg}}, file)

View File

@@ -0,0 +1,83 @@
import os, sys, argparse, json
# 添加上级目录到 sys.path
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../')))
from Initial_Alg_Program.Initial_Alg_Select_Tool import select_crop_size, select_pretrained_model, select_test_cfg_slide
from Initial_Alg_Program.Initial_Alg_Gen_Tool import format_all_data, write_config_to_file
from Initial_Alg_Program.Initial_Alg_Gen_0_base_ import generate_base_config
from Initial_Alg_Program.Initial_Alg_Gen_1_norm_cfg import generate_norm_cfg
from Initial_Alg_Program.Initial_Alg_Gen_2_data_preprocessor import generate_data_preprocessor
from Initial_Alg_Program.Initial_Alg_Gen_3_model import generate_model_pretrained, generate_model_backbone, generate_model_data_preprocessor, generate_model_decode_head, generate_model_auxiliary_head, generate_model_train_cfg, generate_model_test_cfg
from Initial_Alg_Program.Initial_Alg_Gen_4_optimizer import generate_optim_wrapper, generate_param_scheduler
if __name__ == '__main__':
########### 1.超参数-算法默认参数 ###########
alg_file_name = 'erfnet_fcn'
########### 2.参数解析 ###########
parser = argparse.ArgumentParser(description="Run algorithm with specified parameters.")
# 添加命令行参数
parser.add_argument('--alg_name', type=str, required=True, help="Algorithm name (e.g., 'ann')")
parser.add_argument('--dataset_file_name', type=str, required=True, help="Dataset file name (e.g., 'my_dataset_model')")
parser.add_argument('--mean', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
parser.add_argument('--std', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
parser.add_argument('--dataset_num_classes', type=int, required=True, help="Dataset class num (e.g., 36)")
parser.add_argument('--GPU_num', type=int, required=True, help="Num of GPU to train (e.g., 1)")
parser.add_argument('--schedule_file_name', type=str, required=True, help="Schedule file name (e.g., 'schedule_4k_check_400')")
parser.add_argument('--train_type', type=str, required=True, help="Schedule train type (e.g., Epoch or Iteration)")
parser.add_argument('--train_time_or_epoch_k', type=int, required=True, help="Schedule train time k (e.g., 4)")
# 解析命令行参数
args = parser.parse_args()
# 变量赋值
alg_name = args.alg_name
dataset_file_name = args.dataset_file_name
mean = json.loads(args.mean) # 使用json.loads转换成list
std = json.loads(args.std) # 使用json.loads转换成list
num_classes = args.dataset_num_classes
schedule_file_name = args.schedule_file_name
train_type = args.train_type
train_time_or_epoch_k = args.train_time_or_epoch_k
GPU_num = args.GPU_num
########### 3. 设定参数 ###########
# 3.1. 自选参数
crop_size = select_crop_size(predefined_options=[(512,1024), (512, 512)]) # 选择切割大小
# 3.2. 自动生成参数
# decode、auxiliary损失下载方式
# decode_head_loss_decode_dict = dict(type='CrossEntropyLoss',
# use_sigmoid=False,
# loss_weight=1.0)
decode_head_loss_decode_dict=dict(type='DiceLoss', use_sigmoid=False, loss_weight=1.0) # decode损失函数 # TODO
decode_head=dict(loss_decode=decode_head_loss_decode_dict)
########### 3.生成各个部分 ###########
_base_ = generate_base_config(alg_file_name=alg_file_name, dataset_file_name=dataset_file_name, schedule_file_name=schedule_file_name)
norm_cfg = generate_norm_cfg(GPU_num=GPU_num)
data_preprocessor = generate_data_preprocessor(crop_size = crop_size, mean = mean, std = std)
model_data_preprocessor = data_preprocessor # 修改data_preprocessor
model_data_preprocessor = generate_model_data_preprocessor(model_data_preprocessor=model_data_preprocessor)
# 综合model
model = dict(data_preprocessor = model_data_preprocessor, decode_head=decode_head)
# test_cfg = generate_model_test_cfg(test_cfg_mode=test_cfg_mode, crop_size=crop_size, test_cfg_crop_div_stride=test_cfg_crop_div_stride)
optim_wrapper = generate_optim_wrapper()
param_scheduler = generate_param_scheduler(train_type, train_time_or_epoch_k)
########### 4.程序写入 ###########
# 算法名称解析erfnet【Alg】 _ r50【pretrained模型深度】 - d8【_base_/models】 _ 4【GPU数量】 x b2【Batch Size大小】 - 40k【schedule】 _ cityscapes【数据集】 - 512x1024【crop_size】.py
alg_file_name = f"{alg_name}_fcn_g{GPU_num}-{schedule_file_name.lstrip('schedule_')}_{dataset_file_name}-{crop_size[0]}x{crop_size[1]}.py"
output_configs_alg_my_alg = os.path.join(f'./configs/erfnet/', alg_file_name)
write_config_to_file(output_configs_alg_my_alg, _base_=_base_, norm_cfg=norm_cfg, crop_size=crop_size, data_preprocessor=data_preprocessor, model=model, optim_wrapper=optim_wrapper, param_scheduler=param_scheduler)
# 将信息写入文件
with open("_temp_.txt", "w") as file:
json.dump({'alg_infos':{'alg_file_name':alg_file_name,'alg_file_pth':output_configs_alg_my_alg}}, file)

View File

@@ -0,0 +1,89 @@
import os, sys, argparse, json
# 添加上级目录到 sys.path
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../')))
from Initial_Alg_Program.Initial_Alg_Select_Tool import select_crop_size, select_pretrained_model, select_test_cfg_slide
from Initial_Alg_Program.Initial_Alg_Gen_Tool import format_all_data, write_config_to_file, get_var_from_py_file
from Initial_Alg_Program.Initial_Alg_Gen_0_base_ import generate_base_config
from Initial_Alg_Program.Initial_Alg_Gen_1_norm_cfg import generate_norm_cfg
from Initial_Alg_Program.Initial_Alg_Gen_2_data_preprocessor import generate_data_preprocessor
from Initial_Alg_Program.Initial_Alg_Gen_3_model import generate_model_pretrained, generate_model_backbone, generate_model_data_preprocessor, generate_model_decode_head, generate_model_auxiliary_head, generate_model_train_cfg, generate_model_test_cfg
from Initial_Alg_Program.Initial_Alg_Gen_4_optimizer import generate_optim_wrapper, generate_param_scheduler
if __name__ == '__main__':
########### 1.超参数-算法默认参数 ###########
alg_file_name = 'fast_scnn'
########### 2.参数解析 ###########
parser = argparse.ArgumentParser(description="Run algorithm with specified parameters.")
# 添加命令行参数
parser.add_argument('--alg_name', type=str, required=True, help="Algorithm name (e.g., 'ann')")
parser.add_argument('--dataset_file_name', type=str, required=True, help="Dataset file name (e.g., 'my_dataset_model')")
parser.add_argument('--mean', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
parser.add_argument('--std', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
parser.add_argument('--dataset_num_classes', type=int, required=True, help="Dataset class num (e.g., 36)")
parser.add_argument('--GPU_num', type=int, required=True, help="Num of GPU to train (e.g., 1)")
parser.add_argument('--schedule_file_name', type=str, required=True, help="Schedule file name (e.g., 'schedule_4k_check_400')")
parser.add_argument('--train_type', type=str, required=True, help="Schedule train type (e.g., Epoch or Iteration)")
parser.add_argument('--train_time_or_epoch_k', type=int, required=True, help="Schedule train time k (e.g., 4)")
# 解析命令行参数
args = parser.parse_args()
# 变量赋值
alg_name = args.alg_name
dataset_file_name = args.dataset_file_name
mean = json.loads(args.mean) # 使用json.loads转换成list
std = json.loads(args.std) # 使用json.loads转换成list
num_classes = args.dataset_num_classes
schedule_file_name = args.schedule_file_name
train_type = args.train_type
train_time_or_epoch_k = args.train_time_or_epoch_k
GPU_num = args.GPU_num
########### 3. 设定参数 ###########
# 3.1. 自选参数
crop_size = select_crop_size(predefined_options=[(512,1024), (512, 512)]) # 选择切割大小
########### 3.生成各个部分 ###########
# 3.1. base、norm_cfg、data_preprocessor部分
_base_ = generate_base_config(alg_file_name=alg_file_name, dataset_file_name=dataset_file_name, schedule_file_name=schedule_file_name)
norm_cfg = generate_norm_cfg(GPU_num=GPU_num)
data_preprocessor = generate_data_preprocessor(crop_size = crop_size, mean = mean, std = std)
# 3.2. model部分
model_data_preprocessor = data_preprocessor # 修改data_preprocessor
model_data_preprocessor = generate_model_data_preprocessor(model_data_preprocessor=model_data_preprocessor)
# decode损失
decode_head_loss_decode_dict = dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)
decode_head_loss_decode_dict=dict(type='DiceLoss', use_sigmoid=False, loss_weight=1.0) # decode损失函数 # TODO
decode_head=dict(loss_decode=decode_head_loss_decode_dict)
# auxiliary损失
auxiliary_head = get_var_from_py_file(os.path.join('./configs/_base_/models', alg_file_name+'.py'), 'model')['auxiliary_head']
for i in range(len(auxiliary_head)):
auxiliary_head[i]['loss_decode']['use_sigmoid']=False
auxiliary_head[i]['loss_decode']['type']='DiceLoss'
auxiliary_head[i]['num_classes']=num_classes
auxiliary_head[i]['norm_cfg']=norm_cfg
# 综合model
model = dict(data_preprocessor = model_data_preprocessor, decode_head = decode_head, auxiliary_head=auxiliary_head)
# 3.3. optim_wrapper、param_scheduler
optim_wrapper = generate_optim_wrapper()
param_scheduler = generate_param_scheduler(train_type, train_time_or_epoch_k)
########### 4.程序写入 ###########
# 算法名称解析fastscnn【Alg】 _ r50【pretrained模型深度】 - d8【_base_/models】 _ 4【GPU数量】 x b2【Batch Size大小】 - 40k【schedule】 _ cityscapes【数据集】 - 512x1024【crop_size】.py
alg_file_name = f"{alg_name}_g{GPU_num}-{schedule_file_name.lstrip('schedule_')}_{dataset_file_name}-{crop_size[0]}x{crop_size[1]}.py"
output_configs_alg_my_alg = os.path.join(f'./configs/fastscnn/', alg_file_name)
write_config_to_file(output_configs_alg_my_alg, _base_=_base_, norm_cfg=norm_cfg, crop_size=crop_size, data_preprocessor=data_preprocessor, model=model, optim_wrapper=optim_wrapper, param_scheduler=param_scheduler)
# 将信息写入文件
with open("_temp_.txt", "w") as file:
json.dump({'alg_infos':{'alg_file_name':alg_file_name,'alg_file_pth':output_configs_alg_my_alg}}, file)

View File

@@ -0,0 +1,189 @@
import os, sys, argparse, json
# 添加上级目录到 sys.path
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../')))
from Initial_Alg_Program.Initial_Alg_Select_Tool import select_crop_size, select_pretrained_model
from Initial_Alg_Program.Initial_Alg_Gen_Tool import format_all_data, write_config_to_file, create_dict_by_kwargs, get_var_from_file, update_list_dict_var, get_var_from_py_file
from Initial_Alg_Program.Initial_Alg_Gen_0_base_ import generate_base_config
from Initial_Alg_Program.Initial_Alg_Gen_1_norm_cfg import generate_norm_cfg
from Initial_Alg_Program.Initial_Alg_Gen_2_data_preprocessor import generate_data_preprocessor
from Initial_Alg_Program.Initial_Alg_Gen_3_model import generate_model_pretrained, generate_model_backbone, generate_model_data_preprocessor, generate_model_decode_head, generate_model_auxiliary_head, generate_model_train_cfg, generate_model_test_cfg
from Initial_Alg_Program.Initial_Alg_Gen_4_optimizer import generate_optim_wrapper, generate_param_scheduler
# 交互式选择 decode_head 的函数
def select_decode_head(decode_head_choose):
print("可用的 decode head 选项:")
for i, key in enumerate(decode_head_choose.keys()):
print(f"{i + 1}. {key}")
while True:
try:
# 提示用户输入选择
choice = int(input("请选择需要的 decode head输入编号"))
# 检查输入是否在有效范围内
if 1 <= choice <= len(decode_head_choose):
selected_key = list(decode_head_choose.keys())[choice - 1]
print(f"你选择了: {selected_key}")
return decode_head_choose[selected_key]
else:
print("输入的编号不正确,请重新输入。")
except ValueError:
print("输入无效,请输入有效的编号。")
if __name__ == '__main__':
########### 1.超参数-算法默认参数 ###########
alg_file_name = 'fastfcn_r50-d32_jpu_psp'
alg_file_pth = f"configs/_base_/models/{alg_file_name}.py"
########### 2.参数解析 ###########
parser = argparse.ArgumentParser(description="Run algorithm with specified parameters.")
# 添加命令行参数
parser.add_argument('--alg_name', type=str, required=True, help="Algorithm name (e.g., 'ann')")
parser.add_argument('--dataset_file_name', type=str, required=True, help="Dataset file name (e.g., 'my_dataset_model')")
parser.add_argument('--mean', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
parser.add_argument('--std', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
parser.add_argument('--dataset_num_classes', type=int, required=True, help="Dataset class num (e.g., 36)")
parser.add_argument('--GPU_num', type=int, required=True, help="Num of GPU to train (e.g., 1)")
parser.add_argument('--schedule_file_name', type=str, required=True, help="Schedule file name (e.g., 'schedule_4k_check_400')")
parser.add_argument('--train_type', type=str, required=True, help="Schedule train type (e.g., Epoch or Iteration)")
parser.add_argument('--train_time_or_epoch_k', type=int, required=True, help="Schedule train time k (e.g., 4)")
# 解析命令行参数
args = parser.parse_args()
# 变量赋值
alg_name = args.alg_name
dataset_file_name = args.dataset_file_name
mean = json.loads(args.mean) # 使用json.loads转换成list
std = json.loads(args.std) # 使用json.loads转换成list
num_classes = args.dataset_num_classes
schedule_file_name = args.schedule_file_name
train_type = args.train_type
train_time_or_epoch_k = args.train_time_or_epoch_k
GPU_num = args.GPU_num
########### 3.1. 生成_base_部分 ###########
_base_ = generate_base_config(alg_file_name=alg_file_name, dataset_file_name=dataset_file_name, schedule_file_name=schedule_file_name)
########### 3.1. 生成norm_cfg部分 ###########
norm_cfg = generate_norm_cfg(GPU_num=GPU_num)
########### 3.2. 生成data_preprocessor部分 ###########
crop_size = select_crop_size(predefined_options=[(512, 512), (1024, 1024)]) # 选择切割大小
data_preprocessor = generate_data_preprocessor(crop_size = crop_size, mean = mean, std = std)
########### 3.3. 生成model部分 ###########
model_data_preprocessor = data_preprocessor # 修改data_preprocessor
# 获取backbone模型、是否需要预训练
model_list = ['openmmlab/resnet50_v1c', 'openmmlab/resnet101_v1c']
selected_model_name, select_pretrained, pretrained_pth, selected_model_info = select_pretrained_model(model_list=model_list) # 需要选择是否用预训练模型
depth = selected_model_info['depth']
backbone = create_dict_by_kwargs(norm_cfg=norm_cfg, depth=depth)
# 定义 decode_head 的选项字典
decode_head_choose = {
'psp-PSPHead': {
'decode_head':dict(
type='PSPHead',
in_channels=2048,
in_index=2,
channels=512,
pool_scales=(1, 2, 3, 6),
dropout_ratio=0.1,
num_classes=19,
norm_cfg=norm_cfg, # 假设 norm_cfg 预定义
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)
),
'decode_head_name':'aspp'
},
'enc-EncHead': {
'decode_head':dict(
_delete_=True,
type='EncHead',
in_channels=[512, 1024, 2048],
in_index=(0, 1, 2),
channels=512,
num_codes=32,
use_se_loss=True,
add_lateral=False,
dropout_ratio=0.1,
num_classes=num_classes,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
loss_se_decode=dict(
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=0.2)
),
'decode_head_name':'aspp'
},
'aspp-ASPPHead': {
'decode_head':dict(
_delete_=True,
type='ASPPHead',
in_channels=2048,
in_index=2,
channels=512,
dilations=(1, 12, 24, 36),
dropout_ratio=0.1,
num_classes=num_classes,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)
),
'decode_head_name':'aspp'
}
}
decode_head_dict = select_decode_head(decode_head_choose=decode_head_choose)
decode_head = decode_head_dict['decode_head']
decode_head_name = decode_head_dict['decode_head_name']
decode_head_loss_decode_type = 'DiceLoss' # Way 1: 更改 Loss
# decode_head_loss_decode_type = 'CrossEntropyLoss' # Way ori: 不更改
# 修改decode_head中loss_decode type
if 'loss_se_decode' in decode_head.keys():
decode_head['loss_se_decode']['type'] = decode_head_loss_decode_type
if 'loss_decode' in decode_head.keys():
decode_head['loss_decode']['type'] = decode_head_loss_decode_type
# auxiliary 为 None
auxiliary_head_loss_decode_dict = dict(type='DiceLoss', use_sigmoid=False, loss_weight=0.4)
auxiliary_head = dict(loss_decode=auxiliary_head_loss_decode_dict, norm_cfg=norm_cfg, num_classes=num_classes) # Way 1: 更改 Loss
# auxiliary_head = dict(norm_cfg=norm_cfg, num_classes=num_classes) # Way ori: 不更改
# 综合model
if auxiliary_head == None:
model = dict(
backbone = backbone,
data_preprocessor = model_data_preprocessor,
decode_head = decode_head
)
else:
model = dict(
backbone = backbone,
data_preprocessor = model_data_preprocessor,
decode_head = decode_head,
auxiliary_head = auxiliary_head
)
########### 3.4. 生成optim、param_scheduler部分 ###########
optim_wrapper = generate_optim_wrapper()
param_scheduler = generate_param_scheduler(train_type, train_time_or_epoch_k)
########### 4.程序写入 ###########
# 算法名称解析fastfcn【Alg】 _ r50【pretrained模型深度】 - d8【_base_/models】 _ 4【GPU数量】 x b2【Batch Size大小】 - 40k【schedule】 _ cityscapes【数据集】 - 512x1024【crop_size】.py
alg_file_name = f"{alg_name}_r{depth}_{'Pre' if select_pretrained else 'NoPre'}_jpu_{decode_head_name}_g{GPU_num}-{schedule_file_name.lstrip('schedule_')}_{dataset_file_name}-{crop_size[0]}x{crop_size[1]}.py"
output_configs_alg_my_alg = os.path.join(f'./configs/fastfcn/', alg_file_name)
write_config_to_file(output_configs_alg_my_alg, _base_=_base_, norm_cfg=norm_cfg, crop_size=crop_size, data_preprocessor=data_preprocessor, model=model, optim_wrapper=optim_wrapper, param_scheduler=param_scheduler)
# 将信息写入文件
with open("_temp_.txt", "w") as file:
json.dump({'alg_infos':{'alg_file_name':alg_file_name,'alg_file_pth':output_configs_alg_my_alg}}, file)

View File

@@ -0,0 +1,132 @@
import os, sys, argparse, json
# 添加上级目录到 sys.path
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../')))
from Initial_Alg_Program.Initial_Alg_Select_Tool import select_crop_size, select_pretrained_model, select_test_cfg_slide
from Initial_Alg_Program.Initial_Alg_Gen_Tool import format_all_data, write_config_to_file, create_dict_by_kwargs
from Initial_Alg_Program.Initial_Alg_Gen_0_base_ import generate_base_config
from Initial_Alg_Program.Initial_Alg_Gen_1_norm_cfg import generate_norm_cfg
from Initial_Alg_Program.Initial_Alg_Gen_2_data_preprocessor import generate_data_preprocessor
from Initial_Alg_Program.Initial_Alg_Gen_3_model import generate_model_pretrained, generate_model_backbone, generate_model_data_preprocessor, generate_model_decode_head, generate_model_auxiliary_head, generate_model_train_cfg, generate_model_test_cfg
from Initial_Alg_Program.Initial_Alg_Gen_4_optimizer import generate_optim_wrapper, generate_param_scheduler
if __name__ == '__main__':
########### 1.超参数-算法默认参数 ###########
alg_file_name = 'fcn_r50-d8'
########### 2.参数解析 ###########
parser = argparse.ArgumentParser(description="Run algorithm with specified parameters.")
# 添加命令行参数
parser.add_argument('--alg_name', type=str, required=True, help="Algorithm name (e.g., 'ann')")
parser.add_argument('--dataset_file_name', type=str, required=True, help="Dataset file name (e.g., 'my_dataset_model')")
parser.add_argument('--mean', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
parser.add_argument('--std', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
parser.add_argument('--dataset_num_classes', type=int, required=True, help="Dataset class num (e.g., 36)")
parser.add_argument('--GPU_num', type=int, required=True, help="Num of GPU to train (e.g., 1)")
parser.add_argument('--schedule_file_name', type=str, required=True, help="Schedule file name (e.g., 'schedule_4k_check_400')")
parser.add_argument('--train_type', type=str, required=True, help="Schedule train type (e.g., Epoch or Iteration)")
parser.add_argument('--train_time_or_epoch_k', type=int, required=True, help="Schedule train time k (e.g., 4)")
# 解析命令行参数
args = parser.parse_args()
# 变量赋值
alg_name = args.alg_name
dataset_file_name = args.dataset_file_name
mean = json.loads(args.mean) # 使用json.loads转换成list
std = json.loads(args.std) # 使用json.loads转换成list
num_classes = args.dataset_num_classes
schedule_file_name = args.schedule_file_name
train_type = args.train_type
train_time_or_epoch_k = args.train_time_or_epoch_k
GPU_num = args.GPU_num
########### 3. 设定参数 ###########
# 3.1. 自选参数
crop_size = select_crop_size(predefined_options=[(512, 512), (512, 1024), (769, 769)]) # 选择切割大小
# 如果crop_size更大则调整dilations和strides
if crop_size[0]*crop_size[1] > 512*512:
backbone_dilations = (1, 1, 1, 2) # 每一步更精细信息
backbone_strides = (1, 2, 2, 1) # 步子走的大
decode_head_dilation = 6 # 每一步更多信息
auxiliary_head_dilation = 6 # 每一步更多信息
else:
backbone_dilations = (1, 1, 2, 4) # 每一步更多信息
backbone_strides = (1, 2, 1, 1) # 步子走的小
decode_head_dilation = 1 # 每一步更精细信息
auxiliary_head_dilation = 1 # 每一步更精细信息
model_list = ['torchvision://resnet18', 'torchvision://resnet50', 'torchvision://resnet101', 'open-mmlab/resnet18_v1c', 'openmmlab/resnet50_v1c', 'openmmlab/resnet101_v1c']
selected_model_name, select_pretrained, pretrained_pth, selected_model_info = select_pretrained_model(model_list=model_list)
depth = selected_model_info['depth']
model_type = selected_model_info['type']
test_cfg_mode, test_cfg_crop_div_stride = select_test_cfg_slide(crop_size)
# 模型信息
if str(depth) == '18' :
backbone_out_channels = 256
decode_head_in_channels = 512 #
decode_head_channels = 128 #
auxiliary_head_in_channels = 256 #
auxiliary_head_channels = 64 #
elif str(depth) == '50' or str(depth) == '101':
backbone_out_channels = 1024
decode_head_in_channels = 2048 #
decode_head_channels = 512 #
auxiliary_head_in_channels = 1024 #
auxiliary_head_channels = 256 #
# 3.2. 自动生成参数
# decode、auxiliary损失下载方式
decode_head_loss_decode_dict=dict(type='DiceLoss', use_sigmoid=False, loss_weight=1.0) # decode损失函数 # TODO
# decode_head_loss_decode_dict=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0) # decode损失函数 # TODO
auxiliary_head_loss_decode_dict=dict(type='DiceLoss', use_sigmoid=False, loss_weight=0.4) # DiceLoss损失函数 # TODO
# auxiliary_head_loss_decode_dict=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4) # DiceLoss损失函数 # TODO
# 如果使用slide模式则开启align_corners
if test_cfg_mode == "slide":
align_corners = True
test_slide = 'testslide' # 算法名进行标记
else:
align_corners = False
test_slide = 'no_testslide' # 算法名进行标记
########### 3.生成各个部分 ###########
_base_ = generate_base_config(alg_file_name=alg_file_name, dataset_file_name=dataset_file_name, schedule_file_name=schedule_file_name)
norm_cfg = generate_norm_cfg(GPU_num=GPU_num)
data_preprocessor = generate_data_preprocessor(crop_size = crop_size, mean = mean, std = std)
pretrained = generate_model_pretrained(pretrained_pth = pretrained_pth)
backbone = create_dict_by_kwargs(type=model_type, depth=depth, strides = backbone_strides, dilations=backbone_dilations)
model_data_preprocessor = data_preprocessor # 修改data_preprocessor
model_data_preprocessor = generate_model_data_preprocessor(model_data_preprocessor=model_data_preprocessor)
decode_head = create_dict_by_kwargs(dilation=decode_head_dilation, channels=decode_head_channels, in_channels=decode_head_in_channels, num_classes=num_classes, loss_decode=decode_head_loss_decode_dict, align_corners=align_corners)
auxiliary_head = create_dict_by_kwargs(dilation=auxiliary_head_dilation, channels=auxiliary_head_channels, in_channels=auxiliary_head_in_channels, num_classes=num_classes, loss_decode=auxiliary_head_loss_decode_dict, align_corners=align_corners)
# 综合model
model = dict(
pretrained = pretrained,
backbone = backbone,
data_preprocessor = model_data_preprocessor,
decode_head = decode_head,
auxiliary_head = auxiliary_head,
)
test_cfg = generate_model_test_cfg(test_cfg_mode=test_cfg_mode, crop_size=crop_size, test_cfg_crop_div_stride=test_cfg_crop_div_stride)
optim_wrapper = generate_optim_wrapper()
param_scheduler = generate_param_scheduler(train_type, train_time_or_epoch_k)
########### 4.程序写入 ###########
# 算法名称解析fcn【Alg】 _ r50【pretrained模型深度】 - d8【_base_/models】 _ 4【GPU数量】 x b2【Batch Size大小】 - 40k【schedule】 _ cityscapes【数据集】 - 512x1024【crop_size】.py
alg_file_name = f"{alg_name}_r{depth}_{'Pre' if select_pretrained else 'NoPre'}_g{GPU_num}-{schedule_file_name.lstrip('schedule_')}_{dataset_file_name}-{crop_size[0]}x{crop_size[1]}-{test_slide}.py"
output_configs_alg_my_alg = os.path.join(f'./configs/fcn/', alg_file_name)
write_config_to_file(output_configs_alg_my_alg, _base_=_base_, norm_cfg=norm_cfg, crop_size=crop_size, data_preprocessor=data_preprocessor, model=model, test_cfg=test_cfg, optim_wrapper=optim_wrapper, param_scheduler=param_scheduler)
# 将信息写入文件
with open("_temp_.txt", "w") as file:
json.dump({'alg_infos':{'alg_file_name':alg_file_name,'alg_file_pth':output_configs_alg_my_alg}}, file)

View File

@@ -0,0 +1,106 @@
import os, sys, argparse, json
# 添加上级目录到 sys.path
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../')))
from Initial_Alg_Program.Initial_Alg_Select_Tool import select_crop_size, select_pretrained_model, select_test_cfg_slide
from Initial_Alg_Program.Initial_Alg_Gen_Tool import format_all_data, write_config_to_file
from Initial_Alg_Program.Initial_Alg_Gen_0_base_ import generate_base_config
from Initial_Alg_Program.Initial_Alg_Gen_1_norm_cfg import generate_norm_cfg
from Initial_Alg_Program.Initial_Alg_Gen_2_data_preprocessor import generate_data_preprocessor
from Initial_Alg_Program.Initial_Alg_Gen_3_model import generate_model_pretrained, generate_model_backbone, generate_model_data_preprocessor, generate_model_decode_head, generate_model_auxiliary_head, generate_model_train_cfg, generate_model_test_cfg
from Initial_Alg_Program.Initial_Alg_Gen_4_optimizer import generate_optim_wrapper, generate_param_scheduler
if __name__ == '__main__':
########### 1.超参数-算法默认参数 ###########
alg_file_name = 'gcnet_r50-d8'
########### 2.参数解析 ###########
parser = argparse.ArgumentParser(description="Run algorithm with specified parameters.")
# 添加命令行参数
parser.add_argument('--alg_name', type=str, required=True, help="Algorithm name (e.g., 'ann')")
parser.add_argument('--dataset_file_name', type=str, required=True, help="Dataset file name (e.g., 'my_dataset_model')")
parser.add_argument('--mean', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
parser.add_argument('--std', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
parser.add_argument('--dataset_num_classes', type=int, required=True, help="Dataset class num (e.g., 36)")
parser.add_argument('--GPU_num', type=int, required=True, help="Num of GPU to train (e.g., 1)")
parser.add_argument('--schedule_file_name', type=str, required=True, help="Schedule file name (e.g., 'schedule_4k_check_400')")
parser.add_argument('--train_type', type=str, required=True, help="Schedule train type (e.g., Epoch or Iteration)")
parser.add_argument('--train_time_or_epoch_k', type=int, required=True, help="Schedule train time k (e.g., 4)")
# 解析命令行参数
args = parser.parse_args()
# 变量赋值
alg_name = args.alg_name
dataset_file_name = args.dataset_file_name
mean = json.loads(args.mean) # 使用json.loads转换成list
std = json.loads(args.std) # 使用json.loads转换成list
num_classes = args.dataset_num_classes
schedule_file_name = args.schedule_file_name
train_type = args.train_type
train_time_or_epoch_k = args.train_time_or_epoch_k
GPU_num = args.GPU_num
########### 3. 设定参数 ###########
# 3.1. 自选参数
crop_size = select_crop_size(predefined_options=[(512, 512), (512, 1024), (769, 769)]) # 选择切割大小
model_list = ['openmmlab/resnet50_v1c', 'openmmlab/resnet101_v1c']
selected_model_name, select_pretrained, pretrained_pth, selected_model_info = select_pretrained_model(model_list=model_list)
depth = selected_model_info['depth']
test_cfg_mode, test_cfg_crop_div_stride = select_test_cfg_slide(crop_size)
# test_cfg_mode = None
# test_cfg_crop_div_stride = crop_size
# 3.2. 自动生成参数
# decode、auxiliary损失下载方式
decode_head_loss_decode_dict=dict(type='DiceLoss', use_sigmoid=False, loss_weight=1.0) # decode损失函数 # TODO
# decode_head_loss_decode_dict=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0) # decode损失函数 # TODO
auxiliary_head_loss_decode_dict=dict(type='DiceLoss', use_sigmoid=False, loss_weight=0.4) # DiceLoss损失函数 # TODO
# auxiliary_head_loss_decode_dict=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4) # DiceLoss损失函数 # TODO
# 如果使用slide模式则开启align_corners
if test_cfg_mode == "slide":
align_corners = True
test_slide = 'testslide' # 算法名进行标记
else:
align_corners = False
test_slide = 'no_testslide' # 算法名进行标记
########### 3.生成各个部分 ###########
_base_ = generate_base_config(alg_file_name=alg_file_name, dataset_file_name=dataset_file_name, schedule_file_name=schedule_file_name)
norm_cfg = generate_norm_cfg(GPU_num=GPU_num)
data_preprocessor = generate_data_preprocessor(crop_size = crop_size, mean = mean, std = std)
pretrained = generate_model_pretrained(pretrained_pth = pretrained_pth)
backbone = generate_model_backbone(depth=depth)
model_data_preprocessor = data_preprocessor # 修改data_preprocessor
model_data_preprocessor = generate_model_data_preprocessor(model_data_preprocessor=model_data_preprocessor)
decode_head = generate_model_decode_head(num_classes=num_classes, decode_head_loss_decode_dict=decode_head_loss_decode_dict, align_corners=align_corners)
auxiliary_head = generate_model_auxiliary_head(num_classes=num_classes, auxiliary_head_loss_decode_dict=auxiliary_head_loss_decode_dict, align_corners=align_corners)
# 综合model
model = dict(
pretrained = pretrained,
backbone = backbone,
data_preprocessor = model_data_preprocessor,
decode_head = decode_head,
auxiliary_head = auxiliary_head,
)
test_cfg = generate_model_test_cfg(test_cfg_mode=test_cfg_mode, crop_size=crop_size, test_cfg_crop_div_stride=test_cfg_crop_div_stride)
optim_wrapper = generate_optim_wrapper()
param_scheduler = generate_param_scheduler(train_type, train_time_or_epoch_k)
########### 4.程序写入 ###########
# 算法名称解析gcnet【Alg】 _ r50【pretrained模型深度】 - d8【_base_/models】 _ 4【GPU数量】 x b2【Batch Size大小】 - 40k【schedule】 _ cityscapes【数据集】 - 512x1024【crop_size】.py
alg_file_name = f"{alg_name}_r{depth}_{'Pre' if select_pretrained else 'NoPre'}_g{GPU_num}-{schedule_file_name.lstrip('schedule_')}_{dataset_file_name}-{crop_size[0]}x{crop_size[1]}-{test_slide}.py"
output_configs_alg_my_alg = os.path.join(f'./configs/gcnet/', alg_file_name)
write_config_to_file(output_configs_alg_my_alg, _base_=_base_, norm_cfg=norm_cfg, crop_size=crop_size, data_preprocessor=data_preprocessor, model=model, test_cfg=test_cfg, optim_wrapper=optim_wrapper, param_scheduler=param_scheduler)
# 将信息写入文件
with open("_temp_.txt", "w") as file:
json.dump({'alg_infos':{'alg_file_name':alg_file_name,'alg_file_pth':output_configs_alg_my_alg}}, file)

View File

@@ -0,0 +1,144 @@
import os, sys, argparse, json
# 添加上级目录到 sys.path
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../')))
from Initial_Alg_Program.Initial_Alg_Select_Tool import select_crop_size, select_pretrained_model, select_test_cfg_slide
from Initial_Alg_Program.Initial_Alg_Gen_Tool import format_all_data, write_config_to_file, create_dict_by_kwargs
from Initial_Alg_Program.Initial_Alg_Gen_0_base_ import generate_base_config
from Initial_Alg_Program.Initial_Alg_Gen_1_norm_cfg import generate_norm_cfg
from Initial_Alg_Program.Initial_Alg_Gen_2_data_preprocessor import generate_data_preprocessor
from Initial_Alg_Program.Initial_Alg_Gen_3_model import generate_model_pretrained, generate_model_backbone, generate_model_data_preprocessor, generate_model_decode_head, generate_model_auxiliary_head, generate_model_train_cfg, generate_model_test_cfg
from Initial_Alg_Program.Initial_Alg_Gen_4_optimizer import generate_optim_wrapper, generate_param_scheduler
if __name__ == '__main__':
########### 1.超参数-算法默认参数 ###########
alg_file_name = 'fcn_hr18'
########### 2.参数解析 ###########
parser = argparse.ArgumentParser(description="Run algorithm with specified parameters.")
# 添加命令行参数
parser.add_argument('--alg_name', type=str, required=True, help="Algorithm name (e.g., 'ann')")
parser.add_argument('--dataset_file_name', type=str, required=True, help="Dataset file name (e.g., 'my_dataset_model')")
parser.add_argument('--mean', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
parser.add_argument('--std', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
parser.add_argument('--dataset_num_classes', type=int, required=True, help="Dataset class num (e.g., 36)")
parser.add_argument('--GPU_num', type=int, required=True, help="Num of GPU to train (e.g., 1)")
parser.add_argument('--schedule_file_name', type=str, required=True, help="Schedule file name (e.g., 'schedule_4k_check_400')")
parser.add_argument('--train_type', type=str, required=True, help="Schedule train type (e.g., Epoch or Iteration)")
parser.add_argument('--train_time_or_epoch_k', type=int, required=True, help="Schedule train time k (e.g., 4)")
# 解析命令行参数
args = parser.parse_args()
# 变量赋值
alg_name = args.alg_name
dataset_file_name = args.dataset_file_name
mean = json.loads(args.mean) # 使用json.loads转换成list
std = json.loads(args.std) # 使用json.loads转换成list
num_classes = args.dataset_num_classes
schedule_file_name = args.schedule_file_name
train_type = args.train_type
train_time_or_epoch_k = args.train_time_or_epoch_k
GPU_num = args.GPU_num
########### 3. 设定参数 ###########
# 3.1. 自选参数
crop_size = select_crop_size(predefined_options=[(512, 512), (512, 1024), (769, 769)]) # 选择切割大小
# # 如果crop_size更大则调整dilations和strides
# if crop_size[0]*crop_size[1] > 512*512:
# backbone_dilations = (1, 1, 1, 2) # 每一步更精细信息
# backbone_strides = (1, 2, 2, 1) # 步子走的大
# decode_head_dilation = 6 # 每一步更多信息
# auxiliary_head_dilation = 6 # 每一步更多信息
# else:
# backbone_dilations = (1, 1, 2, 4) # 每一步更多信息
# backbone_strides = (1, 2, 1, 1) # 步子走的小
# decode_head_dilation = 1 # 每一步更精细信息
# auxiliary_head_dilation = 1 # 每一步更精细信息
model_list = ['open-mmlab://msra/hrnetv2_w18', 'open-mmlab://msra/hrnetv2_w18_small', 'open-mmlab://msra/hrnetv2_w48']
selected_model_name, select_pretrained, pretrained_pth, selected_model_info = select_pretrained_model(model_list=model_list)
test_cfg_mode, test_cfg_crop_div_stride = select_test_cfg_slide(crop_size)
# 模型信息
if selected_model_name == 'open-mmlab://msra/hrnetv2_w18':
backbone_extra=dict(
stage1=dict(num_blocks=(4, ), num_channels=(64, )),
stage2=dict(num_blocks=(4, 4), num_channels=(18, 36)),
stage3=dict(num_modules=4, num_blocks=(4, 4, 4), num_channels=(18, 36, 72)),
stage4=dict(num_modules=3, num_blocks=(4, 4, 4, 4), num_channels=(18, 36, 72, 144)))
decode_head_channels = dict(in_channels=[18, 36, 72, 144], channels=sum([18, 36, 72, 144]))
alg_text = 'w18'
elif selected_model_name == 'open-mmlab://msra/hrnetv2_w18_small':
backbone_extra=dict(
stage1=dict(num_blocks=(2, )),
stage2=dict(num_blocks=(2, 2)),
stage3=dict(num_modules=3, num_blocks=(2, 2, 2)),
stage4=dict(num_modules=2, num_blocks=(2, 2, 2, 2)))
decode_head_channels = dict(in_channels=[18, 36, 72, 144], channels=sum([18, 36, 72, 144])) # 同上
alg_text = 'w18-small'
elif selected_model_name == 'open-mmlab://msra/hrnetv2_w48':
# 改变channel
backbone_extra=dict(
stage2=dict(num_channels=(48, 96)),
stage3=dict(num_channels=(48, 96, 192)),
stage4=dict(num_channels=(48, 96, 192, 384)))
decode_head_channels = dict(in_channels=[48, 96, 192, 384], channels=sum([48, 96, 192, 384]))
alg_text = 'w48'
# 3.2. 自动生成参数
# decode、auxiliary损失下载方式
# decode_head_loss_decode_dict=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0) # decode损失函数 # Way: Ori TODO
decode_head_loss_decode_dict=dict(type='DiceLoss', use_sigmoid=False, loss_weight=1.0) # decode损失函数 # Way: 1 TODO
# 如果使用slide模式则开启align_corners
if test_cfg_mode == "slide":
align_corners = True
test_slide = 'testslide' # 算法名进行标记
else:
align_corners = False
test_slide = 'no_testslide' # 算法名进行标记
########### 3.生成各个部分 ###########
_base_ = generate_base_config(alg_file_name=alg_file_name, dataset_file_name=dataset_file_name, schedule_file_name=schedule_file_name)
norm_cfg = generate_norm_cfg(GPU_num=GPU_num)
data_preprocessor = generate_data_preprocessor(crop_size = crop_size, mean = mean, std = std)
pretrained = generate_model_pretrained(pretrained_pth = pretrained_pth)
backbone = create_dict_by_kwargs(extra = backbone_extra)
model_data_preprocessor = data_preprocessor # 修改data_preprocessor
model_data_preprocessor = generate_model_data_preprocessor(model_data_preprocessor=model_data_preprocessor)
decode_head = create_dict_by_kwargs(num_classes=num_classes, loss_decode=decode_head_loss_decode_dict, align_corners=align_corners)
decode_head.update(decode_head_channels) # 加入channels相关信息
auxiliary_head = None
# 综合model
model = dict(
pretrained = pretrained,
backbone = backbone,
data_preprocessor = model_data_preprocessor,
decode_head = decode_head,
# auxiliary_head = auxiliary_head,
)
test_cfg = generate_model_test_cfg(test_cfg_mode=test_cfg_mode, crop_size=crop_size, test_cfg_crop_div_stride=test_cfg_crop_div_stride)
optim_wrapper = generate_optim_wrapper()
param_scheduler = generate_param_scheduler(train_type, train_time_or_epoch_k)
########### 4.程序写入 ###########
# 算法名称解析hrnet【Alg】 _ r50【pretrained模型深度】 - d8【_base_/models】 _ 4【GPU数量】 x b2【Batch Size大小】 - 40k【schedule】 _ cityscapes【数据集】 - 512x1024【crop_size】.py
alg_file_name = f"{alg_name}_{alg_text}_{'Pre' if select_pretrained else 'NoPre'}_g{GPU_num}-{schedule_file_name.lstrip('schedule_')}_{dataset_file_name}-{crop_size[0]}x{crop_size[1]}-{test_slide}.py"
output_configs_alg_my_alg = os.path.join(f'./configs/hrnet/', alg_file_name)
write_config_to_file(output_configs_alg_my_alg, _base_=_base_, norm_cfg=norm_cfg, crop_size=crop_size, data_preprocessor=data_preprocessor, model=model, test_cfg=test_cfg, optim_wrapper=optim_wrapper, param_scheduler=param_scheduler)
# 将信息写入文件
with open("_temp_.txt", "w") as file:
json.dump({'alg_infos':{'alg_file_name':alg_file_name,'alg_file_pth':output_configs_alg_my_alg}}, file)

View File

@@ -0,0 +1,124 @@
import os, sys, argparse, json
# 添加上级目录到 sys.path
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../')))
from Initial_Alg_Program.Initial_Alg_Select_Tool import select_crop_size, select_pretrained_model, select_test_cfg_slide
from Initial_Alg_Program.Initial_Alg_Gen_Tool import format_all_data, write_config_to_file, create_dict_by_kwargs, get_var_from_py_file
from Initial_Alg_Program.Initial_Alg_Gen_0_base_ import generate_base_config
from Initial_Alg_Program.Initial_Alg_Gen_1_norm_cfg import generate_norm_cfg
from Initial_Alg_Program.Initial_Alg_Gen_2_data_preprocessor import generate_data_preprocessor
from Initial_Alg_Program.Initial_Alg_Gen_3_model import generate_model_pretrained, generate_model_backbone, generate_model_data_preprocessor, generate_model_decode_head, generate_model_auxiliary_head, generate_model_train_cfg, generate_model_test_cfg
from Initial_Alg_Program.Initial_Alg_Gen_4_optimizer import generate_optim_wrapper, generate_param_scheduler
if __name__ == '__main__':
########### 1.超参数-算法默认参数 ###########
alg_file_name = 'icnet_r50-d8'
########### 2.参数解析 ###########
parser = argparse.ArgumentParser(description="Run algorithm with specified parameters.")
# 添加命令行参数
parser.add_argument('--alg_name', type=str, required=True, help="Algorithm name (e.g., 'ann')")
parser.add_argument('--dataset_file_name', type=str, required=True, help="Dataset file name (e.g., 'my_dataset_model')")
parser.add_argument('--mean', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
parser.add_argument('--std', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
parser.add_argument('--dataset_num_classes', type=int, required=True, help="Dataset class num (e.g., 36)")
parser.add_argument('--GPU_num', type=int, required=True, help="Num of GPU to train (e.g., 1)")
parser.add_argument('--schedule_file_name', type=str, required=True, help="Schedule file name (e.g., 'schedule_4k_check_400')")
parser.add_argument('--train_type', type=str, required=True, help="Schedule train type (e.g., Epoch or Iteration)")
parser.add_argument('--train_time_or_epoch_k', type=int, required=True, help="Schedule train time k (e.g., 4)")
# 解析命令行参数
args = parser.parse_args()
# 变量赋值
alg_name = args.alg_name
dataset_file_name = args.dataset_file_name
mean = json.loads(args.mean) # 使用json.loads转换成list
std = json.loads(args.std) # 使用json.loads转换成list
num_classes = args.dataset_num_classes
schedule_file_name = args.schedule_file_name
train_type = args.train_type
train_time_or_epoch_k = args.train_time_or_epoch_k
GPU_num = args.GPU_num
########### 3. 设定参数 ###########
# 3.1. 自选参数
crop_size = select_crop_size(predefined_options=[(832, 832), (512, 512)]) # 选择切割大小
model_list = ['openmmlab/resnet18_v1c', 'openmmlab/resnet50_v1c', 'openmmlab/resnet101_v1c']
selected_model_name, select_pretrained, pretrained_pth, selected_model_info = select_pretrained_model(model_list=model_list, need_select_pretrained=True)
depth = selected_model_info['depth']
test_cfg_mode, test_cfg_crop_div_stride = select_test_cfg_slide(crop_size)
# 需要预训练模型
if select_pretrained == True:
backbone_backbone_cfg_init_cfg = dict(type='Pretrained', checkpoint=pretrained_pth)
else:
backbone_cfg_init_cfg = None
backbone_backbone_cfg = create_dict_by_kwargs(depth=depth, init_cfg=backbone_backbone_cfg_init_cfg)
# 模型信息
if str(depth) == '18' :
backbone_layer_channels = (128, 512)
elif str(depth) == '50' or str(depth) == '101':
backbone_layer_channels = (512, 2048)
# 3.2. 自动生成参数
# decode、auxiliary损失下载方式
decode_head_loss_decode_dict=dict(type='DiceLoss', use_sigmoid=False, loss_weight=1.0) # Way 1: decode损失函数 # TODO
# decode_head_loss_decode_dict=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0) # Way Ori: decode损失函数 # TODO
auxiliary_head_loss_decode_dict=dict(type='DiceLoss', use_sigmoid=False, loss_weight=0.4) # Way 1: DiceLoss损失函数 # TODO
# auxiliary_head_loss_decode_dict=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4) # Way Ori: DiceLoss损失函数 # TODO
# 如果使用slide模式则开启align_corners
if test_cfg_mode == "slide":
align_corners = True
test_slide = 'testslide' # 算法名进行标记
else:
align_corners = False
test_slide = 'no_testslide' # 算法名进行标记
########### 3.生成各个部分 ###########
_base_ = generate_base_config(alg_file_name=alg_file_name, dataset_file_name=dataset_file_name, schedule_file_name=schedule_file_name)
norm_cfg = generate_norm_cfg(GPU_num=GPU_num)
data_preprocessor = generate_data_preprocessor(crop_size = crop_size, mean = mean, std = std)
model_data_preprocessor = data_preprocessor # 修改data_preprocessor
model_data_preprocessor = generate_model_data_preprocessor(model_data_preprocessor=model_data_preprocessor)
backbone = create_dict_by_kwargs(backbone_cfg = backbone_backbone_cfg, layer_channels=backbone_layer_channels, norm_cfg=norm_cfg, align_corners=align_corners)
neck = create_dict_by_kwargs(norm_cfg=norm_cfg, align_corners=align_corners)
decode_head = create_dict_by_kwargs(num_classes=num_classes, loss_decode=decode_head_loss_decode_dict, align_corners=align_corners)
auxiliary_head = get_var_from_py_file(os.path.join('./configs/_base_/models', alg_file_name+'.py'), 'model')['auxiliary_head']
for i in range(len(auxiliary_head)):
auxiliary_head[i].update(num_classes=num_classes, loss_decode=auxiliary_head_loss_decode_dict, align_corners=align_corners)
# 综合model
model = dict(
backbone = backbone,
data_preprocessor = model_data_preprocessor,
decode_head = decode_head,
auxiliary_head = auxiliary_head,
)
test_cfg = generate_model_test_cfg(test_cfg_mode=test_cfg_mode, crop_size=crop_size, test_cfg_crop_div_stride=test_cfg_crop_div_stride)
optim_wrapper = generate_optim_wrapper()
param_scheduler = generate_param_scheduler(train_type, train_time_or_epoch_k)
########### 4.程序写入 ###########
# 算法名称解析icnet【Alg】 _ r50【pretrained模型深度】 - d8【_base_/models】 _ 4【GPU数量】 x b2【Batch Size大小】 - 40k【schedule】 _ cityscapes【数据集】 - 512x1024【crop_size】.py
alg_file_name = f"{alg_name}_r{depth}_{'Pre' if select_pretrained else 'NoPre'}_g{GPU_num}-{schedule_file_name.lstrip('schedule_')}_{dataset_file_name}-{crop_size[0]}x{crop_size[1]}-{test_slide}.py"
output_configs_alg_my_alg = os.path.join(f'./configs/icnet/', alg_file_name)
write_config_to_file(output_configs_alg_my_alg, _base_=_base_, norm_cfg=norm_cfg, crop_size=crop_size, data_preprocessor=data_preprocessor, model=model, test_cfg=test_cfg, optim_wrapper=optim_wrapper, param_scheduler=param_scheduler)
# 将信息写入文件
with open("_temp_.txt", "w") as file:
json.dump({'alg_infos':{'alg_file_name':alg_file_name,'alg_file_pth':output_configs_alg_my_alg}}, file)

View File

@@ -0,0 +1,106 @@
import os, sys, argparse, json
# 添加上级目录到 sys.path
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../')))
from Initial_Alg_Program.Initial_Alg_Select_Tool import select_crop_size, select_pretrained_model, select_test_cfg_slide
from Initial_Alg_Program.Initial_Alg_Gen_Tool import format_all_data, write_config_to_file
from Initial_Alg_Program.Initial_Alg_Gen_0_base_ import generate_base_config
from Initial_Alg_Program.Initial_Alg_Gen_1_norm_cfg import generate_norm_cfg
from Initial_Alg_Program.Initial_Alg_Gen_2_data_preprocessor import generate_data_preprocessor
from Initial_Alg_Program.Initial_Alg_Gen_3_model import generate_model_pretrained, generate_model_backbone, generate_model_data_preprocessor, generate_model_decode_head, generate_model_auxiliary_head, generate_model_train_cfg, generate_model_test_cfg
from Initial_Alg_Program.Initial_Alg_Gen_4_optimizer import generate_optim_wrapper, generate_param_scheduler
if __name__ == '__main__':
########### 1.超参数-算法默认参数 ###########
alg_file_name = 'isanet_r50-d8'
########### 2.参数解析 ###########
parser = argparse.ArgumentParser(description="Run algorithm with specified parameters.")
# 添加命令行参数
parser.add_argument('--alg_name', type=str, required=True, help="Algorithm name (e.g., 'ann')")
parser.add_argument('--dataset_file_name', type=str, required=True, help="Dataset file name (e.g., 'my_dataset_model')")
parser.add_argument('--mean', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
parser.add_argument('--std', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
parser.add_argument('--dataset_num_classes', type=int, required=True, help="Dataset class num (e.g., 36)")
parser.add_argument('--GPU_num', type=int, required=True, help="Num of GPU to train (e.g., 1)")
parser.add_argument('--schedule_file_name', type=str, required=True, help="Schedule file name (e.g., 'schedule_4k_check_400')")
parser.add_argument('--train_type', type=str, required=True, help="Schedule train type (e.g., Epoch or Iteration)")
parser.add_argument('--train_time_or_epoch_k', type=int, required=True, help="Schedule train time k (e.g., 4)")
# 解析命令行参数
args = parser.parse_args()
# 变量赋值
alg_name = args.alg_name
dataset_file_name = args.dataset_file_name
mean = json.loads(args.mean) # 使用json.loads转换成list
std = json.loads(args.std) # 使用json.loads转换成list
num_classes = args.dataset_num_classes
schedule_file_name = args.schedule_file_name
train_type = args.train_type
train_time_or_epoch_k = args.train_time_or_epoch_k
GPU_num = args.GPU_num
########### 3. 设定参数 ###########
# 3.1. 自选参数
crop_size = select_crop_size(predefined_options=[(512, 512), (512, 1024), (769, 769)]) # 选择切割大小
model_list = ['openmmlab/resnet50_v1c', 'openmmlab/resnet101_v1c']
selected_model_name, select_pretrained, pretrained_pth, selected_model_info = select_pretrained_model(model_list=model_list)
depth = selected_model_info['depth']
test_cfg_mode, test_cfg_crop_div_stride = select_test_cfg_slide(crop_size)
# test_cfg_mode = None
# test_cfg_crop_div_stride = crop_size
# 3.2. 自动生成参数
# decode、auxiliary损失下载方式
decode_head_loss_decode_dict=dict(type='DiceLoss', use_sigmoid=False, loss_weight=1.0) # decode损失函数 # TODO
# decode_head_loss_decode_dict=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0) # decode损失函数 # TODO
auxiliary_head_loss_decode_dict=dict(type='DiceLoss', use_sigmoid=False, loss_weight=0.4) # DiceLoss损失函数 # TODO
# auxiliary_head_loss_decode_dict=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4) # DiceLoss损失函数 # TODO
# 如果使用slide模式则开启align_corners
if test_cfg_mode == "slide":
align_corners = True
test_slide = 'testslide' # 算法名进行标记
else:
align_corners = False
test_slide = 'no_testslide' # 算法名进行标记
########### 3.生成各个部分 ###########
_base_ = generate_base_config(alg_file_name=alg_file_name, dataset_file_name=dataset_file_name, schedule_file_name=schedule_file_name)
norm_cfg = generate_norm_cfg(GPU_num=GPU_num)
data_preprocessor = generate_data_preprocessor(crop_size = crop_size, mean = mean, std = std)
pretrained = generate_model_pretrained(pretrained_pth = pretrained_pth)
backbone = generate_model_backbone(depth=depth)
model_data_preprocessor = data_preprocessor # 修改data_preprocessor
model_data_preprocessor = generate_model_data_preprocessor(model_data_preprocessor=model_data_preprocessor)
decode_head = generate_model_decode_head(num_classes=num_classes, decode_head_loss_decode_dict=decode_head_loss_decode_dict, align_corners=align_corners)
auxiliary_head = generate_model_auxiliary_head(num_classes=num_classes, auxiliary_head_loss_decode_dict=auxiliary_head_loss_decode_dict, align_corners=align_corners)
# 综合model
model = dict(
pretrained = pretrained,
backbone = backbone,
data_preprocessor = model_data_preprocessor,
decode_head = decode_head,
auxiliary_head = auxiliary_head,
)
test_cfg = generate_model_test_cfg(test_cfg_mode=test_cfg_mode, crop_size=crop_size, test_cfg_crop_div_stride=test_cfg_crop_div_stride)
optim_wrapper = generate_optim_wrapper()
param_scheduler = generate_param_scheduler(train_type, train_time_or_epoch_k)
########### 4.程序写入 ###########
# 算法名称解析isanet【Alg】 _ r50【pretrained模型深度】 - d8【_base_/models】 _ 4【GPU数量】 x b2【Batch Size大小】 - 40k【schedule】 _ cityscapes【数据集】 - 512x1024【crop_size】.py
alg_file_name = f"{alg_name}_r{depth}_{'Pre' if select_pretrained else 'NoPre'}_g{GPU_num}-{schedule_file_name.lstrip('schedule_')}_{dataset_file_name}-{crop_size[0]}x{crop_size[1]}-{test_slide}.py"
output_configs_alg_my_alg = os.path.join(f'./configs/isanet/', alg_file_name)
write_config_to_file(output_configs_alg_my_alg, _base_=_base_, norm_cfg=norm_cfg, crop_size=crop_size, data_preprocessor=data_preprocessor, model=model, test_cfg=test_cfg, optim_wrapper=optim_wrapper, param_scheduler=param_scheduler)
# 将信息写入文件
with open("_temp_.txt", "w") as file:
json.dump({'alg_infos':{'alg_file_name':alg_file_name,'alg_file_pth':output_configs_alg_my_alg}}, file)

View File

@@ -0,0 +1,276 @@
import os, sys, argparse, json
# 添加上级目录到 sys.path
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../')))
from Initial_Alg_Program.Initial_Alg_Select_Tool import select_crop_size, select_pretrained_model, select_test_cfg_slide
from Initial_Alg_Program.Initial_Alg_Gen_Tool import format_all_data, write_config_to_file, create_dict_by_kwargs, get_var_from_py_file
from Initial_Alg_Program.Initial_Alg_Gen_0_base_ import generate_base_config
from Initial_Alg_Program.Initial_Alg_Gen_1_norm_cfg import generate_norm_cfg
from Initial_Alg_Program.Initial_Alg_Gen_2_data_preprocessor import generate_data_preprocessor
from Initial_Alg_Program.Initial_Alg_Gen_3_model import generate_model_pretrained, generate_model_backbone, generate_model_data_preprocessor, generate_model_decode_head, generate_model_auxiliary_head, generate_model_train_cfg, generate_model_test_cfg
from Initial_Alg_Program.Initial_Alg_Gen_4_optimizer import generate_optim_wrapper, generate_param_scheduler
from Initial_Alg_Program.Initial_Alg_Gen_5_train_dataloader import generate_train_dataloader
if __name__ == '__main__':
########### 1.超参数-算法默认参数 ###########
alg_file_name = 'knet_r50-d8_my'
########### 2.参数解析 ###########
parser = argparse.ArgumentParser(description="Run algorithm with specified parameters.")
# 添加命令行参数
parser.add_argument('--alg_name', type=str, required=True, help="Algorithm name (e.g., 'ann')")
parser.add_argument('--dataset_file_name', type=str, required=True, help="Dataset file name (e.g., 'my_dataset_model')")
parser.add_argument('--mean', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
parser.add_argument('--std', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
parser.add_argument('--dataset_num_classes', type=int, required=True, help="Dataset class num (e.g., 36)")
parser.add_argument('--GPU_num', type=int, required=True, help="Num of GPU to train (e.g., 1)")
parser.add_argument('--schedule_file_name', type=str, required=True, help="Schedule file name (e.g., 'schedule_4k_check_400')")
parser.add_argument('--train_type', type=str, required=True, help="Schedule train type (e.g., Epoch or Iteration)")
parser.add_argument('--train_time_or_epoch_k', type=int, required=True, help="Schedule train time k (e.g., 4)")
# 解析命令行参数
args = parser.parse_args()
# 变量赋值
alg_name = args.alg_name
dataset_file_name = args.dataset_file_name
mean = json.loads(args.mean) # 使用json.loads转换成list
std = json.loads(args.std) # 使用json.loads转换成list
num_classes = args.dataset_num_classes
schedule_file_name = args.schedule_file_name
train_type = args.train_type
train_time_or_epoch_k = args.train_time_or_epoch_k
GPU_num = args.GPU_num
########### 3. 设定参数 ###########
_base_ = generate_base_config(alg_file_name=alg_file_name, dataset_file_name=dataset_file_name, schedule_file_name=schedule_file_name) # _base_无算法
crop_size = select_crop_size(predefined_options=[(512, 512), (640,640)]) # 选择切割大小
data_preprocessor = generate_data_preprocessor(crop_size = crop_size, mean = mean, std = std)
model_data_preprocessor = data_preprocessor # 修改data_preprocessor
model_data_preprocessor = generate_model_data_preprocessor(model_data_preprocessor=model_data_preprocessor)
norm_cfg = generate_norm_cfg(GPU_num=GPU_num)
# 选择对应的模型
config_dict = {
'DeepLabV3': {
'decode_head_kernel_generate_head_dict': dict(
_delete_ = True,
type='ASPPHead',
in_channels=2048,
in_index=3,
channels=512,
dilations=(1, 12, 24, 36),
dropout_ratio=0.1,
num_classes=150,
norm_cfg=norm_cfg,
align_corners=False),
'backbone_dilations': (1, 1, 2, 4),
'backbone_strides': (1, 2, 1, 1),
'auxiliary_head_in_channels':1024
},
'PSPNet': {
'decode_head_kernel_generate_head_dict': dict(
_delete_ = True,
type='PSPHead',
in_channels=2048,
in_index=3,
channels=512,
pool_scales=(1, 2, 3, 6),
dropout_ratio=0.1,
num_classes=150,
norm_cfg=norm_cfg,
align_corners=False),
'backbone_dilations': (1, 1, 2, 4),
'backbone_strides': (1, 2, 1, 1),
'auxiliary_head_in_channels':1024
},
'FCN': {
'decode_head_kernel_generate_head_dict': dict(
_delete_ = True,
type='FCNHead',
in_channels=2048,
in_index=3,
channels=512,
num_convs=2,
concat_input=True,
dropout_ratio=0.1,
num_classes=150,
norm_cfg=norm_cfg,
align_corners=False),
'backbone_dilations': (1, 1, 2, 4),
'backbone_strides': (1, 2, 1, 1),
'auxiliary_head_in_channels':1024
},
'UPerNet': {
'decode_head_kernel_generate_head_dict': dict(
type='UPerHead',
in_channels=[256, 512, 1024, 2048],
in_index=[0, 1, 2, 3],
pool_scales=(1, 2, 3, 6),
channels=512,
dropout_ratio=0.1,
num_classes=150,
norm_cfg=norm_cfg,
align_corners=False),
'backbone_dilations': (1, 1, 1, 1),
'backbone_strides': (1, 2, 2, 2),
'auxiliary_head_in_channels':1024
}
}
models = ['DeepLabV3', 'PSPNet', 'FCN', 'UPerNet']
print("请你选择对应的模型:")
while True:
for index, model in enumerate(models, 1):
print(f"{index}. {model}")
try:
user_input = int(input("Enter your choice (1-4): "))
if 1 <= user_input <= 4:
model_choice = models[user_input - 1]
break
else:
print("Invalid input, please enter a number between 1 and 4.")
except ValueError:
print("Invalid input, please enter a valid number.")
decode_head_kernel_generate_head_dict = config_dict[model_choice]['decode_head_kernel_generate_head_dict']
backbone_dilations = config_dict[model_choice]['backbone_dilations']
backbone_strides = config_dict[model_choice]['backbone_strides']
auxiliary_head_in_channels = config_dict[model_choice]['auxiliary_head_in_channels']
model_list = ['openmmlab/resnet50_v1c', 'openmmlab/resnet101_v1c', "pretrain/swin_tiny-f41b89d3.pth", "pretrain/swin_large-d5bdebaf.pth"]
selected_model_name, select_pretrained, pretrained_pth, selected_model_info = select_pretrained_model(model_list=model_list)
if selected_model_info['type'] == 'ResNetV1c':
depth = selected_model_info['depth']
backbone = create_dict_by_kwargs(depth = depth, norm_cfg=norm_cfg, dilations=backbone_dilations, strides=backbone_strides)
elif selected_model_info['type'] == 'swin':
if selected_model_name == 'pretrain/swin_tiny-f41b89d3.pth':
backbone = dict(
_delete_=True,
type='SwinTransformer',
embed_dims=96, #
depths=[2, 2, 6, 2], #
num_heads=[3, 6, 12, 24], #
window_size=7,
mlp_ratio=4,
qkv_bias=True,
qk_scale=None,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.3, #
use_abs_pos_embed=False,
patch_norm=True,
out_indices=(0, 1, 2, 3))
if selected_model_name == 'pretrain/swin_large-d5bdebaf.pth':
backbone = dict(
_delete_=True,
type='SwinTransformer',
embed_dims=192, #
depths=[2, 2, 18, 2], #
num_heads=[6, 12, 24, 48], #
window_size=7,
mlp_ratio=4,
qkv_bias=True,
qk_scale=None,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.4, #
use_abs_pos_embed=False,
patch_norm=True,
out_indices=(0, 1, 2, 3))
backbone.update(create_dict_by_kwargs(norm_cfg=dict(type='LN'))) # TODO Transformer一类的内容norm_cfg都设置为LN TODO # 没有 dilations 和 strides
test_cfg_mode, test_cfg_crop_div_stride = select_test_cfg_slide(crop_size)
# 3.2. 自动生成参数
# 如果使用slide模式则开启align_corners
if test_cfg_mode == "slide":
align_corners = True
test_slide = 'testslide' # 算法名进行标记
else:
align_corners = False
test_slide = 'no_testslide' # 算法名进行标记
########### 3.生成各个部分 ###########
# 3.2. decode_head
decode_head = get_var_from_py_file(os.path.join('./configs/_base_/models', alg_file_name+'.py'), 'model')['decode_head']
decode_head_kernel_update_head_list = decode_head['kernel_update_head']
for i in range(len(decode_head_kernel_update_head_list)):
decode_head_kernel_update_head_list[i]['num_classes'] = num_classes
decode_head['kernel_update_head'] = decode_head_kernel_update_head_list
decode_head_kernel_generate_head_dict['num_classes'] = num_classes
if model_choice == 'UPerNet':
if selected_model_name == 'pretrain/swin_tiny-f41b89d3.pth':
decode_head_kernel_generate_head_dict['in_channels'] = [96, 192, 384, 768]
auxiliary_head_in_channels = 384
elif selected_model_name == 'pretrain/swin_large-d5bdebaf.pth':
decode_head_kernel_generate_head_dict['in_channels'] = [192, 384, 768, 1536]
auxiliary_head_in_channels = 768
elif model_choice in ['DeepLabV3', 'PSPNet', 'FCN']:
if selected_model_name == 'pretrain/swin_tiny-f41b89d3.pth':
decode_head_kernel_generate_head_dict['in_channels'] = 768
auxiliary_head_in_channels = 384
elif selected_model_name == 'pretrain/swin_large-d5bdebaf.pth':
decode_head_kernel_generate_head_dict['in_channels'] = 1536
auxiliary_head_in_channels = 768
decode_head_kernel_generate_head_dict_loss_decode = dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0) # Way Ori
decode_head_kernel_generate_head_dict_loss_decode = dict(type='DiceLoss', use_sigmoid=False, loss_weight=1.0) # Way 1
decode_head_kernel_generate_head_dict['loss_decode'] = decode_head_kernel_generate_head_dict_loss_decode
decode_head['kernel_generate_head'] = decode_head_kernel_generate_head_dict
decode_head['kernel_generate_head']['align_corners'] = align_corners
# 3.2. auxiliary_head
auxiliary_head = get_var_from_py_file(os.path.join('./configs/_base_/models', alg_file_name+'.py'), 'model')['auxiliary_head']
auxiliary_head_loss_decode = dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4) # Way Ori
auxiliary_head_loss_decode = dict(type='DiceLoss', use_sigmoid=False, loss_weight=0.4) # Way 1
auxiliary_head['loss_decode'] = auxiliary_head_loss_decode
auxiliary_head['align_corners'] = align_corners
auxiliary_head['in_channels'] = auxiliary_head_in_channels
# 综合model
model = dict(
backbone = backbone,
pretrained = pretrained_pth,
data_preprocessor = model_data_preprocessor,
decode_head = decode_head,
auxiliary_head = auxiliary_head,
)
# 4. train_dataloader
train_dataloader, batch_size, num_workers = generate_train_dataloader(batch_size_default=2, num_workers_default=2)
if selected_model_info['type'] == 'ResNetV1c':
optim_wrapper = generate_optim_wrapper()
elif selected_model_info['type'] == 'swin':
optim_wrapper = generate_optim_wrapper('swin')
param_scheduler = generate_param_scheduler(train_type, train_time_or_epoch_k)
########### 4.程序写入 ###########
# 算法名称解析knet【Alg】 _ r50【pretrained模型深度】 - d8【_base_/models】 _ 4【GPU数量】 x b2【Batch Size大小】 - 40k【schedule】 _ cityscapes【数据集】 - 512x1024【crop_size】.py
if selected_model_info['type'] == 'ResNetV1c':
alg_file_name = f"{alg_name}_r{depth}_{model_choice}_{'Pre' if select_pretrained else 'NoPre'}_g{GPU_num}-{schedule_file_name.lstrip('schedule_')}_{dataset_file_name}-{crop_size[0]}x{crop_size[1]}-{test_slide}.py"
elif selected_model_info['type'] == 'swin':
alg_file_name = f"{alg_name}_swin_{selected_model_info['size']}_{model_choice}_{'Pre' if select_pretrained else 'NoPre'}_g{GPU_num}-{schedule_file_name.lstrip('schedule_')}_{dataset_file_name}-{crop_size[0]}x{crop_size[1]}-{test_slide}.py"
output_configs_alg_my_alg = os.path.join(f'./configs/knet/', alg_file_name)
write_config_to_file(output_configs_alg_my_alg, _base_=_base_, norm_cfg=norm_cfg, crop_size=crop_size, data_preprocessor=data_preprocessor, model=model, train_dataloader=train_dataloader, optim_wrapper=optim_wrapper, param_scheduler=param_scheduler)
# 将信息写入文件
with open("_temp_.txt", "w") as file:
json.dump({'alg_infos':{'alg_file_name':alg_file_name,'alg_file_pth':output_configs_alg_my_alg}}, file)

View File

@@ -0,0 +1,136 @@
import os, sys, argparse, json
# 添加上级目录到 sys.path
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../')))
from Initial_Alg_Program.Initial_Alg_Select_Tool import select_crop_size, select_pretrained_model, select_test_cfg_slide
from Initial_Alg_Program.Initial_Alg_Gen_Tool import format_all_data, write_config_to_file, create_dict_by_kwargs, get_var_from_file, update_list_dict_var, get_var_from_py_file
from Initial_Alg_Program.Initial_Alg_Gen_0_base_ import generate_base_config
from Initial_Alg_Program.Initial_Alg_Gen_1_norm_cfg import generate_norm_cfg
from Initial_Alg_Program.Initial_Alg_Gen_2_data_preprocessor import generate_data_preprocessor
from Initial_Alg_Program.Initial_Alg_Gen_3_model import generate_model_pretrained, generate_model_backbone, generate_model_data_preprocessor, generate_model_decode_head, generate_model_auxiliary_head, generate_model_train_cfg, generate_model_test_cfg
from Initial_Alg_Program.Initial_Alg_Gen_4_optimizer import generate_optim_wrapper, generate_param_scheduler
from Initial_Alg_Program.Initial_Alg_Gen_5_train_dataloader import generate_train_dataloader
# 交互式选择 decode_head 的函数
def select_decode_head(decode_head_choose):
print("可用的 decode head 选项:")
for i, key in enumerate(decode_head_choose.keys()):
print(f"{i + 1}. {key}")
while True:
try:
# 提示用户输入选择
choice = int(input("请选择需要的 decode head输入编号"))
# 检查输入是否在有效范围内
if 1 <= choice <= len(decode_head_choose):
selected_key = list(decode_head_choose.keys())[choice - 1]
print(f"你选择了: {selected_key}")
return decode_head_choose[selected_key]
else:
print("输入的编号不正确,请重新输入。")
except ValueError:
print("输入无效,请输入有效的编号。")
if __name__ == '__main__':
########### 1.超参数-算法默认参数 ###########
alg_file_name = 'upernet_mae'
alg_file_pth = f"configs/_base_/models/{alg_file_name}.py"
########### 2.参数解析 ###########
parser = argparse.ArgumentParser(description="Run algorithm with specified parameters.")
# 添加命令行参数
parser.add_argument('--alg_name', type=str, required=True, help="Algorithm name (e.g., 'ann')")
parser.add_argument('--dataset_file_name', type=str, required=True, help="Dataset file name (e.g., 'my_dataset_model')")
parser.add_argument('--mean', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
parser.add_argument('--std', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
parser.add_argument('--dataset_num_classes', type=int, required=True, help="Dataset class num (e.g., 36)")
parser.add_argument('--GPU_num', type=int, required=True, help="Num of GPU to train (e.g., 1)")
parser.add_argument('--schedule_file_name', type=str, required=True, help="Schedule file name (e.g., 'schedule_4k_check_400')")
parser.add_argument('--train_type', type=str, required=True, help="Schedule train type (e.g., Epoch or Iteration)")
parser.add_argument('--train_time_or_epoch_k', type=int, required=True, help="Schedule train time k (e.g., 4)")
# 解析命令行参数
args = parser.parse_args()
# 变量赋值
alg_name = args.alg_name
dataset_file_name = args.dataset_file_name
mean = json.loads(args.mean) # 使用json.loads转换成list
std = json.loads(args.std) # 使用json.loads转换成list
num_classes = args.dataset_num_classes
schedule_file_name = args.schedule_file_name
train_type = args.train_type
train_time_or_epoch_k = args.train_time_or_epoch_k
GPU_num = args.GPU_num
########### 3.1. 生成_base_部分 ###########
_base_ = generate_base_config(alg_file_name=alg_file_name, dataset_file_name=dataset_file_name, schedule_file_name=schedule_file_name)
########### 3.1. 生成norm_cfg部分 ###########
norm_cfg = generate_norm_cfg(GPU_num=GPU_num)
########### 3.2. 生成data_preprocessor部分 ###########
crop_size = select_crop_size(predefined_options=[(512, 512)]) # 选择切割大小
data_preprocessor = generate_data_preprocessor(crop_size = crop_size, mean = mean, std = std)
########### 3.3. 生成model部分 ###########
model_data_preprocessor = data_preprocessor # 修改data_preprocessor
# 获取backbone模型、是否需要预训练
model_list = ["mae_pretrain_vit_base_mmcls.pth"]
selected_model_name, select_pretrained, pretrained_pth, selected_model_info = select_pretrained_model(model_list=model_list) # 需要选择是否用预训练模型
backbone = create_dict_by_kwargs(type='MAE', img_size=crop_size, init_values=1.0)
decode_head_loss_decode=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0) # Way Ori
decode_head_loss_decode=dict(type='DiceLoss', use_sigmoid=False, loss_weight=1.0) # Way 1
auxiliary_head_loss_decode=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4) # Way Ori
auxiliary_head_loss_decode=dict(type='DiceLoss', use_sigmoid=False, loss_weight=0.4) # Way 1
neck=dict(embed_dim=768, rescales=[4, 2, 1, 0.5])
decode_head=dict(in_channels=[768, 768, 768, 768], channels=768, num_classes=num_classes, norm_cfg=norm_cfg, loss_decode=decode_head_loss_decode)
auxiliary_head=dict(in_channels=768, norm_cfg=norm_cfg, num_classes=num_classes, loss_decode = auxiliary_head_loss_decode)
test_cfg_mode, test_cfg_crop_div_stride = select_test_cfg_slide(crop_size, select_slide=True)
test_cfg = generate_model_test_cfg(test_cfg_mode='slide', crop_size=crop_size, test_cfg_crop_div_stride=test_cfg_crop_div_stride)
optim_wrapper = dict(
_delete_=True,
type='AmpOptimWrapper', # type='OptimWrapper', # TODO Ori
optimizer=dict(
type='AdamW', lr=3e-5, betas=(0.9, 0.999), weight_decay=0.05),
constructor='LayerDecayOptimizerConstructor',
paramwise_cfg=dict(num_layers=12, layer_decay_rate=0.9))
model_size = 'base'
train_dataloader, batch_size = generate_train_dataloader(4)
model = dict(
backbone = backbone,
data_preprocessor = model_data_preprocessor,
pretrained = pretrained_pth,
decode_head = decode_head,
neck = neck,
auxiliary_head = auxiliary_head,
test_cfg = test_cfg
)
########### 3.4. 生成fp16部分 ###########
fp16 = dict(loss_scale='dynamic') # mixed precision
########### 3.4. 生成optim、param_scheduler部分 ###########
optim_wrapper = optim_wrapper
param_scheduler = generate_param_scheduler(train_type, train_time_or_epoch_k)
########### 4.程序写入 ###########
# 算法名称解析beit【Alg】 _ r50【pretrained模型深度】 - d8【_base_/models】 _ 4【GPU数量】 x b2【Batch Size大小】 - 40k【schedule】 _ cityscapes【数据集】 - 512x1024【crop_size】.py
alg_file_name = f"{alg_name}-{model_size}_b{batch_size}_g{GPU_num}-{schedule_file_name.lstrip('schedule_')}_{dataset_file_name}-{crop_size[0]}x{crop_size[1]}-testslide.py"
output_configs_alg_my_alg = os.path.join(f'./configs/mae/', alg_file_name)
write_config_to_file(output_configs_alg_my_alg, _base_=_base_, norm_cfg=norm_cfg, crop_size=crop_size, data_preprocessor=data_preprocessor, model=model, fp16=fp16,optim_wrapper=optim_wrapper, param_scheduler=param_scheduler, train_dataloader = train_dataloader)
# 将信息写入文件
with open("_temp_.txt", "w") as file:
json.dump({'alg_infos':{'alg_file_name':alg_file_name,'alg_file_pth':output_configs_alg_my_alg}}, file)

View File

@@ -0,0 +1,284 @@
import os, sys, argparse, json
# 添加上级目录到 sys.path
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../')))
from Initial_Alg_Program.Initial_Alg_Select_Tool import select_crop_size, select_pretrained_model, select_test_cfg_slide
from Initial_Alg_Program.Initial_Alg_Gen_Tool import format_all_data, write_config_to_file, create_dict_by_kwargs, get_var_from_py_file
from Initial_Alg_Program.Initial_Alg_Gen_0_base_ import generate_base_config
from Initial_Alg_Program.Initial_Alg_Gen_1_norm_cfg import generate_norm_cfg
from Initial_Alg_Program.Initial_Alg_Gen_2_data_preprocessor import generate_data_preprocessor
from Initial_Alg_Program.Initial_Alg_Gen_3_model import generate_model_pretrained, generate_model_backbone, generate_model_data_preprocessor, generate_model_decode_head, generate_model_auxiliary_head, generate_model_train_cfg, generate_model_test_cfg
from Initial_Alg_Program.Initial_Alg_Gen_4_optimizer import generate_optim_wrapper, generate_param_scheduler
from Initial_Alg_Program.Initial_Alg_Gen_5_train_dataloader import generate_train_dataloader
if __name__ == '__main__':
########### 1.超参数-算法默认参数 ###########
alg_file_name = 'mask2former_my'
########### 2.参数解析 ###########
parser = argparse.ArgumentParser(description="Run algorithm with specified parameters.")
# 添加命令行参数
parser.add_argument('--alg_name', type=str, required=True, help="Algorithm name (e.g., 'ann')")
parser.add_argument('--dataset_file_name', type=str, required=True, help="Dataset file name (e.g., 'my_dataset_model')")
parser.add_argument('--mean', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
parser.add_argument('--std', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
parser.add_argument('--dataset_num_classes', type=int, required=True, help="Dataset class num (e.g., 36)")
parser.add_argument('--GPU_num', type=int, required=True, help="Num of GPU to train (e.g., 1)")
parser.add_argument('--schedule_file_name', type=str, required=True, help="Schedule file name (e.g., 'schedule_4k_check_400')")
parser.add_argument('--train_type', type=str, required=True, help="Schedule train type (e.g., Epoch or Iteration)")
parser.add_argument('--train_time_or_epoch_k', type=int, required=True, help="Schedule train time k (e.g., 4)")
# 解析命令行参数
args = parser.parse_args()
# 变量赋值
alg_name = args.alg_name
dataset_file_name = args.dataset_file_name
mean = json.loads(args.mean) # 使用json.loads转换成list
std = json.loads(args.std) # 使用json.loads转换成list
num_classes = args.dataset_num_classes
schedule_file_name = args.schedule_file_name
train_type = args.train_type
train_time_or_epoch_k = args.train_time_or_epoch_k
GPU_num = args.GPU_num
########### 3. 设定参数 ###########
_base_ = generate_base_config(alg_file_name=alg_file_name, dataset_file_name=dataset_file_name, schedule_file_name=schedule_file_name) # _base_无算法
crop_size = select_crop_size(predefined_options=[(512, 512), (640, 640), (512, 1024)]) # 选择切割大小
data_preprocessor = generate_data_preprocessor(crop_size = crop_size, mean = mean, std = std)
model_data_preprocessor = data_preprocessor # 修改data_preprocessor
model_data_preprocessor = generate_model_data_preprocessor(model_data_preprocessor=model_data_preprocessor)
norm_cfg = generate_norm_cfg(GPU_num=GPU_num)
# 3.1. 选择backbone
model_list = ['torchvision://resnet50', 'torchvision://resnet101', 'pretrain/swin_large-6580f57d.pth', 'pretrain/swin_base-e5c09f74.pth', 'pretrain/swin_small-7ba6d6dd.pth', 'pretrain/swin_tiny-1cdeb081.pth']
selected_model_name, select_pretrained, pretrained_pth, selected_model_info = select_pretrained_model(model_list=model_list)
if selected_model_info['type'] == 'ResNet':
depth = selected_model_info['depth']
backbone = create_dict_by_kwargs(depth = depth, type=selected_model_info['type'], norm_cfg=norm_cfg, num_stages=4, frozen_stages=-1, style='pytorch', init_cfg=dict(type='Pretrained', checkpoint=pretrained_pth))
in_channels = [256, 512, 1024, 2048]
elif selected_model_info['type'] == 'swin':
if selected_model_name == 'pretrain/swin_tiny-1cdeb081.pth':
in_channels = [96, 192, 384, 768]
depths=[2, 2, 6, 2]
backbone = dict(
_delete_=True,
type='SwinTransformer',
pretrain_img_size=224,
embed_dims=96,
depths=depths,
num_heads=[3, 6, 12, 24],
window_size=7,
mlp_ratio=4,
qkv_bias=True,
qk_scale=None,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.3,
patch_norm=True,
# out_indices=(0, 1, 2, 3),
with_cp=False,
# frozen_stages=-1,
init_cfg=dict(type='Pretrained', checkpoint=pretrained_pth))
if selected_model_name == 'pretrain/swin_small-7ba6d6dd.pth':
depths=[2, 2, 18, 2]
in_channels = [96, 192, 384, 768]
backbone = dict(
_delete_=True,
type='SwinTransformer',
pretrain_img_size=224,
embed_dims=96,
depths=depths,
num_heads=[3, 6, 12, 24],
window_size=7,
mlp_ratio=4,
qkv_bias=True,
qk_scale=None,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.3,
patch_norm=True,
# out_indices=(0, 1, 2, 3),
with_cp=False,
# frozen_stages=-1,
init_cfg=dict(type='Pretrained', checkpoint=pretrained_pth))
if selected_model_name == 'pretrain/swin_base-e5c09f74.pth':
depths=[2, 2, 18, 2]
in_channels = [128, 256, 512, 1024]
backbone = dict(
type='SwinTransformer',
pretrain_img_size=384,
embed_dims=128,
depths=depths,
num_heads=[4, 8, 16, 32],
window_size=12,
mlp_ratio=4,
qkv_bias=True,
qk_scale=None,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.3,
patch_norm=True,
# out_indices=(0, 1, 2, 3),
with_cp=False,
# frozen_stages=-1,
init_cfg=dict(type='Pretrained', checkpoint=pretrained_pth))
if selected_model_name == 'pretrain/swin_large-6580f57d.pth':
in_channels = [192, 384, 768, 1536]
depths=[2, 2, 18, 2]
backbone = dict(
type='SwinTransformer',
pretrain_img_size=384,
embed_dims=192,
depths=depths,
num_heads=[6, 12, 24, 48],
window_size=12,
mlp_ratio=4,
qkv_bias=True,
qk_scale=None,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.3,
patch_norm=True,
# out_indices=(0, 1, 2, 3),
with_cp=False,
# frozen_stages=-1,
init_cfg=dict(type='Pretrained', checkpoint=pretrained_pth))
# set all layers in backbone to lr_mult=0.1
# set all norm layers, position_embeding,
# query_embeding, level_embeding to decay_multi=0.0
backbone_norm_multi = dict(lr_mult=0.1, decay_mult=0.0)
backbone_embed_multi = dict(lr_mult=0.1, decay_mult=0.0)
embed_multi = dict(lr_mult=1.0, decay_mult=0.0)
custom_keys = {
'backbone': dict(lr_mult=0.1, decay_mult=1.0),
'backbone.patch_embed.norm': backbone_norm_multi,
'backbone.norm': backbone_norm_multi,
'absolute_pos_embed': backbone_embed_multi,
'relative_position_bias_table': backbone_embed_multi,
'query_embed': embed_multi,
'query_feat': embed_multi,
'level_embed': embed_multi
}
custom_keys.update({
f'backbone.stages.{stage_id}.blocks.{block_id}.norm': backbone_norm_multi
for stage_id, num_blocks in enumerate(depths)
for block_id in range(num_blocks)
})
custom_keys.update({
f'backbone.stages.{stage_id}.downsample.norm': backbone_norm_multi
for stage_id in range(len(depths) - 1)
})
# 3.2. decode_head
decode_head = dict(num_classes=num_classes, in_channels=in_channels,
loss_cls=dict(class_weight=[1.0] * num_classes + [0.1]))
# 3.3. optimizer部分
if selected_model_info['type'] == 'ResNet':
# optimizer
embed_multi = dict(lr_mult=1.0, decay_mult=0.0)
optimizer = dict(
type='AdamW', lr=0.0001, weight_decay=0.05, eps=1e-8, betas=(0.9, 0.999))
optim_wrapper = dict(
_delete_=True,
type='OptimWrapper',
optimizer=optimizer,
clip_grad=dict(max_norm=0.01, norm_type=2),
paramwise_cfg=dict(
custom_keys={
'backbone': dict(lr_mult=0.1, decay_mult=1.0),
'query_embed': embed_multi,
'query_feat': embed_multi,
'level_embed': embed_multi,
},
norm_decay_mult=0.0))
train_dataloader, batch_size, num_workers = generate_train_dataloader(batch_size_default=4, num_workers_default=4)
elif selected_model_info['type'] == 'swin':
# set all layers in backbone to lr_mult=0.1
# set all norm layers, position_embeding,
# query_embeding, level_embeding to decay_multi=0.0
backbone_norm_multi = dict(lr_mult=0.1, decay_mult=0.0)
backbone_embed_multi = dict(lr_mult=0.1, decay_mult=0.0)
embed_multi = dict(lr_mult=1.0, decay_mult=0.0)
custom_keys = {
'backbone': dict(lr_mult=0.1, decay_mult=1.0),
'backbone.patch_embed.norm': backbone_norm_multi,
'backbone.norm': backbone_norm_multi,
'absolute_pos_embed': backbone_embed_multi,
'relative_position_bias_table': backbone_embed_multi,
'query_embed': embed_multi,
'query_feat': embed_multi,
'level_embed': embed_multi
}
custom_keys.update({
f'backbone.stages.{stage_id}.blocks.{block_id}.norm': backbone_norm_multi
for stage_id, num_blocks in enumerate(depths)
for block_id in range(num_blocks)
})
custom_keys.update({
f'backbone.stages.{stage_id}.downsample.norm': backbone_norm_multi
for stage_id in range(len(depths) - 1)
})
# optimizer
optimizer = dict(
type='AdamW', lr=0.0001, weight_decay=0.05, eps=1e-8, betas=(0.9, 0.999))
optim_wrapper = dict(
_delete_=True,
type='OptimWrapper',
optimizer=optimizer,
clip_grad=dict(max_norm=0.01, norm_type=2),
paramwise_cfg={'custom_keys':custom_keys, 'norm_decay_mult':0.0})
# 更新train_dataloader
train_dataloader, batch_size, num_workers = generate_train_dataloader(batch_size_default=2, num_workers_default=2)
# 3.4. train_dataloader部分
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations'),
dict(
type='RandomChoiceResize',
scales=[int(max(crop_size[0], crop_size[1]) * x * 0.1) for x in range(5, 21)],
resize_type='ResizeShortestEdge',
max_size=max(crop_size[0], crop_size[1])*4),
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
dict(type='RandomFlip', prob=0.5),
dict(type='PhotoMetricDistortion'),
dict(type='PackSegInputs')
]
train_dataloader.update(dict(dataset=dict(pipeline=train_pipeline)))
# 3.5. param_scheduler部分
param_scheduler = generate_param_scheduler(train_type, train_time_or_epoch_k)
########### 3.生成各个部分 ###########
# 综合model
model = dict(
backbone = backbone,
data_preprocessor = model_data_preprocessor,
decode_head = decode_head
)
########### 4.程序写入 ###########
# 算法名称解析mask2former【Alg】 _ r50【pretrained模型深度】 - d8【_base_/models】 _ 4【GPU数量】 x b2【Batch Size大小】 - 40k【schedule】 _ cityscapes【数据集】 - 512x1024【crop_size】.py
if selected_model_info['type'] == 'ResNet':
alg_file_name = f"{alg_name}_r{depth}_{'Pre' if select_pretrained else 'NoPre'}_g{GPU_num}-{schedule_file_name.lstrip('schedule_')}_{dataset_file_name}-{crop_size[0]}x{crop_size[1]}.py"
elif selected_model_info['type'] == 'swin':
alg_file_name = f"{alg_name}_swin_{selected_model_info['size']}_{'Pre' if select_pretrained else 'NoPre'}_g{GPU_num}-{schedule_file_name.lstrip('schedule_')}_{dataset_file_name}-{crop_size[0]}x{crop_size[1]}.py"
output_configs_alg_my_alg = os.path.join(f'./configs/mask2former/', alg_file_name)
write_config_to_file(output_configs_alg_my_alg, _base_=_base_, norm_cfg=norm_cfg, crop_size=crop_size, data_preprocessor=data_preprocessor, model=model, train_dataloader=train_dataloader, optim_wrapper=optim_wrapper, param_scheduler=param_scheduler)
# 将信息写入文件
with open("_temp_.txt", "w") as file:
json.dump({'alg_infos':{'alg_file_name':alg_file_name,'alg_file_pth':output_configs_alg_my_alg}}, file)

View File

@@ -0,0 +1,246 @@
import os, sys, argparse, json
# 添加上级目录到 sys.path
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../')))
from Initial_Alg_Program.Initial_Alg_Select_Tool import select_crop_size, select_pretrained_model, select_test_cfg_slide
from Initial_Alg_Program.Initial_Alg_Gen_Tool import format_all_data, write_config_to_file, create_dict_by_kwargs, get_var_from_py_file
from Initial_Alg_Program.Initial_Alg_Gen_0_base_ import generate_base_config
from Initial_Alg_Program.Initial_Alg_Gen_1_norm_cfg import generate_norm_cfg
from Initial_Alg_Program.Initial_Alg_Gen_2_data_preprocessor import generate_data_preprocessor
from Initial_Alg_Program.Initial_Alg_Gen_3_model import generate_model_pretrained, generate_model_backbone, generate_model_data_preprocessor, generate_model_decode_head, generate_model_auxiliary_head, generate_model_train_cfg, generate_model_test_cfg
from Initial_Alg_Program.Initial_Alg_Gen_4_optimizer import generate_optim_wrapper, generate_param_scheduler
from Initial_Alg_Program.Initial_Alg_Gen_5_train_dataloader import generate_train_dataloader
if __name__ == '__main__':
########### 1.超参数-算法默认参数 ###########
alg_file_name = 'maskformer_my'
########### 2.参数解析 ###########
parser = argparse.ArgumentParser(description="Run algorithm with specified parameters.")
# 添加命令行参数
parser.add_argument('--alg_name', type=str, required=True, help="Algorithm name (e.g., 'ann')")
parser.add_argument('--dataset_file_name', type=str, required=True, help="Dataset file name (e.g., 'my_dataset_model')")
parser.add_argument('--mean', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
parser.add_argument('--std', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
parser.add_argument('--dataset_num_classes', type=int, required=True, help="Dataset class num (e.g., 36)")
parser.add_argument('--GPU_num', type=int, required=True, help="Num of GPU to train (e.g., 1)")
parser.add_argument('--schedule_file_name', type=str, required=True, help="Schedule file name (e.g., 'schedule_4k_check_400')")
parser.add_argument('--train_type', type=str, required=True, help="Schedule train type (e.g., Epoch or Iteration)")
parser.add_argument('--train_time_or_epoch_k', type=int, required=True, help="Schedule train time k (e.g., 4)")
# 解析命令行参数
args = parser.parse_args()
# 变量赋值
alg_name = args.alg_name
dataset_file_name = args.dataset_file_name
mean = json.loads(args.mean) # 使用json.loads转换成list
std = json.loads(args.std) # 使用json.loads转换成list
num_classes = args.dataset_num_classes
schedule_file_name = args.schedule_file_name
train_type = args.train_type
train_time_or_epoch_k = args.train_time_or_epoch_k
GPU_num = args.GPU_num
########### 3. 设定参数 ###########
_base_ = generate_base_config(alg_file_name=alg_file_name, dataset_file_name=dataset_file_name, schedule_file_name=schedule_file_name) # _base_无算法
crop_size = select_crop_size(predefined_options=[(512, 512), (640, 640), (512, 1024)]) # 选择切割大小
data_preprocessor = generate_data_preprocessor(crop_size = crop_size, mean = mean, std = std)
model_data_preprocessor = data_preprocessor # 修改data_preprocessor
model_data_preprocessor = generate_model_data_preprocessor(model_data_preprocessor=model_data_preprocessor)
norm_cfg = generate_norm_cfg(GPU_num=GPU_num)
# 3.1. 选择backbone
model_list = ['torchvision://resnet50', 'torchvision://resnet101', 'pretrain/swin_large-6580f57d.pth', 'pretrain/swin_base-e5c09f74.pth', 'pretrain/swin_small-7ba6d6dd.pth', 'pretrain/swin_tiny-1cdeb081.pth']
selected_model_name, select_pretrained, pretrained_pth, selected_model_info = select_pretrained_model(model_list=model_list)
if selected_model_info['type'] == 'ResNet':
depth = selected_model_info['depth']
backbone = create_dict_by_kwargs(depth = depth, type=selected_model_info['type'], norm_cfg=norm_cfg, num_stages=4, frozen_stages=-1, style='pytorch', init_cfg=dict(type='Pretrained', checkpoint=pretrained_pth))
in_channels = [256, 512, 1024, 2048]
elif selected_model_info['type'] == 'swin':
if selected_model_name == 'pretrain/swin_tiny-1cdeb081.pth':
in_channels = [96, 192, 384, 768]
depths=[2, 2, 6, 2]
backbone = dict(
_delete_=True,
type='SwinTransformer',
pretrain_img_size=224,
embed_dims=96,
depths=depths,
num_heads=[3, 6, 12, 24],
window_size=7,
mlp_ratio=4,
qkv_bias=True,
qk_scale=None,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.3,
patch_norm=True,
act_cfg=dict(type='GELU'), # ADD
use_abs_pos_embed=False, # ADD
# out_indices=(0, 1, 2, 3),
with_cp=False,
init_cfg=dict(type='Pretrained', checkpoint=pretrained_pth))
if selected_model_name == 'pretrain/swin_small-7ba6d6dd.pth':
depths=[2, 2, 18, 2]
in_channels = [96, 192, 384, 768]
backbone = dict(
_delete_=True,
type='SwinTransformer',
pretrain_img_size=224,
embed_dims=96,
depths=depths,
num_heads=[3, 6, 12, 24],
window_size=7,
mlp_ratio=4,
qkv_bias=True,
qk_scale=None,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.3,
patch_norm=True,
act_cfg=dict(type='GELU'), # ADD
use_abs_pos_embed=False, # ADD
# out_indices=(0, 1, 2, 3),
with_cp=False,
init_cfg=dict(type='Pretrained', checkpoint=pretrained_pth))
if selected_model_name == 'pretrain/swin_base-e5c09f74.pth':
depths=[2, 2, 18, 2]
in_channels = [128, 256, 512, 1024]
backbone = dict(
type='SwinTransformer',
pretrain_img_size=384,
embed_dims=128,
depths=depths,
num_heads=[4, 8, 16, 32],
window_size=12,
mlp_ratio=4,
qkv_bias=True,
qk_scale=None,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.3,
patch_norm=True,
act_cfg=dict(type='GELU'), # ADD
use_abs_pos_embed=False, # ADD
# out_indices=(0, 1, 2, 3),
with_cp=False,
init_cfg=dict(type='Pretrained', checkpoint=pretrained_pth))
if selected_model_name == 'pretrain/swin_large-6580f57d.pth':
in_channels = [192, 384, 768, 1536]
depths=[2, 2, 18, 2]
backbone = dict(
type='SwinTransformer',
pretrain_img_size=384,
embed_dims=192,
depths=depths,
num_heads=[6, 12, 24, 48],
window_size=12,
mlp_ratio=4,
qkv_bias=True,
qk_scale=None,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.3,
patch_norm=True,
act_cfg=dict(type='GELU'), # ADD
use_abs_pos_embed=False, # ADD
# out_indices=(0, 1, 2, 3),
with_cp=False,
init_cfg=dict(type='Pretrained', checkpoint=pretrained_pth))
# 3.2. decode_head
decode_head = dict(num_classes=num_classes, in_channels=in_channels,
loss_cls=dict(class_weight=[1.0] * num_classes + [0.1]))
# 3.3. optimizer部分
if selected_model_info['type'] == 'ResNet':
# optimizer
embed_multi = dict(lr_mult=1.0, decay_mult=0.0)
optimizer = dict(
type='AdamW', lr=0.0001, weight_decay=0.05, eps=1e-8, betas=(0.9, 0.999))
optim_wrapper = dict(
_delete_=True,
type='OptimWrapper',
optimizer=optimizer,
clip_grad=dict(max_norm=0.01, norm_type=2),
paramwise_cfg=dict(
custom_keys={
'backbone': dict(lr_mult=0.1, decay_mult=1.0),
'query_embed': embed_multi,
'query_feat': embed_multi,
'level_embed': embed_multi,
},
norm_decay_mult=0.0))
train_dataloader, batch_size, num_workers = generate_train_dataloader(batch_size_default=4, num_workers_default=4)
elif selected_model_info['type'] == 'swin':
# set all layers in backbone to lr_mult=0.1
# set all norm layers, position_embeding,
# query_embeding, level_embeding to decay_multi=0.0
backbone_norm_multi = dict(lr_mult=0.1, decay_mult=0.0)
backbone_embed_multi = dict(lr_mult=0.1, decay_mult=0.0)
embed_multi = dict(lr_mult=1.0, decay_mult=0.0)
custom_keys = {
'backbone': dict(lr_mult=0.1, decay_mult=1.0),
'backbone.patch_embed.norm': backbone_norm_multi,
'backbone.norm': backbone_norm_multi,
'absolute_pos_embed': backbone_embed_multi,
'relative_position_bias_table': backbone_embed_multi,
'query_embed': embed_multi,
'query_feat': embed_multi,
'level_embed': embed_multi
}
custom_keys.update({
f'backbone.stages.{stage_id}.blocks.{block_id}.norm': backbone_norm_multi
for stage_id, num_blocks in enumerate(depths)
for block_id in range(num_blocks)
})
custom_keys.update({
f'backbone.stages.{stage_id}.downsample.norm': backbone_norm_multi
for stage_id in range(len(depths) - 1)
})
# optimizer
optimizer = dict(
type='AdamW', lr=0.0001, weight_decay=0.05, eps=1e-8, betas=(0.9, 0.999))
optim_wrapper = dict(
_delete_=True,
type='OptimWrapper',
optimizer=optimizer,
clip_grad=dict(max_norm=0.01, norm_type=2),
paramwise_cfg={'custom_keys':custom_keys, 'norm_decay_mult':0.0})
# 更新train_dataloader
train_dataloader, batch_size, num_workers = generate_train_dataloader(batch_size_default=2, num_workers_default=2)
# 3.5. param_scheduler部分
param_scheduler = generate_param_scheduler(train_time_or_epoch_k)
########### 3.生成各个部分 ###########
# 综合model
model = dict(
backbone = backbone,
data_preprocessor = model_data_preprocessor,
decode_head = decode_head
)
########### 4.程序写入 ###########
# 算法名称解析maskformer【Alg】 _ r50【pretrained模型深度】 - d8【_base_/models】 _ 4【GPU数量】 x b2【Batch Size大小】 - 40k【schedule】 _ cityscapes【数据集】 - 512x1024【crop_size】.py
if selected_model_info['type'] == 'ResNet':
alg_file_name = f"{alg_name}_r{depth}_{'Pre' if select_pretrained else 'NoPre'}_g{GPU_num}-{schedule_file_name.lstrip('schedule_')}_{dataset_file_name}-{crop_size[0]}x{crop_size[1]}.py"
elif selected_model_info['type'] == 'swin':
alg_file_name = f"{alg_name}_swin_{selected_model_info['size']}_{'Pre' if select_pretrained else 'NoPre'}_g{GPU_num}-{schedule_file_name.lstrip('schedule_')}_{dataset_file_name}-{crop_size[0]}x{crop_size[1]}.py"
output_configs_alg_my_alg = os.path.join(f'./configs/maskformer/', alg_file_name)
write_config_to_file(output_configs_alg_my_alg, _base_=_base_, norm_cfg=norm_cfg, crop_size=crop_size, data_preprocessor=data_preprocessor, model=model, train_dataloader=train_dataloader, optim_wrapper=optim_wrapper, param_scheduler=param_scheduler)
# 将信息写入文件
with open("_temp_.txt", "w") as file:
json.dump({'alg_infos':{'alg_file_name':alg_file_name,'alg_file_pth':output_configs_alg_my_alg}}, file)

View File

@@ -0,0 +1,152 @@
import os, sys, argparse, json
import importlib.util
# 添加上级目录到 sys.path
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../')))
from Initial_Alg_Program.Initial_Alg_Select_Tool import select_crop_size, select_pretrained_model, select_sampler
from Initial_Alg_Program.Initial_Alg_Gen_Tool import format_all_data, write_config_to_file, create_dict_by_kwargs, get_var_from_file, update_list_dict_var, get_var_from_py_file
from Initial_Alg_Program.Initial_Alg_Gen_0_base_ import generate_base_config
from Initial_Alg_Program.Initial_Alg_Gen_1_norm_cfg import generate_norm_cfg
from Initial_Alg_Program.Initial_Alg_Gen_2_data_preprocessor import generate_data_preprocessor
from Initial_Alg_Program.Initial_Alg_Gen_3_model import generate_model_pretrained, generate_model_backbone, generate_model_data_preprocessor, generate_model_decode_head, generate_model_auxiliary_head, generate_model_train_cfg, generate_model_test_cfg
from Initial_Alg_Program.Initial_Alg_Gen_4_optimizer import generate_optim_wrapper, generate_param_scheduler
def get_train_pipeline_from_config(dataset_file_name: str):
"""
动态加载指定的配置文件并读取其中的 'train_pipeline' 列表。
Args:
dataset_file_name (str): 数据集配置文件的名字 (不包含.py后缀)。
Returns:
list | None: 如果成功找到,则返回 train_pipeline 列表;
如果文件不存在或文件中没有 train_pipeline 变量,则返回 None。
"""
# 1. 构建配置文件的相对路径
# 使用 os.path.join 来确保路径在不同操作系统上都是正确的
config_path = os.path.join('configs/_base_/datasets/', f'{dataset_file_name}.py')
# 2. 检查文件是否存在
if not os.path.exists(config_path):
print(f"\033[30错误配置文件不存在于 '{config_path}'\033[0m")
return None
try:
# 3. 动态加载 .py 文件作为一个模块
# 创建一个模块规范 (spec)
# 模块名可以是任意的,这里用文件名以防冲突
spec = importlib.util.spec_from_file_location(dataset_file_name, config_path)
# 根据规范创建一个模块对象
config_module = importlib.util.module_from_spec(spec)
# 执行模块代码,使其所有变量(如 train_pipeline都加载到模块对象中
spec.loader.exec_module(config_module)
# 4. 从加载的模块中获取 train_pipeline 变量
# 使用 getattr 来安全地获取,如果不存在,可以设置一个默认值
train_pipeline = getattr(config_module, 'train_pipeline', None)
if train_pipeline is None:
print(f"错误:在 '{config_path}' 文件中未找到 'train_pipeline' 变量。")
return None
return train_pipeline
except Exception as e:
print(f"读取配置文件时发生未知错误: {e}")
return None
if __name__ == '__main__':
########### 1.超参数-算法默认参数 ###########
alg_file_name = 'pidnet'
alg_file_pth = f"configs/_base_/models/{alg_file_name}.py"
########### 2.参数解析 ###########
parser = argparse.ArgumentParser(description="Run algorithm with specified parameters.")
# 添加命令行参数
parser.add_argument('--alg_name', type=str, required=True, help="Algorithm name (e.g., 'ann')")
parser.add_argument('--dataset_file_name', type=str, required=True, help="Dataset file name (e.g., 'my_dataset_model')")
parser.add_argument('--mean', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
parser.add_argument('--std', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
parser.add_argument('--dataset_num_classes', type=int, required=True, help="Dataset class num (e.g., 36)")
parser.add_argument('--GPU_num', type=int, required=True, help="Num of GPU to train (e.g., 1)")
parser.add_argument('--schedule_file_name', type=str, required=True, help="Schedule file name (e.g., 'schedule_4k_check_400')")
parser.add_argument('--train_type', type=str, required=True, help="Schedule train type (e.g., Epoch or Iteration)")
parser.add_argument('--train_time_or_epoch_k', type=int, required=True, help="Schedule train time k (e.g., 4)")
# 解析命令行参数
args = parser.parse_args()
# 变量赋值
alg_name = args.alg_name
dataset_file_name = args.dataset_file_name
mean = json.loads(args.mean) # 使用json.loads转换成list
std = json.loads(args.std) # 使用json.loads转换成list
num_classes = args.dataset_num_classes
schedule_file_name = args.schedule_file_name
train_type = args.train_type
train_time_or_epoch_k = args.train_time_or_epoch_k
GPU_num = args.GPU_num
########### 3.1. 生成_base_部分 ###########
_base_ = generate_base_config(alg_file_name=alg_file_name, dataset_file_name=dataset_file_name, schedule_file_name=schedule_file_name)
########### 3.1. 生成norm_cfg部分 ###########
norm_cfg = generate_norm_cfg(GPU_num=GPU_num)
########### 3.1. 生成train_pipeline部分 ###########
config_path = os.path.join('../_base_/datasets/', f'{dataset_file_name}.py')
train_pipeline = get_train_pipeline_from_config(dataset_file_name)
train_pipeline.insert(-1, dict(type='GenerateEdge', edge_width=4)) # For pidnet
train_dataloader = dict(dataset=dict(pipeline=train_pipeline))
########### 3.2. 生成data_preprocessor部分 ###########
crop_size = select_crop_size(predefined_options=[(512, 512)]) # 选择切割大小
data_preprocessor = generate_data_preprocessor(crop_size = crop_size, mean = mean, std = std)
########### 3.3. 生成model部分 ###########
# 3.3.1. 预处理data_preprocessor
model_data_preprocessor = data_preprocessor # 修改data_preprocessor
# # 3.3.2. 骨架backbone、解码器decode_head
model_list = ['openmmlab/pidnet-s', 'openmmlab/pidnet-m', 'openmmlab/pidnet-l']
selected_model_name, select_pretrained, pretrained_pth, selected_model_info = select_pretrained_model(model_list=model_list)
if selected_model_name == 'openmmlab/pidnet-s':
backbone = dict(channels=32, ppm_channels=96, num_stem_blocks=2, num_branch_blocks=3, init_cfg=dict(type='Pretrained', checkpoint=pretrained_pth))
decode_head = dict(num_classes=num_classes, in_channels=128, channels=128)
model_size = 'small'
elif selected_model_name == 'openmmlab/pidnet-m':
backbone = dict(channels=64, ppm_channels=96, num_stem_blocks=2, num_branch_blocks=3, init_cfg=dict(type='Pretrained', checkpoint=pretrained_pth))
decode_head = dict(num_classes=num_classes, in_channels=256, channels=128)
model_size = 'middle'
elif selected_model_name == 'openmmlab/pidnet-l':
backbone = dict(channels=64, ppm_channels=112, num_stem_blocks=3, num_branch_blocks=4, init_cfg=dict(type='Pretrained', checkpoint=pretrained_pth))
decode_head = dict(num_classes=num_classes, in_channels=256, channels=256)
model_size = 'large'
else:
quit("Error: 未知的模型名称")
# 3.3.4. 综合model
model = dict(
data_preprocessor = model_data_preprocessor,
backbone = backbone,
decode_head = decode_head,
)
########### 3.4. 生成optim、param_scheduler部分 ###########
optim_wrapper = generate_optim_wrapper()
param_scheduler = generate_param_scheduler(train_type, train_time_or_epoch_k)
########### 4.程序写入 ###########
# 4.1. 算法名称解析bisenetv2【Alg】 _ r50【pretrained模型深度】 - d8【_base_/models】 _ 4【GPU数量】 x b2【Batch Size大小】 - 40k【schedule】 _ cityscapes【数据集】 - 512x1024【crop_size】.py
alg_file_name = f"{alg_name}_{model_size}_g{GPU_num}-{schedule_file_name.lstrip('schedule_')}_{dataset_file_name}-{crop_size[0]}x{crop_size[1]}.py"
output_configs_alg_my_alg = os.path.join(f'./configs/pidnet/', alg_file_name)
write_config_to_file(output_configs_alg_my_alg, _base_=_base_, norm_cfg=norm_cfg, crop_size=crop_size, data_preprocessor=data_preprocessor, train_pipeline = train_pipeline, train_dataloader = train_dataloader, model=model, optim_wrapper=optim_wrapper, param_scheduler=param_scheduler)
# 4.2. 将信息临时写入文件
with open("_temp_.txt", "w") as file:
json.dump({'alg_infos':{'alg_file_name':alg_file_name,'alg_file_pth':output_configs_alg_my_alg}}, file)

View File

@@ -0,0 +1,111 @@
import os, sys, argparse, json
# 添加上级目录到 sys.path
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../')))
from Initial_Alg_Program.Initial_Alg_Select_Tool import select_crop_size, select_pretrained_model, select_test_cfg_slide, select_sampler
from Initial_Alg_Program.Initial_Alg_Gen_Tool import format_all_data, write_config_to_file, create_dict_by_kwargs
from Initial_Alg_Program.Initial_Alg_Gen_0_base_ import generate_base_config
from Initial_Alg_Program.Initial_Alg_Gen_1_norm_cfg import generate_norm_cfg
from Initial_Alg_Program.Initial_Alg_Gen_2_data_preprocessor import generate_data_preprocessor
from Initial_Alg_Program.Initial_Alg_Gen_3_model import generate_model_pretrained, generate_model_backbone, generate_model_data_preprocessor, generate_model_decode_head, generate_model_auxiliary_head, generate_model_train_cfg, generate_model_test_cfg
from Initial_Alg_Program.Initial_Alg_Gen_4_optimizer import generate_optim_wrapper, generate_param_scheduler
if __name__ == '__main__':
########### 1.超参数-算法默认参数 ###########
alg_file_name = 'pspnet_r50-d8'
########### 2.参数解析 ###########
parser = argparse.ArgumentParser(description="Run algorithm with specified parameters.")
# 添加命令行参数
parser.add_argument('--alg_name', type=str, required=True, help="Algorithm name (e.g., 'ann')")
parser.add_argument('--dataset_file_name', type=str, required=True, help="Dataset file name (e.g., 'my_dataset_model')")
parser.add_argument('--mean', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
parser.add_argument('--std', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
parser.add_argument('--dataset_num_classes', type=int, required=True, help="Dataset class num (e.g., 36)")
parser.add_argument('--GPU_num', type=int, required=True, help="Num of GPU to train (e.g., 1)")
parser.add_argument('--schedule_file_name', type=str, required=True, help="Schedule file name (e.g., 'schedule_4k_check_400')")
parser.add_argument('--train_type', type=str, required=True, help="Schedule train type (e.g., Epoch or Iteration)")
parser.add_argument('--train_time_or_epoch_k', type=int, required=True, help="Schedule train time k (e.g., 4)")
# 解析命令行参数
args = parser.parse_args()
# 变量赋值
alg_name = args.alg_name
dataset_file_name = args.dataset_file_name
mean = json.loads(args.mean) # 使用json.loads转换成list
std = json.loads(args.std) # 使用json.loads转换成list
num_classes = args.dataset_num_classes
schedule_file_name = args.schedule_file_name
train_type = args.train_type
train_time_or_epoch_k = args.train_time_or_epoch_k
GPU_num = args.GPU_num
########### 3. 设定参数 ###########
# 3.1. 自选参数
crop_size = select_crop_size(predefined_options=[(512, 512), (512, 1024), (769, 769),(1280,1280)]) # 选择切割大小
model_list = ['openmmlab/resnet18_v1c', 'openmmlab/resnet50_v1c', 'openmmlab/resnet101_v1c', 'torchvision://resnet18', 'torchvision://resnet50', 'torchvision://resnet101', ]
selected_model_name, select_pretrained, pretrained_pth, selected_model_info = select_pretrained_model(model_list=model_list)
depth = selected_model_info['depth']
type_model = selected_model_info['type']
test_cfg_mode, test_cfg_crop_div_stride = select_test_cfg_slide(crop_size)
# 3.2. 自动生成参数
# decode、auxiliary损失下载方式
decode_head_loss_decode_dict=dict(type='DiceLoss', use_sigmoid=False, loss_weight=1.0) # decode损失函数 # TODO
auxiliary_head_loss_decode_dict=dict(type='DiceLoss', use_sigmoid=False, loss_weight=0.4) # DiceLoss损失函数 # TODO
# decode_head_loss_decode_dict=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0) # decode损失函数 # TODO
# auxiliary_head_loss_decode_dict=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4) # DiceLoss损失函数 # TODO
########### 3.生成各个部分 ###########
_base_ = generate_base_config(alg_file_name=alg_file_name, dataset_file_name=dataset_file_name, schedule_file_name=schedule_file_name)
norm_cfg = generate_norm_cfg(GPU_num=GPU_num)
data_preprocessor = generate_data_preprocessor(crop_size = crop_size, mean = mean, std = std)
pretrained = generate_model_pretrained(pretrained_pth = pretrained_pth)
backbone = create_dict_by_kwargs(depth=depth, type=type_model) # generate_model_backbone(depth=depth,)
model_data_preprocessor = data_preprocessor # 修改data_preprocessor
model_data_preprocessor = generate_model_data_preprocessor(model_data_preprocessor=model_data_preprocessor)
if depth == 18:
decode_head = create_dict_by_kwargs(in_channels=512, channels=128, num_classes=num_classes, loss_decode=decode_head_loss_decode_dict, norm_cfg=norm_cfg,)
elif depth == 101 or depth == 50:
decode_head = create_dict_by_kwargs(in_channels=2048, channels=512, num_classes=num_classes, loss_decode=decode_head_loss_decode_dict, norm_cfg=norm_cfg,)
if depth == 18:
auxiliary_head = create_dict_by_kwargs(in_channels=256, channels=64, num_classes=num_classes, loss_decode=decode_head_loss_decode_dict, norm_cfg=norm_cfg,)
elif depth == 101 or depth == 50:
auxiliary_head = create_dict_by_kwargs(in_channels=1024, channels=256, num_classes=num_classes, loss_decode=decode_head_loss_decode_dict, norm_cfg=norm_cfg,)
# 综合model
model = dict(
pretrained = pretrained,
backbone = backbone,
data_preprocessor = model_data_preprocessor,
decode_head = decode_head,
auxiliary_head = auxiliary_head,
)
test_cfg = generate_model_test_cfg(test_cfg_mode=test_cfg_mode, crop_size=crop_size, test_cfg_crop_div_stride=test_cfg_crop_div_stride)
optim_wrapper = generate_optim_wrapper()
param_scheduler = generate_param_scheduler(train_type, train_time_or_epoch_k)
########### 4.程序写入 ###########
# 算法名称解析pspnet【Alg】 _ r50【pretrained模型深度】 - d8【_base_/models】 _ 4【GPU数量】 x b2【Batch Size大小】 - 40k【schedule】 _ cityscapes【数据集】 - 512x1024【crop_size】.py
alg_file_name = f"{alg_name}_r{depth}_{'Pre' if select_pretrained else 'NoPre'}_g{GPU_num}-{schedule_file_name.lstrip('schedule_')}_{dataset_file_name}-{crop_size[0]}x{crop_size[1]}.py"
output_configs_alg_my_alg = os.path.join(f'./configs/pspnet/', alg_file_name)
write_config_to_file(output_configs_alg_my_alg, _base_=_base_, norm_cfg=norm_cfg, crop_size=crop_size, data_preprocessor=data_preprocessor, model=model, test_cfg=test_cfg, optim_wrapper=optim_wrapper, param_scheduler=param_scheduler)
# 将信息写入文件
with open("_temp_.txt", "w") as file:
json.dump({'alg_infos':{'alg_file_name':alg_file_name,'alg_file_pth':output_configs_alg_my_alg}}, file)

View File

@@ -0,0 +1,139 @@
import os, sys, argparse, json
import importlib.util
# 添加上级目录到 sys.path
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../')))
from Initial_Alg_Program.Initial_Alg_Select_Tool import select_crop_size, select_pretrained_model, select_sampler
from Initial_Alg_Program.Initial_Alg_Gen_Tool import format_all_data, write_config_to_file, create_dict_by_kwargs, get_var_from_file, update_list_dict_var, get_var_from_py_file
from Initial_Alg_Program.Initial_Alg_Gen_0_base_ import generate_base_config
from Initial_Alg_Program.Initial_Alg_Gen_1_norm_cfg import generate_norm_cfg
from Initial_Alg_Program.Initial_Alg_Gen_2_data_preprocessor import generate_data_preprocessor
from Initial_Alg_Program.Initial_Alg_Gen_3_model import generate_model_pretrained, generate_model_backbone, generate_model_data_preprocessor, generate_model_decode_head, generate_model_auxiliary_head, generate_model_train_cfg, generate_model_test_cfg
from Initial_Alg_Program.Initial_Alg_Gen_4_optimizer import generate_optim_wrapper, generate_param_scheduler
def get_pretrained_model_choice():
"""
提示用户选择是否使用预训练模型,并返回他们的选择。
返回:
str: 如果用户选择 '' 或默认,则返回 ''
如果用户选择 '',则返回 ''
"""
while True:
# 如果用户直接按回车input()返回空字符串or "1" 使其默认值为 "1"
choice = input("可用的预训练模型选项:\n1. 是\n2. 否\n请选择是否使用预训练模型默认1: ") or "1"
if choice == "1":
print("您已选择1. 是")
return True
elif choice == "2":
print("您已选择2. 否")
return False
else:
# 如果输入了其他无效内容(如"3"),则提示重新输入
print("\n无效输入,请输入 1 或 2。\n")
if __name__ == '__main__':
########### 1.超参数-算法默认参数 ###########
alg_file_name = 'stdc'
alg_file_pth = f"configs/_base_/models/{alg_file_name}.py"
########### 2.参数解析 ###########
parser = argparse.ArgumentParser(description="Run algorithm with specified parameters.")
# 添加命令行参数
parser.add_argument('--alg_name', type=str, required=True, help="Algorithm name (e.g., 'ann')")
parser.add_argument('--dataset_file_name', type=str, required=True, help="Dataset file name (e.g., 'my_dataset_model')")
parser.add_argument('--mean', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
parser.add_argument('--std', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
parser.add_argument('--dataset_num_classes', type=int, required=True, help="Dataset class num (e.g., 36)")
parser.add_argument('--GPU_num', type=int, required=True, help="Num of GPU to train (e.g., 1)")
parser.add_argument('--schedule_file_name', type=str, required=True, help="Schedule file name (e.g., 'schedule_4k_check_400')")
parser.add_argument('--train_type', type=str, required=True, help="Schedule train type (e.g., Epoch or Iteration)")
parser.add_argument('--train_time_or_epoch_k', type=int, required=True, help="Schedule train time k (e.g., 4)")
# 解析命令行参数
args = parser.parse_args()
# 变量赋值
alg_name = args.alg_name
dataset_file_name = args.dataset_file_name
mean = json.loads(args.mean) # 使用json.loads转换成list
std = json.loads(args.std) # 使用json.loads转换成list
num_classes = args.dataset_num_classes
schedule_file_name = args.schedule_file_name
train_type = args.train_type
train_time_or_epoch_k = args.train_time_or_epoch_k
GPU_num = args.GPU_num
########### 3.1. 生成_base_部分 ###########
_base_ = generate_base_config(alg_file_name=alg_file_name, dataset_file_name=dataset_file_name, schedule_file_name=schedule_file_name)
########### 3.1. 生成norm_cfg部分 ###########
norm_cfg = generate_norm_cfg(GPU_num=GPU_num)
########### 3.2. 生成data_preprocessor部分 ###########
crop_size = select_crop_size(predefined_options=[(512, 512)]) # 选择切割大小
data_preprocessor = generate_data_preprocessor(crop_size = crop_size, mean = mean, std = std)
########### 3.3. 生成model部分 ###########
# 3.3.1. 预处理data_preprocessor
model_data_preprocessor = data_preprocessor # 修改data_preprocessor
# # 3.3.2. 骨架backbone、解码器decode_head
decode_head_loss_decode=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0) # Way Ori
decode_head_loss_decode=dict(type='DiceLoss', use_sigmoid=False, loss_weight=1.0) # Way 1
# auxiliary_head_loss_decode_dict=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0) # Way Ori: DiceLoss损失函数 # TODO
auxiliary_head_loss_decode_dict=dict(type='DiceLoss', use_sigmoid=False, loss_weight=1.0) # Way 1: DiceLoss损失函数 # TODO
auxiliary_head = get_var_from_py_file(alg_file_pth, 'model')['auxiliary_head']
for i in range(len(auxiliary_head)):
if auxiliary_head[i]['type'] == 'FCNHead':
auxiliary_head[i].update(num_classes=num_classes, loss_decode=auxiliary_head_loss_decode_dict)
model_list = ["openmmlab/stdc1", "openmmlab/stdc2"]
selected_model_name, select_pretrained, pretrained_pth, selected_model_info = select_pretrained_model(model_list=model_list)
use_pretrained = get_pretrained_model_choice()
if selected_model_name == "openmmlab/stdc1":
if use_pretrained:
backbone = dict(backbone_cfg=dict(stdc_type='STDCNet1'), init_cfg=dict(type='Pretrained', checkpoint=pretrained_pth))
Pre = "Pre"
else:
backbone = dict(backbone_cfg=dict(stdc_type='STDCNet1'))
Pre = "NoPre"
decode_head = dict(num_classes=num_classes, loss_decode=decode_head_loss_decode)
model_size = 'V1_'+Pre
elif selected_model_name == "openmmlab/stdc2":
if use_pretrained:
backbone = dict(backbone_cfg=dict(stdc_type='STDCNet2'), init_cfg=dict(type='Pretrained', checkpoint=pretrained_pth))
Pre = "Pre"
else:
backbone = dict(backbone_cfg=dict(stdc_type='STDCNet2'))
Pre = "NoPre"
decode_head = dict(num_classes=num_classes, loss_decode=decode_head_loss_decode)
model_size = 'V2_'+Pre
else:
quit("Error: 未知的模型名称")
# 3.3.4. 综合model
model = dict(
data_preprocessor = model_data_preprocessor,
backbone = backbone,
decode_head = decode_head,
)
########### 3.4. 生成optim、param_scheduler部分 ###########
optim_wrapper = generate_optim_wrapper()
param_scheduler = generate_param_scheduler(train_type, train_time_or_epoch_k)
########### 4.程序写入 ###########
# 4.1. 算法名称解析bisenetv2【Alg】 _ r50【pretrained模型深度】 - d8【_base_/models】 _ 4【GPU数量】 x b2【Batch Size大小】 - 40k【schedule】 _ cityscapes【数据集】 - 512x1024【crop_size】.py
alg_file_name = f"{alg_name}_{model_size}_g{GPU_num}-{schedule_file_name.lstrip('schedule_')}_{dataset_file_name}-{crop_size[0]}x{crop_size[1]}.py"
output_configs_alg_my_alg = os.path.join(f'./configs/stdc/', alg_file_name)
write_config_to_file(output_configs_alg_my_alg, _base_=_base_, norm_cfg=norm_cfg, crop_size=crop_size, data_preprocessor=data_preprocessor, model=model, optim_wrapper=optim_wrapper, param_scheduler=param_scheduler)
# 4.2. 将信息临时写入文件
with open("_temp_.txt", "w") as file:
json.dump({'alg_infos':{'alg_file_name':alg_file_name,'alg_file_pth':output_configs_alg_my_alg}}, file)

View File

@@ -0,0 +1,105 @@
import os, sys, argparse, json
# 添加上级目录到 sys.path
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../')))
from Initial_Alg_Program.Initial_Alg_Select_Tool import select_crop_size, select_pretrained_model, select_test_cfg_slide, select_sampler
from Initial_Alg_Program.Initial_Alg_Gen_Tool import format_all_data, write_config_to_file, create_dict_by_kwargs
from Initial_Alg_Program.Initial_Alg_Gen_0_base_ import generate_base_config
from Initial_Alg_Program.Initial_Alg_Gen_1_norm_cfg import generate_norm_cfg
from Initial_Alg_Program.Initial_Alg_Gen_2_data_preprocessor import generate_data_preprocessor
from Initial_Alg_Program.Initial_Alg_Gen_3_model import generate_model_pretrained, generate_model_backbone, generate_model_data_preprocessor, generate_model_decode_head, generate_model_auxiliary_head, generate_model_train_cfg, generate_model_test_cfg
from Initial_Alg_Program.Initial_Alg_Gen_4_optimizer import generate_optim_wrapper, generate_param_scheduler
if __name__ == '__main__':
########### 1.超参数-算法默认参数 ###########
alg_file_name = 'deeplabv3_unet_s5-d16'
########### 2.参数解析 ###########
parser = argparse.ArgumentParser(description="Run algorithm with specified parameters.")
# 添加命令行参数
parser.add_argument('--alg_name', type=str, required=True, help="Algorithm name (e.g., 'ann')")
parser.add_argument('--dataset_file_name', type=str, required=True, help="Dataset file name (e.g., 'my_dataset_model')")
parser.add_argument('--mean', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
parser.add_argument('--std', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
parser.add_argument('--dataset_num_classes', type=int, required=True, help="Dataset class num (e.g., 36)")
parser.add_argument('--GPU_num', type=int, required=True, help="Num of GPU to train (e.g., 1)")
parser.add_argument('--schedule_file_name', type=str, required=True, help="Schedule file name (e.g., 'schedule_4k_check_400')")
parser.add_argument('--train_type', type=str, required=True, help="Schedule train type (e.g., Epoch or Iteration)")
parser.add_argument('--train_time_or_epoch_k', type=int, required=True, help="Schedule train time k (e.g., 4)")
# 解析命令行参数
args = parser.parse_args()
# 变量赋值
alg_name = args.alg_name
dataset_file_name = args.dataset_file_name
mean = json.loads(args.mean) # 使用json.loads转换成list
std = json.loads(args.std) # 使用json.loads转换成list
num_classes = args.dataset_num_classes
schedule_file_name = args.schedule_file_name
train_type = args.train_type
train_time_or_epoch_k = args.train_time_or_epoch_k
GPU_num = args.GPU_num
########### 3. 设定参数 ###########
# 3.1. 自选参数
crop_size = select_crop_size(predefined_options=[(128, 128), (256, 256)]) # 选择切割大小
test_cfg_mode, test_cfg_crop_div_stride = select_test_cfg_slide(crop_size, select_slide=True) # 默认选择滑动模式
# 3.2. 自动生成参数
# decode、auxiliary损失下载方式
decode_head_loss_decode_dict=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0) # Way Ori: decode损失函数 # TODO
decode_head_loss_decode_dict=dict(type='DiceLoss', use_sigmoid=False, loss_weight=1.0) # Way 1: decode损失函数 # TODO
decode_head_loss_decode_dict = [
dict(type='CrossEntropyLoss', loss_name='loss_ce', loss_weight=1.0),
dict(type='DiceLoss', loss_name='loss_dice', loss_weight=3.0)]
auxiliary_head_loss_decode_dict=dict(type='DiceLoss', use_sigmoid=False, loss_weight=0.4) # Way 1: DiceLoss损失函数 # TODO
# 开启align_corners
if test_cfg_mode == "slide":
align_corners = True
test_slide = 'testslide' # 算法名进行标记
########### 3.生成各个部分 ###########
_base_ = generate_base_config(alg_file_name=alg_file_name, dataset_file_name=dataset_file_name, schedule_file_name=schedule_file_name)
norm_cfg = generate_norm_cfg(GPU_num=GPU_num)
data_preprocessor = generate_data_preprocessor(crop_size = crop_size, mean = mean, std = std)
model_data_preprocessor = data_preprocessor # 修改data_preprocessor
model_data_preprocessor = generate_model_data_preprocessor(model_data_preprocessor=model_data_preprocessor)
# TODO TODO decode_head_loss_decode_dict
decode_head = create_dict_by_kwargs(num_classes=num_classes, loss_decode=decode_head_loss_decode_dict, align_corners=align_corners, norm_cfg=norm_cfg,)
# auxiliary_head_loss_decode=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)
auxiliary_head = create_dict_by_kwargs(num_classes=num_classes, norm_cfg=norm_cfg, loss_decode=auxiliary_head_loss_decode_dict)
test_cfg = generate_model_test_cfg(test_cfg_mode=test_cfg_mode, crop_size=crop_size, test_cfg_crop_div_stride=test_cfg_crop_div_stride)
# 综合model
model = dict(
data_preprocessor = model_data_preprocessor,
decode_head = decode_head,
test_cfg = test_cfg,
auxiliary_head = auxiliary_head,
)
optim_wrapper = generate_optim_wrapper()
param_scheduler = generate_param_scheduler(train_type, train_time_or_epoch_k)
########### 4.程序写入 ###########
# 算法名称解析unet【Alg】 _ r50【pretrained模型深度】 - d8【_base_/models】 _ 4【GPU数量】 x b2【Batch Size大小】 - 40k【schedule】 _ cityscapes【数据集】 - 512x1024【crop_size】.py
alg_file_name = f"{alg_name}_g{GPU_num}-{schedule_file_name.lstrip('schedule_')}_{dataset_file_name}-{crop_size[0]}x{crop_size[1]}-{test_slide}.py"
output_configs_alg_my_alg = os.path.join(f'./configs/unet/', alg_file_name)
write_config_to_file(output_configs_alg_my_alg, _base_=_base_, norm_cfg=norm_cfg, crop_size=crop_size, data_preprocessor=data_preprocessor, model=model, optim_wrapper=optim_wrapper, param_scheduler=param_scheduler)
# 将信息写入文件
with open("_temp_.txt", "w") as file:
json.dump({'alg_infos':{'alg_file_name':alg_file_name,'alg_file_pth':output_configs_alg_my_alg}}, file)

View File

@@ -0,0 +1,70 @@
#!/usr/bin/env python3
# -*- coding: UTF-8 -*-
import os
from Initial_Alg_Program.Initial_Alg_Gen_Tool import format_all_data, write_config_to_file
from Initial_Alg_Program.Initial_Alg_Gen_0_base_ import generate_base_config
from Initial_Alg_Program.Initial_Alg_Gen_1_norm_cfg import generate_norm_cfg
from Initial_Alg_Program.Initial_Alg_Gen_2_data_preprocessor import generate_data_preprocessor
from Initial_Alg_Program.Initial_Alg_Gen_3_model import generate_model_pretrained, generate_model_backbone, generate_model_data_preprocessor, generate_model_decode_head, generate_model_auxiliary_head, generate_model_train_cfg, generate_model_test_cfg
if __name__ == '__main__':
########### 1.1.定义算法基本参数 ###########
# A. generate_base_config
alg_name = 'ann' # ./configs中算法简称
alg_file_name = 'ann_r50-d8' # 算法根文件
dataset_file_name = 'my_dataset_model' # 数据文件
schedule_file_name = 'schedule_4k_check_400' # schedule文件
_base_ = generate_base_config(alg_file_name=alg_file_name, dataset_file_name=dataset_file_name, schedule_file_name=schedule_file_name)
# B. generate_norm_cfg
GPU_num = 1 # GPU数量
norm_cfg = generate_norm_cfg(GPU_num=GPU_num)
# C. generate_data_preprocessor
crop_size = (512,512) # 分割大小
data_preprocessor = generate_data_preprocessor(crop_size = crop_size)
# D. generate_model
# D.1. pretrained
pretrained_pth = './My_Local_Model/open_mmlab/resnet50_v1c.pth' # 预训练模型位置
pretrained = generate_model_pretrained(pretrained_pth = pretrained_pth)
# D.2. backbone
depth = 50 # 模型深度
backbone = generate_model_backbone(depth=depth)
# D.3. data_preprocessor
model_data_preprocessor = data_preprocessor # 修改data_preprocessor
model_data_preprocessor = generate_model_data_preprocessor(model_data_preprocessor=model_data_preprocessor)
# D.4. decode_head、auxiliary_head
num_classes=36 # 分类数
decode_head_loss_decode_dict=dict(type='DiceLoss', use_sigmoid=False, loss_weight=1.0) # decode损失函数
align_corners=False # 是否需要角对齐
auxiliary_head_loss_decode_dict=dict(type='DiceLoss', use_sigmoid=False, loss_weight=0.4) # DiceLoss损失函数
decode_head = generate_model_decode_head(num_classes=num_classes, decode_head_loss_decode_dict=decode_head_loss_decode_dict, align_corners=align_corners)
auxiliary_head = generate_model_auxiliary_head(num_classes=num_classes, auxiliary_head_loss_decode_dict=auxiliary_head_loss_decode_dict, align_corners=align_corners)
# D.5. train_cfg
# train_cfg = generate_model_train_cfg()
# D.6. test_cfg
# test_cfg = generate_model_test_cfg()
# E. 综合model
model = dict(
pretrained = pretrained,
backbone = backbone,
data_preprocessor = model_data_preprocessor,
decode_head = decode_head,
auxiliary_head = auxiliary_head,
# train_cfg = train_cfg,
# test_cfg = test_cfg,
)
########### 2.文件存储 ###########
# output_configs_alg_my_alg = os.path.join(f'my_{alg_name}.py')
output_configs_alg_my_alg = os.path.join(f'./configs/{alg_name}/', f'my_{alg_name}.py')
write_config_to_file(output_configs_alg_my_alg, _base_=_base_, crop_size=crop_size, data_preprocessor=data_preprocessor, model=model)

View File

@@ -0,0 +1,301 @@
import os, json, subprocess
from Initial_Schedule_Program.Initial_Train_Gen_configs_base_schedules_schedule_XXk import generate_times_configs_base_schedules_schedule_file
from Initial_Schedule_Program.Initial_Train_Gen_configs_base_schedules_schedule_XXe import generate_epochs_configs_base_schedules_schedule_file
from datetime import datetime
from Initial_Alg_Program.Initial_Alg_Gen_Tool import format_all_data, write_config_to_file, get_gpu_info
from Initial_Alg_Program.Initial_Alg_Gen_0_base_ import generate_base_config
from Initial_Alg_Program.Initial_Alg_Gen_1_norm_cfg import generate_norm_cfg
from Initial_Alg_Program.Initial_Alg_Gen_2_data_preprocessor import generate_data_preprocessor
from Initial_Alg_Program.Initial_Alg_Gen_3_model import generate_model_pretrained, generate_model_backbone, generate_model_data_preprocessor, generate_model_decode_head, generate_model_auxiliary_head, generate_model_train_cfg, generate_model_test_cfg
def load_json_file(file_name):
"""
读取并返回指定 JSON 文件的内容。
:param file_name: 要读取的 JSON 文件路径
:return: JSON 文件的内容作为字典或列表
"""
try:
# 打开并读取 JSON 文件
with open(file_name, 'r', encoding='utf-8') as f:
data = json.load(f)
return data
except FileNotFoundError:
print(f"\033[91mError: File {file_name} not found.\033[0m")
return None
except json.JSONDecodeError:
print(f"\033[91mError: Failed to decode JSON from {file_name}.\033[0m")
return None
except Exception as e:
print(f"\033[91mAn error occurred while reading {file_name}: {str(e)}\033[0m")
return None
def process_data_record_json_data(json_data):
"""
将 JSON 数据直接存入一个大的字典 all_data_record。
:param json_data: 从 JSON 文件加载的数据
:return: all_data_record 字典
"""
all_data_record = {}
# 遍历 json_data 中的每个键(即每个数据集)
for dataset_name, dataset_info in json_data.items():
all_data_record[dataset_name] = {
"classes": dataset_info['classes'],
"palette": dataset_info['palette'],
"palette_num": dataset_info['palette_num'],
"mean": dataset_info['mean'],
"std": dataset_info['std'],
}
return all_data_record
# 选择要处理的数据集
def select_dataset(all_data_record):
"""
让用户从 all_data_record 中选择数据集,并返回 dataset_file_name 和对应的 palette_num。
:param all_data_record: 包含所有数据集信息的字典
:return: 选定的 dataset_file_name 和 palette_num
"""
# 获取所有数据集的名称
dataset_names = list(all_data_record.keys())
# 显示可用的数据集,并让用户选择
print("选择可用数据集:")
for i, name in enumerate(dataset_names):
print(f"{i + 1}. {name} - {all_data_record[name]['palette_num']}")
# 用户输入选择的数据集编号
while True:
try:
selection = int(input("请选择数据集编号(输入数字):")) - 1
if 0 <= selection < len(dataset_names):
dataset_file_name = dataset_names[selection]
break
else:
print(f"输入的编号无效,请输入 1 到 {len(dataset_names)} 之间的数字。")
except ValueError:
print("无效输入,请输入数字。")
# 获取对应的 palette_num
palette_num = all_data_record[dataset_file_name]['palette_num']
mean = all_data_record[dataset_file_name]['mean']
std = all_data_record[dataset_file_name]['std']
print(f" 已选择数据集: {dataset_file_name} ,其对应的分类数为: {palette_num}")
return dataset_file_name, palette_num, mean, std
# 选择对应的算法
def select_alg(alg_directory):
"""
选择一个算法文件并运行它。
:return: 选定的算法文件名和相对路径
"""
# 获取 ./Alg 目录下的所有 Python 文件
algorithms = [f for f in os.listdir(alg_directory) if f.endswith('.py')]
# 显示可用的算法文件
print("选择可用算法:")
for i, alg in enumerate(algorithms):
print(f"{i + 1}. {alg}")
# 用户选择算法
while True:
try:
selection = int(input("请选择算法编号(输入数字):")) - 1
if 0 <= selection < len(algorithms):
selected_alg = algorithms[selection]
break
else:
print(f"输入的编号无效,请输入 1 到 {len(algorithms)} 之间的数字。")
except ValueError:
print("无效输入,请输入数字。")
# 生成选定算法的相对路径
relative_alg_path = os.path.join(alg_directory, selected_alg)
print(f" 已选择算法: {selected_alg} {relative_alg_path}")
return selected_alg, relative_alg_path
# 选择计算用的GPU
def select_GPU():
"""
让用户选择要使用的 GPU 数量,并返回相应的 GPU 列表(以 0,1 格式输出)。
如果用户没有输入任何选择,默认为选择一块 GPU编号为 0。
:return: 用户选择的 GPU 列表和数量
"""
# 获取 GPU 信息
gpu_info = get_gpu_info()
if not gpu_info:
return None, 0
# 显示 GPU 信息
print("可用的 GPU 列表:")
for idx, mem_free in gpu_info:
print(f"GPU {idx}: 剩余显存 {mem_free} MB")
# 提供默认选择 GPU 个数和编号
default_num_gpus = 1
default_gpu_idx = 0
# 用户选择 GPU 个数(默认选择 1
try:
num_gpus = input(f"请选择使用的 GPU 个数 (1-{len(gpu_info)}, 默认为 1): ")
num_gpus = int(num_gpus) if num_gpus else default_num_gpus
if not (1 <= num_gpus <= len(gpu_info)):
print(f"输入无效,使用默认值 1 个 GPU。")
num_gpus = default_num_gpus
except ValueError:
print(f"无效输入,使用默认值 1 个 GPU。")
num_gpus = default_num_gpus
# 用户选择 GPU 编号(如果选择多个 GPU逐个选择编号
selected_gpus = []
for i in range(num_gpus):
try:
gpu_idx = input(f"请输入要使用的第 {i + 1} 个 GPU 的编号 (默认为 {default_gpu_idx}): ")
gpu_idx = int(gpu_idx) if gpu_idx else default_gpu_idx
if gpu_idx in [idx for idx, _ in gpu_info] and gpu_idx not in selected_gpus:
selected_gpus.append(gpu_idx)
else:
print(f"无效输入,使用默认 GPU {default_gpu_idx}")
selected_gpus.append(default_gpu_idx)
except ValueError:
print(f"无效输入,使用默认 GPU {default_gpu_idx}")
selected_gpus.append(default_gpu_idx)
# 返回以 0,1 格式的 GPU 列表和 GPU 数量
gpu_list_str = ','.join(map(str, selected_gpus))
print(f"\n已选择 GPU: {gpu_list_str},共 {num_gpus} 块 GPU")
return gpu_list_str, num_gpus
# 选择训练批次相关信息
def select_schedule():
"""
选择训练时间、验证次数、日志间隔,并生成训练计划配置文件。
提供默认值train_time_k=40, check_num=10, loggerhook_interval=50
"""
# 交互式输入 train_time_k默认值 40
try:
train_time_k = input("请输入训练次数k为单位默认40k").strip()
train_time_k = int(train_time_k) if train_time_k else 40
except ValueError:
print("无效输入,使用默认训练时间 40k")
train_time_k = 40
# 交互式输入 check_num默认值 10
try:
check_num = input("请输入检查点数量默认10").strip()
check_num = int(check_num) if check_num else 10
except ValueError:
print("无效输入,使用默认检查点数量 20")
check_num = 20
# 交互式输入 loggerhook_interval默认值 50
try:
loggerhook_interval = input("请输入日志间隔默认50次迭代").strip()
loggerhook_interval = int(loggerhook_interval) if loggerhook_interval else 50
except ValueError:
print("无效输入,使用默认日志间隔 50 次迭代")
loggerhook_interval = 50
# 计算验证比例和检查点间隔
val_proportion = 1 / check_num # 验证比例
checkpoint_interval = int(train_time_k * 1000 * val_proportion) # 计算保存检查点的间隔
# 生成文件名和路径
output_configs_base_schedules_schedules_Timek = os.path.join('./configs/_base_/schedules/', f'schedule_{train_time_k}k_check_{checkpoint_interval}.py')
schedule_file_name = f'schedule_{train_time_k}k_check_{checkpoint_interval}.py'
# 调用生成配置文件函数
generate_times_configs_base_schedules_schedule_file(
output_file=output_configs_base_schedules_schedules_Timek,
train_time_k=train_time_k,
val_proportion=val_proportion,
loggerhook_interval=loggerhook_interval
)
return schedule_file_name, train_time_k
# 运行对应的代码去生成算法配置文件
def run_alg_to_gen_alg(relative_alg_path, mean, std, alg_name, dataset_file_name, num_classes, GPU_num, schedule_file_name, train_time_k):
try:
# 构建命令和参数
cmd = [
'python', relative_alg_path,
'--alg_name', alg_name,
'--dataset_file_name', dataset_file_name,
'--mean', str(mean),
'--std', str(std),
'--dataset_num_classes', str(num_classes),
'--GPU_num', str(GPU_num),
'--schedule_file_name', schedule_file_name,
'--train_time_k', str(train_time_k)
]
# 打印当前执行的命令
print(f"\n正在运行: \033[33m{' '.join(cmd)}\033[0m")
# 运行 Python 脚本并传递参数
subprocess.run(cmd, check=True)
print(f"\033[92m算法生成器 {relative_alg_path} 运行成功。\033[0m")
except subprocess.CalledProcessError as e:
print(f"\033[91m算法生成器 {relative_alg_path} 运行失败,错误信息: {e}\033[0m")
if __name__ == '__main__':
########### 1.超参数-定义数据集、算法信息路径 ###########
train_parameter_dir = './My_All_In_One/1_Data_Parameter'
all_data_record_json = "All_Data_Record.json"
all_data_record_file = os.path.join(train_parameter_dir, all_data_record_json) # 数据集信息记录路径
alg_directory = './My_All_In_One/2_Alg_Program' # 算法配置生成路径
work_dir_base = './work_dirs' # 工作路径
########### 2.获取现有数据集信息 ###########
print(f"\033[36m{'='*10} 一、选择训练数据集 {'='*10}\033[0m")
data_record_json_data = load_json_file(file_name = all_data_record_file) # 加载数据集信息
all_data_record = process_data_record_json_data(json_data = data_record_json_data) # 分析数据集信息
dataset_file_name, num_classes, mean, std = select_dataset(all_data_record = all_data_record) # 选择特定数据集
########### 3.获取现有GPU信息 ###########
print(f"\033[36m{'='*10} 二、选择训练GPU {'='*10}\033[0m")
gpu_list_str, GPU_num = select_GPU() # 选择GPU
########### 4.获取训练批次信息 ###########
print(f"\033[36m{'='*10} 三、选择训练批次信息 {'='*10}\033[0m")
schedule_file_name, train_time_k = select_schedule() # 选择训练批次
schedule_file_name = schedule_file_name.rstrip('.py')
########### 5.获取现有算法信息 ###########
print(f"\033[36m{'='*10} 四、选择训练算法 {'='*10}\033[0m")
selected_alg, relative_alg_path = select_alg(alg_directory=alg_directory) # 选择算法
alg_name = selected_alg.rstrip(".py")
########### 6.运行选定的算法 ###########
print(f"\033[36m{'='*5} 生成训练算法配置 {'='*5}\033[0m")
run_alg_to_gen_alg(relative_alg_path=relative_alg_path, alg_name=alg_name, dataset_file_name=dataset_file_name, num_classes=num_classes, GPU_num=GPU_num, mean=mean, std=std,schedule_file_name=schedule_file_name, train_time_k=train_time_k)
# 算法相关信息
with open("_temp_.txt", "r") as file:
data = json.load(file)
alg_file_name = data['alg_infos']['alg_file_name']
alg_file_pth = data['alg_infos']['alg_file_pth']
os.remove("_temp_.txt")
########### 7.输出工作、算法目录 ###########
# 获取并打印当前的年月日时分秒
data_now = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
work_dir_file = os.path.join(work_dir_base, f"{dataset_file_name}-Class_{num_classes}-Alg_{alg_name}-AlgName_{alg_file_name}-Card_{GPU_num}-Data_{data_now}")
print(f"\033[36m训练指令python tools/train.py {alg_file_pth} --work-dir {work_dir_file}")

View File

@@ -0,0 +1,357 @@
import os, json, subprocess
from Initial_Schedule_Program.Initial_Train_Gen_configs_base_schedules_schedule_XXk import generate_times_configs_base_schedules_schedule_file
from Initial_Schedule_Program.Initial_Train_Gen_configs_base_schedules_schedule_XXe import generate_epochs_configs_base_schedules_schedule_file
from datetime import datetime
from Initial_Alg_Program.Initial_Alg_Gen_Tool import format_all_data, write_config_to_file, get_gpu_info
from Initial_Alg_Program.Initial_Alg_Gen_0_base_ import generate_base_config
from Initial_Alg_Program.Initial_Alg_Gen_1_norm_cfg import generate_norm_cfg
from Initial_Alg_Program.Initial_Alg_Gen_2_data_preprocessor import generate_data_preprocessor
from Initial_Alg_Program.Initial_Alg_Gen_3_model import generate_model_pretrained, generate_model_backbone, generate_model_data_preprocessor, generate_model_decode_head, generate_model_auxiliary_head, generate_model_train_cfg, generate_model_test_cfg
def load_json_file(file_name):
"""
读取并返回指定 JSON 文件的内容。
:param file_name: 要读取的 JSON 文件路径
:return: JSON 文件的内容作为字典或列表
"""
try:
# 打开并读取 JSON 文件
with open(file_name, 'r', encoding='utf-8') as f:
data = json.load(f)
return data
except FileNotFoundError:
print(f"\033[91mError: File {file_name} not found.\033[0m")
return None
except json.JSONDecodeError:
print(f"\033[91mError: Failed to decode JSON from {file_name}.\033[0m")
return None
except Exception as e:
print(f"\033[91mAn error occurred while reading {file_name}: {str(e)}\033[0m")
return None
def process_data_record_json_data(json_data):
"""
将 JSON 数据直接存入一个大的字典 all_data_record。
:param json_data: 从 JSON 文件加载的数据
:return: all_data_record 字典
"""
all_data_record = {}
# 遍历 json_data 中的每个键(即每个数据集)
for dataset_name, dataset_info in json_data.items():
all_data_record[dataset_name] = {
"classes": dataset_info['classes'],
"palette": dataset_info['palette'],
"palette_num": dataset_info['palette_num'],
"mean": dataset_info['mean'],
"std": dataset_info['std'],
"train_imgs_num": dataset_info['train_imgs_num']
}
return all_data_record
# 选择要处理的数据集
def select_dataset(all_data_record):
"""
让用户从 all_data_record 中选择数据集,并返回 dataset_file_name 和对应的 palette_num。
:param all_data_record: 包含所有数据集信息的字典
:return: 选定的 dataset_file_name 和 palette_num
"""
# 获取所有数据集的名称
dataset_names = list(all_data_record.keys())
# 显示可用的数据集,并让用户选择
print("选择可用数据集:")
for i, name in enumerate(dataset_names):
print(f"{i + 1}. {name} - {all_data_record[name]['palette_num']}")
# 用户输入选择的数据集编号
while True:
try:
selection = int(input("请选择数据集编号(输入数字):")) - 1
if 0 <= selection < len(dataset_names):
dataset_file_name = dataset_names[selection]
break
else:
print(f"输入的编号无效,请输入 1 到 {len(dataset_names)} 之间的数字。")
except ValueError:
print("无效输入,请输入数字。")
# 获取对应的 palette_num
palette_num = all_data_record[dataset_file_name]['palette_num']
mean = all_data_record[dataset_file_name]['mean']
std = all_data_record[dataset_file_name]['std']
train_imgs_num = all_data_record[dataset_file_name]['train_imgs_num']
print(f" 已选择数据集: {dataset_file_name} ,其对应的分类数为: {palette_num}")
return dataset_file_name, palette_num, mean, std, train_imgs_num
# 选择对应的算法
def select_alg(alg_directory):
"""
选择一个算法文件并运行它。
:return: 选定的算法文件名和相对路径
"""
# 获取 ./Alg 目录下的所有 Python 文件
algorithms = [f for f in os.listdir(alg_directory) if f.endswith('.py')]
# 显示可用的算法文件
print("选择可用算法:")
for i, alg in enumerate(algorithms):
print(f"{i + 1}. {alg}")
# 用户选择算法
while True:
try:
selection = int(input("请选择算法编号(输入数字):")) - 1
if 0 <= selection < len(algorithms):
selected_alg = algorithms[selection]
break
else:
print(f"输入的编号无效,请输入 1 到 {len(algorithms)} 之间的数字。")
except ValueError:
print("无效输入,请输入数字。")
# 生成选定算法的相对路径
relative_alg_path = os.path.join(alg_directory, selected_alg)
print(f" 已选择算法: {selected_alg} {relative_alg_path}")
return selected_alg, relative_alg_path
# 选择计算用的GPU
def select_GPU():
"""
让用户选择要使用的 GPU 数量,并返回相应的 GPU 列表(以 0,1 格式输出)。
如果用户没有输入任何选择,默认为选择一块 GPU编号为 0。
:return: 用户选择的 GPU 列表和数量
"""
# 获取 GPU 信息
gpu_info = get_gpu_info()
if not gpu_info:
return None, 0
# 显示 GPU 信息
print("可用的 GPU 列表:")
for idx, mem_free in gpu_info:
print(f"GPU {idx}: 剩余显存 {mem_free} MB")
# 提供默认选择 GPU 个数和编号
default_num_gpus = 1
default_gpu_idx = 0
# 用户选择 GPU 个数(默认选择 1
try:
num_gpus = input(f"请选择使用的 GPU 个数 (1-{len(gpu_info)}, 默认为 1): ")
num_gpus = int(num_gpus) if num_gpus else default_num_gpus
if not (1 <= num_gpus <= len(gpu_info)):
print(f"输入无效,使用默认值 1 个 GPU。")
num_gpus = default_num_gpus
except ValueError:
print(f"无效输入,使用默认值 1 个 GPU。")
num_gpus = default_num_gpus
# 用户选择 GPU 编号(如果选择多个 GPU逐个选择编号
selected_gpus = []
for i in range(num_gpus):
try:
gpu_idx = input(f"请输入要使用的第 {i + 1} 个 GPU 的编号 (默认为 {default_gpu_idx}): ")
gpu_idx = int(gpu_idx) if gpu_idx else default_gpu_idx
if gpu_idx in [idx for idx, _ in gpu_info] and gpu_idx not in selected_gpus:
selected_gpus.append(gpu_idx)
else:
print(f"无效输入,使用默认 GPU {default_gpu_idx}")
selected_gpus.append(default_gpu_idx)
except ValueError:
print(f"无效输入,使用默认 GPU {default_gpu_idx}")
selected_gpus.append(default_gpu_idx)
# 返回以 0,1 格式的 GPU 列表和 GPU 数量
gpu_list_str = ','.join(map(str, selected_gpus))
print(f"\n已选择 GPU: {gpu_list_str},共 {num_gpus} 块 GPU")
return gpu_list_str, num_gpus
# 选择训练批次相关信息
def select_schedule(train_imgs_num):
"""
让用户选择训练模式Iteration 或 Epoch并根据选择收集参数
最终生成对应的训练计划 schedule 配置文件。
"""
# 1. 让用户选择模式
while True:
mode = input("请选择训练模式 (1: Iteration, 2: Epoch) [默认: 2]: ").strip()
if mode in ['1', '2', '']:
mode = '2' if mode == '' else mode
break
else:
print("无效输入,请输入 1 或 2。")
# --- 模式 1: 基于 Iteration 的配置 ---
if mode == '1':
print("\n--- 您已选择 Iteration 模式 ---")
try:
train_time_or_epoch_k = input("请输入训练次数 (k为单位, 默认40k): ").strip()
train_time_or_epoch_k = int(train_time_or_epoch_k) if train_time_or_epoch_k else 40
except ValueError:
print("无效输入,使用默认训练时间 40k")
train_time_or_epoch_k = 40
try:
check_num = input("请输入总的验证/保存次数 (默认10): ").strip()
check_num = int(check_num) if check_num else 10
except ValueError:
print("无效输入,使用默认次数 10")
check_num = 10
try:
loggerhook_interval = input("请输入日志间隔 (默认50次迭代): ").strip()
loggerhook_interval = int(loggerhook_interval) if loggerhook_interval else 50
except ValueError:
print("无效输入,使用默认日志间隔 50")
loggerhook_interval = 50
val_proportion = 1 / check_num
# 根据验证比例计算间隔确保至少为1
interval = max(1, int(train_time_or_epoch_k * 1000 * val_proportion))
schedule_file_name = f'schedule_{train_time_or_epoch_k}k_check_{interval}.py'
output_path = os.path.join('./configs/_base_/schedules/', schedule_file_name)
generate_times_configs_base_schedules_schedule_file(
output_file=output_path,
train_time_or_epoch_k=train_time_or_epoch_k,
val_proportion=val_proportion,
loggerhook_interval=loggerhook_interval
)
# 返回文件名和训练时长
return schedule_file_name, "Iteration", train_time_or_epoch_k
# --- 模式 2: 基于 Epoch 的配置 ---
elif mode == '2':
print("\n--- 您已选择 Epoch 模式 ---")
try:
max_epochs = input("请输入训练总轮数 (Epochs, 默认300): ").strip()
max_epochs = int(max_epochs) if max_epochs else 300
except ValueError:
print("无效输入,使用默认轮数 300")
max_epochs = 300
try:
val_interval = input("请输入验证间隔的轮次数 (Epochs, 默认1): ").strip()
val_interval = int(val_interval) if val_interval else 1
except ValueError:
print("无效输入,使用默认验证间隔 1")
val_interval = 1
try:
checkpoint_interval = input("请输入模型保存间隔的轮次数 (Epochs, 默认10): ").strip()
checkpoint_interval = int(checkpoint_interval) if checkpoint_interval else 10
except ValueError:
print("无效输入,使用默认保存间隔 10")
checkpoint_interval = 10
try:
loggerhook_interval_default = train_imgs_num // 16
loggerhook_interval = input(f"请输入日志间隔 (默认{loggerhook_interval_default}次迭代): ").strip()
loggerhook_interval = int(loggerhook_interval) if loggerhook_interval else loggerhook_interval_default
except ValueError:
print(f"无效输入,使用默认日志间隔 {loggerhook_interval_default}")
loggerhook_interval = loggerhook_interval_default
schedule_file_name = f'schedule_{max_epochs}e_val{val_interval}_check{checkpoint_interval}.py'
output_path = os.path.join('./configs/_base_/schedules/', schedule_file_name)
# 注意这里调用的是基于Epoch的生成函数
generate_epochs_configs_base_schedules_schedule_file(
output_file=output_path,
max_epochs=max_epochs,
val_interval=val_interval,
checkpoint_interval=checkpoint_interval,
loggerhook_interval=loggerhook_interval
)
# 返回文件名和训练时长
return schedule_file_name, "Epoch", max_epochs
# 运行对应的代码去生成算法配置文件
def run_alg_to_gen_alg(relative_alg_path, mean, std, alg_name, dataset_file_name, num_classes, GPU_num, schedule_file_name, train_type, train_time_or_epoch_k):
try:
# 构建命令和参数
cmd = [
'python', relative_alg_path,
'--alg_name', alg_name,
'--dataset_file_name', dataset_file_name,
'--mean', str(mean),
'--std', str(std),
'--dataset_num_classes', str(num_classes),
'--GPU_num', str(GPU_num),
'--schedule_file_name', schedule_file_name,
'--train_type', train_type,
'--train_time_or_epoch_k', str(train_time_or_epoch_k)
]
# 打印当前执行的命令
print(f"\n正在运行: \033[33m{' '.join(cmd)}\033[0m")
# 运行 Python 脚本并传递参数
subprocess.run(cmd, check=True)
print(f"\033[92m算法生成器 {relative_alg_path} 运行成功。\033[0m")
except subprocess.CalledProcessError as e:
print(f"\033[91m算法生成器 {relative_alg_path} 运行失败,错误信息: {e}\033[0m")
if __name__ == '__main__':
########### 1.超参数-定义数据集、算法信息路径 ###########
train_parameter_dir = './My_All_In_One/1_Data_Parameter'
all_data_record_json = "All_Data_Record.json"
all_data_record_file = os.path.join(train_parameter_dir, all_data_record_json) # 数据集信息记录路径
alg_directory = './My_All_In_One/2_Alg_Program' # 算法配置生成路径
work_dir_base = '../DataSet_Public_outputs' # 工作路径
########### 2.获取现有数据集信息 ###########
print(f"\033[36m{'='*10} 一、选择训练数据集 {'='*10}\033[0m")
data_record_json_data = load_json_file(file_name = all_data_record_file) # 加载数据集信息
all_data_record = process_data_record_json_data(json_data = data_record_json_data) # 分析数据集信息
dataset_file_name, num_classes, mean, std, train_imgs_num = select_dataset(all_data_record = all_data_record) # 选择特定数据集
########### 3.获取现有GPU信息 ###########
print(f"\033[36m{'='*10} 二、选择训练GPU {'='*10}\033[0m")
gpu_list_str, GPU_num = select_GPU() # 选择GPU
########### 4.获取训练批次信息 ###########
print(f"\033[36m{'='*10} 三、选择训练批次信息 {'='*10}\033[0m")
schedule_file_name, train_type, train_time_or_epoch_k = select_schedule(train_imgs_num) # 选择训练批次
schedule_file_name = schedule_file_name.rstrip('.py')
########### 5.获取现有算法信息 ###########
print(f"\033[36m{'='*10} 四、选择训练算法 {'='*10}\033[0m")
selected_alg, relative_alg_path = select_alg(alg_directory=alg_directory) # 选择算法
alg_name = selected_alg.rstrip(".py")
########### 6.运行选定的算法 ###########
print(f"\033[36m{'='*5} 生成训练算法配置 {'='*5}\033[0m")
run_alg_to_gen_alg(relative_alg_path=relative_alg_path, alg_name=alg_name, dataset_file_name=dataset_file_name, num_classes=num_classes, GPU_num=GPU_num, mean=mean, std=std,schedule_file_name=schedule_file_name, train_type = train_type, train_time_or_epoch_k=train_time_or_epoch_k)
# 算法相关信息
with open("_temp_.txt", "r") as file:
data = json.load(file)
alg_file_name = data['alg_infos']['alg_file_name']
alg_file_pth = data['alg_infos']['alg_file_pth']
os.remove("_temp_.txt")
########### 7.输出工作、算法目录 ###########
# 获取并打印当前的年月日时分秒
data_now = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
work_dir_file = os.path.join(work_dir_base, f"{dataset_file_name}-Class_{num_classes}-Alg_{alg_name}-AlgName_{alg_file_name}-Card_{GPU_num}-Data_{data_now}")
print(f"\033[36m训练指令python tools/train.py {alg_file_pth} --work-dir {work_dir_file}")

View File

@@ -0,0 +1,78 @@
import os
import sys
def find_specific_epochs(root_folder):
"""
遍历指定文件夹,查找所有 epoch_XXX.pth 文件,
其中 XXX 为整数且不能被 10 整除,并返回它们的绝对路径。
:param root_folder: 要搜索的根文件夹路径
:return: 一个包含符合条件文件绝对路径的列表
"""
matching_files = []
for dirpath, _, filenames in os.walk(root_folder):
for filename in filenames:
if filename.startswith('epoch_') and filename.endswith('.pth'):
try:
number_str = filename[len('epoch_'):-len('.pth')]
if number_str.isdigit():
epoch_number = int(number_str)
if epoch_number % 10 != 0:
full_path = os.path.join(dirpath, filename)
matching_files.append(os.path.abspath(full_path))
except (ValueError, IndexError):
continue
return matching_files
if __name__ == "__main__":
# 1. 获取当前脚本文件所在的绝对目录
script_dir = os.path.dirname(os.path.abspath(__file__))
# 2. 以脚本目录为基准,构建目标路径
# V1
target_directory = os.path.abspath(os.path.join(script_dir, '../../Hardisk'))
# V2
# target_directory = os.path.abspath(os.path.join(script_dir, '../../DataSet_Public_outputs'))
# 1. 查找符合条件的文件
found_paths = find_specific_epochs(target_directory)
if not found_paths:
print(f"'{os.path.abspath(target_directory)}' 及其子目录中没有找到符合条件的文件。")
sys.exit(0)
# 2. 列出所有找到的文件,并请求用户确认
print("找到了以下符合条件的文件:")
for path in found_paths:
print(path)
print("\n" + "="*50)
print("警告:接下来的操作将永久删除以上列出的所有文件!")
print("="*50 + "\n")
# 3. 获取用户输入
try:
# 使用 strip() 去除首尾空格,使用 lower() 转换为小写
confirm = input("您确定要删除这 {} 个文件吗? (请输入 'yes' 进行确认,输入其他任何内容则取消): ".format(len(found_paths)))
except KeyboardInterrupt:
print("\n\n操作被用户中断。")
sys.exit(1)
# 4. 根据用户的输入执行操作
if confirm.lower().strip() == 'yes':
print("\n正在开始删除文件...")
deleted_count = 0
error_count = 0
for path in found_paths:
try:
os.remove(path)
print(f"已删除: {path}")
deleted_count += 1
except OSError as e:
print(f"删除失败: {path} (原因: {e})")
error_count += 1
print(f"\n操作完成。成功删除 {deleted_count} 个文件,{error_count} 个文件删除失败。")
else:
print("\n操作已取消,没有文件被删除。")

View File

@@ -0,0 +1,99 @@
#!/bin/bash
# --- 脚本说明 ---
# 功能: 使用 rsync 将指定格式的文件夹从源目录同步到目标目录。
# - 脚本可以从任何位置安全执行。
# - 在同步前会检查源文件/目录是否存在,如果不存在则跳过该任务。
#
# rsync 参数说明:
# -a: 归档模式,保留文件属性。
# -v: 详细模式。
# -h: 人性化显示大小。
# --progress: 显示传输进度。
# --stats: 显示任务统计。
# --- 路径配置 ---
# 获取脚本文件所在的绝对目录路径,确保路径的准确性。
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
# 基于脚本所在目录设置源和目标的绝对路径
SOURCE_DIR="$SCRIPT_DIR/../../DataSet_Public_outputs"
DEST_DIR="$SCRIPT_DIR/../../Hardisk"
# --- 主程序 ---
echo "脚本执行目录已锁定为: $SCRIPT_DIR"
echo "解析后的源目录 (Source): $SOURCE_DIR"
echo "解析后的目标根目录 (Destination): $DEST_DIR"
echo "========================================"
echo "开始执行 rsync 同步任务(带安全检查)..."
echo ""
# --- 任务处理函数 (简化代码) ---
# 定义一个函数来处理每个同步任务,避免代码重复
# 参数1: 任务名称 (例如: 1_cholecseg8k)
# 参数2: 源文件/目录的通配符模式
# 参数3: 目标目录
handle_sync_task() {
local task_name="$1"
local src_pattern="$2"
local dest_dir="$3"
echo "--> 正在检查任务: $task_name"
# 将通配符匹配到的文件存入数组
local sources=($src_pattern)
# 检查数组的第一个元素是否存在。如果通配符没有匹配到任何文件,
# bash会把通配符本身作为字符串返回而这个字符串命名的文件通常不存在。
if [ -e "${sources[0]}" ]; then
echo " 源文件已找到,准备同步..."
echo " 目标目录: $dest_dir"
# 创建目标目录(如果不存在)
mkdir -p "$dest_dir"
# 执行 rsync
rsync -avh --progress --stats "${sources[@]}" "$dest_dir/"
echo "--> 任务 '$task_name' 同步完成。"
else
echo " 警告: 源路径 '$src_pattern' 未匹配到任何文件,已跳过此任务。"
fi
echo "----------------------------------------"
}
# --- 任务列表 ---
# 任务1: 同步 1_cholecseg8k
handle_sync_task \
"1_cholecseg8k" \
"$SOURCE_DIR/1_cholecseg8k-Class_13-Alg*" \
"$DEST_DIR/1_CholecSeg8k-13Type-1920x1080_outputs-MMSeg"
# 任务2: 同步 2_autolaparo
handle_sync_task \
"2_autolaparo" \
"$SOURCE_DIR/2_autolaparo-Class_10-Alg*" \
"$DEST_DIR/2_AutoLaparo-10Type-1920x1080_outputs-MMSeg"
# 任务3: 同步 3_1_endovis_2017
handle_sync_task \
"3_1_endovis_2017" \
"$SOURCE_DIR/3_1_endovis_2017-Class_8-Alg*" \
"$DEST_DIR/3_1_Endovis_2017-8Type-512x512_outputs-MMSeg"
# 任务4: 同步 3_2_endovis_2018
handle_sync_task \
"3_2_endovis_2018" \
"$SOURCE_DIR/3_2_endovis_2018-Class_8-Alg*" \
"$DEST_DIR/3_2_Endovis_2018-8Type-512x512_outputs-MMSeg"
# 任务5: 同步 4_dresden
handle_sync_task \
"4_dresden" \
"$SOURCE_DIR/4_dresden-Class_11-Alg*" \
"$DEST_DIR/4_Dresden-11Type-512x512_outputs-MMSeg"
echo "========================================"
echo "所有 rsync 同步任务已执行完毕!"

View File

@@ -0,0 +1,427 @@
import os
import glob
import logging
import argparse
import re
import subprocess
import csv
from typing import Dict, Optional, Tuple, List
import time
import numpy as np
import torch
from mmengine import Config
from mmengine.model.utils import revert_sync_batchnorm
from mmengine.registry import init_default_scope
from mmengine.runner import Runner, load_checkpoint
from mmseg.registry import MODELS
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
# --- 辅助函数 ---
def find_model_files(model_dir: str):
"""
在给定的模型目录中查找配置文件、最佳检查点和日志文件。
Args:
model_dir (str): 模型的根目录。
Returns:
Optional]: 包含 'config', 'checkpoint', 'log' 路径的字典,
如果缺少任何必要文件,则返回 None。
"""
config_files = glob.glob(os.path.join(model_dir, '*.py'))
if not config_files:
logging.warning(f"在目录 {model_dir} 中未找到配置文件 (.py)。")
return None
config_path = config_files[0]
checkpoint_path = os.path.join(model_dir, 'best.pth')
if not os.path.exists(checkpoint_path):
epoch_files = glob.glob(os.path.join(model_dir, 'epoch_*.pth'))
if not epoch_files:
logging.warning(f"在目录 {model_dir} 中未找到 'best.pth''epoch_*.pth' 检查点文件。")
return None
# 通过正则表达式从文件名中提取周期数并找到最大的
latest_epoch = -1
latest_file = None
for f in epoch_files:
match = re.search(r'epoch_(\d+)\.pth', os.path.basename(f))
if match:
epoch_num = int(match.group(1))
if epoch_num > latest_epoch:
latest_epoch = epoch_num
latest_file = f
if latest_file:
checkpoint_path = latest_file
else:
logging.warning(f"在目录 {model_dir} 中无法确定最新的检查点文件。")
return None
return {'config': config_path, 'checkpoint': checkpoint_path}
def find_model_config(model_dir: str):
"""
在给定的模型目录中查找配置文件 (.py)。
Args:
model_dir (str): 模型的根目录。
Returns:
Optional[str]: 配置文件的路径,如果未找到则返回 None。
"""
config_files = glob.glob(os.path.join(model_dir, '*.py'))
if not config_files:
logging.warning(f"在目录 {model_dir} 中未找到配置文件 (.py)。")
return None
return config_files[0]
def get_shape_from_path(path: str):
"""
从文件夹路径中通过正则表达式提取分辨率 (宽x高)。
Args:
path (str): 数据集文件夹的路径。
Returns:
Optional[Tuple[int, int]]: 一个包含 (高度, 宽度) 的元组,如果未找到则返回 None。
注意:工具需要 H W 格式。
"""
match = re.search(r'(\d+)x(\d+)', os.path.basename(path))
if match:
width, height = int(match.group(1)), int(match.group(2))
return (height, width) # 返回 H, W
return None
def get_flops_and_params(config_path: str, shape: Tuple[int, int]) -> Optional[Dict[str, str]]:
"""
运行 mmsegmentation 的 get_flops.py 工具并解析其输出。
此版本适配了新版的直接输出格式 (例如 "Flops: 0.118T")。
Args:
config_path (str): 模型的 .py 配置文件路径。
shape (Tuple[int, int]): 输入图像的 (H, W) 元组。
Returns:
Optional[Dict[str, str]]: 包含 'params''flops' 的字典,如果失败则返回 None。
"""
# 检查工具脚本是否存在
tool_script = 'tools/analysis_tools/get_flops.py'
if not os.path.exists(tool_script):
logging.error(f"错误: '{tool_script}' 未找到。请确保在 MMSegmentation 项目的根目录下运行此脚本。")
return None
# 构建命令行
command = [
'python', tool_script, config_path,
'--shape', str(shape[0]), str(shape[1])
]
logging.info(f"执行命令: {' '.join(command)}")
try:
# 执行命令并捕获输出
result = subprocess.run(command, capture_output=True, text=True, check=True, encoding='utf-8')
output = result.stdout
# 使用新的正则表达式来匹配更新后的输出格式
flops_match = re.search(r"Flops:\s*([0-9.]+\s*[TGMK]?)", output)
params_match = re.search(r"Params:\s*([0-9.]+\s*[TGMK]?)", output)
if flops_match and params_match:
raw_flops_str = flops_match.group(1).strip()
params = params_match.group(1).strip()
# --- 开始单位换算 ---
value_str = raw_flops_str.rstrip('TGMKtgmk').strip()
unit = raw_flops_str[-1].upper() if raw_flops_str[-1].isalpha() else 'G'
try:
value = float(value_str)
if unit == 'T':
value_in_g = value * 1000
elif unit == 'M':
value_in_g = value / 1000
elif unit == 'K':
value_in_g = value / 1_000_000
else: # 默认单位是 G
value_in_g = value
# 使用 :g 格式化可以去除末尾多余的0
flops = f"{value_in_g:g} G"
except ValueError:
flops = raw_flops_str # 如果转换失败,则使用原始值
# --- 单位换算结束 ---
logging.info(f"✅ 解析成功: FLOPs={flops} (原始值: {raw_flops_str}), Params={params}")
return {'flops': flops, 'params': params}
else:
logging.warning(f"❌ 无法从命令输出中解析 FLOPs 或 Params。请检查以下输出内容:\n---\n{output}\n---")
return None
except subprocess.CalledProcessError as e:
logging.error(f"执行 get_flops.py 时出错。返回码: {e.returncode}")
logging.error(f"错误输出 (stderr):\n---\n{e.stderr}\n---")
return None
except FileNotFoundError:
logging.error("错误: 'python' 命令未找到。请确保 Python 环境已正确配置。")
return None
def get_benchmark_stats(config_path: str, checkpoint_path: str, repeat_times: int) -> Optional[Dict[str, float]]:
"""
Runs inference benchmark based on the logic from benchmark.py.
Args:
config_path (str): Path to the model's .py config file.
checkpoint_path (str): Path to the model's .pth checkpoint file.
repeat_times (int): Number of times to run the benchmark.
Returns:
Optional[Dict[str, float]]: Dict containing 'average_fps' and 'fps_variance', or None on failure.
"""
try:
cfg = Config.fromfile(config_path)
init_default_scope(cfg.get('default_scope', 'mmseg'))
torch.backends.cudnn.benchmark = False
cfg.model.pretrained = None
cfg.test_dataloader.batch_size = 1 # Crucial for FPS measurement
overall_fps_list = []
for time_index in range(repeat_times):
logging.info(f"--- Starting Benchmark Run {time_index + 1}/{repeat_times} ---")
data_loader = Runner.build_dataloader(cfg.test_dataloader)
cfg.model.train_cfg = None
model = MODELS.build(cfg.model)
load_checkpoint(model, checkpoint_path, map_location='cpu')
if torch.cuda.is_available():
model = model.cuda(0)
else:
logging.warning("CUDA is not available. Benchmarking on CPU, results may be slow.")
model = revert_sync_batchnorm(model)
model.eval()
num_warmup = 5
pure_inf_time = 0
total_iters = 100 # Reduced from 200 for faster script execution
for i, data in enumerate(data_loader):
data = model.data_preprocessor(data, True)
if torch.cuda.is_available():
torch.cuda.synchronize()
start_time = time.perf_counter()
with torch.no_grad():
model(data['inputs'], data['data_samples'], mode='predict')
if torch.cuda.is_available():
torch.cuda.synchronize()
elapsed = time.perf_counter() - start_time
if i >= num_warmup:
pure_inf_time += elapsed
if (i + 1) == total_iters:
fps = (total_iters - num_warmup) / pure_inf_time
logging.info(f"Run {time_index + 1} Overall FPS: {fps:.2f} img/s")
overall_fps_list.append(fps)
break
if not overall_fps_list:
logging.error("Benchmark failed to produce any results.")
return None
avg_fps = round(np.mean(overall_fps_list), 2)
fps_var = round(np.var(overall_fps_list), 4)
logging.info(f"✅ Benchmark Complete: Average FPS={avg_fps}, Variance={fps_var}")
return {'average_fps': avg_fps, 'fps_variance': fps_var}
except Exception as e:
logging.error(f"An exception occurred during benchmarking: {e}")
return None
# --- 主函数 ---
def main(args):
"""
脚本主入口,负责编排整个自动化分析流程。
"""
input_root = args.input_dir
output_root = args.output_dir
# --- 开始交互式选择修改 (V2 - 两级菜单) ---
if not os.path.isdir(input_root):
logging.error(f"输入目录不存在: {input_root}")
return
# 1. 定义有效的数据集文件夹白名单
VALID_DATASET_FOLDERS = [
'1_CholecSeg8k-13Type-1920x1080_outputs-MMSeg',
'2_AutoLaparo-10Type-1920x1080_outputs-MMSeg',
'3_1_Endovis_2017-8Type-512x512_outputs-MMSeg',
'3_2_Endovis_2018-8Type-512x512_outputs-MMSeg',
'4_Dresden-11Type-512x512_outputs-MMSeg'
]
# 2. 查找存在的、有效的数据集目录
existing_dataset_dirs = [
os.path.join(input_root, d) for d in VALID_DATASET_FOLDERS
if os.path.isdir(os.path.join(input_root, d))
]
if not existing_dataset_dirs:
logging.error(f"在输入目录 {input_root} 中未找到任何有效的数据集文件夹。")
return
# 3. 第一级菜单:选择数据集
dataset_map = {str(i + 1): path for i, path in enumerate(existing_dataset_dirs)}
print("\n" + "="*50)
print("--- 步骤 1: 请选择要处理的数据集 ---")
for key, path in dataset_map.items():
print(f"{key}: {os.path.basename(path)}")
print("="*50)
choice1 = input("请输入数据集编号并按回车键: ").strip()
model_dirs = [] # 初始化最终要处理的目录列表
if choice1 in dataset_map:
selected_dataset_dir = dataset_map[choice1]
logging.info(f"您已选择数据集: [{os.path.basename(selected_dataset_dir)}]")
# 4. 查找选定数据集下的所有算法子目录
alg_dirs = sorted([
d for d in glob.glob(os.path.join(selected_dataset_dir, '*')) if os.path.isdir(d)
])
if not alg_dirs:
logging.warning(f"{os.path.basename(selected_dataset_dir)} 中未发现任何算法子文件夹。程序退出。")
else:
# 5. 第二级菜单:选择算法
alg_map = {str(i + 1): path for i, path in enumerate(alg_dirs)}
print("\n" + "="*50)
print("--- 步骤 2: 请选择要处理的算法 ---")
print("0: 批量处理当前数据集下的【全部】算法")
for key, path in alg_map.items():
print(f"{key}: {os.path.basename(path)}")
print("="*50)
choice2 = input("请输入算法编号 (或输入 '0' 处理全部) 并按回车键: ").strip()
# 6. 根据第二级选择,最终确定 model_dirs
if choice2 == '0':
model_dirs = alg_dirs
logging.info(f"您选择了批量处理全部 {len(model_dirs)} 个算法。")
elif choice2 in alg_map:
model_dirs = [alg_map[choice2]] # 将单个路径放入列表中
logging.info(f"您选择了处理单个算法: {os.path.basename(model_dirs[0])}")
else:
logging.error("无效的算法选择,程序已退出。")
else:
logging.error("无效的数据集选择,程序已退出。")
# --- 交互式选择修改结束 ---
results: List[Dict[str, str]] = []
# 修改后的循环,将遍历经过用户筛选后的 model_dirs 列表
for model_dir in model_dirs:
model_name = os.path.basename(model_dir)
logging.info(f"--- 开始处理模型: {model_name} ---")
files = find_model_files(model_dir)
if not files:
logging.warning(f"跳过目录 {model_dir},因为缺少必要文件。")
continue
# 构建输出目录
# 从模型名中提取数据集标识作为Key
dataset_key = model_name.split('-')[0]
dataset_map = {
'1_cholecseg8k': '1_CholecSeg8k-13Type-1920x1080_outputs-MMSeg',
'2_autolaparo': '2_AutoLaparo-10Type-1920x1080_outputs-MMSeg',
'3_1_endovis_2017': '3_1_Endovis_2017-8Type-512x512_outputs-MMSeg',
'3_2_endovis_2018': '3_2_Endovis_2018-8Type-512x512_outputs-MMSeg',
'4_dresden': '4_Dresden-11Type-512x512_outputs-MMSeg'
}
# 使用提取的Key字符串进行查询并为默认值也使用该Key
output_dataset_folder = dataset_map.get(dataset_key, f"{dataset_key}_outputs-MMSeg")
# 尝试从数据集文件夹名称中获取分辨率
input_shape = get_shape_from_path(selected_dataset_dir)
if not input_shape:
# 如果无法自动提取,要求用户输入
logging.warning(f"无法从文件夹 '{os.path.basename(selected_dataset_dir)}' 名称中自动检测分辨率。")
try:
h_str = input("请输入默认测试高度 (H),例如 512: ").strip()
w_str = input("请输入默认测试宽度 (W),例如 512: ").strip()
input_shape = (int(h_str), int(w_str))
except ValueError:
logging.error("输入无效,必须是整数。程序退出。")
return
logging.info(f"将使用输入形状 (H, W): {input_shape} 进行计算。")
# 加载配置
config_file = find_model_config(model_dir)
# 获取 FLOPs 和 Params
flops_and_params_stats = get_flops_and_params(config_file, input_shape)
benchmark_stats = get_benchmark_stats(files['config'], files['checkpoint'], args.repeat_times)
if flops_and_params_stats:
short_model_name = model_name.split('Alg_', 1)[1]
results.append({
'Model': short_model_name,
'Params': flops_and_params_stats['params'] if flops_and_params_stats else 'N/A' ,
'FLOPs': flops_and_params_stats['flops'] if flops_and_params_stats else 'N/A' ,
'Input_Shape (HxW)': f"{input_shape[0]}x{input_shape[1]}",
'Average_FPS': benchmark_stats['average_fps'] if benchmark_stats else 'N/A',
'FPS_Variance': benchmark_stats['fps_variance'] if benchmark_stats else 'N/A'
})
else:
logging.warning(f"未能获取模型 {model_name} 的统计信息。")
# --- 将结果写入 CSV 文件 ---
if not results:
logging.info("没有成功获取任何模型的统计数据,不生成 CSV 文件。")
return
# 新建文件夹并保存 CSV
final_output_dir = os.path.join(output_root, output_dataset_folder)
os.makedirs(final_output_dir, exist_ok=True)
dataset_name = os.path.basename(selected_dataset_dir).split('_outputs-MMSeg')[0]
output_csv_path = os.path.join(final_output_dir, f'{dataset_name}_flops_params_fps_summary.csv')
try:
with open(output_csv_path, 'w', newline='', encoding='utf-8') as csvfile:
fieldnames = ['Model', 'Params', 'FLOPs', 'Input_Shape (HxW)', 'Average_FPS', 'FPS_Variance']
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
writer.writeheader()
writer.writerows(results)
logging.info(f"=== 全部处理完成!结果已成功保存到: {output_csv_path} ===")
except IOError as e:
logging.error(f"无法写入 CSV 文件: {output_csv_path}。错误: {e}")
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="MMSegmentation 自动化评估脚本")
parser.add_argument(
'--input_dir',
type=str,
default='../Hardisk',
help="包含已训练模型文件夹的根目录。"
)
parser.add_argument(
'--output_dir',
type=str,
default='../BestMode_Predict_Results_DataSet_Public',
help="用于存储所有分析结果的根目录。"
)
parser.add_argument(
'--repeat-times',
type=int,
default=3,
help="Number of times to repeat the benchmark for averaging."
)
args = parser.parse_args()
main(args)

View File

@@ -0,0 +1,332 @@
import os
import glob
import logging
import argparse
import re
import subprocess
import csv
from typing import Dict, Optional, Tuple, List
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
# --- 辅助函数 ---
def find_model_files(model_dir: str):
"""
在给定的模型目录中查找配置文件、最佳检查点和日志文件。
Args:
model_dir (str): 模型的根目录。
Returns:
Optional]: 包含 'config', 'checkpoint', 'log' 路径的字典,
如果缺少任何必要文件,则返回 None。
"""
config_files = glob.glob(os.path.join(model_dir, '*.py'))
if not config_files:
logging.warning(f"在目录 {model_dir} 中未找到配置文件 (.py)。")
return None
config_path = config_files[0]
checkpoint_path = os.path.join(model_dir, 'best.pth')
if not os.path.exists(checkpoint_path):
epoch_files = glob.glob(os.path.join(model_dir, 'epoch_*.pth'))
if not epoch_files:
logging.warning(f"在目录 {model_dir} 中未找到 'best.pth''epoch_*.pth' 检查点文件。")
return None
# 通过正则表达式从文件名中提取周期数并找到最大的
latest_epoch = -1
latest_file = None
for f in epoch_files:
match = re.search(r'epoch_(\d+)\.pth', os.path.basename(f))
if match:
epoch_num = int(match.group(1))
if epoch_num > latest_epoch:
latest_epoch = epoch_num
latest_file = f
if latest_file:
checkpoint_path = latest_file
else:
logging.warning(f"在目录 {model_dir} 中无法确定最新的检查点文件。")
return None
return {'config': config_path, 'checkpoint': checkpoint_path}
def find_model_config(model_dir: str):
"""
在给定的模型目录中查找配置文件 (.py)。
Args:
model_dir (str): 模型的根目录。
Returns:
Optional[str]: 配置文件的路径,如果未找到则返回 None。
"""
config_files = glob.glob(os.path.join(model_dir, '*.py'))
if not config_files:
logging.warning(f"在目录 {model_dir} 中未找到配置文件 (.py)。")
return None
return config_files[0]
def get_shape_from_path(path: str):
"""
从文件夹路径中通过正则表达式提取分辨率 (宽x高)。
Args:
path (str): 数据集文件夹的路径。
Returns:
Optional[Tuple[int, int]]: 一个包含 (高度, 宽度) 的元组,如果未找到则返回 None。
注意:工具需要 H W 格式。
"""
match = re.search(r'(\d+)x(\d+)', os.path.basename(path))
if match:
width, height = int(match.group(1)), int(match.group(2))
return (height, width) # 返回 H, W
return None
def get_flops_and_params(config_path: str, shape: Tuple[int, int]) -> Optional[Dict[str, str]]:
"""
运行 mmsegmentation 的 get_flops.py 工具并解析其输出。
此版本适配了新版的直接输出格式 (例如 "Flops: 0.118T")。
Args:
config_path (str): 模型的 .py 配置文件路径。
shape (Tuple[int, int]): 输入图像的 (H, W) 元组。
Returns:
Optional[Dict[str, str]]: 包含 'params''flops' 的字典,如果失败则返回 None。
"""
# 检查工具脚本是否存在
tool_script = 'tools/analysis_tools/get_flops.py'
if not os.path.exists(tool_script):
logging.error(f"错误: '{tool_script}' 未找到。请确保在 MMSegmentation 项目的根目录下运行此脚本。")
return None
# 构建命令行
command = [
'python', tool_script, config_path,
'--shape', str(shape[0]), str(shape[1])
]
logging.info(f"执行命令: {' '.join(command)}")
try:
# 执行命令并捕获输出
result = subprocess.run(command, capture_output=True, text=True, check=True, encoding='utf-8')
output = result.stdout
# 使用新的正则表达式来匹配更新后的输出格式
flops_match = re.search(r"Flops:\s*([0-9.]+\s*[TGMK]?)", output)
params_match = re.search(r"Params:\s*([0-9.]+\s*[TGMK]?)", output)
if flops_match and params_match:
raw_flops_str = flops_match.group(1).strip()
params = params_match.group(1).strip()
# --- 开始单位换算 ---
value_str = raw_flops_str.rstrip('TGMKtgmk').strip()
unit = raw_flops_str[-1].upper() if raw_flops_str[-1].isalpha() else 'G'
try:
value = float(value_str)
if unit == 'T':
value_in_g = value * 1000
elif unit == 'M':
value_in_g = value / 1000
elif unit == 'K':
value_in_g = value / 1_000_000
else: # 默认单位是 G
value_in_g = value
# 使用 :g 格式化可以去除末尾多余的0
flops = f"{value_in_g:g} G"
except ValueError:
flops = raw_flops_str # 如果转换失败,则使用原始值
# --- 单位换算结束 ---
logging.info(f"✅ 解析成功: FLOPs={flops} (原始值: {raw_flops_str}), Params={params}")
return {'flops': flops, 'params': params}
else:
logging.warning(f"❌ 无法从命令输出中解析 FLOPs 或 Params。请检查以下输出内容:\n---\n{output}\n---")
return None
except subprocess.CalledProcessError as e:
logging.error(f"执行 get_flops.py 时出错。返回码: {e.returncode}")
logging.error(f"错误输出 (stderr):\n---\n{e.stderr}\n---")
return None
except FileNotFoundError:
logging.error("错误: 'python' 命令未找到。请确保 Python 环境已正确配置。")
return None
# --- 主函数 ---
def main(args):
"""
脚本主入口,负责编排整个自动化分析流程。
"""
input_root = args.input_dir
output_root = args.output_dir
# --- 开始交互式选择修改 (V2 - 两级菜单) ---
if not os.path.isdir(input_root):
logging.error(f"输入目录不存在: {input_root}")
return
# 1. 定义有效的数据集文件夹白名单
VALID_DATASET_FOLDERS = [
'1_CholecSeg8k-13Type-1920x1080_outputs-MMSeg',
'2_AutoLaparo-10Type-1920x1080_outputs-MMSeg',
'3_1_Endovis_2017-8Type-512x512_outputs-MMSeg',
'3_2_Endovis_2018-8Type-512x512_outputs-MMSeg',
'4_Dresden-11Type-512x512_outputs-MMSeg'
]
# 2. 查找存在的、有效的数据集目录
existing_dataset_dirs = [
os.path.join(input_root, d) for d in VALID_DATASET_FOLDERS
if os.path.isdir(os.path.join(input_root, d))
]
if not existing_dataset_dirs:
logging.error(f"在输入目录 {input_root} 中未找到任何有效的数据集文件夹。")
return
# 3. 第一级菜单:选择数据集
dataset_map = {str(i + 1): path for i, path in enumerate(existing_dataset_dirs)}
print("\n" + "="*50)
print("--- 步骤 1: 请选择要处理的数据集 ---")
for key, path in dataset_map.items():
print(f"{key}: {os.path.basename(path)}")
print("="*50)
choice1 = input("请输入数据集编号并按回车键: ").strip()
model_dirs = [] # 初始化最终要处理的目录列表
if choice1 in dataset_map:
selected_dataset_dir = dataset_map[choice1]
logging.info(f"您已选择数据集: [{os.path.basename(selected_dataset_dir)}]")
# 4. 查找选定数据集下的所有算法子目录
alg_dirs = sorted([
d for d in glob.glob(os.path.join(selected_dataset_dir, '*')) if os.path.isdir(d)
])
if not alg_dirs:
logging.warning(f"{os.path.basename(selected_dataset_dir)} 中未发现任何算法子文件夹。程序退出。")
else:
# 5. 第二级菜单:选择算法
alg_map = {str(i + 1): path for i, path in enumerate(alg_dirs)}
print("\n" + "="*50)
print("--- 步骤 2: 请选择要处理的算法 ---")
print("0: 批量处理当前数据集下的【全部】算法")
for key, path in alg_map.items():
print(f"{key}: {os.path.basename(path)}")
print("="*50)
choice2 = input("请输入算法编号 (或输入 '0' 处理全部) 并按回车键: ").strip()
# 6. 根据第二级选择,最终确定 model_dirs
if choice2 == '0':
model_dirs = alg_dirs
logging.info(f"您选择了批量处理全部 {len(model_dirs)} 个算法。")
elif choice2 in alg_map:
model_dirs = [alg_map[choice2]] # 将单个路径放入列表中
logging.info(f"您选择了处理单个算法: {os.path.basename(model_dirs[0])}")
else:
logging.error("无效的算法选择,程序已退出。")
else:
logging.error("无效的数据集选择,程序已退出。")
# --- 交互式选择修改结束 ---
results: List[Dict[str, str]] = []
# 修改后的循环,将遍历经过用户筛选后的 model_dirs 列表
for model_dir in model_dirs:
model_name = os.path.basename(model_dir)
logging.info(f"--- 开始处理模型: {model_name} ---")
files = find_model_files(model_dir)
if not files:
logging.warning(f"跳过目录 {model_dir},因为缺少必要文件。")
continue
# 构建输出目录
# 从模型名中提取数据集标识作为Key
dataset_key = model_name.split('-')[0]
dataset_map = {
'1_cholecseg8k': '1_CholecSeg8k-13Type-1920x1080_outputs-MMSeg',
'2_autolaparo': '2_AutoLaparo-10Type-1280x1024_outputs-MMSeg',
'3_1_endovis_2017': '3_1_EndoVis_2017-7Type-1280x1024_outputs-MMSeg',
'3_2_endovis_2018': '3_2_EndoVis_2018-11Type-1280x1024_outputs-MMSeg',
'4_dresden': '4_Dresden-6Type-1920x1080_outputs-MMSeg'
}
# 使用提取的Key字符串进行查询并为默认值也使用该Key
output_dataset_folder = dataset_map.get(dataset_key, f"{dataset_key}_outputs-MMSeg")
# 尝试从数据集文件夹名称中获取分辨率
input_shape = get_shape_from_path(selected_dataset_dir)
if not input_shape:
# 如果无法自动提取,要求用户输入
logging.warning(f"无法从文件夹 '{os.path.basename(selected_dataset_dir)}' 名称中自动检测分辨率。")
try:
h_str = input("请输入默认测试高度 (H),例如 512: ").strip()
w_str = input("请输入默认测试宽度 (W),例如 512: ").strip()
input_shape = (int(h_str), int(w_str))
except ValueError:
logging.error("输入无效,必须是整数。程序退出。")
return
logging.info(f"将使用输入形状 (H, W): {input_shape} 进行计算。")
# 加载配置
config_file = find_model_config(model_dir)
# 获取 FLOPs 和 Params
stats = get_flops_and_params(config_file, input_shape)
if stats:
short_model_name = model_name.split('Alg_', 1)[1]
results.append({
'Model': short_model_name,
'Params': stats['params'],
'FLOPs': stats['flops'],
'Input_Shape (HxW)': f"{input_shape[0]}x{input_shape[1]}"
})
else:
logging.warning(f"未能获取模型 {model_name} 的统计信息。")
# --- 将结果写入 CSV 文件 ---
if not results:
logging.info("没有成功获取任何模型的统计数据,不生成 CSV 文件。")
return
# 新建文件夹并保存 CSV
final_output_dir = os.path.join(output_root, output_dataset_folder)
os.makedirs(final_output_dir, exist_ok=True)
dataset_name = os.path.basename(selected_dataset_dir).split('_outputs-MMSeg')[0]
output_csv_path = os.path.join(final_output_dir, f'{dataset_name}_flops_params_summary.csv')
try:
with open(output_csv_path, 'w', newline='', encoding='utf-8') as csvfile:
fieldnames = ['Model', 'Params', 'FLOPs', 'Input_Shape (HxW)']
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
writer.writeheader()
writer.writerows(results)
logging.info(f"=== 全部处理完成!结果已成功保存到: {output_csv_path} ===")
except IOError as e:
logging.error(f"无法写入 CSV 文件: {output_csv_path}。错误: {e}")
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="MMSegmentation 自动化评估脚本")
parser.add_argument(
'--input_dir',
type=str,
default='../Hardisk',
help="包含已训练模型文件夹的根目录。"
)
parser.add_argument(
'--output_dir',
type=str,
default='../BestMode_Predict_Results_DataSet_Public',
help="用于存储所有分析结果的根目录。"
)
args = parser.parse_args()
main(args)

View File

@@ -0,0 +1,322 @@
import os
import glob
import logging
import argparse
import re
import csv
from typing import Dict, Optional, List
# TODO 这个是获取最后一次结果的 TODO
# --- 配置日志记录 ---
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
# --- 辅助函数 ---
def find_all_log_files_sorted(algorithm_dir: str) -> List[str]:
"""
查找给定算法目录中所有的.log文件并按从新到旧的顺序排列。
Args:
algorithm_dir (str): 算法的根目录。
Returns:
List[str]: 按时间倒序排列的日志文件路径列表。
"""
try:
subdirs = [d for d in os.listdir(algorithm_dir) if os.path.isdir(os.path.join(algorithm_dir, d))]
except FileNotFoundError:
logging.error(f"算法目录不存在: {algorithm_dir}")
return []
if not subdirs:
logging.warning(f"在目录 {algorithm_dir} 中未找到任何时间戳子目录。")
return []
# 按名称倒序排序,最新的目录会排在最前面
sorted_subdirs = sorted(subdirs, reverse=True)
log_files = []
for subdir_name in sorted_subdirs:
subdir_path = os.path.join(algorithm_dir, subdir_name)
logs_in_subdir = glob.glob(os.path.join(subdir_path, '*.log'))
if logs_in_subdir:
# 假设每个子目录只有一个log文件
log_files.append(logs_in_subdir[0])
return log_files
def parse_log_metrics(log_path: str) -> Optional[Dict]:
"""
解析日志文件提取最后一次完整验证validation的结果及其对应的Epoch。
Args:
log_path (str): 日志文件的路径。
Returns:
Optional[Dict]: 包含 'epoch', 'summary', 'class_wise' 指标的字典,如果解析失败则返回 None。
"""
try:
with open(log_path, 'r', encoding='utf-8') as f:
content = f.read()
except IOError as e:
logging.error(f"无法读取日志文件: {log_path}。错误: {e}")
return None
# 定义正则表达式
summary_pattern = re.compile(
r"Iter\(val\) \[\d+/\d+\]\s+aAcc:\s*([\d.]+)\s+mIoU:\s*([\d.]+)\s+mAcc:\s*([\d.]+)"
)
class_table_pattern = re.compile(
# --- 使用新的模式匹配可变长度的顶部边框 ---
r"\+(?:-+\+)+\s*\n"
r"\|.*?Class.*?\|.*?IoU.*?\|.*?Acc.*?\|\n"
# --- 匹配中间边框 ---
r"\+(?:-+\+)+\s*\n"
# --- 捕获表格主体 ---
r"((?:\|.*?\|.*?\|.*?\|\n)+)"
# --- 匹配底部边框 ---
r"\+(?:-+\+)+\s*\n",
re.MULTILINE
)
epoch_pattern = re.compile(r"Saving checkpoint at (\d+) epochs|resumed epoch: (\d+)")
# 查找所有匹配项
summary_matches = list(re.finditer(summary_pattern, content))
table_matches = list(re.finditer(class_table_pattern, content))
epoch_matches = list(re.finditer(epoch_pattern, content))
if not summary_matches or not table_matches:
logging.warning(f"❌ 在日志 {os.path.basename(log_path)} 中未能找到完整的验证结果。")
return None
last_summary_match = summary_matches[-1]
last_table_match = None
for table in reversed(table_matches):
if table.end() < last_summary_match.start():
last_table_match = table
break
if not last_table_match:
logging.warning(f"❌ 在日志 {os.path.basename(log_path)} 中找到总结行但未能匹配到对应的类别表格。")
return None
# 寻找关联的最新Epoch
last_epoch = "N/A"
latest_epoch_num = -1
for epoch_match in epoch_matches:
if epoch_match.end() < last_table_match.start():
epoch_str = epoch_match.group(1) or epoch_match.group(2)
if epoch_str:
epoch_num = int(epoch_str)
if epoch_num > latest_epoch_num:
latest_epoch_num = epoch_num
last_epoch = f"epoch_{epoch_num}"
# 解析数据
summary_groups = last_summary_match.groups()
results = {
'epoch': last_epoch,
'summary': {
'aAcc': summary_groups[0],
'mIoU': summary_groups[1],
'mAcc': summary_groups[2]
},
'class_wise': []
}
table_content = last_table_match.group(0)
row_pattern = re.compile(r"\|\s*([\w\s]+?)\s*\|\s*([\d.]+)\s*\|\s*([\d.]+)\s*\|")
for line in table_content.strip().split('\n'):
row_match = row_pattern.match(line)
if row_match:
class_name, iou, acc = row_match.groups()
results['class_wise'].append({
'Class': class_name.strip(),
'IoU': iou,
'Acc': acc
})
if results['class_wise']:
logging.info(f"✅ 成功从 {os.path.basename(log_path)} 中解析出 Epoch '{last_epoch}' 的指标。")
return results
else:
logging.warning(f"❌ 在 {os.path.basename(log_path)} 中未能解析出任何类别行。")
return None
# --- 主函数 ---
def main(args):
"""
脚本主入口,负责编排整个自动化分析流程。
"""
input_root = args.input_dir
output_root = args.output_dir
if not os.path.isdir(input_root):
logging.error(f"输入目录不存在: {input_root}")
return
# --- 交互式菜单 ---
all_dataset_dirs = sorted([
d for d in glob.glob(os.path.join(input_root, '*_outputs-MMSeg')) if os.path.isdir(d)
])
if not all_dataset_dirs:
logging.error(f"在输入目录 {input_root} 中未找到任何有效的数据集文件夹。")
return
dataset_map = {str(i + 1): path for i, path in enumerate(all_dataset_dirs)}
print("\n" + "="*50)
print("--- 步骤 1: 请选择要处理的数据集 ---")
for key, path in dataset_map.items():
print(f"{key}: {os.path.basename(path)}")
print("="*50)
choice1 = input("请输入数据集编号并按回车键: ").strip()
model_dirs = []
selected_dataset_dir = None
if choice1 in dataset_map:
selected_dataset_dir = dataset_map[choice1]
logging.info(f"您已选择数据集: [{os.path.basename(selected_dataset_dir)}]")
alg_dirs = sorted([
d for d in glob.glob(os.path.join(selected_dataset_dir, '*')) if os.path.isdir(d)
])
if not alg_dirs:
logging.warning(f"{os.path.basename(selected_dataset_dir)} 中未发现任何算法子文件夹。")
return
alg_map = {str(i + 1): path for i, path in enumerate(alg_dirs)}
print("\n" + "="*50)
print("--- 步骤 2: 请选择要处理的算法 ---")
print("0: 批量处理当前数据集下的【全部】算法")
for key, path in alg_map.items():
print(f"{key}: {os.path.basename(path)}")
print("="*50)
choice2 = input("请输入算法编号 (或输入 '0' 处理全部) 并按回车键: ").strip()
if choice2 == '0':
model_dirs = alg_dirs
logging.info(f"您选择了批量处理全部 {len(model_dirs)} 个算法。")
elif choice2 in alg_map:
model_dirs = [alg_map[choice2]]
logging.info(f"您选择了处理单个算法: {os.path.basename(model_dirs[0])}")
else:
logging.error("无效的算法选择,程序已退出。")
return
else:
logging.error("无效的数据集选择,程序已退出。")
return
# --- 开始处理选定的算法 (逻辑已修改) ---
csv_rows = []
output_dataset_folder = ""
for model_dir in model_dirs:
model_name = os.path.basename(model_dir)
logging.info(f"\n--- 开始处理算法: {model_name} ---")
# (路径构建代码保持不变)
if not output_dataset_folder:
dataset_key = model_name.split('-')[0]
dataset_folder_map = {
'1_cholecseg8k': '1_CholecSeg8k-13Type-1920x1080_outputs-MMSeg',
'2_autolaparo': '2_AutoLaparo-10Type-1920x1080_outputs-MMSeg',
'3_1_endovis_2017': '3_1_Endovis_2017-8Type-512x512_outputs-MMSeg',
'3_2_endovis_2018': '3_2_Endovis_2018-8Type-512x512_outputs-MMSeg',
'4_dresden': '4_Dresden-11Type-512x512_outputs-MMSeg'
}
output_dataset_folder = dataset_folder_map.get(dataset_key, f"{dataset_key}_outputs-MMSeg")
# --- 新的循环查找逻辑 ---
# 1. 获取所有按时间倒序排列的日志文件
all_logs_sorted = find_all_log_files_sorted(model_dir)
if not all_logs_sorted:
logging.warning(f"跳过算法 {model_name},因为未找到任何日志文件。")
continue
# 2. 循环尝试解析,直到成功或全部失败
metrics = None
for log_file_path in all_logs_sorted:
logging.info(f"正在尝试解析: {os.path.relpath(log_file_path)}")
metrics = parse_log_metrics(log_file_path)
if metrics:
logging.info(f"{os.path.basename(log_file_path)} 中成功找到并解析了指标。")
break # 找到后立即跳出循环
# 3. 如果所有日志都尝试失败,则跳过此算法
if not metrics:
logging.warning(f"❌❌❌跳过算法 {model_name},因为在其所有日志文件中都未能找到有效的指标。❌❌❌")
continue
# --- 创建一个 "宽" 格式的行 ---
summary = metrics['summary']
short_model_name = model_name.split('Alg_', 1)[1]
row_data = {
'Algorithm': short_model_name,
'Epoch': metrics['epoch'],
'mIoU': summary['mIoU'],
'mAcc': summary['mAcc'],
'aAcc': summary['aAcc']
}
# 将每个类别的IoU和Acc作为新列添加到行数据中
for class_data in metrics['class_wise']:
class_name = class_data['Class'].replace(' ', '_') # 清理类名以用作表头
row_data[f'{class_name}_IoU'] = class_data['IoU']
row_data[f'{class_name}_Acc'] = class_data['Acc']
csv_rows.append(row_data)
# --- 将结果写入 CSV 文件 (逻辑已修改) ---
if not csv_rows:
logging.info("没有成功获取任何模型的统计数据,不生成 CSV 文件。")
return
# --- 动态生成并排序表头 ---
# 基础列保持固定顺序
base_fieldnames = ['Algorithm', 'Epoch', 'mIoU', 'mAcc', 'aAcc']
# 从第一个结果中获取所有与类别相关的列名,并按字母排序
first_row_keys = csv_rows[0].keys()
class_fieldnames = sorted([key for key in first_row_keys if key not in base_fieldnames])
# 最终的完整表头
final_fieldnames = base_fieldnames + class_fieldnames
# --- 构建输出路径并写入文件 ---
final_output_dir = os.path.join(output_root, output_dataset_folder)
os.makedirs(final_output_dir, exist_ok=True)
dataset_name = os.path.basename(selected_dataset_dir).split('_outputs-MMSeg')[0]
# 在文件名中加入 "_wide" 以区分格式
output_csv_path = os.path.join(final_output_dir, f'{dataset_name}_metrics_summary_wide.csv')
try:
with open(output_csv_path, 'w', newline='', encoding='utf-8-sig') as csvfile:
writer = csv.DictWriter(csvfile, fieldnames=final_fieldnames)
writer.writeheader()
writer.writerows(csv_rows)
logging.info(f"=== 全部处理完成!结果已成功保存到: {output_csv_path} ===")
except IOError as e:
logging.error(f"无法写入 CSV 文件: {output_csv_path}。错误: {e}")
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="MMSegmentation 最终指标提取脚本 (V2)")
parser.add_argument(
'--input_dir',
type=str,
default='../Hardisk',
help="包含数据集输出文件夹 (例如 '..._outputs-MMSeg') 的根目录。"
)
parser.add_argument(
'--output_dir',
type=str,
default='../BestMode_Predict_Results_DataSet_Public',
help="用于存储所有分析结果的根目录。"
)
args = parser.parse_args()
main(args)

View File

@@ -0,0 +1,366 @@
import os
import glob
import logging
import argparse
import re
import csv
from typing import Dict, Optional, List
# TODO 这个是获取最后一次结果的 TODO
# --- 配置日志记录 ---
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
# --- 辅助函数 ---
def find_all_log_files_sorted(algorithm_dir: str) -> List[str]:
"""
查找给定算法目录中所有的.log文件并按从新到旧的顺序排列。
Args:
algorithm_dir (str): 算法的根目录。
Returns:
List[str]: 按时间倒序排列的日志文件路径列表。
"""
try:
subdirs = [d for d in os.listdir(algorithm_dir) if os.path.isdir(os.path.join(algorithm_dir, d))]
except FileNotFoundError:
logging.error(f"算法目录不存在: {algorithm_dir}")
return []
if not subdirs:
logging.warning(f"在目录 {algorithm_dir} 中未找到任何时间戳子目录。")
return []
# 按名称倒序排序,最新的目录会排在最前面
sorted_subdirs = sorted(subdirs, reverse=True)
log_files = []
for subdir_name in sorted_subdirs:
subdir_path = os.path.join(algorithm_dir, subdir_name)
logs_in_subdir = glob.glob(os.path.join(subdir_path, '*.log'))
if logs_in_subdir:
# 假设每个子目录只有一个log文件
log_files.append(logs_in_subdir[0])
return log_files
def get_max_epochs(config_path: str) -> Optional[int]:
"""
从config.py文件中解析train_cfg字典以获取max_epochs的值。
Args:
config_path (str): config.py文件的路径。
Returns:
Optional[int]: max_epochs的值如果找不到则返回None。
"""
if not os.path.exists(config_path):
logging.error(f"配置文件不存在: {config_path}")
return None
try:
with open(config_path, 'r', encoding='utf-8') as f:
content = f.read()
# 使用正则表达式查找 max_epochs
match = re.search(r"train_cfg\s*=\s*dict\(.*?max_epochs\s*=\s*(\d+),.*?\)", content, re.DOTALL)
if match:
max_epochs = int(match.group(1))
logging.info(f"{os.path.basename(config_path)} 中成功读取 max_epochs: {max_epochs}")
return max_epochs
else:
logging.warning(f"{config_path} 中未找到 'max_epochs'")
return None
except Exception as e:
logging.error(f"解析 {config_path} 时出错: {e}")
return None
def parse_log_metrics(log_path: str, max_epochs: int) -> List[Dict]:
"""
解析日志文件提取所有完整验证validation的结果并计算其对应的Epoch。
Args:
log_path (str): 日志文件的路径。
max_epochs (int): 从config.py中读取的最大epoch数。
Returns:
List[Dict]: 包含每次验证的 'epoch', 'summary', 'class_wise' 指标的字典列表。
"""
try:
with open(log_path, 'r', encoding='utf-8') as f:
content = f.read()
except IOError as e:
logging.error(f"无法读取日志文件: {log_path}。错误: {e}")
return []
# 定义所有需要的正则表达式
summary_pattern = re.compile(
r"Iter\(val\) \[\d+/\d+\]\s+aAcc:\s*([\d.]+)\s+mIoU:\s*([\d.]+)\s+mAcc:\s*([\d.]+)"
)
class_table_pattern = re.compile(
r"\+(?:-+\+)+\s*\n"
r"\|.*?Class.*?\|.*?IoU.*?\|.*?Acc.*?\|\n"
r"\+(?:-+\+)+\s*\n"
r"((?:\|.*?\|.*?\|.*?\|\n)+)"
r"\+(?:-+\+)+\s*\n",
re.MULTILINE
)
train_iter_pattern = re.compile(r"Iter\(train\)\s*\[\s*(\d+)\s*/\s*(\d+)\]")
# 查找所有匹配项
summary_matches = list(re.finditer(summary_pattern, content))
table_matches = list(re.finditer(class_table_pattern, content))
train_iter_matches = list(re.finditer(train_iter_pattern, content))
all_metrics = []
# 遍历每一次的总结行 (summary)
for i, summary_match in enumerate(summary_matches):
# 寻找与总结行对应的类别表格
# 表格应该出现在总结行之前
last_table_match = None
for table in reversed(table_matches):
if table.end() < summary_match.start():
# 确保这个表格没有被上一个总结行用过
is_already_used = False
if i > 0:
if table.end() < summary_matches[i-1].start():
is_already_used = True
if not is_already_used:
last_table_match = table
break
if not last_table_match:
continue
# 寻找表格前最近的 Iter(train) 行来计算epoch
last_train_iter_match = None
for train_iter in reversed(train_iter_matches):
if train_iter.end() < last_table_match.start():
last_train_iter_match = train_iter
break
epoch = "N/A"
if last_train_iter_match and max_epochs:
current_iter, total_iters = last_train_iter_match.groups()
try:
# 根据公式计算epoch
epoch = int(int(current_iter) / int(total_iters) * max_epochs)
except (ValueError, ZeroDivisionError) as e:
logging.warning(f"Epoch 计算失败: {e}")
# 解析总结指标
summary_groups = summary_match.groups()
results = {
'epoch': epoch,
'summary': {
'aAcc': summary_groups[0],
'mIoU': summary_groups[1],
'mAcc': summary_groups[2]
},
'class_wise': []
}
# 解析每个类别的数据
table_content = last_table_match.group(1)
row_pattern = re.compile(r"\|\s*([\w\s.-]+?)\s*\|\s*([\d.]+)\s*\|\s*([\d.]+)\s*\|")
for line in table_content.strip().split('\n'):
row_match = row_pattern.match(line)
if row_match:
class_name, iou, acc = row_match.groups()
results['class_wise'].append({
'Class': class_name.strip(),
'IoU': iou,
'Acc': acc
})
if results['class_wise']:
all_metrics.append(results)
if all_metrics:
logging.info(f"✅ 成功从 {os.path.basename(log_path)} 中解析出 {len(all_metrics)} 组指标。")
else:
logging.warning(f"❌ 在日志 {os.path.basename(log_path)} 中未能找到完整的验证结果。")
return all_metrics
# --- 主函数 ---
def main(args):
"""
脚本主入口,负责编排整个自动化分析流程。
"""
input_root = args.input_dir
output_root = args.output_dir
if not os.path.isdir(input_root):
logging.error(f"输入目录不存在: {input_root}")
return
# --- 交互式菜单 (这部分保持不变) ---
all_dataset_dirs = sorted([
d for d in glob.glob(os.path.join(input_root, '*_outputs-MMSeg')) if os.path.isdir(d)
])
if not all_dataset_dirs:
logging.error(f"在输入目录 {input_root} 中未找到任何有效的数据集文件夹。")
return
dataset_map = {str(i + 1): path for i, path in enumerate(all_dataset_dirs)}
print("\n" + "="*50)
print("--- 步骤 1: 请选择要处理的数据集 ---")
for key, path in dataset_map.items():
print(f"{key}: {os.path.basename(path)}")
print("="*50)
choice1 = input("请输入数据集编号并按回车键: ").strip()
model_dirs = []
selected_dataset_dir = None
if choice1 in dataset_map:
selected_dataset_dir = dataset_map[choice1]
logging.info(f"您已选择数据集: [{os.path.basename(selected_dataset_dir)}]")
alg_dirs = sorted([
d for d in glob.glob(os.path.join(selected_dataset_dir, '*')) if os.path.isdir(d)
])
if not alg_dirs:
logging.warning(f"{os.path.basename(selected_dataset_dir)} 中未发现任何算法子文件夹。")
return
alg_map = {str(i + 1): path for i, path in enumerate(alg_dirs)}
print("\n" + "="*50)
print("--- 步骤 2: 请选择要处理的算法 ---")
print("0: 批量处理当前数据集下的【全部】算法")
for key, path in alg_map.items():
print(f"{key}: {os.path.basename(path)}")
print("="*50)
choice2 = input("请输入算法编号 (或输入 '0' 处理全部) 并按回车键: ").strip()
if choice2 == '0':
model_dirs = alg_dirs
logging.info(f"您选择了批量处理全部 {len(model_dirs)} 个算法。")
elif choice2 in alg_map:
model_dirs = [alg_map[choice2]]
logging.info(f"您选择了处理单个算法: {os.path.basename(model_dirs[0])}")
else:
logging.error("无效的算法选择,程序已退出。")
return
else:
logging.error("无效的数据集选择,程序已退出。")
return
# --- 开始处理选定的算法 ---
csv_rows = []
output_dataset_folder = ""
for model_dir in model_dirs:
model_name = os.path.basename(model_dir)
logging.info(f"\n--- 开始处理算法: {model_name} ---")
if not output_dataset_folder:
dataset_key = model_name.split('-')[0]
# ... (这部分路径映射逻辑保持不变)
output_dataset_folder = f"{dataset_key}_outputs-MMSeg"
all_logs_sorted = find_all_log_files_sorted(model_dir)
if not all_logs_sorted:
logging.warning(f"跳过算法 {model_name},因为未找到任何日志文件。")
continue
# --- 新逻辑获取max_epochs ---
# 假设同一算法下所有训练的config是相同的因此我们从最新的log对应的config读取
latest_log_path = all_logs_sorted[0]
config_path = os.path.join(os.path.dirname(latest_log_path), 'vis_data', 'config.py')
max_epochs = get_max_epochs(config_path)
if max_epochs is None:
logging.error(f"无法为算法 {model_name} 找到 max_epochs将跳过。")
continue
# --- 新逻辑:聚合所有日志的所有指标 ---
all_metrics_for_model = []
for log_file_path in all_logs_sorted:
logging.info(f"正在解析: {os.path.relpath(log_file_path)}")
# 传递max_epochs
metrics_from_log = parse_log_metrics(log_file_path, max_epochs)
if metrics_from_log:
all_metrics_for_model.extend(metrics_from_log)
if not all_metrics_for_model:
logging.warning(f"❌❌❌跳过算法 {model_name},因为在其所有日志文件中都未能找到有效的指标。❌❌❌")
continue
# --- 新逻辑选择mIoU最高的记录 ---
try:
best_metric = max(all_metrics_for_model, key=lambda x: float(x['summary']['mIoU']))
logging.info(f"找到了最佳指标: Epoch '{best_metric['epoch']}', mIoU: {best_metric['summary']['mIoU']}")
except (ValueError, TypeError) as e:
logging.error(f"为算法 {model_name} 寻找最佳mIoU时出错: {e}")
continue
# --- 创建一个 "宽" 格式的行 ---
summary = best_metric['summary']
short_model_name = model_name.split('Alg_', 1)[1] if 'Alg_' in model_name else model_name
row_data = {
'Algorithm': short_model_name,
'Epoch': best_metric['epoch'],
'mIoU': summary['mIoU'],
'mAcc': summary['mAcc'],
'aAcc': summary['aAcc']
}
for class_data in best_metric['class_wise']:
class_name = class_data['Class'].replace(' ', '_')
row_data[f'{class_name}_IoU'] = class_data['IoU']
row_data[f'{class_name}_Acc'] = class_data['Acc']
csv_rows.append(row_data)
# --- 将结果写入 CSV 文件 (这部分保持不变) ---
if not csv_rows:
logging.info("没有成功获取任何模型的统计数据,不生成 CSV 文件。")
return
base_fieldnames = ['Algorithm', 'Epoch', 'mIoU', 'mAcc', 'aAcc']
first_row_keys = csv_rows[0].keys()
class_fieldnames = sorted([key for key in first_row_keys if key not in base_fieldnames])
final_fieldnames = base_fieldnames + class_fieldnames
final_output_dir = os.path.join(output_root, os.path.basename(selected_dataset_dir))
os.makedirs(final_output_dir, exist_ok=True)
dataset_name = os.path.basename(selected_dataset_dir).split('_outputs-MMSeg')[0]
output_csv_path = os.path.join(final_output_dir, f'{dataset_name}_metrics_summary_wide.csv')
try:
with open(output_csv_path, 'w', newline='', encoding='utf-8-sig') as csvfile:
writer = csv.DictWriter(csvfile, fieldnames=final_fieldnames)
writer.writeheader()
writer.writerows(csv_rows)
logging.info(f"\n=== 全部处理完成!最佳结果已成功保存到: {output_csv_path} ===")
except IOError as e:
logging.error(f"无法写入 CSV 文件: {output_csv_path}。错误: {e}")
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="MMSegmentation 最终指标提取脚本 (V2)")
parser.add_argument(
'--input_dir',
type=str,
default='../Hardisk',
help="包含数据集输出文件夹 (例如 '..._outputs-MMSeg') 的根目录。"
)
parser.add_argument(
'--output_dir',
type=str,
default='../BestMode_Predict_Results_DataSet_Public',
help="用于存储所有分析结果的根目录。"
)
args = parser.parse_args()
main(args)

View File

@@ -0,0 +1,249 @@
import os
import glob
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import re
def get_model_family(model_name):
"""
根据模型名称提取模型族。
例如: 'my_bisenetv1_r50' -> 'my_bisenetv1'
'my_fast_scnn' -> 'my_fast_scnn'
"""
# 使用正则表达式匹配,将 _rXX 或 _dXX 等后缀去掉
match = re.match(r'^(.*?)_r\d+$', model_name)
if match:
return match.group(1)
return model_name
def select_dataset(results_dir):
"""
扫描目录,让用户交互式选择一个数据集。
"""
print("正在扫描可用的数据集...")
try:
all_dataset_dirs = sorted([
d for d in glob.glob(os.path.join(results_dir, '*_outputs-MMSeg')) if os.path.isdir(d)
])
except Exception as e:
print(f"扫描目录 '{results_dir}' 时出错: {e}")
return None, None
if not all_dataset_dirs:
print(f"'{results_dir}' 中未找到任何数据集目录 (以 '_outputs-MMSeg' 结尾)。")
print("请确保脚本与 'BestMode_Predict_Results_DataSet_Public' 文件夹在同一级目录下。")
return None, None
print("\n请选择要可视化的数据集:")
for i, dir_path in enumerate(all_dataset_dirs):
dataset_name = os.path.basename(dir_path).replace('_outputs-MMSeg', '')
print(f" [{i+1}] {dataset_name}")
while True:
try:
choice = input(f"\n请输入选项编号 (1-{len(all_dataset_dirs)}): ")
choice_idx = int(choice) - 1
if 0 <= choice_idx < len(all_dataset_dirs):
selected_dir = all_dataset_dirs[choice_idx]
dataset_name = os.path.basename(selected_dir).replace('_outputs-MMSeg', '')
return selected_dir, dataset_name
else:
print("无效的选项,请输入列表中的编号。")
except (ValueError, IndexError):
print("无效的输入,请输入一个数字编号。")
except (KeyboardInterrupt, EOFError):
print("\n操作已取消。")
return None, None
def plot_performance_speed(selected_dir, dataset_name):
"""
根据选定的数据集目录,加载数据并生成图表。
"""
print(f"\n正在为数据集 '{dataset_name}' 生成图表...")
# 构建文件路径
metrics_file = os.path.join(selected_dir, f"{dataset_name}_metrics_summary_wide.csv")
fps_file = os.path.join(selected_dir, f"{dataset_name}_flops_params_fps_summary.csv")
# 检查文件是否存在
if not os.path.exists(metrics_file) or not os.path.exists(fps_file):
print(f"错误: 在目录 '{selected_dir}' 中缺少所需的数据文件。")
print(f" - 检查是否存在: {os.path.basename(metrics_file)}")
print(f" - 检查是否存在: {os.path.basename(fps_file)}")
return
# 加载数据
try:
metrics_df = pd.read_csv(metrics_file)
# 只保留最新的epoch结果避免重复
metrics_df = metrics_df.sort_values('Epoch', ascending=False).drop_duplicates('Algorithm')
fps_df = pd.read_csv(fps_file)
except FileNotFoundError as e:
print(f"错误: 无法找到文件 {e.filename}")
return
except Exception as e:
print(f"读取CSV文件时出错: {e}")
return
# 合并两个DataFrame
# metrics_df中的'Algorithm'列对应fps_df中的'Model'列
merged_df = pd.merge(metrics_df, fps_df, left_on='Algorithm', right_on='Model')
if merged_df.empty:
print("错误: 数据合并失败。请检查 'Algorithm''Model' 列中的模型名称是否匹配。")
return
# 调用新函数来创建并保存摘要表格
T1_create_and_save_summary_table(merged_df, selected_dir, dataset_name)
# 调用新函数来提取和保存所有IoU数据
T2_extract_and_save_iou_data(metrics_df, selected_dir, dataset_name)
# 提取模型族
merged_df['Family'] = merged_df['Model'].apply(get_model_family)
# --- 绘图 ---
plt.style.use('seaborn-v0_8-whitegrid')
fig, ax = plt.subplots(figsize=(16, 10))
# 定义颜色和标记
families = sorted(merged_df['Family'].unique())
palette = sns.color_palette("husl", len(families))
markers = ['o', 's', 'X', 'D', '^', 'P', '*', 'v', '<', '>']
# 循环绘制每个模型族
for i, family in enumerate(families):
family_df = merged_df[merged_df['Family'] == family].sort_values('Average_FPS')
color = palette[i]
marker = markers[i % len(markers)]
# 绘制散点
ax.scatter(family_df['Average_FPS'], family_df['mIoU'],
color=color, marker=marker, s=150, label=family, zorder=3)
# 如果族内有多个模型,则用线连接
if len(family_df) > 1:
ax.plot(family_df['Average_FPS'], family_df['mIoU'],
color=color, linestyle='--', linewidth=1.5, zorder=2)
# 在每个点旁边添加模型全名注释
for j, row in family_df.iterrows():
ax.text(row['Average_FPS'] * 1.01, row['mIoU'], row['Model'],
fontsize=9, verticalalignment='center')
# 设置图表属性
ax.set_title(f'Model Performance vs. Inference Speed ({dataset_name})', fontsize=18, pad=20)
ax.set_xlabel('Inference Speed (FPS)', fontsize=14)
ax.set_ylabel('Mean IoU (%)', fontsize=14)
ax.legend(title='Model Family', bbox_to_anchor=(1.02, 1), loc='upper left', borderaxespad=0.)
plt.tight_layout(rect=[0, 0, 0.88, 1]) # 调整布局为图例留出空间
plt.grid(True, which='both', linestyle='--', linewidth=0.5)
# 保存并显示图表
output_filename_png = f"F1_{dataset_name}_mIoU_vs_FPS.png"
save_file_path_png = os.path.join(selected_dir, output_filename_png)
plt.savefig(save_file_path_png, dpi=600)
output_filename_svg = f"F1_{dataset_name}_mIoU_vs_FPS.svg"
save_file_path_svg = os.path.join(selected_dir, output_filename_svg)
plt.savefig(save_file_path_svg)
print(f"\n图表已成功生成并保存为: {save_file_path_svg}{save_file_path_png}")
plt.show()
def T1_create_and_save_summary_table(merged_df, output_dir, dataset_name):
"""
根据合并后的数据创建、格式化并保存性能摘要表格。
Args:
merged_df (pd.DataFrame): 包含所有模型指标和性能数据的DataFrame。
output_dir (str): 保存CSV文件的目标目录。
dataset_name (str): 数据集的名称,用于生成文件名。
"""
print("正在创建摘要表格...")
# 检查所需列是否存在
required_columns = ['Model', 'mIoU', 'mAcc', 'aAcc', 'Average_FPS', 'FLOPs', 'Params']
if not all(col in merged_df.columns for col in required_columns):
print("错误: DataFrame中缺少必要的列。请检查CSV文件内容。")
return
# 提取并复制数据避免修改原始DataFrame
summary_df = merged_df[required_columns].copy()
# 清理和转换数据
# 将 '118 G' -> 118.0
summary_df['FLOPs'] = summary_df['FLOPs'].astype(str).str.replace(' G', '', regex=False).astype(float)
# 将 '13.274M' -> 13.274
summary_df['Params'] = summary_df['Params'].astype(str).str.replace('M', '', regex=False).astype(float)
# 按照用户的要求重命名列
summary_df.rename(columns={
'Average_FPS': 'FPS',
'FLOPs': 'G(GFLOPS)',
'Params': 'Para(Params)'
}, inplace=True)
# 按 mIoU 降序排序
summary_df = summary_df.sort_values(by='mIoU', ascending=False)
# 保存表格到CSV文件
summary_filename = f"T1_{dataset_name}_performance_summary.csv"
summary_save_path = os.path.join(output_dir, summary_filename)
try:
summary_df.to_csv(summary_save_path, index=False, float_format='%.3f')
print(f"摘要表格已成功保存到: {summary_save_path}")
except Exception as e:
print(f"保存摘要表格时出错: {e}")
def T2_extract_and_save_iou_data(metrics_df, output_dir, dataset_name):
"""
从 metrics DataFrame 中提取所有 mIoU 和 Class_IoU并保存到新的CSV文件。
Args:
metrics_df (pd.DataFrame): 包含所有指标的原始DataFrame。
output_dir (str): 保存CSV文件的目标目录。
dataset_name (str): 数据集的名称,用于生成文件名。
"""
print("正在提取所有 mIoU 和 Class_IoU 数据...")
# 检查'Algorithm'列是否存在
if 'Algorithm' not in metrics_df.columns:
print("错误: 'Algorithm' 列未找到,无法继续。")
return
# 找出所有与IoU相关的列
# 包括 'mIoU' 以及所有以 '_IoU' 结尾的列
iou_columns = ['Algorithm', 'mIoU'] + [col for col in metrics_df.columns if col.endswith('_IoU') and col != 'mIoU']
# 移除重复的列名(以防万一)
iou_columns = list(dict.fromkeys(iou_columns))
# 提取数据
iou_df = metrics_df[iou_columns].copy()
# 按 mIoU 降序排序,便于查看
iou_df = iou_df.sort_values(by='mIoU', ascending=False)
# 定义并保存文件
iou_filename = f"T2_{dataset_name}_all_iou_summary.csv"
iou_save_path = os.path.join(output_dir, iou_filename)
try:
iou_df.to_csv(iou_save_path, index=False, float_format='%.2f')
print(f"所有IoU数据已成功保存到: {iou_save_path}")
except Exception as e:
print(f"保存IoU数据时出错: {e}")
if __name__ == '__main__':
# 设置包含所有数据集结果的根目录
results_root_dir = '../BestMode_Predict_Results_DataSet_Public'
# 启动交互式选择
selected_directory, selected_dataset_name = select_dataset(results_root_dir)
# 如果用户成功选择,则生成图表
if selected_directory and selected_dataset_name:
plot_performance_speed(selected_directory, selected_dataset_name)

View File

@@ -0,0 +1,413 @@
import os
import glob
import logging
import argparse
import re
import csv
from typing import Dict, Optional, List
from collections import defaultdict
# --- 新增导入:用于绘图 ---
try:
import pandas as pd
import matplotlib.pyplot as plt
MATPLOTLIB_AVAILABLE = True
except ImportError:
MATPLOTLIB_AVAILABLE = False
logging.warning(
"未找到 'pandas''matplotlib' 库。"
"脚本将只生成CSV文件无法自动绘图。"
"请运行 'pip install pandas matplotlib' 来安装它们。"
)
# --- 导入结束 ---
# --- 配置日志记录 ---
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
# --- 辅助函数 (与 4_3 版本相同) ---
def find_all_log_files_sorted(algorithm_dir: str) -> List[str]:
"""
查找给定算法目录中所有的.log文件并按从新到旧的顺序排列。
"""
try:
subdirs = [d for d in os.listdir(algorithm_dir) if os.path.isdir(os.path.join(algorithm_dir, d))]
except FileNotFoundError:
logging.error(f"算法目录不存在: {algorithm_dir}")
return []
if not subdirs:
logging.warning(f"在目录 {algorithm_dir} 中未找到任何时间戳子目录。")
return []
sorted_subdirs = sorted(subdirs, reverse=True)
log_files = []
for subdir_name in sorted_subdirs:
subdir_path = os.path.join(algorithm_dir, subdir_name)
logs_in_subdir = glob.glob(os.path.join(subdir_path, '*.log'))
if logs_in_subdir:
log_files.append(logs_in_subdir[0])
return log_files
def get_max_epochs(config_path: str) -> Optional[int]:
"""
从config.py文件中解析train_cfg字典以获取max_epochs的值。
"""
if not os.path.exists(config_path):
logging.error(f"配置文件不存在: {config_path}")
return None
try:
with open(config_path, 'r', encoding='utf-8') as f:
content = f.read()
match = re.search(r"train_cfg\s*=\s*dict\(.*?max_epochs\s*=\s*(\d+),.*?\)", content, re.DOTALL)
if match:
max_epochs = int(match.group(1))
logging.info(f"{os.path.basename(config_path)} 中成功读取 max_epochs: {max_epochs}")
return max_epochs
else:
logging.warning(f"{config_path} 中未找到 'max_epochs'")
return None
except Exception as e:
logging.error(f"解析 {config_path} 时出错: {e}")
return None
def parse_log_file_data(log_path: str, max_epochs: int) -> Dict:
"""
解析日志文件提取所有训练损失和所有验证mIoU。
"""
try:
with open(log_path, 'r', encoding='utf-8') as f:
content = f.read()
except IOError as e:
logging.error(f"无法读取日志文件: {log_path}。错误: {e}")
return {'training_losses': {}, 'validation_mious': []}
training_losses_by_epoch = defaultdict(list)
validation_mious = []
# 1. 提取训练损失
total_iters_match = re.search(r"Iter\(train\)\s*\[\s*(\d+)\s*/\s*(\d+)\]", content)
total_iters = 0
if total_iters_match:
try:
total_iters = int(total_iters_match.group(2))
except ValueError:
pass
if total_iters > 0:
train_loss_pattern = re.compile(
r"Iter\(train\)\s*\[\s*(\d+)\s*/\s*(\d+)\]"
r"(?:.*?)loss:\s*([\d\.]+)"
)
for match in re.finditer(train_loss_pattern, content):
try:
current_iter = int(match.group(1))
loss = float(match.group(3))
epoch = int((current_iter / total_iters) * max_epochs)
training_losses_by_epoch[epoch].append(loss)
except (ValueError, ZeroDivisionError) as e:
logging.warning(f"解析训练损失时出错: {e}")
else:
logging.warning(f"{os.path.basename(log_path)} 中未找到有效的 'Iter(train)' 行来确定总迭代次数。")
# 2. 提取验证 mIoU
val_summary_pattern = re.compile(
r"Iter\(val\) \[\d+/\d+\]\s+aAcc:\s*[\d.]+\s+mIoU:\s*([\d.]+)\s+mAcc:\s*[\d.]+"
)
for match in re.finditer(val_summary_pattern, content):
try:
miou = float(match.group(1))
validation_mious.append(miou)
except ValueError:
pass
log_name = os.path.basename(log_path)
if training_losses_by_epoch:
logging.info(f"✅ 从 {log_name} 提取了 {len(training_losses_by_epoch)} 个Epoch的训练损失。")
if validation_mious:
logging.info(f"✅ 从 {log_name} 提取了 {len(validation_mious)} 次验证的mIoU。")
if not training_losses_by_epoch and not validation_mious:
logging.warning(f"❌ 在 {log_name} 中未找到训练损失或验证mIoU。")
return {
'training_losses': training_losses_by_epoch,
'validation_mious': validation_mious
}
# --- 新增:从 4_4 脚本中合并过来的绘图函数 ---
def plot_loss_curves(csv_path: str):
"""
读取_training_loss_summary.csv文件并绘制训练损失曲线。
(此版本已根据用户需求修改)
Args:
csv_path (str): 输入的CSV文件路径。
"""
if not MATPLOTLIB_AVAILABLE:
logging.warning("由于缺少 'pandas''matplotlib',跳过绘图。")
return
if not os.path.exists(csv_path):
logging.error(f"[绘图] 文件未找到: {csv_path}")
return
logging.info(f"[绘图] 正在读取数据: {os.path.basename(csv_path)}")
try:
df = pd.read_csv(csv_path)
except Exception as e:
logging.error(f"[绘图] 读取CSV时出错: {e}")
return
try:
df = df.set_index('Algorithm')
except KeyError:
logging.error("[绘图] CSV文件中未找到 'Algorithm' 列。")
return
loss_cols = [col for col in df.columns if col.startswith('Epoch_') and col.endswith('_Loss')]
if not loss_cols:
logging.warning("[绘图] 在CSV中未找到任何 'Epoch_X_Loss' 列。")
return
try:
epochs = [int(col.split('_')[1]) for col in loss_cols]
except (ValueError, IndexError) as e:
logging.error(f"[绘图] 解析Epoch列名时出错: {e}。列名格式应为 'Epoch_N_Loss'")
return
df_losses = df[loss_cols].apply(pd.to_numeric, errors='coerce')
# --- 开始绘图 ---
logging.info("[绘图] 正在生成图表...")
fig, ax = plt.subplots(figsize=(14, 8))
for alg_name, row in df_losses.iterrows():
# --- 修改点 1 ---
# 移除 marker 和 markersize使用实线 (linestyle='-')
ax.plot(epochs, row.values, label=alg_name, linestyle='-')
# --- 设置图表样式 ---
ax.set_xlabel('Epoch', fontsize=12)
ax.set_ylabel('Average Training Loss', fontsize=12)
ax.set_title(f'Training Loss per Epoch\n(Source: {os.path.basename(csv_path)})', fontsize=14)
ax.grid(True, linestyle=':', alpha=0.7)
# --- 修改点 2 ---
# 将Y轴(纵轴)的范围设置为 0 到 5
ax.set_ylim(bottom=0, top=5) # TODO TODO
# 将图例放在图表右侧外部
ax.legend(bbox_to_anchor=(1.04, 1), loc='upper left', title="Algorithms")
fig.tight_layout(rect=[0, 0, 0.85, 1])
# --- 保存图表 ---
output_png_path = os.path.splitext(csv_path)[0] + '.png'
try:
plt.savefig(output_png_path, dpi=150, bbox_inches='tight')
logging.info(f"✅ 图表已成功保存到: {output_png_path}")
except IOError as e:
logging.error(f"[绘图] 保存图像时出错: {e}")
finally:
plt.close(fig) # 释放内存
# --- 主函数 (小幅修改以调用绘图) ---
def main(args):
"""
脚本主入口,负责编排整个自动化分析流程。
"""
input_root = args.input_dir
output_root = args.output_dir
if not os.path.isdir(input_root):
logging.error(f"输入目录不存在: {input_root}")
return
# --- 交互式菜单 (保持不变) ---
all_dataset_dirs = sorted([
d for d in glob.glob(os.path.join(input_root, '*_outputs-MMSeg')) if os.path.isdir(d)
])
if not all_dataset_dirs:
logging.error(f"在输入目录 {input_root} 中未找到任何有效的数据集文件夹。")
return
dataset_map = {str(i + 1): path for i, path in enumerate(all_dataset_dirs)}
print("\n" + "="*50)
print("--- 步骤 1: 请选择要处理的数据集 ---")
for key, path in dataset_map.items():
print(f"{key}: {os.path.basename(path)}")
print("="*50)
choice1 = input("请输入数据集编号并按回车键: ").strip()
model_dirs = []
selected_dataset_dir = None
if choice1 in dataset_map:
selected_dataset_dir = dataset_map[choice1]
logging.info(f"您已选择数据集: [{os.path.basename(selected_dataset_dir)}]")
alg_dirs = sorted([
d for d in glob.glob(os.path.join(selected_dataset_dir, '*')) if os.path.isdir(d)
])
if not alg_dirs:
logging.warning(f"{os.path.basename(selected_dataset_dir)} 中未发现任何算法子文件夹。")
return
alg_map = {str(i + 1): path for i, path in enumerate(alg_dirs)}
print("\n" + "="*50)
print("--- 步骤 2: 请选择要处理的算法 ---")
print("0: 批量处理当前数据集下的【全部】算法")
for key, path in alg_map.items():
print(f"{key}: {os.path.basename(path)}")
print("="*50)
choice2 = input("请输入算法编号 (或输入 '0' 处理全部) 并按回车键: ").strip()
if choice2 == '0':
model_dirs = alg_dirs
logging.info(f"您选择了批量处理全部 {len(model_dirs)} 个算法。")
elif choice2 in alg_map:
model_dirs = [alg_map[choice2]]
logging.info(f"您选择了处理单个算法: {os.path.basename(model_dirs[0])}")
else:
logging.error("无效的算法选择,程序已退出。")
return
else:
logging.error("无效的数据集选择,程序已退出。")
return
# --- 交互式菜单结束 ---
# --- 处理逻辑 (保持不变) ---
csv_rows = []
for model_dir in model_dirs:
model_name = os.path.basename(model_dir)
logging.info(f"\n--- 开始处理算法: {model_name} ---")
all_logs_sorted = find_all_log_files_sorted(model_dir)
if not all_logs_sorted:
logging.warning(f"跳过算法 {model_name},因为未找到任何日志文件。")
continue
latest_log_path = all_logs_sorted[0]
config_path = os.path.join(os.path.dirname(latest_log_path), 'vis_data', 'config.py')
max_epochs = get_max_epochs(config_path)
if max_epochs is None:
logging.error(f"无法为算法 {model_name} 找到 max_epochs将跳过。")
continue
all_losses_for_model = defaultdict(list)
all_mious_for_model = []
for log_file_path in all_logs_sorted:
logging.info(f"正在解析: {os.path.relpath(log_file_path)}")
parsed_data = parse_log_file_data(log_file_path, max_epochs)
for epoch, losses in parsed_data['training_losses'].items():
all_losses_for_model[epoch].extend(losses)
if parsed_data['validation_mious']:
all_mious_for_model.extend(parsed_data['validation_mious'])
if not all_losses_for_model and not all_mious_for_model:
logging.warning(f"❌❌❌跳过算法 {model_name},因为在其所有日志文件中都未能找到有效的训练或验证数据。❌❌❌")
continue
best_miou = 0.0
if all_mious_for_model:
try:
best_miou = max(all_mious_for_model)
logging.info(f"找到了最佳 mIoU: {best_miou:.4f}")
except (ValueError, TypeError) as e:
logging.error(f"为算法 {model_name} 寻找最佳mIoU时出错: {e}")
else:
logging.warning(f"算法 {model_name} 没有找到任何mIoU数据。")
short_model_name = model_name.split('Alg_', 1)[1] if 'Alg_' in model_name else model_name
row_data = {
'Algorithm': short_model_name
}
sorted_epochs = sorted(all_losses_for_model.keys())
if not sorted_epochs:
logging.warning(f"算法 {model_name} 没有找到任何训练损失数据。")
for epoch in sorted_epochs:
losses = all_losses_for_model[epoch]
avg_loss = sum(losses) / len(losses)
row_data[f'Epoch_{epoch}_Loss'] = f"{avg_loss:.4f}"
row_data['Best_mIoU'] = f"{best_miou:.4f}"
csv_rows.append(row_data)
# --- 处理逻辑结束 ---
# --- 写入 CSV 文件 (修改后) ---
if not csv_rows:
logging.info("没有成功获取任何模型的统计数据,不生成 CSV 文件。")
return
# 动态生成所有列名
all_fieldnames_set = set()
for row in csv_rows:
all_fieldnames_set.update(row.keys())
base_fields = ['Algorithm']
miou_field = ['Best_mIoU']
loss_fields = [f for f in all_fieldnames_set if f.startswith('Epoch_')]
try:
loss_fields.sort(key=lambda x: int(x.split('_')[1]))
except (ValueError, IndexError):
logging.error("排序Epoch列名时出错将按字母顺序排序。")
loss_fields.sort()
final_fieldnames = base_fields + loss_fields + miou_field
final_output_dir = os.path.join(output_root, os.path.basename(selected_dataset_dir))
os.makedirs(final_output_dir, exist_ok=True)
dataset_name = os.path.basename(selected_dataset_dir).split('_outputs-MMSeg')[0]
output_csv_path = os.path.join(final_output_dir, f'{dataset_name}_training_loss_summary.csv')
try:
with open(output_csv_path, 'w', newline='', encoding='utf-8-sig') as csvfile:
writer = csv.DictWriter(csvfile, fieldnames=final_fieldnames, extrasaction='ignore')
writer.writeheader()
writer.writerows(csv_rows)
logging.info(f"\n=== CSV文件已成功保存到: {output_csv_path} ===")
# --- 新增:调用绘图函数 ---
# 仅在CSV成功写入后才尝试绘图
plot_loss_curves(output_csv_path)
# --- 新增结束 ---
except IOError as e:
logging.error(f"无法写入 CSV 文件: {output_csv_path}。错误: {e}")
except Exception as e_plot:
# 捕获绘图时可能发生的其他错误
logging.error(f"在主流程中调用绘图时出错: {e_plot}")
if __name__ == '__main__':
parser = argparse.ArgumentParser(
description="MMSegmentation 训练损失提取与绘图脚本 (V3-Integrated)"
)
parser.add_argument(
'--input_dir',
type=str,
default='../Hardisk',
help="包含数据集输出文件夹 (例如 '..._outputs-MMSeg') 的根目录。"
)
parser.add_argument(
'--output_dir',
type=str,
default='../BestMode_Predict_Results_DataSet_Public',
help="用于存储所有分析结果 (CSV和PNG) 的根目录。"
)
args = parser.parse_args()
main(args)

View File

@@ -0,0 +1,28 @@
from .Initial_Alg_Gen_Tool import get_var_from_file
# 生成 _base_ 变量传入算法配置、数据集配置、schedule配置
def generate_base_config(alg_file_name, dataset_file_name, schedule_file_name):
if alg_file_name != None:
base_config =[
f'../_base_/models/{alg_file_name}.py',
f'../_base_/datasets/{dataset_file_name}.py', #换成自己定义的数据集
f'../_base_/default_runtime.py',
f'../_base_/schedules/{schedule_file_name}.py'
]
else:
base_config =[
f'../_base_/datasets/{dataset_file_name}.py', #换成自己定义的数据集
f'../_base_/default_runtime.py',
f'../_base_/schedules/{schedule_file_name}.py'
]
return base_config
if __name__ == '__main__':
# 示例用法
alg_file_name = 'ann_r50-d8' # 算法根文件
dataset_file_name = 'my_dataset_model' # 数据文件
schedule_file_name = 'schedule_4k_check_400' # schedule文件
_base_ = generate_base_config(alg_file_name=alg_file_name, dataset_file_name=dataset_file_name, schedule_file_name=schedule_file_name)
print(_base_)

View File

@@ -0,0 +1,19 @@
from .Initial_Alg_Gen_Tool import get_var_from_file
# 单卡 norm_cfg = dict(type='BN')
# 多卡 norm_cfg = dict(type='SyncBN')
def generate_norm_cfg(GPU_num = 2):
GPU_num = int(GPU_num)
if GPU_num == 1:
return dict(type='BN')
elif GPU_num > 1:
return dict(type='SyncBN')
else:
raise ValueError("GPU_num需要为大于等于1的整数数值")
if __name__ == '__main__':
# 示例用法
GPU_num = 1
norm_cfg = generate_norm_cfg(GPU_num=GPU_num)
print(norm_cfg)

View File

@@ -0,0 +1,23 @@
from .Initial_Alg_Gen_Tool import get_var_from_file
# crop_size 数据预处理是分割大小
# crop_size = (512,512)
def generate_data_preprocessor(crop_size=None, mean=None, std=None, bgr_to_rgb=False):
data_preprocessor = {}
if crop_size != None:
data_preprocessor['size']=crop_size
if mean != None:
data_preprocessor['mean']=mean
if std != None:
data_preprocessor['std']=std
if bgr_to_rgb != None:
data_preprocessor['bgr_to_rgb']=bgr_to_rgb
return data_preprocessor
if __name__ == '__main__':
# 示例用法
crop_size = (512,512)
data_preprocessor = generate_data_preprocessor(crop_size = crop_size)
print(data_preprocessor)

View File

@@ -0,0 +1,119 @@
from .Initial_Alg_Gen_Tool import get_var_from_file, format_dict
# 1. 修改pretrained
# pretrained_pth = './My_Local_Model/open_mmlab/resnet50_v1c.pth')
def generate_model_pretrained(pretrained_pth=None):
pretrained = pretrained_pth
if pretrained_pth == None:
return None
return pretrained
# 2. 修改backbone
# depth=50
def generate_model_backbone(depth=None):
backbone = {}
if depth != None:
backbone['depth']=depth
return backbone
# 3. 修改model_data_preprocessor
# model_data_preprocessor='data_preprocessor'
def generate_model_data_preprocessor(model_data_preprocessor=None):
if model_data_preprocessor != None:
model_data_preprocessor = model_data_preprocessor
return model_data_preprocessor
# 4. 修改decode_head
# num_classes='36' # 分割数
# decode_head_loss_decode_dict=dict(type='DiceLoss', use_sigmoid=False, loss_weight=1.0) # 直接传入dict
# align_corners=False是否需要角对齐
def generate_model_decode_head(num_classes=None, decode_head_loss_decode_dict=None, align_corners=None):
decode_head = {}
if num_classes != None:
decode_head['num_classes']=num_classes
if decode_head_loss_decode_dict != None:
decode_head['loss_decode']=decode_head_loss_decode_dict
if align_corners != None:
decode_head['align_corners']=align_corners
return decode_head
# 5. 修改auxiliary_head
# num_classes='36' # 分割数
# auxiliary_head_loss_decode_dict=dict(type='DiceLoss', use_sigmoid=False, loss_weight=0.4) # 直接传入dict
# align_corners=False是否需要角对齐
def generate_model_auxiliary_head(num_classes=None, auxiliary_head_loss_decode_dict=None, align_corners=None):
auxiliary_head = {}
if num_classes != None:
auxiliary_head['num_classes']=num_classes
if auxiliary_head_loss_decode_dict != None:
auxiliary_head['loss_decode']=auxiliary_head_loss_decode_dict
if align_corners != None:
auxiliary_head['align_corners']=align_corners
return auxiliary_head
# 6. 修改train_cfg
def generate_model_train_cfg():
train_cfg = {}
return train_cfg
# 6. 修改test_cfg
# mode='slide'
# crop_size = (767,767)
# test_cfg_crop_div_stride = 1.5
def generate_model_test_cfg(test_cfg_mode=None, crop_size=None, test_cfg_crop_div_stride=None):
test_cfg = {}
if test_cfg_mode != None:
test_cfg['mode'] = test_cfg_mode
if crop_size != None:
test_cfg['crop_size'] = crop_size
if test_cfg_crop_div_stride != None:
test_cfg['stride'] = (int(crop_size[0]/test_cfg_crop_div_stride), int(crop_size[1]/test_cfg_crop_div_stride))
return test_cfg
if __name__ == '__main__':
# 1. 修改pretrained
pretrained_pth = './My_Local_Model/open_mmlab/resnet50_v1c.pth'
pretrained = generate_model_pretrained(pretrained_pth = pretrained_pth)
# 2. 修改backbone
depth = 50
backbone = generate_model_backbone(depth=depth)
# 3. 修改data_preprocessor
model_data_preprocessor = 'data_preprocessor'
model_data_preprocessor = generate_model_data_preprocessor(model_data_preprocessor=model_data_preprocessor)
# 4.5. 修改decode_head、auxiliary_head
num_classes='36' # 分割数
decode_head_loss_decode_dict=dict(type='DiceLoss', use_sigmoid=False, loss_weight=1.0) # 直接传入dict
align_corners=False # 是否需要角对齐
auxiliary_head_loss_decode_dict=dict(type='DiceLoss', use_sigmoid=False, loss_weight=0.4)
decode_head = generate_model_decode_head(num_classes=num_classes, decode_head_loss_decode_dict=decode_head_loss_decode_dict, align_corners=align_corners)
auxiliary_head = generate_model_auxiliary_head(num_classes=num_classes, auxiliary_head_loss_decode_dict=auxiliary_head_loss_decode_dict, align_corners=align_corners)
# 6. 修改train_cfg
train_cfg = generate_model_train_cfg()
# 6. 修改test_cfg
mode='slide'
crop_size = (767,767)
test_cfg_crop_div_stride = 1.5
test_cfg = generate_model_test_cfg(mode=mode, crop_size=crop_size, test_cfg_crop_div_stride=test_cfg_crop_div_stride)
# 汇总为model
model = dict(
pretrained = pretrained,
backbone = backbone,
data_preprocessor = model_data_preprocessor,
decode_head = decode_head,
auxiliary_head = auxiliary_head,
train_cfg = train_cfg,
test_cfg = test_cfg,
)
model_ = format_dict(model)
print(model_)

View File

@@ -0,0 +1,173 @@
# optimizer优化器设计TODO
# type_of_back_bone = "Vit"
def generate_optim_wrapper(type_of_back_bone=None):
# optim_wrapper 配置-1
optim_wrapper_1 = {
'type': 'OptimWrapper', # 表示这是一个优化器包装器OptimWrapper
'_delete_': True, # 通常用于在继承配置时删除旧的优化器配置,替换为新的优化器配置
'optimizer': {
'type': 'AdamW', # 优化器类型为 AdamW
'lr': 0.0001, # 学习率,通常设定为一个较小的值
'weight_decay': 0.0005 # 权重衰减系数
},
'clip_grad': {
'max_norm': 1, # 梯度裁剪的最大范数
'norm_type': 2 # L2 范数
}
}
# optim_wrapper 配置-2
optim_wrapper_2 = {
'type': 'OptimWrapper', # 表示这是一个优化器包装器OptimWrapper
'_delete_': True, # 通常用于在继承配置时删除旧的优化器配置,替换为新的优化器配置
'optimizer': {
'type': 'SGD', # 优化器类型为 SGD
'lr': 0.05, # 学习率
'weight_decay': 0.0005, # 权重衰减系数
'momentum': 0.9
},
'clip_grad': {
'max_norm': 1, # 梯度裁剪的最大范数
'norm_type': 2 # L2 范数
}
}
optim_wrapper_list = [optim_wrapper_1, optim_wrapper_2]
# 打印所有可用的优化器选项
while True:
print("请选择 optim_wrapper (按 Enter 使用默认):")
for i, scheduler in enumerate(optim_wrapper_list, 1):
optimizer_type = scheduler['optimizer']['type']
learning_rate = scheduler['optimizer']['lr']
print(f"{i}. Optimizer: {optimizer_type}, LR: {learning_rate}")
# 如果是vit网络则需要额外的paramwise_cfg
if type_of_back_bone != None and (type_of_back_bone.lower() == 'vit' or type_of_back_bone.lower() == 'visiontransformer'):
custom_keys={
'pos_embed': dict(decay_mult=0.), # 位置嵌入positional embeddings。decay_mult=0. 意味着对这些嵌入不应用权重衰减
'cls_token': dict(decay_mult=0.), # 是在某些模型(如 Transformer 或 BERT用于分类任务的特定 token
'norm': dict(decay_mult=0.) # 对归一化层的参数也禁用了权重衰减
}
optim_wrapper_list[i-1]['paramwise_cfg'] = dict(custom_keys)
optim_wrapper_list[i-1]['_delete_'] = True
if type_of_back_bone != None and (type_of_back_bone.lower() == 'swin'):
custom_keys={
'absolute_pos_embed': dict(decay_mult=0.),
'relative_position_bias_table': dict(decay_mult=0.),
'norm': dict(decay_mult=0.)
}
optim_wrapper_list[i-1]['paramwise_cfg'] = dict(custom_keys)
optim_wrapper_list[i-1]['_delete_'] = True
choice = input(f"输入 1 到 {len(optim_wrapper_list)} 进行选择(默认 1: ")
# 如果用户按下 Enter 或选择 1默认返回 optim_wrapper_1
if choice == '' or choice == '1':
return optim_wrapper_list[0]
elif choice.isdigit() and 1 <= int(choice) <= len(optim_wrapper_list):
return optim_wrapper_list[int(choice) - 1]
else:
print(f"无效输入,请输入 1 到 {len(optim_wrapper_list)} 或按 Enter 使用默认值")
def generate_param_scheduler(train_type, train_time_or_epoch):
"""
根据训练模式epoch或iteration动态生成并选择学习率调度器。
Args:
train_type (str): 训练模式, 'epoch''iteration'
train_time_or_epoch (int): 训练的总轮数(epochs)或总迭代次数(k)。
"""
# 1. 根据训练模式确定核心参数
if train_type.lower() == 'epoch':
by_epoch = True
end_value = train_time_or_epoch
warmup_end = 10 # 为epoch模式设置一个合理的10个epoch的预热期
# 按比例调整MultiStepLR的milestones
milestones = [int(end_value * 0.75), int(end_value * 0.9)]
else: # 默认为 iteration 模式
by_epoch = False
end_value = train_time_or_epoch * 1000
warmup_end = 1500 # iteration模式使用原有的1500次迭代预热
# MultiStepLR的milestones, 同样适配总时长
milestones = [int(end_value * 0.75), int(end_value * 0.9)]
# 2. 使用动态参数定义调度器模板
# param_scheduler 配置-1
param_scheduler_1 = [
dict(
type='LinearLR',
start_factor=1e-6,
by_epoch=by_epoch,
begin=0,
end=warmup_end),
dict(
type='PolyLR',
power=0.9,
begin=warmup_end,
end=end_value,
eta_min=1e-5,
by_epoch=by_epoch)
]
# param_scheduler 配置-2
param_scheduler_2 = [
dict(
type='LinearLR',
start_factor=1e-6,
by_epoch=by_epoch,
begin=0,
end=warmup_end),
dict(
type='PolyLR',
power=1.0,
begin=warmup_end,
end=end_value,
eta_min=0.0,
by_epoch=by_epoch)
]
# param_scheduler 配置-3
param_scheduler_3 = [
dict(
type='LinearLR',
start_factor=0.001,
by_epoch=by_epoch,
begin=0,
end=warmup_end),
dict(
type='MultiStepLR',
begin=warmup_end,
end=end_value,
milestones=milestones,
by_epoch=by_epoch)
]
param_scheduler_list = [param_scheduler_1, param_scheduler_2, param_scheduler_3]
# 3. 打印并让用户选择 (这部分逻辑不变)
while True:
print("请选择 param_scheduler (学习率调度器):")
for i, schedulers in enumerate(param_scheduler_list, 1):
print(f"Scheduler {i}:")
for scheduler in schedulers:
# 提取字段并处理缺省情况
scheduler_type = scheduler.get('type', '/')
begin = scheduler.get('begin', '/')
end = scheduler.get('end', '/')
power = scheduler.get('power', '/')
eta_min = scheduler.get('eta_min', '/')
milestones_val = scheduler.get('milestones', '/')
# 根据类型决定显示哪些信息
if scheduler_type == 'MultiStepLR':
print(f" - {scheduler_type}: begin={begin}, end={end}, milestones={milestones_val}")
else:
print(f" - {scheduler_type}: begin={begin}, end={end}, power={power}, eta_min={eta_min}")
choice = input(f"输入 1 到 {len(param_scheduler_list)} 进行选择(默认 1: ").strip()
if choice == '' or choice == '1':
return param_scheduler_list[0]
elif choice.isdigit() and 1 <= int(choice) <= len(param_scheduler_list):
return param_scheduler_list[int(choice) - 1]
else:
print(f"无效输入,请输入 1 到 {len(param_scheduler_list)} 或按 Enter 使用默认值\n")

View File

@@ -0,0 +1,59 @@
# train_dataloader TODO
def generate_train_dataloader(batch_size_default=None, num_workers_default=None):
batch_size = None
if batch_size_default != None:
while True:
user_input = input(f"请输入 batch size (默认为 {batch_size_default}): ")
# 如果用户没有输入内容,使用默认值
if not user_input.strip():
batch_size = batch_size_default
break
# 尝试将输入转换为整数
try:
batch_size = int(user_input)
if batch_size > 0:
break # 输入正确,退出循环
else:
print("Batch size 必须是正整数,请重新输入。")
except ValueError:
print("输入无效,请输入一个有效的整数。")
print(f"将train_dataloader的batch_size设置为{batch_size}")
num_workers = None
if num_workers_default != None:
while True:
user_input = input(f"请输入 num workers (默认为 {num_workers_default}): ")
# 如果用户没有输入内容,使用默认值
if not user_input.strip():
num_workers = num_workers_default
break
# 尝试将输入转换为整数
try:
num_workers = int(user_input)
if num_workers > 0:
break # 输入正确,退出循环
else:
print("Num workers 必须是正整数,请重新输入。")
except ValueError:
print("输入无效,请输入一个有效的整数。")
print(f"将train_dataloader的num_workers设置为{num_workers}")
# 返回包含 batch_size 的字典
train_dataloader = {}
if num_workers == None:
train_dataloader['batch_size'] = batch_size
return train_dataloader, batch_size
if batch_size == None:
train_dataloader['num_workers'] = num_workers
return train_dataloader, num_workers
train_dataloader['batch_size'] = batch_size
train_dataloader['num_workers'] = num_workers
return train_dataloader, batch_size, num_workers

View File

@@ -0,0 +1,264 @@
import ast, subprocess, os
# 从文件中获取特定变量信息,返回变量直和类型
def get_var_from_file(filename, var_name="norm_cfg"):
with open(filename, 'r', encoding='utf-8') as file:
code = file.read()
# 解析文件的 AST 树
tree = ast.parse(code)
# 遍历 AST 树,查找指定变量
for node in ast.walk(tree):
if isinstance(node, ast.Assign):
for target in node.targets:
if isinstance(target, ast.Name) and target.id == var_name:
# 将 AST 节点转换为 Python 对象
var_value = ast.literal_eval(node.value)
# 返回变量的值和类型
return var_value, type(var_value)
# 如果没有找到指定的变量,返回 None 和 None
return None, None
def get_var_from_py_file(file_path, var_name="auxiliary_head"):
# 定义一个字典用于保存文件中的变量
context = {}
# 读取并执行文件
with open(file_path, 'r') as file:
exec(file.read(), context)
# 获取 auxiliary_head 变量
if var_name in context:
return context[var_name]
else:
raise AttributeError(f"文件中没有找到 {var_name} 变量")
# 更新list的dict变量
def update_list_dict_var(var, var_new):
# 判断 var 和 var_new 的长度是否相等
if len(var) != len(var_new):
raise ValueError(f"var {len(var)} 和 var_new {len(var_new)} 的大小不相等,无法更新")
# 遍历 var_new按索引更新 var
for i, new_entry in enumerate(var_new):
# 将 new_entry 中的键值覆盖 var[i] 中的相同键
var[i].update(new_entry)
return var
if __name__ == '__main__':
# 示例调用
var_value, var_type = get_var_from_file('./configs/_base_/models', 'norm_cfg')
if var_value is not None:
print(f"变量名: norm_cfg\n值: {var_value}\n类型: {var_type}")
else:
print("未找到指定变量")
# 将文件以dict格式输出
def format_dict(d, indent_level=1):
formatted_items = []
indent = ' ' * indent_level # 根据缩进级别生成空格
for key, value in d.items():
if isinstance(value, dict):
# 如果值是字典,递归调用 format_dict_as_func 并增加缩进级别
formatted_value = format_dict(value, indent_level + 1)
formatted_items.append(f"{indent}{key}={formatted_value}")
elif isinstance(value, str):
# 如果是字符串,格式化时加引号
formatted_items.append(f"{indent}{key}='{value}'")
else:
# 其他类型(如数值等)不加引号
formatted_items.append(f"{indent}{key}={value}")
# 将所有键值对合成为多行的格式,并加上结尾逗号
formatted_str = ",\n".join(formatted_items)
# 返回更美观的 dict() 的格式,保留缩进和换行
return f"dict(\n{formatted_str},\n{' ' * (indent_level - 1)})"
# 所有格式正确输出
def format_all_data_old(data, indent_level=0):
indent = ' ' * indent_level # 根据缩进级别生成空格
if isinstance(data, dict):
formatted_items = []
for key, value in data.items():
formatted_value = format_all_data(value, indent_level + 1)
formatted_items.append(f"{indent} {key}={formatted_value}")
formatted_str = ",\n".join(formatted_items)
return f"dict(\n{formatted_str},\n{indent})"
elif isinstance(data, list):
formatted_items = [f"{indent} {format_all_data(item, indent_level + 1)}" for item in data]
formatted_str = ",\n".join(formatted_items)
return f"[\n{formatted_str},\n{indent}]"
elif isinstance(data, tuple):
return f"{data}"
elif isinstance(data, str):
# 如果字符串包含单引号,则使用双引号,否则使用单引号
if "'" in data:
return f'"{data}"'
else:
return f"'{data}'"
else:
return str(data)
# 所有格式正确输出,加入键值中带有"."的处理
def format_all_data(data, indent_level=0):
indent = ' ' * indent_level # 根据缩进级别生成空格
dot_key_items = [] # 用于存储带点号的键
regular_items = [] # 用于存储普通键
if isinstance(data, dict):
for key, value in data.items():
formatted_value = format_all_data(value, indent_level + 1)
# 如果键包含点号,收集到 dot_key_items 中
if '.' in key:
dot_key_items.append(f"('{key}', {formatted_value})")
else:
regular_items.append(f"{indent} {key}={formatted_value}")
# 生成带点号键的部分,并放在最前面
if dot_key_items:
dot_key_str = f"{indent} [{', '.join(dot_key_items)}]"
else:
dot_key_str = ""
# 生成普通键的部分
regular_str = ",\n".join(regular_items)
# 返回最终的组合字符串,带点号键放在普通键前面
if dot_key_str:
return f"dict(\n{dot_key_str},\n{regular_str},\n{indent})"
else:
return f"dict(\n{regular_str},\n{indent})"
elif isinstance(data, list):
formatted_items = [f"{indent} {format_all_data(item, indent_level + 1)}" for item in data]
formatted_str = ",\n".join(formatted_items)
return f"[\n{formatted_str},\n{indent}]"
elif isinstance(data, tuple):
return f"{data}"
elif isinstance(data, str):
# 如果字符串包含单引号,则使用双引号,否则使用单引号
if "'" in data:
return f'"{data}"'
else:
return f"'{data}'"
else:
return str(data)
# # 批量将参数内容写入文件 # V1 传统版
# def write_config_to_file(output_file, **kwargs):
# """
# 将传入的任意数量的参数写入指定文件,并格式化输出。
# :param output_file: 要写入的文件路径
# :param kwargs: 任意数量的配置项,格式为 key=value
# """
# with open(output_file, 'w', encoding='utf-8') as file:
# # 遍历 kwargs将每个 key, value 写入文件
# for key, value in kwargs.items():
# file.write(f"{key} = {format_all_data(value)}\n\n")
# # 打印成功信息
# print(f"\033[93m{output_file} file generated successfully\033[0m")
# 批量将参数内容写入文件 # V2 加入训练过程可视化 TODO
def format_all_data(value):
"""
一个辅助函数,用于将 Python 对象格式化为适合写入配置文件的字符串。
使用 repr() 可以确保字符串、列表、字典等都保持其 Python 语法格式。
"""
return repr(value)
def write_config_to_file(output_file, **kwargs):
"""
将传入的任意数量的参数以及一个自动生成的 visualizer 配置写入指定文件。
:param output_file: 要写入的文件路径 (e.g., 'configs/exp1.py')
:param kwargs: 任意数量的配置项,格式为 key=value
"""
# --- 1. 从 output_file 路径中提取实验名称 ---
# 首先获取基本文件名 (e.g., 'exp1.py')
base_name = os.path.basename(output_file)
# 然后去掉文件扩展名,得到纯净的实验名 (e.g., 'exp1')
experiment_name, _ = os.path.splitext(base_name)
# --- 2. 构建 visualizer 配置字典 ---
vis_backends = [
dict(type='LocalVisBackend'),
dict(type='TensorboardVisBackend'),
dict(
type='WandbVisBackend',
init_kwargs=dict(
project='Seg_MMSeg_Test', # 你的 wandb 项目名称
name=experiment_name # 使用上面提取的文件名作为实验名
)
)
]
visualizer = dict(
name='visualizer',
type='SegLocalVisualizer',
vis_backends=vis_backends
)
# --- 3. 将自动生成的 visualizer 添加到要写入的内容中 ---
# 如果 kwargs 中已经有 'visualizer',它将会被新的配置覆盖
kwargs['vis_backends'] = vis_backends
kwargs['visualizer'] = visualizer
# --- 4. 将所有配置项写入文件 ---
with open(output_file, 'w', encoding='utf-8') as file:
# 遍历所有配置项(包括新加入的 visualizer写入文件
for key, value in kwargs.items():
# 使用 format_all_data 保证格式正确
file.write(f"{key} = {format_all_data(value)}\n\n")
# 打印成功信息
print(f"\033[93mConfiguration saved to '{output_file}' successfully.\033[0m")
print(f"\033[96mWandB experiment name will be: '{experiment_name}'\033[0m")
# 获取系统中所有 GPU 的可用显存信息。
def get_gpu_info():
"""
获取系统中所有 GPU 的可用显存信息。
:return: GPU 显存信息的列表,列表中的每一项是一个 (GPU编号, 剩余显存) 元组
"""
try:
# 使用 nvidia-smi 命令获取 GPU 显存信息
result = subprocess.run(
['nvidia-smi', '--query-gpu=index,memory.free', '--format=csv,noheader,nounits'],
stdout=subprocess.PIPE,
encoding='utf-8'
)
# 解析结果
gpu_info = []
lines = result.stdout.strip().split('\n')
for line in lines:
index, memory_free = line.split(',')
gpu_info.append((int(index.strip()), int(memory_free.strip())))
return gpu_info
except FileNotFoundError:
print("\033[91mError: nvidia-smi 命令未找到,请确保 NVIDIA 驱动正确安装。\033[0m")
return []
# 批量生成dict
def create_dict_by_kwargs(**kwargs):
# 批量生成字典kwargs 会自动收集所有传入的命名参数
return {key: value for key, value in kwargs.items() if value is not None}
if __name__ == '__main__':
# 示例字典
my_dict = {'pretrained': './My_Local_Model/open_mmlab/resnet50_v1c.pth'}
# 打印成 dict() 的格式
print(format_dict(my_dict))

View File

@@ -0,0 +1,236 @@
# 选择crop_size大小
def select_crop_size(predefined_options=[(512, 512), (512, 1024), (769, 769)]):
"""
让用户选择裁剪大小。默认选择 (512, 512),用户可以选择其他预定义的裁剪大小,
或者输入自定义的大小,且自定义的数字不能为负数。
:return: 选择的裁剪大小 (width, height)
"""
# 预定义的裁剪大小选项
predefined_options = {str(i+1): option for i, option in enumerate(predefined_options)}
# 显示可选的裁剪大小
print("可用的裁剪大小选项:")
for key, value in predefined_options.items():
print(f"{key}. {value}", end=" ")
print(f"{len(predefined_options)+1}. 自定义大小")
# 用户选择
choice = input("请选择裁剪大小选项 (默认 1): ").strip()
# 如果用户没有输入,使用默认值
if choice == "" or choice == "1":
return predefined_options["1"]
# 如果用户选择了预定义的选项
if choice in predefined_options:
return predefined_options[choice]
# 如果用户选择自定义大小
if choice == f"{len(predefined_options)+1}":
while True:
try:
width = int(input("请输入裁剪宽度 (正整数): "))
height = int(input("请输入裁剪高度 (正整数): "))
if width > 0 and height > 0:
return (width, height)
else:
print("宽度和高度必须是正整数。")
except ValueError:
print("输入无效,请输入有效的正整数。")
# 如果用户输入无效,返回默认值
print("无效选择,使用默认值 (512, 512)")
return predefined_options["1"]
# 大字典,包含所有模型的信息
pretrained_models_dict = {
"openmmlab/resnet18_v1c": {'pth': './My_Local_Model/open_mmlab/resnet18_v1c.pth', 'depth': 18, 'type':'ResNetV1c'},
"openmmlab/resnet50_v1c": {'pth': './My_Local_Model/open_mmlab/resnet50_v1c.pth', 'depth': 50, 'type':'ResNetV1c'},
"openmmlab/resnet101_v1c": {'pth': './My_Local_Model/open_mmlab/resnet101_v1c.pth', 'depth': 101, 'type':'ResNetV1c'},
"torchvision://resnet18": {'pth': './My_Local_Model/torchvision_012/resnet18.pth', 'depth': 18, 'type':'ResNet'},
"torchvision://resnet50": {'pth': './My_Local_Model/torchvision_012/resnet50.pth', 'depth': 50, 'type':'ResNet'},
"torchvision://resnet101": {'pth': './My_Local_Model/torchvision_012/resnet101.pth', 'depth': 101, 'type':'ResNet'},
"openmmlab/pidnet-s":{'pth': './My_Local_Model/open_mmlab/pidnet-s.pth', 'type':'pidnet', 'size':'small'},
"openmmlab/pidnet-m":{'pth': './My_Local_Model/open_mmlab/pidnet-m.pth', 'type':'pidnet', 'size':'medium'},
"openmmlab/pidnet-l":{'pth': './My_Local_Model/open_mmlab/pidnet-l.pth', 'type':'pidnet', 'size':'large'},
"openmmlab/ddrnet23-s":{'pth': './My_Local_Model/open_mmlab/ddrnet23-s.pth', 'type':'ddrnet', 'size':'small'},
"openmmlab/ddrnet23":{'pth': './My_Local_Model/open_mmlab/ddrnet23.pth', 'type':'ddrnet', 'size':'normal'},
"openmmlab/stdc1":{'pth': './My_Local_Model/open_mmlab/stdc1.pth', 'type':'stdc', 'size':'V1'},
"openmmlab/stdc2":{'pth': './My_Local_Model/open_mmlab/stdc2.pth', 'type':'stdc', 'size':'V2'},
"pretrain/vit-b16_p16_224-80ecf9dd.pth":{'pth': './My_Local_Model/pretrain/vit-b16_p16_224-80ecf9dd.pth'},
"pretrain/beit_base_patch16_224_pt22k_ft22k.pth":{'pth': './My_Local_Model/pretrain/beit_base_patch16_224_pt22k_ft22k.pth'},
"pretrain/beit_large_patch16_224_pt22k_ft22k.pth":{'pth': './My_Local_Model/pretrain/beit_large_patch16_224_pt22k_ft22k.pth'},
"pretrain/swin_large-d5bdebaf.pth":{'pth': './My_Local_Model/pretrain/swin_large_patch4_window7_224_22k_20220308-d5bdebaf.pth', 'type':'swin', 'size':'large'},
"pretrain/swin_tiny-f41b89d3.pth":{'pth': './My_Local_Model/pretrain/swin_tiny_patch4_window7_224_20220308-f41b89d3.pth', 'type':'swin', 'size':'tiny'},
"mae_pretrain_vit_base_mmcls.pth":{'pth': './My_Local_Model/pretrain/mae_pretrain_vit_base_mmcls.pth'},
"open-mmlab://msra/hrnetv2_w18":{'pth': './My_Local_Model/open_mmlab/msra/hrnetv2_w18.pth'},
"open-mmlab://msra/hrnetv2_w18_small":{'pth': './My_Local_Model/open_mmlab/msra/hrnetv2_w18_small.pth'},
'open-mmlab://msra/hrnetv2_w48':{'pth': './My_Local_Model/open_mmlab/msra/hrnetv2_w48.pth'},
"pretrain/swin_large-6580f57d.pth":{'pth': './My_Local_Model/pretrain/swin_large_patch4_window12_384_22k_20220412-6580f57d.pth', 'type':'swin', 'size':'large'}, # mask2former
"pretrain/swin_base-e5c09f74.pth":{'pth': './My_Local_Model/pretrain/swin_base_patch4_window12_384_22k_20220317-e5c09f74.pth', 'type':'swin', 'size':'base'}, # mask2former
"pretrain/swin_small-7ba6d6dd.pth":{'pth': './My_Local_Model/pretrain/swin_small_patch4_window7_224_20220317-7ba6d6dd.pth', 'type':'swin', 'size':'small'}, # mask2former
"pretrain/swin_tiny-1cdeb081.pth":{'pth': './My_Local_Model/pretrain/swin_tiny_patch4_window7_224_20220317-1cdeb081.pth', 'type':'swin', 'size':'tiny'}, # mask2former
# 这里可以包含更多模型信息...
}
# 选择预训练模型
def select_pretrained_model(model_list, need_select_pretrained = False, pretrained_models_dict = pretrained_models_dict):
"""
让用户从给定的模型列表中选择预训练模型、是否选择预训练(否-默认开启预训练),并返回对应的 pth 路径和 其他 信息。
:param model_list: 可用的模型名称列表,例如 ['openmmlab/resnet50_v1c', 'openmmlab/resnet101_v1c']
:return: (pretrained_pth, depth)
"""
# 过滤传入的模型列表,确保它们在字典中有信息
valid_models = {key: pretrained_models_dict[key] for key in model_list if key in pretrained_models_dict}
if not valid_models:
print("错误:提供的模型列表中没有有效的模型信息。")
return None, None
# 显示可用的预训练模型
print("可用的预训练模型类型:")
for i, (model_name, model_info) in enumerate(valid_models.items(), 1):
print(f"{i}. {model_name}")
# 用户选择
choice = input(f"请选择预训练模型编号 (1-{len(valid_models)}, 默认 1): ").strip()
# 如果用户没有输入,或输入无效,使用默认值
if not choice.isdigit() or not (1 <= int(choice) <= len(valid_models)):
choice = "1"
# 获取用户选择的模型信息
selected_model_name = list(valid_models.keys())[int(choice) - 1]
selected_model_info = valid_models[selected_model_name]
# TODO 选择特定信息
# # 返回模型的 pth 路径和 depth
pretrained_pth = selected_model_info['pth']
# depth = selected_model_info['depth']
# print(f" 已选择模型: {selected_model_name} depth: {depth} pth: {pretrained_pth}")
print(f" 已选择模型: {selected_model_name} pth: {pretrained_pth}")
if need_select_pretrained == False:
print("默认开启预训练")
select_pretrained = True
return selected_model_name, select_pretrained, pretrained_pth, selected_model_info
else:
# 提示用户输入 True 或 False或者直接按 Enter 默认使用预训练模型
choice = input("是否使用预训练模型?输入 Y使用或 N不使用直接按 Enter 默认使用True")
while True:
# 如果用户没有输入,默认使用预训练模型
if choice == '':
select_pretrained = True
break
# 转换输入为布尔值
elif choice.lower() == 'y':
select_pretrained = True
break
elif choice.lower() == 'n':
select_pretrained = False
break
else:
print("无效输入,请输入 'True''False',或直接按 Enter 选择默认值")
return selected_model_name, select_pretrained, pretrained_pth, selected_model_info
# 大字典,包含所有模型的信息
samplers_dict = {
"OHEMPixelSampler": dict(type='OHEMPixelSampler', thresh=0.7, min_kept=10000),
# 这里可以包含更多采样函数模型信息...
}
# 选择采样函数
def select_sampler(sampler_list, samplers_dict = samplers_dict):
# 提示用户输入 True 或 False或者直接按 Enter 默认使用预训练模型
choice = input("是否使用采样函数?输入 Y使用或 N不使用直接按 Enter 默认不使用False")
while True:
# 如果用户没有输入,默认使用预训练模型
if choice == '':
use_sampler = False
return None, use_sampler, None
# 转换输入为布尔值
elif choice.lower() == 'y':
use_sampler = True
break
elif choice.lower() == 'n':
use_sampler = False
return None, use_sampler, None
else:
print("无效输入,请输入 'True''False',或直接按 Enter 选择默认值")
# 过滤传入的模型列表,确保它们在字典中有信息
valid_samplers = {key: samplers_dict[key] for key in sampler_list if key in samplers_dict}
if not valid_samplers:
print("错误:提供的采样函数列表中没有有效的采样函数信息。")
return None, None
# 如果只有一个可用的采样函数,直接选择
if len(valid_samplers) != 1:
# 显示可用的采样函数
print("可用的采样函数:")
for i, (sampler_name, sampler_info) in enumerate(valid_samplers.items(), 1):
print(f"{i}. {sampler_name}")
# 用户选择
choice = input(f"请选择采样函数编号 (1-{len(valid_samplers)}, 默认 1): ").strip()
# 如果用户没有输入,或输入无效,使用默认值
if not choice.isdigit() or not (1 <= int(choice) <= len(valid_samplers)):
choice = "1"
else:
choice = "1"
# 获取用户选择的模型信息
selected_sampler_name = list(valid_samplers.keys())[int(choice) - 1]
selected_sampler_info = valid_samplers[selected_sampler_name]
print(f" 已选择采样函数: {selected_sampler_name}")
return selected_sampler_name, use_sampler, selected_sampler_info
# 选择test_cfg中是否滑动是否默认选择select_slide
def select_test_cfg_slide(crop_size, select_slide=False):
"""
让用户选择是否使用滑动窗口,并根据选择设置相应的模式和参数。
:param crop_size: 输入的裁剪大小 (宽, 高)
:return: test_cfg_mode, test_cfg_crop_div_stride, crop_size
"""
if select_slide == False:
# 提示用户是否选择滑动窗口模式
use_slide = input("是否选择滑动窗口模式?(y/n, 默认 n): ").strip().lower()
else:
use_slide = 'y' # 使用滑动窗口模式
# 默认不使用滑动窗口模式
if use_slide == 'y':
test_cfg_mode = 'slide'
# 提示用户输入 test_cfg_crop_div_stride默认值为 1.5
try:
test_cfg_crop_div_stride = input("请输入滑动窗口的 stride 除以 crop_size 比例 (默认 1.5): ").strip()
test_cfg_crop_div_stride = float(test_cfg_crop_div_stride) if test_cfg_crop_div_stride else 1.5
except ValueError:
print("输入无效,使用默认比例 1.5")
test_cfg_crop_div_stride = 1.5
# 计算 stride
stride = tuple(int(c / test_cfg_crop_div_stride) for c in crop_size)
print(f" 已选择滑动窗口模式: {test_cfg_mode}")
print(f"crop_size: {crop_size}")
print(f"stride: {stride}")
else:
test_cfg_mode = None # 默认模式
test_cfg_crop_div_stride = None
stride = None
print(f" 关闭滑动窗口模式")
return test_cfg_mode, test_cfg_crop_div_stride

View File

@@ -0,0 +1,46 @@
from PIL import Image
import os
import numpy as np
from tqdm import tqdm
def calculate_pic_std_and_mean(dataset_dir = r'./My_Data/A_Ori'):
# 获取所有jpg图像文件
image_files = [os.path.join(dataset_dir, filename) for filename in os.listdir(dataset_dir) if filename.lower().endswith(('.jpg', '.png', '.tiff', '.jpeg', '.bmp'))]
# 初始化用于存储累积的像素值
sum_pixels_normalized = np.zeros(3)
sum_squared_pixels_normalized = np.zeros(3)
num_pixels = 0
# 使用tqdm创建一个进度条
for image_file in tqdm(image_files, desc="Calculating mean and std"):
image = Image.open(image_file).convert('RGB') # 确保图像为RGB
image = np.array(image) # 原图像像素范围[0, 255]
# 归一化到[0, 1]范围
image_normalized = image / 255.0
# 累积归一化像素值和归一化像素平方值
sum_pixels_normalized += np.sum(image_normalized, axis=(0, 1)) # 按通道累积
sum_squared_pixels_normalized += np.sum(image_normalized ** 2, axis=(0, 1)) # 按通道累积像素平方值
num_pixels += image.shape[0] * image.shape[1] # 累积总像素数
# 计算整个数据集的归一化后的均值
mean_normalized = sum_pixels_normalized / num_pixels
# 计算整个数据集的归一化后的标准差
variance_normalized = sum_squared_pixels_normalized / num_pixels - mean_normalized ** 2
variance_normalized = np.maximum(variance_normalized, 0) # 防止负数
std_normalized = np.sqrt(variance_normalized)
# 反归一化回[0, 255]范围
mean = mean_normalized * 255
std = std_normalized * 255
print(f"\033[93m计算得图片均值-Mean: {mean} 计算得图片方差-Std: {std}\033[0m")
return mean, std
if __name__ == '__main__':
# 示例调用
calculate_pic_std_and_mean()

Some files were not shown because too many files have changed in this diff Show More