first commit
This commit is contained in:
117
Seg_All_In_One_YoloModel/Tool_Yolo_Copy_Best_Model.sh
Normal file
117
Seg_All_In_One_YoloModel/Tool_Yolo_Copy_Best_Model.sh
Normal file
@@ -0,0 +1,117 @@
|
||||
#!/bin/bash
|
||||
|
||||
# +----------------------------------------------------------------------------+
|
||||
# | 脚本: 将指定的 .pt 文件 (默认为 best.pt) 从 Nas_BackUp_Seg 复制到 BestMode... |
|
||||
# | 功能: 1. 自动查找源目录中所有 "weights/[指定文件名].pt" 文件。 |
|
||||
# | 2. 将其复制到目标目录,并从路径中移除 "/weights" 中间层。
|
||||
# | 3. 仅在目标文件不存在或内容不一致时才执行复制,提高效率。
|
||||
# | 4. 提供可视化进度条和最终统计。
|
||||
# | V3.0: «-- 修改: 改为使用 --pt_name 命名参数 |
|
||||
# | 用法: ./script.sh [--pt_name "文件名.pt"] (例如: ./script.sh --pt_name "100.pt") «-- 修改
|
||||
# +----------------------------------------------------------------------------+
|
||||
|
||||
# --- 1. 目录设置 ---
|
||||
SOURCE_BASE_DIR=$(realpath "../Hardisk")
|
||||
DEST_BASE_DIR=$(realpath "../BestMode_Predict_Results_DataSet_Public")
|
||||
|
||||
# --- 1b. 参数设置 --- «-- 修改: 替换为参数解析循环
|
||||
# 设置默认文件名
|
||||
TARGET_FILENAME="best.pt"
|
||||
|
||||
# 解析命令行参数
|
||||
while [[ "$#" -gt 0 ]]; do
|
||||
case $1 in
|
||||
--pt_name)
|
||||
if [[ -n "$2" && ! "$2" =~ ^-- ]]; then
|
||||
TARGET_FILENAME="$2"
|
||||
shift # 消耗 "--pt_name"
|
||||
shift # 消耗它的值
|
||||
else
|
||||
echo "错误: '--pt_name' 需要一个文件名参数。" >&2
|
||||
exit 1
|
||||
fi
|
||||
;;
|
||||
*)
|
||||
echo "未知参数: $1" >&2
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
# --- 2. 环境检查 ---
|
||||
if [ ! -d "$SOURCE_BASE_DIR" ]; then
|
||||
echo "错误: 源目录 '$SOURCE_BASE_DIR' 不存在。"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
mkdir -p "$DEST_BASE_DIR"
|
||||
|
||||
echo "正在扫描源目录: $SOURCE_BASE_DIR"
|
||||
echo "目标模型文件: $TARGET_FILENAME" # 告知用户当前在找哪个文件
|
||||
|
||||
# --- 3. 查找所有符合条件的文件 ---
|
||||
echo "正在查找所有符合 '...-Yolo/YOLO*_*/weights/$TARGET_FILENAME' 结构的文件..."
|
||||
readarray -d '' files_to_process < <(find "$SOURCE_BASE_DIR"/*-Yolo -path "*/YOLO*_*/weights/$TARGET_FILENAME" -type f -print0 2>/dev/null)
|
||||
TOTAL_FILES=${#files_to_process[@]}
|
||||
|
||||
if [ "$TOTAL_FILES" -eq 0 ]; then
|
||||
echo "在源目录中没有找到任何 '$TARGET_FILENAME' 文件。脚本退出。"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
echo "总共找到 $TOTAL_FILES 个 '$TARGET_FILENAME' 文件需要处理。"
|
||||
echo "--------------------------------------------------"
|
||||
|
||||
# --- 4. 进度条函数 ---
|
||||
print_progress() {
|
||||
local current=$1
|
||||
local total=$2
|
||||
local term_width=$(tput cols)
|
||||
local bar_width=$((term_width - 35)) # 为文本留出更多空间
|
||||
|
||||
if [ "$bar_width" -lt 10 ]; then
|
||||
bar_width=10
|
||||
fi
|
||||
local percent=$((current * 100 / total))
|
||||
local filled_len=$((bar_width * percent / 100))
|
||||
local bar=""
|
||||
for ((i=0; i<filled_len; i++)); do bar+="="; done
|
||||
for ((i=filled_len; i<bar_width; i++)); do bar+=" "; done
|
||||
|
||||
printf "\r拷贝进度: %3d%% [%s] %d/%d" $percent "$bar" $current $total
|
||||
}
|
||||
|
||||
# --- 5. 执行复制操作 ---
|
||||
CURRENT_FILE=0
|
||||
COPIED_COUNT=0
|
||||
SKIPPED_COUNT=0
|
||||
|
||||
for source_file_path in "${files_to_process[@]}"; do
|
||||
if [[ -z "$source_file_path" ]]; then
|
||||
continue
|
||||
fi
|
||||
|
||||
((CURRENT_FILE++))
|
||||
|
||||
relative_path="${source_file_path#"$SOURCE_BASE_DIR"}"
|
||||
dest_file_path="${DEST_BASE_DIR}${relative_path}"
|
||||
dest_dir_path=$(dirname "$dest_file_path")
|
||||
|
||||
if [ ! -f "$dest_file_path" ] || ! cmp -s "$source_file_path" "$dest_file_path"; then
|
||||
mkdir -p "$dest_dir_path"
|
||||
cp "$source_file_path" "$dest_file_path"
|
||||
((COPIED_COUNT++))
|
||||
else
|
||||
((SKIPPED_COUNT++))
|
||||
fi
|
||||
|
||||
print_progress $CURRENT_FILE $TOTAL_FILES
|
||||
done
|
||||
|
||||
# --- 6. 打印最终总结 ---
|
||||
echo ""
|
||||
echo "--------------------------------------------------"
|
||||
echo "所有操作已完成!"
|
||||
echo " - 总共处理文件: $TOTAL_FILES (目标文件: $TARGET_FILENAME)"
|
||||
echo " - 已复制或更新的文件: $COPIED_COUNT"
|
||||
echo " - 因已存在且内容一致而跳过的文件: $SKIPPED_COUNT"
|
||||
2
Seg_All_In_One_YoloModel/Yolo可视化测试/Yolo可视化测试-使用手册
Normal file
2
Seg_All_In_One_YoloModel/Yolo可视化测试/Yolo可视化测试-使用手册
Normal file
@@ -0,0 +1,2 @@
|
||||
# 本手册主要目的在于判断给Yolo模型的各个层是否能成功生成热图
|
||||
使用方法: python yolo_layer_tester.py --model "YOLOv8m-seg" --cam_method "GradCAM" --pt_name "best.pt"
|
||||
359
Seg_All_In_One_YoloModel/Yolo可视化测试/yolo_layer_tester.py
Normal file
359
Seg_All_In_One_YoloModel/Yolo可视化测试/yolo_layer_tester.py
Normal file
@@ -0,0 +1,359 @@
|
||||
import logging
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from ultralytics import YOLO
|
||||
import random
|
||||
import sys
|
||||
import os
|
||||
from typing import Dict, Any, Tuple, List # [!!] 确保导入 List
|
||||
|
||||
# --- [!! 关键修改 !!] 动态添加项目根目录到 sys.path ---
|
||||
# (来自您之前的目录结构)
|
||||
script_path = Path(__file__).resolve()
|
||||
project_root = script_path.parent.parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
# ----------------------------------------------------
|
||||
|
||||
# 现在下面这行导入就可以正常工作了
|
||||
import yolo_config as config
|
||||
|
||||
# --- [导入 V3 的 CAM 库] ---
|
||||
from pytorch_grad_cam import (
|
||||
GradCAM, GradCAMPlusPlus, XGradCAM, EigenCAM,
|
||||
HiResCAM, LayerCAM, RandomCAM, EigenGradCAM
|
||||
)
|
||||
from pytorch_grad_cam.utils.image import show_cam_on_image
|
||||
from pytorch_grad_cam.base_cam import BaseCAM
|
||||
|
||||
# --- 日志设置 ---
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s [%(levelname)s] %(message)s",
|
||||
handlers=[logging.StreamHandler()]
|
||||
)
|
||||
logging.getLogger("pytorch_grad_cam").setLevel(logging.WARNING)
|
||||
|
||||
# --- [!! 新增 !!] 从 V1 移植的函数 ---
|
||||
def find_trained_models(outputs_dir: Path, model_key: str, pt_name: str) -> List[str]:
|
||||
"""
|
||||
扫描输出目录,查找特定基础模型的所有有效的、已完成的训练项目。
|
||||
(函数来自 yolo_predict_visualize_nn_V1.py)
|
||||
"""
|
||||
trained_models = []
|
||||
if not outputs_dir.is_dir():
|
||||
logging.warning(f"输出目录不存在: {outputs_dir}")
|
||||
return []
|
||||
|
||||
for project_folder in outputs_dir.iterdir():
|
||||
if project_folder.is_dir() and project_folder.name.startswith(model_key + '_'):
|
||||
if (project_folder / 'weights' / pt_name).exists():
|
||||
trained_models.append(project_folder.name)
|
||||
|
||||
trained_models.sort()
|
||||
return trained_models
|
||||
# -------------------------------------
|
||||
|
||||
# --- [保留] V3 的 CAM 方法字典 ---
|
||||
CAM_METHODS = {
|
||||
"GradCAM": GradCAM,
|
||||
"GradCAMPlusPlus": GradCAMPlusPlus,
|
||||
"XGradCAM": XGradCAM,
|
||||
"EigenCAM": EigenCAM,
|
||||
"HiResCAM": HiResCAM,
|
||||
"LayerCAM": LayerCAM,
|
||||
"RandomCAM": RandomCAM,
|
||||
"EigenGradCAM": EigenGradCAM,
|
||||
}
|
||||
|
||||
# --- [保留] V3 的 Target ---
|
||||
class ActivationMaximizationTarget:
|
||||
def __init__(self, channel=0):
|
||||
self.channel = channel
|
||||
def __call__(self, model_output):
|
||||
if model_output.ndim == 4:
|
||||
return model_output[:, self.channel, :, :].mean()
|
||||
elif model_output.ndim == 3:
|
||||
return model_output[self.channel, :, :].mean()
|
||||
else:
|
||||
raise ValueError(f"Unsupported model_output shape: {model_output.shape}")
|
||||
|
||||
# --- [保留] V3 的预处理 ---
|
||||
def preprocess_image(bgr_img: np.ndarray, imgsz: int, device: str) -> (torch.Tensor, Dict[str, Any]):
|
||||
# ... (此函数内容与上一版完全相同,为节省篇幅已折叠) ...
|
||||
orig_h, orig_w = bgr_img.shape[:2]
|
||||
rgb_img = bgr_img[:, :, ::-1]
|
||||
scale = min(imgsz / orig_w, imgsz / orig_h)
|
||||
new_w, new_h = int(orig_w * scale), int(orig_h * scale)
|
||||
resized_image = cv2.resize(rgb_img, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
|
||||
pad_w, pad_h = imgsz - new_w, imgsz - new_h
|
||||
pad_left, pad_right = pad_w // 2, pad_w - pad_w // 2
|
||||
pad_top, pad_bottom = pad_h // 2, pad_h - pad_h // 2
|
||||
padded_image = cv2.copyMakeBorder(resized_image, pad_top, pad_bottom, pad_left, pad_right,
|
||||
cv2.BORDER_CONSTANT, value=(114, 114, 114))
|
||||
img_float = np.float32(padded_image) / 255.0
|
||||
input_tensor = torch.from_numpy(img_float).permute(2, 0, 1).unsqueeze(0).to(device)
|
||||
padding_info = {
|
||||
"pad_top": pad_top, "new_h": new_h,
|
||||
"pad_left": pad_left, "new_w": new_w,
|
||||
"orig_h": orig_h, "orig_w": orig_w
|
||||
}
|
||||
return input_tensor, padding_info
|
||||
|
||||
|
||||
# --- [保留] V1/V3 的关键修复 ---
|
||||
def reshape_transform(outputs):
|
||||
if isinstance(outputs, tuple):
|
||||
return outputs[0]
|
||||
return outputs
|
||||
|
||||
# --- [!! 核心功能 !!] (此函数本身无需修改) ---
|
||||
def generate_heatmaps_for_all_layers(model_path: str, image_path: str, cam_method_name: str, model_name: str):
|
||||
"""
|
||||
加载模型和图像,遍历所有层,并为每个成功的层生成热图。
|
||||
(此函数内容与上一版完全相同,为节省篇幅已折叠)
|
||||
"""
|
||||
|
||||
# --- 1. 获取 CAM 方法 ---
|
||||
cam_class = CAM_METHODS.get(cam_method_name)
|
||||
if not cam_class:
|
||||
logging.error(f"无效的 CAM 方法: {cam_method_name}")
|
||||
return
|
||||
# --- 2. 定义 Targets ---
|
||||
targets = None
|
||||
if cam_method_name != "EigenCAM":
|
||||
targets = [ActivationMaximizationTarget(channel=0)]
|
||||
logging.info(f"已为 {cam_method_name} 设置 ActivationMaximizationTarget(channel=0)。")
|
||||
else:
|
||||
logging.info(f"将使用 {cam_method_name} (无需 Targets)。")
|
||||
# --- 3. 加载模型 ---
|
||||
try:
|
||||
logging.info(f"正在加载模型: {model_path}")
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
model = YOLO(model_path).to(device)
|
||||
model.model.eval()
|
||||
imgsz = model.model.args['imgsz']
|
||||
logging.info(f"模型加载成功。设备: {device}, 图像尺寸: {imgsz}")
|
||||
except Exception as e:
|
||||
logging.error(f"无法加载模型 {model_path}: {e}", exc_info=True)
|
||||
return
|
||||
# --- 4. 加载和预处理图像 ---
|
||||
try:
|
||||
logging.info(f"正在加载测试图像: {image_path}")
|
||||
orig_img_bgr = cv2.imread(image_path)
|
||||
if orig_img_bgr is None:
|
||||
logging.error(f"无法读取图像: {image_path}")
|
||||
return
|
||||
orig_img_rgb_float = np.float32(orig_img_bgr[:, :, ::-1]) / 255.0
|
||||
input_tensor, pad_info = preprocess_image(orig_img_bgr, imgsz, device)
|
||||
input_tensor.requires_grad_()
|
||||
logging.info(f"图像加载并预处理完成。Tensor shape: {input_tensor.shape}")
|
||||
except Exception as e:
|
||||
logging.error(f"处理图像时出错: {e}", exc_info=True)
|
||||
return
|
||||
# --- 5. 创建输出目录 (您的自定义逻辑) ---
|
||||
image_stem = Path(image_path).stem
|
||||
output_dir_name = f"./{cam_method_name}_{model_name}_{image_stem}_All_HeartMap"
|
||||
output_dir = Path(output_dir_name)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
logging.info(f"所有热图将保存到: {output_dir.resolve()}")
|
||||
|
||||
# --- 6. 遍历并测试所有层 ---
|
||||
base_model_layers = model.model.model
|
||||
successful_layers, failed_layers = [], []
|
||||
logging.info(f"\n" + "="*50)
|
||||
logging.info(f"开始遍历 {len(base_model_layers)} 个层 (使用 {cam_method_name})")
|
||||
logging.info("\n" + "="*50 + "\n")
|
||||
|
||||
for i, layer in enumerate(base_model_layers):
|
||||
layer_name = f"model.model.model[{i}]"
|
||||
layer_type = layer.__class__.__name__
|
||||
print(f"--- [ {i:02d} / {len(base_model_layers)-1} ] 尝试层: {layer_name:<25} (类型: {layer_type}) ---")
|
||||
try:
|
||||
with cam_class(model=model.model, target_layers=[layer], reshape_transform=reshape_transform) as cam:
|
||||
grayscale_cam = cam(input_tensor=input_tensor, targets=targets)
|
||||
if grayscale_cam is None: raise ValueError("CAM-Method 返回了 None")
|
||||
|
||||
# ... (结果提取、裁剪、保存逻辑) ...
|
||||
if isinstance(grayscale_cam, (list, tuple)): result = grayscale_cam[0]
|
||||
else: result = grayscale_cam
|
||||
if isinstance(result, torch.Tensor): result = result.detach().cpu().numpy()
|
||||
if result.ndim == 4: result = result[0].mean(axis=0)
|
||||
elif result.ndim == 3: result = result[0, :]
|
||||
elif result.ndim != 2: raise TypeError(f"CAM 输出维度异常 {result.shape}")
|
||||
|
||||
cam_cropped = result[pad_info['pad_top'] : pad_info['pad_top'] + pad_info['new_h'],
|
||||
pad_info['pad_left'] : pad_info['pad_left'] + pad_info['new_w']]
|
||||
cam_upscaled = cv2.resize(cam_cropped, (pad_info['orig_w'], pad_info['orig_h']), interpolation=cv2.INTER_LINEAR)
|
||||
cam_normalized = (cam_upscaled - cam_upscaled.min()) / (cam_upscaled.max() + 1e-8)
|
||||
|
||||
visualization = show_cam_on_image(orig_img_rgb_float, cam_normalized, use_rgb=True, image_weight=0.5)
|
||||
viz_bgr = visualization[:, :, ::-1]
|
||||
heatmap_uint8 = np.uint8(255 * cam_normalized)
|
||||
heatmap_color = cv2.applyColorMap(heatmap_uint8, cv2.COLORMAP_JET)
|
||||
|
||||
base_filename = f"Layer_{i:02d}_{layer_type}"
|
||||
cv2.imwrite(str(output_dir / f"{base_filename}_overlay.jpg"), viz_bgr)
|
||||
# cv2.imwrite(str(output_dir / f"{base_filename}_heatmap.jpg"), heatmap_color)
|
||||
|
||||
print(f" [√] 成功! 已保存 {base_filename}_overlay.jpg 和 _heatmap.jpg")
|
||||
successful_layers.append((i, layer_name, layer_type))
|
||||
|
||||
except Exception as e:
|
||||
error_msg = str(e).split('\n')[0]
|
||||
print(f" [X] 失败. 错误: {error_msg}")
|
||||
failed_layers.append((i, layer_name, layer_type, error_msg))
|
||||
|
||||
# --- 7. 打印总结报告 ---
|
||||
# ... (报告逻辑与上一版相同) ...
|
||||
print("\n" + "="*60)
|
||||
print(" 全层热图生成报告")
|
||||
print("="*60)
|
||||
print(f"模型: {Path(model_path).name} (来自: {Path(model_path).parent.parent})")
|
||||
print(f"测试图像: {Path(image_path).name}")
|
||||
print(f"测试方法: {cam_method_name}")
|
||||
print(f"输出目录: {output_dir.resolve()}")
|
||||
print(f"总计: {len(successful_layers)} 个成功 (已生成), {len(failed_layers)} 个失败/跳过")
|
||||
print("\n" + "--- [√] 成功生成的层 ---")
|
||||
if not successful_layers: print(" (无)")
|
||||
else:
|
||||
for i, name, type in successful_layers: print(f" - [{i:02d}] {name:<25} (类型: {type})")
|
||||
print("\n" + "--- [X] 失败或跳过的层 ---")
|
||||
if not failed_layers: print(" (无)")
|
||||
else:
|
||||
for i, name, type, error in failed_layers:
|
||||
print(f" - [{i:02d}] {name:<25} (类型: {type})")
|
||||
print(f" └> 失败原因: {error[:100]}...")
|
||||
print("\n" + "="*60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
# --- [!! 重构 !!] (以下逻辑来自 yolo_predict_visualize_nn_V1.py) ---
|
||||
|
||||
config.show_config_summary()
|
||||
|
||||
parser = argparse.ArgumentParser(description="全层热图生成器 (V1 模型选择模式)")
|
||||
|
||||
# 1. (来自 V1) --model 参数
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
required=True,
|
||||
choices=list(config.MODEL_CONFIGS.keys()),
|
||||
help="选择一个基础模型类型来筛选其训练历史。"
|
||||
)
|
||||
# 2. (来自 V1) --pt_name 参数
|
||||
parser.add_argument(
|
||||
"--pt_name",
|
||||
type=str,
|
||||
default="best.pt",
|
||||
help="要使用的权重文件名 (例如 'best.pt' 或 'epoch100.pt')。"
|
||||
)
|
||||
# 3. (来自您) --source 参数
|
||||
parser.add_argument(
|
||||
"--source",
|
||||
type=str,
|
||||
default="random",
|
||||
help="测试图像源: 'random' (从 config.TEST_IMAGE_DIR 随机选), "
|
||||
"'path/to/image.jpg' (指定文件), "
|
||||
"或 'path/to/dir/' (从该目录随机选)"
|
||||
)
|
||||
# 4. (来自您) --cam_method 参数
|
||||
parser.add_argument(
|
||||
"--cam_method",
|
||||
type=str,
|
||||
default="GradCAM",
|
||||
choices=list(CAM_METHODS.keys()),
|
||||
help=f"选择要使用的 CAM 可视化方法。默认为: GradCAM。"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# --- (来自 V1) 模型搜索和选择逻辑 ---
|
||||
available_runs = find_trained_models(config.PREDICT_BEST_MODEL_DIR, args.model, args.pt_name)
|
||||
|
||||
run_to_use = None
|
||||
|
||||
if not available_runs:
|
||||
logging.error(f"错误:在 {config.PREDICT_BEST_MODEL_DIR} 中未找到模型 '{args.model}' (使用 '{args.pt_name}') 的有效训练记录。")
|
||||
sys.exit(1)
|
||||
|
||||
elif len(available_runs) == 1:
|
||||
run_to_use = available_runs[0]
|
||||
logging.info(f"只找到一个训练版本,已自动选择: {run_to_use}")
|
||||
|
||||
else:
|
||||
print(f"\n为模型 '{args.model}' (使用 '{args.pt_name}') 找到多个训练版本:")
|
||||
for i, run_name in enumerate(available_runs, 1):
|
||||
print(f" [{i}] {run_name}")
|
||||
|
||||
while True:
|
||||
try:
|
||||
choice = input(f"请输入您想使用的版本序号 (1-{len(available_runs)}): ")
|
||||
choice_index = int(choice)
|
||||
if 1 <= choice_index <= len(available_runs):
|
||||
run_to_use = available_runs[choice_index - 1]
|
||||
break
|
||||
else:
|
||||
print("错误:输入无效,请输入列表中的序号。")
|
||||
except ValueError:
|
||||
print("错误:请输入一个数字。")
|
||||
except (KeyboardInterrupt, EOFError):
|
||||
print("\n操作已取消。")
|
||||
sys.exit(0)
|
||||
|
||||
# --- (!! 融合 !!) 执行逻辑 ---
|
||||
if run_to_use:
|
||||
# 1. (来自 V1) 确定最终模型路径
|
||||
model_to_use_path = config.PREDICT_BEST_MODEL_DIR / run_to_use / 'weights' / args.pt_name
|
||||
|
||||
if not model_to_use_path.exists():
|
||||
logging.error(f"严重错误:找不到权重文件 {model_to_use_path}。")
|
||||
sys.exit(1)
|
||||
|
||||
# 2. (来自您) 确定最终图像路径
|
||||
source_input = args.source
|
||||
image_to_test = None
|
||||
source_dir_for_random = None
|
||||
|
||||
if source_input.lower() == "random":
|
||||
source_dir_for_random = config.TEST_IMAGE_DIR
|
||||
logging.info(f"Source='random'。将从 config.TEST_IMAGE_DIR 随机选择图像: {source_dir_for_random}")
|
||||
elif Path(source_input).is_file():
|
||||
image_to_test = source_input
|
||||
logging.info(f"Source 是一个特定文件: {image_to_test}")
|
||||
elif Path(source_input).is_dir():
|
||||
source_dir_for_random = Path(source_input)
|
||||
logging.info(f"Source 是一个目录。将从该目录随机选择图像: {source_dir_for_random}")
|
||||
else:
|
||||
logging.error(f"错误: 指定的源路径无效: {source_input}")
|
||||
sys.exit(1)
|
||||
|
||||
# 如果需要随机选择...
|
||||
if image_to_test is None:
|
||||
if not source_dir_for_random or not source_dir_for_random.is_dir():
|
||||
logging.error(f"错误: 用于随机选择的目录无效: {source_dir_for_random}")
|
||||
sys.exit(1)
|
||||
|
||||
image_paths = sorted(
|
||||
list(source_dir_for_random.glob("*.jpg")) +
|
||||
list(source_dir_for_random.glob("*.jpeg")) +
|
||||
list(source_dir_for_random.glob("*.png"))
|
||||
)
|
||||
if not image_paths:
|
||||
logging.error(f"错误: 在目录 {source_dir_for_random} 中未找到任何图像 (.jpg, .jpeg, .png)。")
|
||||
sys.exit(1)
|
||||
|
||||
image_to_test = str(random.choice(image_paths))
|
||||
logging.info(f"已随机选择图像: {Path(image_to_test).name}")
|
||||
|
||||
# 3. (融合) 调用核心函数
|
||||
if image_to_test:
|
||||
generate_heatmaps_for_all_layers(
|
||||
model_path=str(model_to_use_path), # <-- 使用 V1 找到的路径
|
||||
image_path=image_to_test, # <-- 使用您的逻辑找到的路径
|
||||
cam_method_name=args.cam_method, # <-- 使用您的参数
|
||||
model_name = args.model
|
||||
)
|
||||
261
Seg_All_In_One_YoloModel/Yolo数据集构建/0_1_check_picture_pair.py
Normal file
261
Seg_All_In_One_YoloModel/Yolo数据集构建/0_1_check_picture_pair.py
Normal file
@@ -0,0 +1,261 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Set, Dict, Tuple
|
||||
|
||||
# <--- 新增: 导入 Pillow 库用于图像处理 ---
|
||||
try:
|
||||
from PIL import Image
|
||||
# 禁用解压缩炸弹检查,以防万一标签图非常大且简单
|
||||
Image.MAX_IMAGE_PIXELS = None
|
||||
except ImportError:
|
||||
print("错误: 本脚本需要 'Pillow' 库来检查图像尺寸。", file=sys.stderr)
|
||||
print("请运行: pip install Pillow", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
# --- 新增结束 ---
|
||||
|
||||
|
||||
# 定义支持的文件扩展名
|
||||
VALID_EXTENSIONS = {'.png', '.jpg', '.jpeg'}
|
||||
|
||||
def process_directory(path: Path,
|
||||
prefix: str = "",
|
||||
suffix: str = "") -> Tuple[Set[str], Dict[str, Path], Dict[Tuple[int, int], int]]: # <--- 修改:更新了返回类型
|
||||
"""
|
||||
处理单个目录,提取所有有效文件的文件名(stem),并根据前缀和后缀进行规范化。
|
||||
同时统计所有图像的尺寸。
|
||||
|
||||
返回:
|
||||
normalized_stems (Set[str]): 规范化处理后的文件名集合。
|
||||
normalized_to_full_path (Dict[str, Path]): 规范化文件名 -> 原始完整路径 的映射。
|
||||
size_counts (Dict[Tuple[int, int], int]): (宽度, 高度) -> 数量 的映射。 # <--- 新增
|
||||
"""
|
||||
if not path.is_dir():
|
||||
print(f"错误: 路径 '{path}' 不是一个有效的目录。", file=sys.stderr)
|
||||
return set(), {}, {} # <--- 修改
|
||||
|
||||
normalized_stems: Set[str] = set()
|
||||
normalized_to_full_path: Dict[str, Path] = {}
|
||||
size_counts: Dict[Tuple[int, int], int] = {} # <--- 新增:用于统计尺寸
|
||||
|
||||
for file_path in path.glob('*'):
|
||||
# 确保是文件,并且扩展名在我们的有效列表中
|
||||
if file_path.is_file() and file_path.suffix.lower() in VALID_EXTENSIONS:
|
||||
|
||||
# <--- 新增: 尝试获取图像尺寸 ---
|
||||
try:
|
||||
# 使用 with...as... 确保文件被正确关闭
|
||||
with Image.open(file_path) as img:
|
||||
size: Tuple[int, int] = img.size
|
||||
size_counts[size] = size_counts.get(size, 0) + 1
|
||||
except Exception as e:
|
||||
print(f"警告: 无法读取图像 '{file_path.name}' 的尺寸. 错误: {e}", file=sys.stderr)
|
||||
# --- 尺寸获取结束 ---
|
||||
|
||||
original_stem = file_path.stem
|
||||
normalized_stem = original_stem
|
||||
|
||||
# 仅当提供了前缀/后缀时才进行处理
|
||||
if prefix and normalized_stem.startswith(prefix):
|
||||
normalized_stem = normalized_stem[len(prefix):]
|
||||
|
||||
if suffix and normalized_stem.endswith(suffix):
|
||||
normalized_stem = normalized_stem[:-len(suffix)]
|
||||
|
||||
# 检查处理后是否重名,如果重名则发出警告
|
||||
if normalized_stem in normalized_to_full_path:
|
||||
print(f"警告: 规范化后文件名发生冲突。")
|
||||
print(f" '{file_path.name}' 和 '{normalized_to_full_path[normalized_stem].name}' 都变成了 '{normalized_stem}'")
|
||||
|
||||
normalized_stems.add(normalized_stem)
|
||||
normalized_to_full_path[normalized_stem] = file_path
|
||||
|
||||
return normalized_stems, normalized_to_full_path, size_counts # <--- 修改
|
||||
|
||||
|
||||
# <--- 新增: 打印尺寸统计的辅助函数 ---
|
||||
def print_size_report(title: str, size_counts: Dict[Tuple[int, int], int]):
|
||||
"""
|
||||
格式化并打印尺寸统计报告。
|
||||
"""
|
||||
print(f"\n{title}")
|
||||
if not size_counts:
|
||||
print(" (未找到或无法读取任何图像文件)")
|
||||
return
|
||||
|
||||
total_files = sum(size_counts.values())
|
||||
print(f" 总共 {total_files} 个文件,分布如下:")
|
||||
|
||||
# 按尺寸 (宽, 高) 排序输出
|
||||
for size, count in sorted(size_counts.items()):
|
||||
width, height = size
|
||||
print(f" - 尺寸 (宽, 高) {width}x{height}: {count} 个文件")
|
||||
# --- 新增结束 ---
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="比较两个文件夹中的文件名是否匹配,并可选择删除不匹配项。")
|
||||
parser.add_argument("-i", "--image",
|
||||
type=Path,
|
||||
default="./ORI",
|
||||
help="Image 文件夹路径")
|
||||
parser.add_argument("-l", "--label",
|
||||
type=Path,
|
||||
default="./Label",
|
||||
help="Label 文件夹路径")
|
||||
parser.add_argument("-p", "--prefix",
|
||||
type=str,
|
||||
default="",
|
||||
help="在 Label 文件名中要忽略的前缀")
|
||||
parser.add_argument("-s", "--suffix",
|
||||
type=str,
|
||||
default="",
|
||||
help="在 Label 文件名中要忽略的后缀")
|
||||
parser.add_argument("-y", "--yes",
|
||||
action="store_true",
|
||||
help="自动确认删除所有不匹配的文件,跳过交互式提示。")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# 1. 处理 Image 文件夹
|
||||
print(f"--- 正在处理 Image 文件夹: {args.image} ---")
|
||||
image_stems, image_norm_to_path, image_size_counts = process_directory(args.image) # <--- 修改
|
||||
if not image_stems:
|
||||
print(f"未在 Image 文件夹中找到任何 .png 或 .jpg 文件。")
|
||||
# <--- 新增: 打印 Image 尺寸报告 ---
|
||||
print_size_report("--- Image 文件夹尺寸统计 ---", image_size_counts)
|
||||
|
||||
# 2. 处理 Label 文件夹
|
||||
print(f"\n--- 正在处理 Label 文件夹: {args.label} ---")
|
||||
print(f"(忽略前缀: '{args.prefix}', 忽略后缀: '{args.suffix}')")
|
||||
label_stems_normalized, label_norm_to_path, label_size_counts = process_directory( # <--- 修改
|
||||
args.label,
|
||||
args.prefix,
|
||||
args.suffix
|
||||
)
|
||||
if not label_stems_normalized:
|
||||
print(f"未在 Label 文件夹中找到任何 .png 或 .jpg 文件。")
|
||||
# <--- 新增: 打印 Label 尺寸报告 ---
|
||||
print_size_report("--- Label 文件夹尺寸统计 ---", label_size_counts)
|
||||
|
||||
|
||||
# 3. 执行比较 (使用集合运算)
|
||||
matching_stems = image_stems.intersection(label_stems_normalized)
|
||||
extra_in_image = image_stems.difference(label_stems_normalized)
|
||||
extra_in_label_normalized = label_stems_normalized.difference(image_stems)
|
||||
extra_in_label_original_stems = {label_norm_to_path[stem].stem for stem in extra_in_label_normalized}
|
||||
|
||||
# 4. 输出结果
|
||||
print("\n" + "="*30)
|
||||
print(" 匹配结果报告")
|
||||
print("="*30)
|
||||
|
||||
print(f"\n匹配的文件总数: {len(matching_stems)}")
|
||||
|
||||
# <--- 新增: 检查匹配项的尺寸是否一致 ---
|
||||
print("\n--- 正在检查匹配文件的尺寸 ---")
|
||||
mismatched_size_pairs = []
|
||||
if not matching_stems:
|
||||
print("(无匹配文件,跳过尺寸检查)")
|
||||
else:
|
||||
print(f"正在检查 {len(matching_stems)} 对文件...")
|
||||
for stem in sorted(matching_stems):
|
||||
try:
|
||||
img_path = image_norm_to_path[stem]
|
||||
lbl_path = label_norm_to_path[stem]
|
||||
|
||||
# 再次打开以比较尺寸
|
||||
with Image.open(img_path) as img:
|
||||
img_size = img.size
|
||||
with Image.open(lbl_path) as lbl:
|
||||
lbl_size = lbl.size
|
||||
|
||||
if img_size != lbl_size:
|
||||
mismatched_size_pairs.append((img_path, img_size, lbl_path, lbl_size))
|
||||
|
||||
except Exception as e:
|
||||
print(f" 错误: 无法比较 '{stem}' 的尺寸. 错误: {e}", file=sys.stderr)
|
||||
|
||||
if not mismatched_size_pairs:
|
||||
print(f"√ 成功: 所有 {len(matching_stems)} 对匹配文件均具有相同的尺寸。")
|
||||
else:
|
||||
print(f"\n!!! 警告: 发现 {len(mismatched_size_pairs)} 对文件尺寸不匹配:")
|
||||
for img_path, img_size, lbl_path, lbl_size in mismatched_size_pairs:
|
||||
print(f" - [Image] {img_path.name} {img_size} != [Label] {lbl_path.name} {lbl_size}")
|
||||
# --- 尺寸检查结束 ---
|
||||
|
||||
print("\n--- Image 文件夹中多余的文件 (共 {} 个) ---".format(len(extra_in_image)))
|
||||
if not extra_in_image:
|
||||
print("(无)")
|
||||
else:
|
||||
for file_stem in sorted(extra_in_image):
|
||||
print(file_stem)
|
||||
|
||||
print("\n--- Label 文件夹中多余的文件 (共 {} 个) ---".format(len(extra_in_label_original_stems)))
|
||||
print("(显示的是原始文件名,非规范化名称)")
|
||||
if not extra_in_label_original_stems:
|
||||
print("(无)")
|
||||
else:
|
||||
for file_stem in sorted(extra_in_label_original_stems):
|
||||
print(file_stem)
|
||||
|
||||
print("\n" + "="*30)
|
||||
|
||||
# 5. 删除交互部分 (无修改)
|
||||
if not extra_in_image and not extra_in_label_normalized:
|
||||
print("\n所有文件均完美匹配,无需删除。")
|
||||
# <--- 新增: 如果文件名匹配,但尺寸不匹配,也在这里提醒一下 ---
|
||||
if mismatched_size_pairs:
|
||||
print("注意:虽然文件名匹配,但有 {} 对文件的尺寸不一致,请检查上面的报告。".format(len(mismatched_size_pairs)))
|
||||
return
|
||||
|
||||
# ... (后续的删除逻辑保持不变) ...
|
||||
confirm_delete = False
|
||||
if args.yes:
|
||||
print("\n检测到 -y/--yes 参数,将自动删除不匹配的文件。")
|
||||
confirm_delete = True
|
||||
else:
|
||||
try:
|
||||
choice = input("\n警告: 是否要永久删除所有上述 '多余' 的文件? (y/n): ").strip().lower()
|
||||
if choice in ['y', 'yes']:
|
||||
confirm_delete = True
|
||||
except (KeyboardInterrupt, EOFError):
|
||||
print("\n操作被用户中断。")
|
||||
confirm_delete = False
|
||||
|
||||
if confirm_delete:
|
||||
print("\n--- 正在删除 Image 文件夹中的多余文件 ---")
|
||||
deleted_count = 0
|
||||
for stem in sorted(extra_in_image):
|
||||
file_path = image_norm_to_path[stem]
|
||||
try:
|
||||
file_path.unlink()
|
||||
print(f" 已删除: {file_path.name}")
|
||||
deleted_count += 1
|
||||
except OSError as e:
|
||||
print(f" 删除失败: {file_path.name} (错误: {e})", file=sys.stderr)
|
||||
print(f"--- Image 文件夹共删除 {deleted_count} 个文件 ---")
|
||||
|
||||
print("\n--- 正在删除 Label 文件夹中的多余文件 ---")
|
||||
deleted_count = 0
|
||||
for norm_stem in sorted(extra_in_label_normalized):
|
||||
file_path = label_norm_to_path[norm_stem]
|
||||
try:
|
||||
file_path.unlink()
|
||||
print(f" 已删除: {file_path.name} (原始文件名)")
|
||||
deleted_count += 1
|
||||
except OSError as e:
|
||||
print(f" 删除失败: {file_path.name} (错误: {e})", file=sys.stderr)
|
||||
print(f"--- Label 文件夹共删除 {deleted_count} 个文件 ---")
|
||||
|
||||
print("\n删除操作完成。")
|
||||
|
||||
else:
|
||||
print("\n操作已取消。未删除任何文件。")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
122
Seg_All_In_One_YoloModel/Yolo数据集构建/0_2_TOOL_stack_pics.sh
Normal file
122
Seg_All_In_One_YoloModel/Yolo数据集构建/0_2_TOOL_stack_pics.sh
Normal file
@@ -0,0 +1,122 @@
|
||||
#!/bin/bash
|
||||
|
||||
usage() {
|
||||
echo "Usage: $0 -i <ori_image_directory> -l <ori_label_directory> -r <stack_result_directory> [ -a <alpha> -p <prefix> -s <suffix> -h]"
|
||||
echo "对image图片和label图片进行匹配(-i、-l -r均不能为空)(-p -s默认为空"" -a默认为\"0.3\") "
|
||||
echo "-i:原始image的路径,-l:原始label的路径,-p:前缀内容,-s:后缀内容(不用管文件后缀名),-h:帮助"
|
||||
echo "e.g. bash 0_2_TOOL_stack_pics.sh -i ./ori -l ./label -r ./result_0.3透明度 -a 0.3 -p Prefix -s _label"
|
||||
}
|
||||
|
||||
ori_image_directorys=""
|
||||
ori_label_directorys=""
|
||||
stack_result_directorys=""
|
||||
prefix=""
|
||||
suffix=""
|
||||
alpha="0.3"
|
||||
|
||||
while getopts "hl:i:r:p:s:a:" opt; do
|
||||
case $opt in
|
||||
h)
|
||||
usage
|
||||
exit 0
|
||||
;;
|
||||
i)
|
||||
ori_image_directorys=$OPTARG
|
||||
;;
|
||||
l)
|
||||
ori_label_directorys=$OPTARG
|
||||
;;
|
||||
p)
|
||||
prefix=$OPTARG
|
||||
;;
|
||||
s)
|
||||
suffix=$OPTARG
|
||||
;;
|
||||
r)
|
||||
stack_result_directorys=$OPTARG
|
||||
;;
|
||||
a)
|
||||
alpha=$OPTARG
|
||||
;;
|
||||
*)
|
||||
echo -e '\033[31m!!! Error, Illegal input !!!\033[0m'
|
||||
usage
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
# 判断输入地址是否为空
|
||||
if [ -z "$ori_label_directorys" ] || [ -z "$ori_image_directorys" ] || [ -z "$stack_result_directorys" ]; then
|
||||
echo -e "\033[31m输入地址 -i -l -z 存在空地址\033[0m"
|
||||
usage
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# 地址转化
|
||||
ori_image_directory=$(readlink -f "$ori_image_directorys")
|
||||
ori_label_directory=$(readlink -f "$ori_label_directorys")
|
||||
stack_result_directory=$(readlink -f "$stack_result_directorys")
|
||||
if [ -z "$ori_label_directory" ] || [ -z "$ori_image_directory" ]|| [ -z "$stack_result_directory" ]; then
|
||||
echo "image、label、result存在无法解析地址,程序退出"
|
||||
echo -e "\033[31mori_image_directory\033[0m: $ori_image_directorys"
|
||||
echo -e "\033[31mori_label_directory\033[0m: $ori_label_directorys"
|
||||
echo -e "\033[31mori_label_directory\033[0m: $stack_result_directorys"
|
||||
exit 1
|
||||
fi
|
||||
if [ ! -d "$ori_label_directory" ] || [ ! -d "$ori_image_directory" ]; then
|
||||
echo "image、label两目录有一个不存在,程序退出"
|
||||
echo -e "\033[31mori_image_directory\033[0m: $ori_image_directory"
|
||||
echo -e "\033[31mori_label_directory\033[0m: $ori_label_directorys"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# 获取当前脚本的路径和名称
|
||||
script_path=$(dirname "$0")
|
||||
# 将当前目录更改为脚本所在的路径
|
||||
cd "$script_path"
|
||||
|
||||
# 激活conda环境
|
||||
source /home/"$USER"/miniconda/bin/activate Deal_Data
|
||||
|
||||
echo -e "\033[32m_____ 0_2_TOOL_stack_pics.sh _____\033[0m"
|
||||
echo -e "\033[33mimage所在文件夹为$ori_image_directory\nlable所在文件夹为$ori_label_directory\033[0m"
|
||||
# 遍历label目录
|
||||
for file_path in "$ori_label_directory"/*; do
|
||||
# 判断是否是文件
|
||||
if [[ -f "$file_path" ]]; then
|
||||
file_name=$(basename "$file_path")
|
||||
# 判断文件名是否符合规范
|
||||
if [[ "$file_name" =~ .*\.(jpg|png|bmp|JPG|PNG|BMP) ]]; then # 判断是否有为图片
|
||||
# if [[ "$file_name" =~ "$prefix".*"$suffix".*\.(jpg|png|bmp|JPG|PNG|BMP)$ ]]; then # 判断是否有满足要求的文件名
|
||||
# 抽取文件名(有前缀、后缀的抽取前缀、后缀里面的,没有的返回整个)
|
||||
if [ -z $prefix ];then
|
||||
file_name_extract=$(echo $file_name | sed "s/"$prefix"\(.*\)"$suffix".*/\1/" | sed "s/\(.*\)\.\(jpg\|png\|bmp\|JPG\|PNG\|BMP\)$/\1/")
|
||||
else
|
||||
file_name_extract=$(echo $file_name | sed "s/".*$prefix"\(.*\)"$suffix".*/\1/" | sed "s/\(.*\)\.\(jpg\|png\|bmp\|JPG\|PNG\|BMP\)$/\1/")
|
||||
fi
|
||||
# 从label目录中看是否有此文件
|
||||
file_name_other=$(ls $ori_image_directory | grep "^${file_name_extract}\.")
|
||||
file_name_other=$(echo "$(echo "$file_name_other" | sed '/^$/d')" | head -n1) # 提取出文件名
|
||||
# 如果另一个目录没有此文件的话
|
||||
if [ -z "$file_name_other" ]; then
|
||||
echo "$file_name label中对应内容未在$ori_image_directory搜索到"
|
||||
# 建立相关存储文件夹
|
||||
if [ ! -d "$ori_label_directory/Not_pair_pics" ]; then
|
||||
mkdir -p "$ori_label_directory/Not_pair_pics" # 建立存储文件夹
|
||||
fi
|
||||
# 移动相关文件
|
||||
cp "$ori_label_directory/$file_name" "$ori_label_directory/Not_pair_pics"
|
||||
echo "$file_name" >> "$ori_label_directory/Not_pair_pics/not_pair.txt"
|
||||
else # 如果另一个目录有此配对文件的话,则运行相关程序
|
||||
echo "image中的$file_name_other,与lable中的$file_name"
|
||||
mkdir -p "$stack_result_directory"
|
||||
python 0_2_stack_picture.py "$ori_image_directory/$file_name_other" "$ori_label_directory/$file_name" "$stack_result_directory" "$alpha"
|
||||
echo ""
|
||||
fi
|
||||
|
||||
fi
|
||||
else
|
||||
echo "$file_path不是文件"
|
||||
fi
|
||||
done
|
||||
41
Seg_All_In_One_YoloModel/Yolo数据集构建/0_2_stack_picture.py
Normal file
41
Seg_All_In_One_YoloModel/Yolo数据集构建/0_2_stack_picture.py
Normal file
@@ -0,0 +1,41 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: UTF-8 -*-
|
||||
import cv2, os, sys
|
||||
|
||||
def Stack_pic(Background_path, Overlay_path, Result_dir, alpha=0.3):
|
||||
# 读取两张没有alpha通道的图片
|
||||
img1 = cv2.imread(Background_path) # 底层图片
|
||||
img2 = cv2.imread(Overlay_path) # 顶层图片
|
||||
|
||||
Result_name = os.path.splitext(os.path.basename(Background_path))[0]
|
||||
|
||||
# 将img2调整为与img1大小相同
|
||||
img2 = cv2.resize(img2, (img1.shape[1], img1.shape[0]))
|
||||
|
||||
# 将img2的透明度调整为20%
|
||||
overlay_alpha = alpha
|
||||
|
||||
# 将img2叠加到img1上
|
||||
overlay = cv2.addWeighted(img1, 1 - overlay_alpha, img2, overlay_alpha, 0)
|
||||
|
||||
# 保存结果
|
||||
if not os.path.exists(Result_dir):
|
||||
os.makedirs(Result_dir)
|
||||
cv2.imwrite(os.path.join(Result_dir, Result_name+'.png'), overlay)
|
||||
print("堆叠图片写入地址:", os.path.join(Result_dir, Result_name+'.png'))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
Background_path = sys.argv[1] # 背景所在路径
|
||||
Overlay_path = sys.argv[2] # 上层图片所在路径
|
||||
Result_dir = sys.argv[3] # 结果所在目录
|
||||
# 透明度,默认为0.3
|
||||
try:
|
||||
alpha = float(sys.argv[4])
|
||||
if(alpha > 1 or alpha < 0):
|
||||
print("alpha 透明度输入不正确,其值应该在0~1之间")
|
||||
alpha = 0.3
|
||||
except:
|
||||
alpha = 0.3
|
||||
# 进行对叠程序
|
||||
Stack_pic(Background_path, Overlay_path, Result_dir, alpha)
|
||||
442
Seg_All_In_One_YoloModel/Yolo数据集构建/1_deal_labels.py
Normal file
442
Seg_All_In_One_YoloModel/Yolo数据集构建/1_deal_labels.py
Normal file
@@ -0,0 +1,442 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*
|
||||
import os,time,sys,threading, colorsys, argparse
|
||||
import asyncio, cv2, multiprocessing, random
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
from Tool_deal_labels import edge_detection, detect_connected_regions, Tool_color_connected_array, fill_white_regions, color_connected_regions
|
||||
from Tool_Classes_And_Palette import Annotate_CLASSES, Annotate_PALETTE, bg_PALETTE
|
||||
|
||||
def getFileList(dir,Filelist=[], ext=None, Max_layer=1, layer=0, Donot_Search=['1_边缘检测并膨胀', '2_连通区域检测', '3_分水岭算法填充']):
|
||||
"""
|
||||
获取文件夹及其子文件夹中文件列表
|
||||
输入 dir:文件夹根目录
|
||||
输入 ext: 扩展名
|
||||
返回: 文件路径列表
|
||||
"""
|
||||
newDir = dir
|
||||
if os.path.isfile(dir):
|
||||
if ext is None:
|
||||
Filelist.append(dir)
|
||||
else:
|
||||
if ext in dir[-3:]:
|
||||
Filelist.append(dir)
|
||||
|
||||
elif os.path.isdir(dir):
|
||||
file_name = os.path.basename(dir)
|
||||
# 判断是否在禁搜名单中
|
||||
if file_name in Donot_Search:
|
||||
return Filelist
|
||||
for s in os.listdir(dir):
|
||||
newDir=os.path.join(dir,s)
|
||||
if layer <= Max_layer:
|
||||
getFileList(newDir, Filelist, ext, Max_layer, layer+1)
|
||||
|
||||
return Filelist
|
||||
|
||||
class Deal_image():
|
||||
def __init__(self, Annotate_CLASSES = ('肝脏','胆囊'), Annotate_PALETTE = [[255,91,0],[255,234,0]], src_label_fold = "./Label", save_pro_label_fold = "./LABEL_PNG_new", save_GT_label_fold = "./Label_Generate", GT_channel = 1, pro_append_name="_label", GT_append_name="", ori_img_folder="./ORI_PNG", res_label_folder="./Result_label", save_merge_pic_folder="./Result_merge", back_gnd_color=0, first_class_color=1, pic_type="png", Max_width = 10000, Label_Max_Search_layer=1000, save_process_pics=False, bg_PALETTE = [0,0,0]):
|
||||
# 背景最好放在最后
|
||||
# self.src_CLASSES = ('肝脏','胆囊','分离钳','止血海绵','肝总管','胆总管','吸引器','剪刀','止血纱布','生物夹','无损伤钳','喷洒','胆囊管','胆囊动脉','电凝','标本袋','引流管','纱布','金属钛夹','术中超声','吻合器','乳胶管','推结器','肝带','钳夹','超声刀','脂肪','双极电凝','棉球','血管阻断夹','肿瘤','针','线','韧带','胆囊静脉','背景')
|
||||
# self.src_PALETTE = np.array([[255,91,0],[255,234,0],[85, 111, 181],[181, 227, 14],[72, 0, 255],[0, 155, 33],[255,0,255],[29, 32, 136],[160, 15, 95],[0,160,233],[52,184,178],[90,120,41],[255,0,0],[177,0,0],[167,24,233],[112,113,150],[0,255,0],[255,255,255],[0,255,255],[138,251,213],[136,162,196],[197,83,181],[202,202,200],[113,102,140],[66,115,82],[240,16,116],[155,132,0],[155,62,0],[146,175,236],[255,172,159],[245,161,0],[134,124,118], [0,157,142], [181,85,105], [42,8,66],[0,0,0]])
|
||||
# self.src_CLASSES_NUM = np.shape(self.src_CLASSES)[0]
|
||||
self.bg_PALETTE = bg_PALETTE # 背景颜色 TODO
|
||||
|
||||
self.Annotate_CLASSES = Annotate_CLASSES # 待分类的类
|
||||
self.Annotate_PALETTE = np.array(Annotate_PALETTE) # 每一类的像素直
|
||||
self.Annotate_CLASSES_NUM = np.shape(Annotate_CLASSES)[0] # 类数量
|
||||
|
||||
self.save_process_pics = save_process_pics # 保存中间过程图片
|
||||
|
||||
self.src_label_fold = src_label_fold # 原始标签图片 保存位置
|
||||
self.save_pro_label_fold = save_pro_label_fold # 优化后标签图片 保存位置
|
||||
self.save_GT_label_fold = save_GT_label_fold # GT标签图片 保存位置
|
||||
|
||||
self.ori_img_folder = ori_img_folder # 最原始手术图片 保存位置
|
||||
self.res_label_folder = res_label_folder # 训练出来的label 保存位置
|
||||
self.save_merge_pic_folder = save_merge_pic_folder # 融合图像保存位置
|
||||
|
||||
self.pro_append_name = pro_append_name # 优化后标签图片后缀
|
||||
self.GT_append_name = GT_append_name # GT标签图片后缀
|
||||
self.GT_channel = GT_channel # GT标签图片通道数
|
||||
|
||||
self.Max_width = Max_width # 最大图片宽度(匹配时候用)
|
||||
self.pic_type = pic_type # 图片类型
|
||||
self.back_gnd_color = back_gnd_color # 背景颜色
|
||||
self.first_class_color = first_class_color # 第一类上的颜色
|
||||
self.Label_Max_Search_layer=Label_Max_Search_layer # 文件夹最大搜索深度
|
||||
try:
|
||||
self.labellist_src = getFileList(src_label_fold, [], pic_type, self.Label_Max_Search_layer)
|
||||
print('本次执行检索到ori_label图片 '+str(len(self.labellist_src))+' 张图像')
|
||||
except:
|
||||
self.labellist_src = None
|
||||
print("没有ori_label相关文件")
|
||||
|
||||
try:
|
||||
# print(save_pro_label_fold)
|
||||
self.labellist_pro = getFileList(save_pro_label_fold, [], pic_type, self.Label_Max_Search_layer)
|
||||
print('本次执行检索到pro_label图片 '+str(len(self.labellist_pro))+' 张图像')
|
||||
except:
|
||||
self.labellist_pro = None
|
||||
print("没有pro_label相关文件")
|
||||
|
||||
try:
|
||||
self.imglist_src = getFileList(ori_img_folder, [], pic_type, self.Label_Max_Search_layer)
|
||||
self.reslist_src = getFileList(res_label_folder, [], pic_type, self.Label_Max_Search_layer)
|
||||
print('本次执行检索到ori原始图片 '+str(len(self.imglist_src))+' 张图像')
|
||||
print('本次执行检索到训练train_result图片 '+str(len(self.reslist_src))+' 张图像')
|
||||
except:
|
||||
self.imglist_src = None
|
||||
self.reslist_src = None
|
||||
print("没有train_result和原始图片相关文件")
|
||||
|
||||
# 获取单张图片各个通路信息
|
||||
def get_single_pic_rgb(self, imgpath):
|
||||
print(imgpath)
|
||||
image = Image.open(imgpath).convert('RGB') # 转为RGB图片
|
||||
# 将 RGB 色值分离
|
||||
image.load()
|
||||
r, g, b = image.split()
|
||||
r = np.array(r)
|
||||
g = np.array(g)
|
||||
b = np.array(b)
|
||||
return image, r, g, b
|
||||
|
||||
# 将单个pro图片变成GT图片
|
||||
def Conver_pro_label_pic_2_GT_pic(self, imgpath, imgname):
|
||||
time_start=time.time() # 记录开始时间
|
||||
# 获取单张图片各个通路信息
|
||||
image, r,g,b = self.get_single_pic_rgb(imgpath)
|
||||
|
||||
result_gt = np.ones(np.shape(image))*self.back_gnd_color # 初始化填充内容为back_gnd_color
|
||||
gt_number = self.first_class_color # 第一类上色颜色确定
|
||||
|
||||
# PALETTE中排除掉 '背景' [0,0,0]
|
||||
PALETTE_No_Bg = self.Annotate_PALETTE[~np.all(self.Annotate_PALETTE == self.bg_PALETTE, axis=1)]
|
||||
|
||||
# 遍历所有待识别颜色
|
||||
for [Annotate_PALETTE_r, Annotate_PALETTE_g, Annotate_PALETTE_b] in PALETTE_No_Bg:
|
||||
# 查找三原色匹配位置
|
||||
locate_r = np.where( r == Annotate_PALETTE_r )
|
||||
locate_g = np.where( g == Annotate_PALETTE_g )
|
||||
locate_b = np.where( b == Annotate_PALETTE_b )
|
||||
|
||||
# 查找都匹配位置(交集)
|
||||
# 将矩阵换一种表示形式
|
||||
locate_r = np.array(locate_r[0]) * self.Max_width + np.array(locate_r[1])
|
||||
locate_g = np.array(locate_g[0]) * self.Max_width + np.array(locate_g[1])
|
||||
locate_b = np.array(locate_b[0]) * self.Max_width + np.array(locate_b[1])
|
||||
|
||||
# 用自带函数寻找匹配项
|
||||
matched = np.intersect1d(np.intersect1d(locate_r, locate_g), locate_b)
|
||||
matched = np.concatenate(([matched // self.Max_width], [np.mod(matched, self.Max_width)]), 0)
|
||||
result_gt[matched[0],matched[1], :] = gt_number
|
||||
gt_number = gt_number + 1
|
||||
|
||||
# 输出GT图片
|
||||
if(int(self.GT_channel) == 1):
|
||||
result_gt = result_gt[:,:,0]
|
||||
elif(int(self.GT_channel) == 3):
|
||||
result_gt = cv2.cvtColor(np.float32(result_gt), cv2.COLOR_RGB2BGR) # rgb颜色互换
|
||||
else:
|
||||
print("GT_channel 必须为1或3")
|
||||
quit
|
||||
try: # 新建文件夹
|
||||
os.mkdir(self.save_GT_label_fold)
|
||||
except:
|
||||
print("已有"+self.save_GT_label_fold)
|
||||
if imgname.lower().endswith(('.jpg', '.png')):
|
||||
save_dir = os.path.join(self.save_GT_label_fold, os.path.basename(imgname).rpartition('.')[0]+self.GT_append_name+'.'+self.pic_type)
|
||||
else:
|
||||
save_dir = os.path.join(self.save_GT_label_fold, os.path.basename(imgname)+self.GT_append_name+'.'+self.pic_type)
|
||||
cv2.imwrite(save_dir, result_gt)
|
||||
print("GT图片已保存", save_dir)
|
||||
time_end=time.time() # 输出结束时间
|
||||
print('time cost',time_end-time_start,'s')
|
||||
|
||||
# 将处理好的图片转化为GT图片
|
||||
def Conver_pro_label_pic_2_GT_pic_all(self):
|
||||
print("\033[33m**** 进行转换将Pro_label_pic转换为GT_label_pic ****\033[0m")
|
||||
print("\033[33mPro_label_pic存储位置为:\033[0m", self.save_pro_label_fold)
|
||||
print("\033[33mGT_label_pic生成位置为:\033[0m", self.save_GT_label_fold)
|
||||
try:
|
||||
# print(save_pro_label_fold)
|
||||
self.labellist_pro = getFileList(save_pro_label_fold, [], pic_type, self.Label_Max_Search_layer)
|
||||
print('本次执行检索到pro_label图片 '+str(len(self.labellist_pro))+' 张图像')
|
||||
except:
|
||||
self.labellist_pro = None
|
||||
print("没有pro_label相关文件")
|
||||
try:
|
||||
os.mkdir(self.save_GT_label_fold) # 新建存储文件夹
|
||||
except:
|
||||
print("已有"+self.save_GT_label_fold)
|
||||
|
||||
# 指定最大进程数为 3
|
||||
max_processes = 20
|
||||
# 创建Pool对象
|
||||
pool = multiprocessing.Pool(processes=max_processes)
|
||||
# 创建并启动进程
|
||||
args_list1 = []
|
||||
args_list2 = []
|
||||
|
||||
# 遍历整个文件夹
|
||||
for imgpath in self.labellist_pro:
|
||||
imgname = os.path.basename(imgpath).rpartition('.')[0].replace(self.pro_append_name,"")
|
||||
args_list1.append(imgpath)
|
||||
args_list2.append(imgname)
|
||||
args_list = zip(args_list1, args_list2)
|
||||
# 使用进程池并行执行任务
|
||||
pool.starmap(self.Conver_pro_label_pic_2_GT_pic, args_list)
|
||||
# 关闭进程池
|
||||
pool.close()
|
||||
pool.join()
|
||||
|
||||
def Conver_ori_label_pic_2_pro_pic(self, imgpath, imgname):
|
||||
time_start=time.time() # 记录开始时间
|
||||
# 获取单张图片各个通路信息
|
||||
image = cv2.imread(imgpath)
|
||||
|
||||
# 1. 边缘检测并膨胀
|
||||
dilated_image = edge_detection(image)
|
||||
# 如果需要存储中间态图片
|
||||
if(self.save_process_pics == True):
|
||||
if imgname.lower().endswith(('.jpg', '.png')):
|
||||
save_dir = os.path.join(self.save_pro_label_fold, '1_边缘检测并膨胀', os.path.basename(imgname).rpartition('.')[0]+self.pro_append_name+'_Edge'+'.'+self.pic_type)
|
||||
else:
|
||||
save_dir = os.path.join(self.save_pro_label_fold, '1_边缘检测并膨胀', os.path.basename(imgname)+self.pro_append_name+'_Edge'+'.'+self.pic_type)
|
||||
cv2.imwrite(save_dir, dilated_image)
|
||||
print("中间态-边缘检测并膨胀 图片已保存", save_dir)
|
||||
time_end=time.time() # 输出结束时间
|
||||
print('time cost',time_end-time_start,'s')
|
||||
|
||||
# 2. 检测连通区域
|
||||
filtered_labeled_array, _ = detect_connected_regions(dilated_image)
|
||||
colored_image_filtered = Tool_color_connected_array(filtered_labeled_array)
|
||||
# 如果需要存储中间态图片
|
||||
if(self.save_process_pics == True):
|
||||
if imgname.lower().endswith(('.jpg', '.png')):
|
||||
save_dir = os.path.join(self.save_pro_label_fold, '2_连通区域检测', os.path.basename(imgname).rpartition('.')[0]+self.pro_append_name+'_Region'+'.'+self.pic_type)
|
||||
else:
|
||||
save_dir = os.path.join(self.save_pro_label_fold, '2_连通区域检测', os.path.basename(imgname)+self.pro_append_name+'_Region'+'.'+self.pic_type)
|
||||
cv2.imwrite(save_dir, colored_image_filtered)
|
||||
print("中间态-连通区域检测 图片已保存", save_dir)
|
||||
time_end=time.time() # 输出结束时间
|
||||
print('time cost',time_end-time_start,'s')
|
||||
|
||||
# 3. 分水岭填充白色区域
|
||||
filled_labeled_array = fill_white_regions(filtered_labeled_array)
|
||||
colored_image_filled = Tool_color_connected_array(filled_labeled_array)
|
||||
# 如果需要存储中间态图片
|
||||
if(self.save_process_pics == True):
|
||||
if imgname.lower().endswith(('.jpg', '.png')):
|
||||
save_dir = os.path.join(self.save_pro_label_fold, '3_分水岭算法填充', os.path.basename(imgname).rpartition('.')[0]+self.pro_append_name+'_FillEdge'+'.'+self.pic_type)
|
||||
else:
|
||||
save_dir = os.path.join(self.save_pro_label_fold, '3_分水岭算法填充', os.path.basename(imgname)+self.pro_append_name+'_FillEdge'+'.'+self.pic_type)
|
||||
cv2.imwrite(save_dir, colored_image_filled)
|
||||
print("中间态-分水岭算法填充 图片已保存", save_dir)
|
||||
time_end=time.time() # 输出结束时间
|
||||
print('time cost',time_end-time_start,'s')
|
||||
|
||||
# 4. 对连通区域最终上色
|
||||
ori_labeled_image = image
|
||||
result_pro = color_connected_regions(filled_labeled_array, filtered_labeled_array, ori_labeled_image, self.Annotate_PALETTE)
|
||||
if imgname.lower().endswith(('.jpg', '.png')):
|
||||
save_dir = os.path.join(self.save_pro_label_fold, os.path.basename(imgname).rpartition('.')[0]+self.pro_append_name+'.'+self.pic_type)
|
||||
else:
|
||||
save_dir = os.path.join(self.save_pro_label_fold, os.path.basename(imgname)+self.pro_append_name+'.'+self.pic_type)
|
||||
print("Pro图片已保存", save_dir)
|
||||
cv2.imwrite(save_dir, result_pro)
|
||||
time_end=time.time() # 输出结束时间
|
||||
print('time cost',time_end-time_start,'s')
|
||||
|
||||
# 将原始src图片转化为处理好的pro图片
|
||||
def Conver_ori_label_pic_2_pro_pic_all(self):
|
||||
print("\033[33m**** 进行转换将Ori_label_pic转换为Pro_label_pic ****\033[0m")
|
||||
print("\033[33mOri_label_pic存储位置为:\033[0m", self.src_label_fold)
|
||||
print("\033[33mPro_label_pic生成位置为:\033[0m", self.save_pro_label_fold)
|
||||
# 输出颜色预处理图片
|
||||
try:
|
||||
os.mkdir(self.save_pro_label_fold) # 新建存储文件夹
|
||||
except:
|
||||
print("已有"+self.save_pro_label_fold)
|
||||
if(self.save_process_pics == True):
|
||||
try:
|
||||
os.mkdir(os.path.join(self.save_pro_label_fold, '1_边缘检测并膨胀')) # 新建存储1_边缘检测并膨胀文件夹
|
||||
except:
|
||||
print("已有"+os.path.join(self.save_pro_label_fold, '1_边缘检测并膨胀'))
|
||||
try:
|
||||
os.mkdir(os.path.join(self.save_pro_label_fold, '2_连通区域检测')) # 新建存储2_连通区域检测文件夹
|
||||
except:
|
||||
print("已有"+os.path.join(self.save_pro_label_fold, '2_连通区域检测'))
|
||||
try:
|
||||
os.mkdir(os.path.join(self.save_pro_label_fold, '3_分水岭算法填充')) # 新建存储1_边缘检测并膨胀文件夹
|
||||
except:
|
||||
print("已有"+os.path.join(self.save_pro_label_fold, '3_分水岭算法填充'))
|
||||
|
||||
# 指定最大进程数为 20,多参数函数并行
|
||||
max_processes = 20
|
||||
# 创建Pool对象
|
||||
pool = multiprocessing.Pool(processes=max_processes)
|
||||
# 创建并启动进程
|
||||
args_list1 = []
|
||||
args_list2 = []
|
||||
|
||||
# 遍历整个文件夹
|
||||
for imgpath in self.labellist_src:
|
||||
if imgpath.lower().endswith(('.jpg', '.png')):
|
||||
imgname= os.path.basename(imgpath).rpartition('.')[0].replace(self.pro_append_name,"")
|
||||
else:
|
||||
imgname= os.path.basename(imgpath).replace(self.pro_append_name,"")
|
||||
try:
|
||||
print("Processing: ", imgname, "...")
|
||||
# self.Conver_ori_label_pic_2_pro_pic(imgpath, imgname)s
|
||||
# args_list.append({'imgpath': imgpath, 'imgname': imgname})
|
||||
args_list1.append(imgpath)
|
||||
args_list2.append(imgname)
|
||||
except:
|
||||
os.system("echo "+imgname+" >> error_1.txt")
|
||||
args_list = zip(args_list1, args_list2)
|
||||
# 使用进程池并行执行任务
|
||||
pool.starmap(self.Conver_ori_label_pic_2_pro_pic, args_list) # 使用starmap进行多参数并行
|
||||
# 关闭进程池
|
||||
pool.close()
|
||||
pool.join()
|
||||
|
||||
# 图片堆叠
|
||||
def Merge_ori_pic_and_label_pic(self, res_img_path, res_imgname):
|
||||
time_start=time.time() # 记录开始时间
|
||||
# 获取单张图片各个通路信息
|
||||
ori_img_path = os.path.join(self.ori_img_folder, res_imgname+'.'+self.pic_type)
|
||||
if not os.path.exists(ori_img_path):
|
||||
print("****照片不存在:****", ori_img_path)
|
||||
return -1
|
||||
ori_image, ori_r, ori_g, ori_b = self.get_single_pic_rgb(ori_img_path)
|
||||
res_image, res_r, res_g, res_b = self.get_single_pic_rgb(res_img_path)
|
||||
|
||||
merge_img = np.array(ori_image) # merge图片初始化,默认图片背景为0.0.0
|
||||
|
||||
# 遍历所有待识别颜色
|
||||
for [Annotate_PALETTE_r, Annotate_PALETTE_g, Annotate_PALETTE_b] in self.Annotate_PALETTE:
|
||||
# 查找三原色匹配位置
|
||||
locate_r = np.where( res_r == Annotate_PALETTE_r )
|
||||
locate_g = np.where( res_g == Annotate_PALETTE_g )
|
||||
locate_b = np.where( res_b == Annotate_PALETTE_b )
|
||||
|
||||
# 查找都匹配位置(交集)
|
||||
# 将矩阵换一种表示形式
|
||||
locate_r = np.array(locate_r[0]) * self.Max_width + np.array(locate_r[1])
|
||||
locate_g = np.array(locate_g[0]) * self.Max_width + np.array(locate_g[1])
|
||||
locate_b = np.array(locate_b[0]) * self.Max_width + np.array(locate_b[1])
|
||||
|
||||
# 用自带函数寻找匹配项
|
||||
matched = np.intersect1d(np.intersect1d(locate_r, locate_g), locate_b)
|
||||
matched = np.concatenate(([matched // self.Max_width], [np.mod(matched, self.Max_width)]), 0)
|
||||
merge_img[matched[0],matched[1], 0] = Annotate_PALETTE_r
|
||||
merge_img[matched[0],matched[1], 1] = Annotate_PALETTE_g
|
||||
merge_img[matched[0],matched[1], 2] = Annotate_PALETTE_b
|
||||
|
||||
# 转成cv2形式
|
||||
merge_img = cv2.cvtColor(np.float32(merge_img), cv2.COLOR_RGB2BGR)
|
||||
|
||||
try: # 新建文件夹
|
||||
os.mkdir(self.save_merge_pic_folder)
|
||||
except:
|
||||
print("已有"+self.save_merge_pic_folder)
|
||||
if res_imgname.lower().endswith(('.jpg', '.png')):
|
||||
save_dir = os.path.join(self.save_merge_pic_folder, os.path.basename(res_imgname).rpartition('.')[0]+'.'+self.pic_type)
|
||||
else:
|
||||
save_dir = os.path.join(self.save_merge_pic_folder, os.path.basename(res_imgname)+'.'+self.pic_type)
|
||||
|
||||
|
||||
cv2.imwrite(save_dir, merge_img)
|
||||
print("Merge图片已保存", save_dir)
|
||||
time_end=time.time() # 输出结束时间
|
||||
print('time cost',time_end-time_start,'s')
|
||||
|
||||
# 将label图片与原图片重合
|
||||
def Merge_ori_pic_and_label_pic_all(self):
|
||||
# 遍历整个文件夹
|
||||
for res_img_path in self.reslist_src:
|
||||
if res_img_path.lower().endswith(('.jpg', '.png')):
|
||||
res_imgname = os.path.basename(res_img_path).rpartition('.')[0].replace(self.pro_append_name,"")
|
||||
else:
|
||||
res_imgname = os.path.basename(res_img_path).replace(self.pro_append_name,"")
|
||||
print("Processing: ", res_imgname, "...")
|
||||
self.Merge_ori_pic_and_label_pic(res_img_path, res_imgname)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
# 创建参数解析器
|
||||
parser = argparse.ArgumentParser(description='Process some files.')
|
||||
# 添加参数选项
|
||||
parser.add_argument('-src_fold', dest='src_label_fold', default='./', help='source label folder')
|
||||
parser.add_argument('-save_pro_fold', dest='save_pro_label_fold', default='./ORI_pro_label_fold', help='processed label folder')
|
||||
parser.add_argument('-save_GT_fold', dest='save_GT_label_fold', default='./ORI_GT_label_fold', help='ground truth folder')
|
||||
parser.add_argument('-fold_search_depth', dest='Label_Max_Search_layer', default='1000', type=int, help='Folder Search Depth')
|
||||
parser.add_argument('-pro_suffix_name', dest='pro_append_name', default='_label', help='Pro file suffix')
|
||||
parser.add_argument('-GT_suffix_name', dest='GT_append_name', default='', help='GT file suffix')
|
||||
parser.add_argument('-GT_channel', dest='GT_channel', default='1', type=int, help='GT file channel(1 or 3)')
|
||||
parser.add_argument('-back_gnd_color', dest='back_gnd_color', default='0', type=int, help='Color of "Back ground"(0 or 255)')
|
||||
parser.add_argument('-first_class_color', dest='first_class_color', default='1', type=int, help='Color of "First Class"')
|
||||
parser.add_argument('-pic_type', dest='pic_type', default='png', help='type of picture(Do not add ".")')
|
||||
parser.add_argument('-Max_width', dest='Max_width', default='10000', type=int, help='Max width of picture')
|
||||
parser.add_argument('-Rebuild_from', dest='Rebuild_from', default='label', help='Source to Rebuild Labels(label/pro)')
|
||||
parser.add_argument('-Rebuild_to', dest='Rebuild_to', default='GT', help='Destination of Rebuild Labels(pro/GT)')
|
||||
parser.add_argument('-save_process_pics', dest='save_process_pics', default='False', help='Save the processed pics(e.g.Gray_pics,Color_pics) in generating pro_pics')
|
||||
|
||||
# 解析命令行参数
|
||||
args = parser.parse_args()
|
||||
|
||||
src_label_fold = args.src_label_fold
|
||||
save_pro_label_fold = args.save_pro_label_fold
|
||||
save_GT_label_fold = args.save_GT_label_fold
|
||||
Label_Max_Search_layer = args.Label_Max_Search_layer
|
||||
pro_append_name = args.pro_append_name
|
||||
GT_append_name = args.GT_append_name
|
||||
GT_channel = args.GT_channel
|
||||
back_gnd_color = args.back_gnd_color
|
||||
first_class_color = args.first_class_color
|
||||
pic_type = args.pic_type
|
||||
Max_width = args.Max_width
|
||||
Rebuild_from = args.Rebuild_from
|
||||
Rebuild_to = args.Rebuild_to
|
||||
save_process_pics = args.save_process_pics
|
||||
|
||||
|
||||
try: # 遍历文件深度,最小为1
|
||||
Label_Max_Search_layer=int(Label_Max_Search_layer)
|
||||
except:
|
||||
Label_Max_Search_layer=1000
|
||||
try: # GT标签图片通道数
|
||||
GT_channel=int(GT_channel)
|
||||
except:
|
||||
GT_channel=1
|
||||
try: # 背景颜色(背景选择0或255)
|
||||
back_gnd_color=int(back_gnd_color)
|
||||
except:
|
||||
back_gnd_color=0
|
||||
try: # 第一类上的颜色(如果背景为0,选择1;)
|
||||
first_class_color=int(first_class_color)
|
||||
except:
|
||||
first_class_color=1
|
||||
try: # 最大图片宽度(匹配时候用)
|
||||
Max_width=int(Max_width)
|
||||
except:
|
||||
Max_width=10000
|
||||
if(save_process_pics.lower() == 'false'):
|
||||
save_process_pics = False
|
||||
elif(save_process_pics.lower() == 'true'):
|
||||
save_process_pics = True
|
||||
else:
|
||||
save_process_pics = False
|
||||
|
||||
D = Deal_image(Annotate_CLASSES=Annotate_CLASSES, Annotate_PALETTE=Annotate_PALETTE, src_label_fold=src_label_fold, save_pro_label_fold=save_pro_label_fold, save_GT_label_fold=save_GT_label_fold, GT_channel=GT_channel, pro_append_name=pro_append_name, GT_append_name=GT_append_name, back_gnd_color=back_gnd_color, first_class_color=first_class_color, pic_type=pic_type, Max_width=Max_width, Label_Max_Search_layer=Label_Max_Search_layer, save_process_pics=save_process_pics, bg_PALETTE = bg_PALETTE)
|
||||
# print(D.src_CLASSES_NUM)
|
||||
if Rebuild_from == 'label':
|
||||
# 1.先将所有原始图片转为pro图片
|
||||
D.Conver_ori_label_pic_2_pro_pic_all()
|
||||
pass
|
||||
if Rebuild_to == 'GT':
|
||||
# 2.再将pro图片转为GT图片
|
||||
D.Conver_pro_label_pic_2_GT_pic_all()
|
||||
pass
|
||||
@@ -0,0 +1,169 @@
|
||||
import os
|
||||
import cv2
|
||||
import numpy as np
|
||||
from collections import Counter, defaultdict
|
||||
from PIL import Image
|
||||
from Tool_Classes_And_Palette import Annotate_CLASSES, Annotate_PALETTE, bg_PALETTE
|
||||
|
||||
# ※ 需要修改输入输出路径 ※ #
|
||||
input_dir = './ORI_GT_label_fold'
|
||||
output_dir_1 = 'Data/labels/train'
|
||||
output_dir_2 = 'Data/labels/val'
|
||||
|
||||
# ----------------------------------------------------
|
||||
# ※※※ 1. 修改: 移除背景项并统一格式 ※※※
|
||||
# ----------------------------------------------------
|
||||
# (1) 从 Annotate_CLASSES (元组) 中移除 '背景'
|
||||
if Annotate_CLASSES and Annotate_CLASSES[-1] == '背景':
|
||||
Annotate_CLASSES = Annotate_CLASSES[:-1]
|
||||
# (2) 从 Annotate_PALETTE (列表) 中移除背景颜色 [0,0,0]
|
||||
if Annotate_PALETTE and (Annotate_PALETTE[-1] == [0, 0, 0] or Annotate_PALETTE[-1] == (0, 0, 0)):
|
||||
Annotate_PALETTE.pop()
|
||||
# (3) 确保调色板中的所有颜色都是元组 (Tuple),以便后续查找
|
||||
Annotate_PALETTE = [tuple(color) for color in Annotate_PALETTE]
|
||||
# (4) 确保 bg_PALETTE 是元组,以匹配 Counter 中的键类型
|
||||
bg_PALETTE = tuple(bg_PALETTE)
|
||||
# (5) 检查类别和调色板长度是否一致
|
||||
if len(Annotate_CLASSES) != len(Annotate_PALETTE):
|
||||
print(f"[警告] 移除背景后,类别 ({len(Annotate_CLASSES)}) 和调色板 ({len(Annotate_PALETTE)}) 长度不一致!")
|
||||
else:
|
||||
print(f"--- 成功移除背景项,剩余 {len(Annotate_CLASSES)} 个有效类别。 ---")
|
||||
|
||||
os.makedirs(output_dir_1, exist_ok=True)
|
||||
os.makedirs(output_dir_2, exist_ok=True)
|
||||
|
||||
# 全局统计颜色频率与 class 像素频率
|
||||
global_color_counter = Counter()
|
||||
global_class_counter = Counter() # 将用于统计(ori_label 模式)
|
||||
color_class_counter = defaultdict(int) # (R,G,B) → count
|
||||
color_to_old_class = {} # (R,G,B) → class_id
|
||||
# *** ori_label 模式: 移除了 remap_class_dict ***
|
||||
|
||||
# 自动提取颜色映射(跳过背景)
|
||||
def extract_color_mapping(img_path):
|
||||
img = Image.open(img_path).convert('RGB')
|
||||
pixels = list(img.getdata())
|
||||
counter = Counter(pixels)
|
||||
color_map = {}
|
||||
for color, count in counter.items():
|
||||
global_color_counter[color] += count
|
||||
if color != bg_PALETTE and color[0] == color[1] == color[2]: # 使用元组 bg_PALETTE
|
||||
class_id = color[0] - 1
|
||||
if class_id >= 0:
|
||||
color_map[color] = class_id
|
||||
global_class_counter[class_id] += count # 使用 global_class_counter 统计
|
||||
color_class_counter[color] += count
|
||||
color_to_old_class[color] = class_id
|
||||
return color_map
|
||||
|
||||
# 处理单张图片
|
||||
def process_image(img_path, save_path_list):
|
||||
color_to_class = extract_color_mapping(img_path)
|
||||
if not color_to_class:
|
||||
print(f"[跳过] {os.path.basename(img_path)} 无有效目标")
|
||||
return
|
||||
|
||||
img = cv2.imread(img_path)
|
||||
h, w = img.shape[:2]
|
||||
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
|
||||
lines = []
|
||||
for rgb, old_class_id in color_to_class.items():
|
||||
mask = np.all(img_rgb == rgb, axis=-1).astype(np.uint8) * 255
|
||||
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||
for contour in contours:
|
||||
if contour.shape[0] < 3:
|
||||
continue
|
||||
norm_pts = contour.squeeze(1).astype(np.float32)
|
||||
norm_pts[:, 0] /= w
|
||||
norm_pts[:, 1] /= h
|
||||
flat = norm_pts.flatten()
|
||||
|
||||
# *** ori_label 模式: 直接使用 old_class_id ***
|
||||
line = f"{old_class_id} " + " ".join([f"{x:.6f}" for x in flat])
|
||||
lines.append(line)
|
||||
|
||||
if lines:
|
||||
for save_path in save_path_list:
|
||||
with open(save_path, 'w') as f:
|
||||
for line in lines:
|
||||
f.write(line + "\n")
|
||||
print(f"[✔] 转换成功: {os.path.basename(save_path)},共 {len(lines)} 个实例")
|
||||
else:
|
||||
print(f"[⚠] {os.path.basename(img_path)} 没有轮廓")
|
||||
|
||||
# 第一次遍历图像,仅提取颜色信息用于构建 class 映射
|
||||
for fname in os.listdir(input_dir):
|
||||
if fname.lower().endswith('.png'):
|
||||
img_path = os.path.join(input_dir, fname)
|
||||
extract_color_mapping(img_path)
|
||||
|
||||
# *** ori_label 模式: 移除了重映射 (remap) 逻辑块 ***
|
||||
# (保留 class_pixel_count 用于统计)
|
||||
class_pixel_count = {color_to_old_class[c]: count for c, count in color_class_counter.items()}
|
||||
|
||||
# 第二次遍历图像,正式处理并生成标签
|
||||
for fname in os.listdir(input_dir):
|
||||
if fname.lower().endswith('.png'):
|
||||
img_path = os.path.join(input_dir, fname)
|
||||
base_name = os.path.splitext(fname)[0]
|
||||
txt_path_1 = os.path.join(output_dir_1, base_name + '.txt')
|
||||
txt_path_2 = os.path.join(output_dir_2, base_name + '.txt')
|
||||
process_image(img_path, [txt_path_1, txt_path_2])
|
||||
|
||||
# ----------------------------------------------------
|
||||
# ※※※ 2. 修改: 打印详细的颜色统计 ※※※
|
||||
# ----------------------------------------------------
|
||||
print("\n📊 所有图像颜色统计:")
|
||||
|
||||
def get_label_info(class_id):
|
||||
"""辅助函数:根据 class_id 获取标签名和颜色"""
|
||||
if 0 <= class_id < len(Annotate_CLASSES) and 0 <= class_id < len(Annotate_PALETTE):
|
||||
return Annotate_CLASSES[class_id], Annotate_PALETTE[class_id]
|
||||
else:
|
||||
return "未知标签", (255, 255, 255) # 返回一个默认值
|
||||
|
||||
for color, count in global_color_counter.most_common():
|
||||
if color == bg_PALETTE: # 使用 bg_PALETTE 变量
|
||||
print(f"背景颜色 {color} 出现次数: {count}")
|
||||
elif color[0] == color[1] == color[2]:
|
||||
class_id = color[0] - 1
|
||||
label_name, label_color = get_label_info(class_id)
|
||||
print(f"颜色 {color} → class {class_id} (标签: '{label_name}', 颜色: {label_color}),出现次数: {count}")
|
||||
else:
|
||||
# 此情况理论上不应出现,因为 extract_color_mapping 已过滤
|
||||
print(f"[⚠] 非灰阶颜色 {color},出现次数: {count} (此颜色不应被处理)")
|
||||
|
||||
# ----------------------------------------------------
|
||||
# ※※※ 3. 修改: 打印详细的类别统计 (ori_label 模式) ※※※
|
||||
# ----------------------------------------------------
|
||||
print("\n✅ 有效类别统计(按 原 class_id 排序 → 总像素数):")
|
||||
# 按照 old_id (原类别索引) 排序输出
|
||||
sorted_classes = sorted(class_pixel_count.items(), key=lambda item: item[0])
|
||||
|
||||
for old_id, pixel_count in sorted_classes:
|
||||
label_name, label_color = get_label_info(old_id)
|
||||
print(f"class {old_id} (标签: '{label_name}', 颜色: {label_color}): {pixel_count} pixels")
|
||||
|
||||
# ----------------------------------------------------
|
||||
# ※※※ 4. 修改: 额外输出 原 class_id 到颜色的映射 (ori_label 模式) ※※※
|
||||
# ----------------------------------------------------
|
||||
print(f"\n🎨 找到的 原 class_id 与标签颜色 (Annotate_PALETTE) 映射表:")
|
||||
|
||||
# 1. 按照 old_id (0, 1, 2...) 排序并按指定格式打印
|
||||
# 我们使用 class_pixel_count.keys() 来获取所有实际找到的 old_id
|
||||
sorted_old_ids = sorted(class_pixel_count.keys())
|
||||
|
||||
for old_id in sorted_old_ids:
|
||||
if 0 <= old_id < len(Annotate_PALETTE):
|
||||
color_tuple = Annotate_PALETTE[old_id]
|
||||
color_list = list(color_tuple)
|
||||
print(f" {old_id}: {color_list}")
|
||||
else:
|
||||
# 预防性代码
|
||||
print(f" {old_id}: [颜色未在调色板中定义]")
|
||||
# ----------------------------------------------------
|
||||
# ※※※ 修改结束 ※※※
|
||||
# ----------------------------------------------------
|
||||
|
||||
print(f"\n✅ 全部图像处理完毕。标签输出目录:{output_dir_1}、{output_dir_2}")
|
||||
@@ -0,0 +1,204 @@
|
||||
import os
|
||||
import cv2
|
||||
import numpy as np
|
||||
from collections import Counter, defaultdict
|
||||
from PIL import Image
|
||||
from Tool_Classes_And_Palette import Annotate_CLASSES, Annotate_PALETTE, bg_PALETTE
|
||||
|
||||
# ※ 需要修改输入输出路径 ※ #
|
||||
input_dir = 'ORI_GT_label_fold'
|
||||
output_dir_1 = 'Data/labels/train'
|
||||
output_dir_2 = 'Data/labels/val'
|
||||
output_GT_dir_1 = 'Data/labels_GT/train'
|
||||
output_GT_dir_2 = 'Data/labels_GT/val'
|
||||
|
||||
# 1. --- 您提供的类别和调色板 ---
|
||||
# (1) 从 Annotate_CLASSES (元组) 中移除 '背景'
|
||||
if Annotate_CLASSES and Annotate_CLASSES[-1] == '背景':
|
||||
Annotate_CLASSES = Annotate_CLASSES[:-1]
|
||||
# (2) 从 Annotate_PALETTE (列表) 中移除背景颜色 [0,0,0]
|
||||
if Annotate_PALETTE and (Annotate_PALETTE[-1] == [0, 0, 0] or Annotate_PALETTE[-1] == (0, 0, 0)):
|
||||
Annotate_PALETTE.pop()
|
||||
# (3) 确保调色板中的所有颜色都是元组 (Tuple),以便后续查找
|
||||
Annotate_PALETTE = [tuple(color) for color in Annotate_PALETTE]
|
||||
# (4) 确保 bg_PALETTE 是元组,以匹配 Counter 中的键类型
|
||||
bg_PALETTE = tuple(bg_PALETTE)
|
||||
# (5) 检查类别和调色板长度是否一致
|
||||
if len(Annotate_CLASSES) != len(Annotate_PALETTE):
|
||||
print(f"[警告] 移除背景后,类别 ({len(Annotate_CLASSES)}) 和调色板 ({len(Annotate_PALETTE)}) 长度不一致!")
|
||||
else:
|
||||
print(f"--- 成功移除背景项,剩余 {len(Annotate_CLASSES)} 个有效类别。 ---")
|
||||
|
||||
os.makedirs(output_dir_1, exist_ok=True)
|
||||
os.makedirs(output_dir_2, exist_ok=True)
|
||||
os.makedirs(output_GT_dir_1, exist_ok=True)
|
||||
os.makedirs(output_GT_dir_2, exist_ok=True)
|
||||
|
||||
# 全局统计颜色频率与 class 像素频率
|
||||
global_color_counter = Counter()
|
||||
global_class_counter = Counter()
|
||||
color_class_counter = defaultdict(int) # (R,G,B) → count
|
||||
color_to_old_class = {} # (R,G,B) → class_id
|
||||
remap_class_dict = {} # old_class_id → new_class_id
|
||||
|
||||
# 自动提取颜色映射(跳过背景)
|
||||
def extract_color_mapping(img_path):
|
||||
img = Image.open(img_path).convert('RGB')
|
||||
pixels = list(img.getdata())
|
||||
counter = Counter(pixels)
|
||||
color_map = {}
|
||||
for color, count in counter.items():
|
||||
global_color_counter[color] += count
|
||||
if color != (0, 0, 0) and color[0] == color[1] == color[2]:
|
||||
class_id = color[0] - 1
|
||||
if class_id >= 0:
|
||||
color_map[color] = class_id
|
||||
global_class_counter[class_id] += count
|
||||
color_class_counter[color] += count
|
||||
color_to_old_class[color] = class_id
|
||||
return color_map
|
||||
|
||||
# --- 修改:函数签名增加了 gt_save_paths ---
|
||||
def process_image(img_path, txt_save_paths, gt_save_paths):
|
||||
color_to_class = extract_color_mapping(img_path)
|
||||
if not color_to_class:
|
||||
print(f"[跳过] {os.path.basename(img_path)} 无有效目标")
|
||||
return
|
||||
|
||||
img = cv2.imread(img_path)
|
||||
h, w = img.shape[:2]
|
||||
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
|
||||
# --- 新增:创建新的 GT 掩码 (灰度图),0 为背景 ---
|
||||
new_gt_mask = np.zeros((h, w), dtype=np.uint8)
|
||||
lines = []
|
||||
|
||||
for rgb, old_class_id in color_to_class.items():
|
||||
# --- 修改:获取 new_class_id 并填充 new_gt_mask ---
|
||||
new_class_id = remap_class_dict.get(old_class_id, old_class_id)
|
||||
|
||||
# 创建布尔掩码
|
||||
mask_bool = np.all(img_rgb == rgb, axis=-1)
|
||||
|
||||
# --- 新增:在 new_gt_mask 上填充新的 class_id
|
||||
# (使用 new_class_id + 1,因为 0 是背景)
|
||||
new_gt_mask[mask_bool] = new_class_id + 1
|
||||
|
||||
# --- 修改:基于布尔掩码创建 255 掩码用于 findContours ---
|
||||
mask_255 = mask_bool.astype(np.uint8) * 255
|
||||
|
||||
contours, _ = cv2.findContours(mask_255, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||
|
||||
for contour in contours:
|
||||
if contour.shape[0] < 3:
|
||||
continue
|
||||
norm_pts = contour.squeeze(1).astype(np.float32)
|
||||
norm_pts[:, 0] /= w
|
||||
norm_pts[:, 1] /= h
|
||||
flat = norm_pts.flatten()
|
||||
|
||||
# --- 修改:此处 new_class_id 已在循环外层定义 ---
|
||||
line = f"{new_class_id} " + " ".join([f"{x:.6f}" for x in flat])
|
||||
lines.append(line)
|
||||
|
||||
# --- 新增:保存处理后的 GT 掩码图像 ---
|
||||
# 只要存在有效类别 (color_to_class 非空),就保存 GT 掩码
|
||||
for save_path in gt_save_paths:
|
||||
try:
|
||||
# 使用 PIL 保存灰度 PNG 图像
|
||||
Image.fromarray(new_gt_mask).save(save_path)
|
||||
print(f"[✔] GT 掩码保存成功: {os.path.basename(save_path)}")
|
||||
except Exception as e:
|
||||
print(f"[✘] GT 掩码保存失败: {os.path.basename(save_path)} - {e}")
|
||||
|
||||
# 保存 .txt 标签文件 (仅当找到轮廓时)
|
||||
if lines:
|
||||
for save_path in txt_save_paths:
|
||||
with open(save_path, 'w') as f:
|
||||
for line in lines:
|
||||
f.write(line + "\n")
|
||||
print(f"[✔] 转换成功: {os.path.basename(save_path)},共 {len(lines)} 个实例")
|
||||
else:
|
||||
# 即使没有轮廓,GT 掩码也已保存
|
||||
print(f"[⚠] {os.path.basename(img_path)} 没有轮廓 (但 GT 掩码已保存)")
|
||||
|
||||
# 第一次遍历图像,仅提取颜色信息用于构建 class 映射
|
||||
for fname in os.listdir(input_dir):
|
||||
if fname.lower().endswith('.png'):
|
||||
img_path = os.path.join(input_dir, fname)
|
||||
extract_color_mapping(img_path)
|
||||
|
||||
# 构建 class_id 重映射表(按像素数从大到小排序)
|
||||
class_pixel_count = {color_to_old_class[c]: count for c, count in color_class_counter.items()}
|
||||
sorted_classes = sorted(class_pixel_count.items(), key=lambda x: x[1], reverse=True)
|
||||
remap_class_dict = {old_cls: new_idx for new_idx, (old_cls, _) in enumerate(sorted_classes)}
|
||||
|
||||
# 第二次遍历图像,正式处理并生成标签
|
||||
for fname in os.listdir(input_dir):
|
||||
if fname.lower().endswith('.png'):
|
||||
img_path = os.path.join(input_dir, fname)
|
||||
base_name = os.path.splitext(fname)[0]
|
||||
|
||||
txt_path_1 = os.path.join(output_dir_1, base_name + '.txt')
|
||||
txt_path_2 = os.path.join(output_dir_2, base_name + '.txt')
|
||||
|
||||
# --- 新增:定义 GT 掩码输出路径 (保存为 .png) ---
|
||||
gt_path_1 = os.path.join(output_GT_dir_1, base_name + '.png')
|
||||
gt_path_2 = os.path.join(output_GT_dir_2, base_name + '.png')
|
||||
|
||||
process_image(img_path, [txt_path_1, txt_path_2], [gt_path_1, gt_path_2])
|
||||
|
||||
# 打印颜色统计和类别映射
|
||||
print("\n📊 所有图像颜色统计:")
|
||||
|
||||
def get_label_info(class_id):
|
||||
"""辅助函数:根据 class_id 获取标签名和颜色"""
|
||||
if 0 <= class_id < len(Annotate_CLASSES) and 0 <= class_id < len(Annotate_PALETTE):
|
||||
return Annotate_CLASSES[class_id], Annotate_PALETTE[class_id]
|
||||
else:
|
||||
return "未知标签", (255, 255, 255) # 返回一个默认值
|
||||
|
||||
for color, count in global_color_counter.most_common():
|
||||
if color == bg_PALETTE: # 使用 bg_PALETTE 变量
|
||||
print(f"背景颜色 {color} 出现次数: {count}")
|
||||
elif color[0] == color[1] == color[2]:
|
||||
class_id = color[0] - 1
|
||||
label_name, label_color = get_label_info(class_id)
|
||||
print(f"颜色 {color} → class {class_id} (标签: '{label_name}', 颜色: {label_color}),出现次数: {count}")
|
||||
else:
|
||||
# 此情况理论上不应出现,因为 extract_color_mapping 已过滤
|
||||
print(f"[⚠] 非灰阶颜色 {color},出现次数: {count} (此颜色不应被处理)")
|
||||
|
||||
# 打印详细的类别统计
|
||||
print("\n✅ 有效类别统计(原 class_id → 新 class_id → 总像素数):")
|
||||
# 按照 new_id (新类别索引) 排序输出
|
||||
sorted_remap = sorted(remap_class_dict.items(), key=lambda item: item[1])
|
||||
for old_id, new_id in sorted_remap:
|
||||
label_name, label_color = get_label_info(old_id)
|
||||
print(f"class {old_id} (标签: '{label_name}', 颜色: {label_color}) → {new_id}: {class_pixel_count[old_id]} pixels")
|
||||
|
||||
# 4. 额外输出新 class_id 到颜色的映射 ※※※
|
||||
print(f"\n🎨 新 class_id 与标签颜色 (Annotate_PALETTE) 映射表:")
|
||||
# 1. 创建一个 new_id -> color 的映射
|
||||
new_id_to_color = {}
|
||||
for old_id, new_id in remap_class_dict.items():
|
||||
if 0 <= old_id < len(Annotate_PALETTE):
|
||||
# 从 Annotate_PALETTE 获取原始颜色(它是一个元组)
|
||||
color_tuple = Annotate_PALETTE[old_id]
|
||||
# 转换为列表 [R, G, B] 以匹配您要的格式
|
||||
new_id_to_color[new_id] = list(color_tuple)
|
||||
else:
|
||||
# 预防性代码,以防 old_id 超出范围
|
||||
new_id_to_color[new_id] = [-1, -1, -1] # 表示错误/未找到
|
||||
|
||||
# 2. 按照 new_id (0, 1, 2...) 排序并按指定格式打印
|
||||
sorted_new_ids = sorted(new_id_to_color.keys())
|
||||
for new_id in sorted_new_ids:
|
||||
color_list = new_id_to_color[new_id]
|
||||
# 格式化输出: {id}: {color_list}
|
||||
# (注意:颜色列表的格式会自然包含逗号和空格)
|
||||
print(f" {new_id}: {color_list}")
|
||||
|
||||
print(f"\n✅ 全部图像处理完毕。")
|
||||
print(f" 标签 (.txt) 输出目录: {output_dir_1}, {output_dir_2}")
|
||||
print(f" GT 掩码 (.png) 输出目录: {output_GT_dir_1}, {output_GT_dir_2}")
|
||||
@@ -0,0 +1,20 @@
|
||||
# # 胆囊标注
|
||||
# Annotate_CLASSES = ('肝脏','胆囊','分离钳','止血海绵','肝总管','胆总管','吸引器','剪刀','止血纱布','生物夹','无损伤钳','喷洒','胆囊管','胆囊动脉','电凝','标本袋','引流管','纱布','金属钛夹','术中超声','吻合器','乳胶管','推结器','肝带','钳夹','超声刀','脂肪','双极电凝','棉球','血管阻断夹','肿瘤','针','线','韧带','胆囊静脉','背景') # 待分类的类
|
||||
# Annotate_PALETTE = [[255,91,0],[255,234,0],[85, 111, 181],[181, 227, 14],[72, 0, 255],[0, 155, 33],[255,0,255],[29, 32, 136],[160, 15, 95],[0,160,233],[52,184,178],[90,120,41],[255,0,0],[117,0,0],[167,24,233],[112,113,150],[0,255,0],[255,255,255],[0,255,255],[138,251,213],[136,162,196],[197,83,181],[202,202,200],[113,102,140],[66,115,82],[240,16,116],[155,132,0],[155,62,0],[146,175,236],[255,172,159],[245,161,0],[134,124,118], [0,157,142], [181,85,105], [42,8,66],[0,0,0]] # 每一类的像素直
|
||||
# # 甲状腺标注
|
||||
# Annotate_CLASSES = ('甲状腺', '针', '双极电凝', '止血海绵', '止血纱布', '喷洒', '标本袋', '肝蒂', '神经', '线', '肝脏', '肝总管', '胆总管', '胆囊管', '引流管', '推结器', '肌肉', '肿瘤', '胆囊', '吸引器', '生物夹', '动脉', '纱布', '乳胶管', '血管阻断夹', '分离钳', '肌肉', '无损伤钳', '电凝', '金属钛夹', '吻合器', '棉球', '脂肪', '超声刀', '钳夹', '静脉', '韧带', '术中超声','背景') # 待分类的类
|
||||
# Annotate_PALETTE = [[255, 148, 81], [134, 124, 118], [155, 62, 0], [181, 227, 14], [160, 15, 95], [90, 120, 41], [112, 113, 150], [133, 102, 140], [168, 162, 252], [0, 157, 142], [255, 91, 0], [72, 0, 255], [0, 155, 33], [255, 0, 0], [0, 255, 0], [202, 202, 200], [254, 141, 179], [254, 161, 0], [255, 234, 0], [255, 0, 255], [0, 160, 233], [117, 0, 0], [255, 255, 255], [197, 83, 181],[255, 172, 159], [85, 111, 181], [29, 32, 136], [52, 184, 178], [166, 24, 232], [0, 254, 254], [136, 162, 196], [146, 175, 236], [155, 132, 0], [240, 16, 116], [66, 115, 82], [42, 8, 66], [181, 85, 105], [138, 251, 213],[0,0,0]] # 每一类的像素直
|
||||
# # 胃癌标注+去雾影像标注
|
||||
Annotate_CLASSES = ('胃', '针', '双极电凝', '止血海绵', '止血纱布', '喷洒', '标本袋', '肝蒂', '小肠', '线', '肝脏', '脾脏', '胆总管', '胆囊管', '引流管', '推结器', '淋巴结', '胰腺', '胆囊', '吸引器', '生物夹', '动脉', '纱布', '乳胶管', '分离钳', '超声刀', '无损伤钳', '电凝', '金属钛夹', '吻合器', '脂肪', '剪刀', '钳夹', '静脉', '韧带','背景') # 待分类的类
|
||||
Annotate_PALETTE = [(237, 35, 85), (134, 124, 118), (155, 62, 0), (187, 227, 14), (160, 15, 95), (90, 120, 41), (112, 113, 150), (133, 102, 140), (110, 255, 166), (0, 157, 142), (255, 91, 0), (72, 0, 255), (0, 155, 33), (255, 0, 0), (0, 255, 0), (202, 202, 200), (201, 255, 74), (245, 161, 0), (255, 234, 0), (255, 0, 255), (0, 160, 233), (117, 0, 0), (255, 255, 255), (197, 83, 181), (85, 111, 181), (29, 32, 136), (52, 184, 178), (167, 24, 233), (0, 255, 255), (136, 162, 196), (155, 132, 0), (240, 16, 116), (66, 115, 82), (42, 8, 66), (181, 85, 105), [0,0,0]] # 每一类的像素直
|
||||
# # 甲状腺标注
|
||||
# Annotate_CLASSES = ('甲状旁腺', '喉返神经', '电凝', '无损伤钳', '超声刀', '分离钳', '纱布', '背景') # 待分类的类
|
||||
# Annotate_PALETTE = [(238, 25, 30), (24, 124, 248), (198, 24, 248), (24, 248, 240), (248, 119, 24), (24, 248, 114), (255,255,255), [0,0,0]] # 每一类的像素值
|
||||
# # 磁器械标注
|
||||
# Annotate_CLASSES = ( '双极电凝', '止血海绵', '止血纱布', '喷洒', '标本袋', '肝脏', '肝总管', '胆总管', '胆囊管', '引流管', '胆囊', '磁器械', '生物夹', '胆囊动脉', '纱布', '分离钳', '剪刀', '无损伤钳','电凝', '金属钛夹', '脂肪', '背景') # 待分类的类
|
||||
# Annotate_PALETTE = [ (155, 62, 0), (181, 227, 14), (160, 15, 95), (90, 120, 41), (112, 113, 150), (255, 91, 0), (72, 0, 255), (0, 155, 33), (255, 0, 0), (0, 255, 0), (255, 234, 0), (255, 0, 255), (0, 160, 233), (117, 0, 0), (255, 255, 255), (85, 111, 181), (29, 32, 136), (52, 184, 178), (167, 24, 233), (0, 255, 255), (155, 132, 0), [0,0,0]] # 每一类的像素值
|
||||
# # 二分类标注
|
||||
# Annotate_CLASSES = ( '特殊部位', '背景') # 待分类的类
|
||||
# Annotate_PALETTE = [ [255,255,255], [0,0,0]] # 每一类的像素值
|
||||
|
||||
bg_PALETTE = [0,0,0] # 背景的RGB
|
||||
@@ -0,0 +1,167 @@
|
||||
import os
|
||||
import argparse
|
||||
from PIL import Image
|
||||
import concurrent.futures
|
||||
|
||||
# 确保使用的是 os.cpu_count() 来自动检测核心数
|
||||
try:
|
||||
DEFAULT_WORKERS = os.cpu_count() or 1
|
||||
except AttributeError:
|
||||
try:
|
||||
import multiprocessing
|
||||
DEFAULT_WORKERS = multiprocessing.cpu_count()
|
||||
except (ImportError, NotImplementedError):
|
||||
DEFAULT_WORKERS = 4
|
||||
|
||||
def _process_single_file(source_path, png_path, delete_source):
|
||||
"""
|
||||
工作函数:处理单个文件的转换、缩放,并保存为PNG。
|
||||
(此函数与上一版完全相同)
|
||||
"""
|
||||
filename = os.path.basename(source_path)
|
||||
png_filename = os.path.basename(png_path)
|
||||
|
||||
try:
|
||||
with Image.open(source_path) as img:
|
||||
width, height = img.size
|
||||
min_side = min(width, height)
|
||||
resize_info = ""
|
||||
|
||||
if min_side > 1080:
|
||||
scale_factor = 1080 / min_side
|
||||
new_width = int(width * scale_factor)
|
||||
new_height = int(height * scale_factor)
|
||||
img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
|
||||
resize_info = f" (已缩放至 {new_width}x{new_height})"
|
||||
|
||||
img.save(png_path, 'PNG')
|
||||
|
||||
if delete_source:
|
||||
os.remove(source_path)
|
||||
return "success_del", filename, f"{png_filename}{resize_info}"
|
||||
else:
|
||||
return "success", filename, f"{png_filename}{resize_info}"
|
||||
|
||||
except Exception as e:
|
||||
return "fail", filename, str(e)
|
||||
|
||||
def process_images(folder_path, delete_source=False, max_workers=DEFAULT_WORKERS):
|
||||
"""
|
||||
批量将指定文件夹 (及其所有子文件夹) 内的 .bmp, .jpg, .jpeg 图像
|
||||
转换为 .png, 并按比例缩放 (最小边<=1080px)。
|
||||
"""
|
||||
|
||||
if not os.path.isdir(folder_path):
|
||||
print(f"错误:文件夹 '{folder_path}' 不存在或不是一个有效的目录。")
|
||||
return
|
||||
|
||||
print(f"--- 开始递归处理文件夹: {folder_path} ---")
|
||||
print(f"--- 使用最多 {max_workers} 个进程并行处理 ---")
|
||||
if delete_source:
|
||||
print("警告:已启用源文件删除模式。")
|
||||
|
||||
# --- 1. 收集所有需要处理的任务 ---
|
||||
tasks = []
|
||||
supported_extensions = ('.bmp', '.jpg', '.jpeg')
|
||||
|
||||
# <--- 更改:从 os.listdir() 切换到 os.walk() 以支持递归
|
||||
print("--- 正在扫描所有子文件夹...")
|
||||
for dirpath, dirnames, filenames in os.walk(folder_path):
|
||||
for filename in filenames:
|
||||
if filename.lower().endswith(supported_extensions):
|
||||
|
||||
# 构建完整的文件路径
|
||||
source_path = os.path.join(dirpath, filename)
|
||||
|
||||
# 构建PNG输出路径 (保持在同一个子文件夹内)
|
||||
base_name = os.path.splitext(filename)[0]
|
||||
png_path = os.path.join(dirpath, base_name + '.png')
|
||||
|
||||
# 避免重复转换
|
||||
if os.path.exists(png_path):
|
||||
# 打印相对路径,使其更清晰
|
||||
relative_path = os.path.relpath(png_path, folder_path)
|
||||
print(f" [跳过] {relative_path} (目标文件已存在)")
|
||||
continue
|
||||
|
||||
tasks.append((source_path, png_path, delete_source))
|
||||
# <--- 更改结束 ---
|
||||
|
||||
if not tasks:
|
||||
print(f"未在文件夹中找到新的 {supported_extensions} 文件。")
|
||||
return
|
||||
|
||||
print(f"--- 发现 {len(tasks)} 个新图像文件,开始转换... ---")
|
||||
|
||||
converted_count = 0
|
||||
failed_count = 0
|
||||
|
||||
# --- 2. 使用 ProcessPoolExecutor 执行任务 (此部分不变) ---
|
||||
with concurrent.futures.ProcessPoolExecutor(max_workers=max_workers) as executor:
|
||||
future_to_source = {
|
||||
executor.submit(_process_single_file, source, png, delete): source
|
||||
for source, png, delete in tasks
|
||||
}
|
||||
|
||||
# --- 3. 实时获取已完成的结果 (此部分不变) ---
|
||||
for future in concurrent.futures.as_completed(future_to_source):
|
||||
source_path_orig = future_to_source[future]
|
||||
|
||||
try:
|
||||
status, name, result = future.result()
|
||||
|
||||
# 获取相对路径以便于阅读
|
||||
relative_dir = os.path.relpath(os.path.dirname(source_path_orig), folder_path)
|
||||
# 如果是根目录,relative_dir 会是 ".",我们将其替换为空字符串
|
||||
if relative_dir == ".":
|
||||
log_name = name
|
||||
else:
|
||||
log_name = os.path.join(relative_dir, name)
|
||||
|
||||
if status == "success":
|
||||
print(f" [成功] {log_name} -> {result}")
|
||||
converted_count += 1
|
||||
elif status == "success_del":
|
||||
print(f" [成功] {log_name} -> {result} (并已删除源文件)")
|
||||
converted_count += 1
|
||||
elif status == "fail":
|
||||
print(f" [失败] 转换 {log_name} 时出错: {result} (源文件未删除)")
|
||||
failed_count += 1
|
||||
|
||||
except Exception as e:
|
||||
print(f" [严重失败] 处理 {os.path.basename(source_path_orig)} 时进程出错: {e}")
|
||||
failed_count += 1
|
||||
|
||||
print("--- 处理完毕 ---")
|
||||
if converted_count > 0:
|
||||
print(f"成功转换 {converted_count} 个文件。")
|
||||
if failed_count > 0:
|
||||
print(f"失败 {failed_count} 个文件。")
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="将指定文件夹(及其所有子文件夹)中的所有 .bmp, .jpg, .jpeg 文件批量转换为 .png 文件 (并行加速),并按比例缩放 (最小边<=1080px)。"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"folder",
|
||||
type=str,
|
||||
help="包含 .bmp, .jpg, .jpeg 图像的 (根) 目标文件夹路径 (例如 ./ABC)"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-d", "--delete-source",
|
||||
action="store_true",
|
||||
help="在成功转换为 .png 后,删除原始的 .bmp/.jpg/.jpeg 文件。"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-w", "--workers",
|
||||
type=int,
|
||||
default=DEFAULT_WORKERS,
|
||||
help=f"指定用于转换的工作进程数 (默认: {DEFAULT_WORKERS}, 即本机CPU核心数)"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
process_images(args.folder, args.delete_source, args.workers)
|
||||
180
Seg_All_In_One_YoloModel/Yolo数据集构建/Tool_deal_labels.py
Normal file
180
Seg_All_In_One_YoloModel/Yolo数据集构建/Tool_deal_labels.py
Normal file
@@ -0,0 +1,180 @@
|
||||
import cv2
|
||||
import random
|
||||
import numpy as np
|
||||
from scipy.ndimage import label, distance_transform_edt
|
||||
|
||||
def skeletonize(image):
|
||||
"""骨架化函数,确保线条连通性并缩减为1像素宽"""
|
||||
skeleton = np.zeros_like(image)
|
||||
temp_image = np.copy(image)
|
||||
|
||||
while True:
|
||||
eroded = cv2.erode(temp_image, None) # 腐蚀操作
|
||||
temp_dilate = cv2.dilate(eroded, None) # 膨胀操作
|
||||
temp = cv2.subtract(temp_image, temp_dilate) # 提取边缘
|
||||
skeleton = cv2.bitwise_or(skeleton, temp) # 将边缘加入骨架
|
||||
temp_image = np.copy(eroded)
|
||||
if cv2.countNonZero(temp_image) == 0:
|
||||
break
|
||||
return skeleton
|
||||
|
||||
# 1. *** 边缘检测并膨胀 ***
|
||||
def edge_detection(image):
|
||||
"""对图像的各个通道进行边缘检测并进行膨胀处理"""
|
||||
b_channel, g_channel, r_channel = cv2.split(image)
|
||||
gray_image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
||||
|
||||
edges_b = cv2.Canny(b_channel, 100, 200)
|
||||
edges_g = cv2.Canny(g_channel, 100, 200)
|
||||
edges_r = cv2.Canny(r_channel, 100, 200)
|
||||
edges_gray = cv2.Canny(gray_image, 100, 200)
|
||||
|
||||
# 合并所有边缘检测结果
|
||||
edges = cv2.bitwise_or(edges_b, edges_g)
|
||||
edges = cv2.bitwise_or(edges, edges_r)
|
||||
edges = cv2.bitwise_or(edges, edges_gray)
|
||||
|
||||
# 创建膨胀核并进行膨胀操作
|
||||
kernel = np.ones((3, 3), np.uint8)
|
||||
dilated_image = cv2.dilate(edges, kernel, iterations=1)
|
||||
|
||||
return dilated_image
|
||||
|
||||
# 2. ** 检测连通区域 ***
|
||||
def detect_connected_regions(dilated_image):
|
||||
"""检测图像中的连通区域并过滤掉小区域"""
|
||||
_, binary = cv2.threshold(dilated_image, 1, 255, cv2.THRESH_BINARY_INV)
|
||||
binary[binary > 0] = 1 # 转换为二值图像
|
||||
|
||||
# 标记连通区域
|
||||
structure = np.ones((3, 3), dtype=int)
|
||||
labeled_array, num_features = label(binary, structure=structure)
|
||||
|
||||
# 清除掉小于100像素的区域
|
||||
filtered_labeled_array = np.copy(labeled_array)
|
||||
for label_num in range(1, num_features + 1):
|
||||
area = np.sum(labeled_array == label_num)
|
||||
if area < 100:
|
||||
filtered_labeled_array[filtered_labeled_array == label_num] = 0
|
||||
|
||||
# 重新标记过滤后的连通区域
|
||||
filtered_labeled_array, num_features = label(filtered_labeled_array, structure=structure)
|
||||
return filtered_labeled_array, num_features # 返回过滤后的labeled_array和未过滤的labeled_array(用于寻找颜色)
|
||||
|
||||
# 3. *** 分水岭填充白色区域 ***
|
||||
def fill_white_regions(filtered_labeled_array):
|
||||
"""使用分水岭算法填充白色区域"""
|
||||
# 准备三通道图像作为分水岭算法的输入
|
||||
color_image = np.zeros((filtered_labeled_array.shape[0], filtered_labeled_array.shape[1], 3), dtype=np.uint8)
|
||||
|
||||
# 将过滤后的 labeled_array 转化为 32 位整型,作为分水岭的 markers
|
||||
markers = np.copy(filtered_labeled_array).astype(np.int32)
|
||||
markers[markers == 0] = -1 # 背景标记为 -1
|
||||
|
||||
# 执行分水岭算法
|
||||
cv2.watershed(color_image, markers)
|
||||
|
||||
# 更新分水岭结果
|
||||
filled_labeled_array = markers.astype(int)
|
||||
filled_labeled_array[filled_labeled_array == -1] = 0
|
||||
|
||||
# 使用距离变换,计算边缘像素(0)最近的非零值
|
||||
non_zero_mask = filled_labeled_array != 0
|
||||
distance, nearest_indices = distance_transform_edt(non_zero_mask == 0, return_indices=True)
|
||||
nearest_values = filled_labeled_array[tuple(nearest_indices)]
|
||||
filled_labeled_array[filled_labeled_array == 0] = nearest_values[filled_labeled_array == 0]
|
||||
|
||||
return filled_labeled_array
|
||||
|
||||
# 4. *** 对连通区域上色(使用“filtered_labeled_array”作为颜色判断,给“filled_labeled_array”上色) ***
|
||||
def color_connected_regions(filled_labeled_array, filtered_labeled_array, ori_labeled_image, Annotate_PALETTE):
|
||||
"""根据原始图像的颜色和注解调色板给连通区域上色"""
|
||||
|
||||
# 初始化一个三通道的彩色图像
|
||||
colored_image = np.zeros((*filled_labeled_array.shape, 3), dtype=np.uint8)
|
||||
|
||||
# 遍历filtered_labeled_array中的每个标签
|
||||
unique_labels = np.unique(filtered_labeled_array)
|
||||
|
||||
for label_num in unique_labels:
|
||||
if label_num == 0:
|
||||
continue # 跳过背景标签
|
||||
|
||||
# 找到filtered_labeled_array中等于当前标签的区域
|
||||
mask_filtered = (filtered_labeled_array == label_num)
|
||||
|
||||
# 获取ori_labeled_image中对应区域的RGB值
|
||||
region_rgb_values = ori_labeled_image[mask_filtered]
|
||||
|
||||
if len(region_rgb_values) == 0:
|
||||
continue
|
||||
|
||||
# 计算区域的RGB平均值
|
||||
average_rgb = np.mean(region_rgb_values, axis=0)
|
||||
|
||||
# 找到Annotate_PALETTE中与average_rgb最接近的颜色
|
||||
closest_palette_color = find_closest_palette_color(average_rgb, Annotate_PALETTE)
|
||||
|
||||
# 将该颜色赋给filled_labeled_array对应区域的元素
|
||||
mask_filled = (filled_labeled_array == label_num)
|
||||
colored_image[mask_filled] = closest_palette_color
|
||||
|
||||
return colored_image
|
||||
|
||||
# 寻找最近邻颜色
|
||||
def find_closest_palette_color(average_rgb, Annotate_PALETTE):
|
||||
"""根据平均RGB值找到Annotate_PALETTE中最接近的颜色"""
|
||||
Annotate_PALETTE = [[color[2], color[1], color[0]] for color in Annotate_PALETTE]
|
||||
|
||||
average_rgb = np.array(average_rgb)
|
||||
min_distance = float('inf')
|
||||
closest_color = None
|
||||
|
||||
# 遍历调色板,计算每个颜色与平均RGB的欧几里得距离
|
||||
for palette_color in Annotate_PALETTE:
|
||||
palette_color = np.array(palette_color)
|
||||
distance = np.linalg.norm(average_rgb - palette_color) # 欧几里得距离
|
||||
|
||||
if distance < min_distance:
|
||||
min_distance = distance
|
||||
closest_color = palette_color
|
||||
|
||||
return closest_color
|
||||
|
||||
# 5. 对Array区域上色(4的简化版)
|
||||
def Tool_color_connected_array(Array):
|
||||
colored_image = np.zeros((*Array.shape, 3), dtype=np.uint8)
|
||||
for label_num in range(1, np.max(Array) + 1):
|
||||
color = [np.random.randint(0, 254) for _ in range(3)]
|
||||
colored_image[Array == label_num] = color
|
||||
return colored_image
|
||||
|
||||
if __name__ == '__main__':
|
||||
"""超参数"""
|
||||
image_path = './2023_02_03_09_13_48.00_08_04_21.Still085.png'
|
||||
|
||||
"""主函数,处理图像并保存结果"""
|
||||
# 读取图像
|
||||
image = cv2.imread(image_path)
|
||||
|
||||
# 1. 边缘检测并膨胀
|
||||
dilated_image = edge_detection(image)
|
||||
cv2.imwrite('./1_1_range_image.png', dilated_image)
|
||||
|
||||
# 2. 检测连通区域
|
||||
filtered_labeled_array, _ = detect_connected_regions(dilated_image)
|
||||
colored_image_filtered = Tool_color_connected_array(filtered_labeled_array)
|
||||
cv2.imwrite('./2_colored_image_filtered.png', colored_image_filtered)
|
||||
|
||||
# 3. 分水岭填充白色区域
|
||||
filled_labeled_array = fill_white_regions(filtered_labeled_array)
|
||||
colored_image_filled = Tool_color_connected_array(filled_labeled_array)
|
||||
cv2.imwrite('./3_colored_image_filled.png', colored_image_filled)
|
||||
|
||||
# 4. 对连通区域上色
|
||||
ori_labeled_image = image
|
||||
colored_image_final = color_connected_regions(filled_labeled_array, filtered_labeled_array, ori_labeled_image, Annotate_PALETTE)
|
||||
cv2.imwrite('./4_color_image_Final.png', colored_image_final)
|
||||
|
||||
|
||||
print("处理后的图片已保存。")
|
||||
190
Seg_All_In_One_YoloModel/Yolo数据集构建/Tool_resize_pics.py
Normal file
190
Seg_All_In_One_YoloModel/Yolo数据集构建/Tool_resize_pics.py
Normal file
@@ -0,0 +1,190 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import os
|
||||
import argparse
|
||||
from PIL import Image
|
||||
import concurrent.futures
|
||||
import time
|
||||
|
||||
# --- 1. 复制自您脚本中的 CPU 核心数检测逻辑 ---
|
||||
try:
|
||||
DEFAULT_WORKERS = os.cpu_count() or 1
|
||||
except AttributeError:
|
||||
try:
|
||||
import multiprocessing
|
||||
DEFAULT_WORKERS = multiprocessing.cpu_count()
|
||||
except (ImportError, NotImplementedError):
|
||||
DEFAULT_WORKERS = 4
|
||||
|
||||
# --- 2. 定义支持的图片格式 ---
|
||||
SUPPORTED_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.bmp', '.webp', '.tiff', '.tif')
|
||||
|
||||
# --- 3. 确定Pillow的重采样过滤器 ---
|
||||
try:
|
||||
RESAMPLE_FILTER = Image.Resampling.LANCZOS
|
||||
except AttributeError:
|
||||
RESAMPLE_FILTER = Image.LANCZOS
|
||||
|
||||
|
||||
def _process_single_image(image_path, target_min_side):
|
||||
"""
|
||||
工作函数:处理单个文件的缩放。(此函数与上一版完全相同)
|
||||
|
||||
:param image_path: 完整图片路径
|
||||
:param target_min_side: 目标最小边长 (例如 1080)
|
||||
:return: (状态, 原始文件名, 结果/错误信息)
|
||||
"""
|
||||
filename = os.path.basename(image_path)
|
||||
|
||||
try:
|
||||
with Image.open(image_path) as img:
|
||||
img_info = img.info.copy()
|
||||
width, height = img.size
|
||||
min_side = min(width, height)
|
||||
|
||||
# --- 核心逻辑:检查是否需要缩放 ---
|
||||
if min_side <= target_min_side:
|
||||
# 返回相对路径以便于阅读
|
||||
return "skipped", image_path, f"最小边 ({min_side}px) 已 <= {target_min_side}px"
|
||||
|
||||
# --- 计算新尺寸 ---
|
||||
scale_ratio = target_min_side / min_side
|
||||
new_width = int(width * scale_ratio)
|
||||
new_height = int(height * scale_ratio)
|
||||
|
||||
# --- 执行缩放 ---
|
||||
resized_img = img.resize((new_width, new_height), RESAMPLE_FILTER)
|
||||
|
||||
# --- 覆盖保存 ---
|
||||
resized_img.save(image_path, **img_info)
|
||||
|
||||
# 返回相对路径以便于阅读
|
||||
return "success", image_path, f"从 {width}x{height} -> {new_width}x{new_height}"
|
||||
|
||||
except Exception as e:
|
||||
# 返回相对路径以便于阅读
|
||||
return "fail", image_path, str(e)
|
||||
|
||||
def resize_images_in_folder(folder_path, target_min_side=1080, max_workers=DEFAULT_WORKERS):
|
||||
"""
|
||||
*** [已更新] ***
|
||||
批量将文件夹及其所有子文件夹中最小边 > target_min_side 的图片等比例缩小。
|
||||
|
||||
:param folder_path: 目标根文件夹路径
|
||||
:param target_min_side: 目标最小边长
|
||||
:param max_workers: 使用的进程数
|
||||
"""
|
||||
|
||||
if not os.path.isdir(folder_path):
|
||||
print(f"错误:文件夹 '{folder_path}' 不存在或不是一个有效的目录。")
|
||||
return
|
||||
|
||||
print(f"--- 开始 **递归** 处理文件夹: {folder_path} ---")
|
||||
print(f"--- 目标最小边: {target_min_side}px ---")
|
||||
print(f"--- 使用最多 {max_workers} 个进程并行处理 ---")
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
# --- 1. 收集所有需要处理的任务 (*** [核心修改] ***) ---
|
||||
tasks = []
|
||||
print("--- 正在递归扫描所有子文件夹... ---")
|
||||
|
||||
# 使用 os.walk 递归遍历
|
||||
for dirpath, dirnames, filenames in os.walk(folder_path):
|
||||
for filename in filenames:
|
||||
# 检查文件扩展名是否在支持的列表中
|
||||
if filename.lower().endswith(SUPPORTED_EXTENSIONS):
|
||||
# 构造完整的文件路径
|
||||
image_path = os.path.join(dirpath, filename)
|
||||
tasks.append((image_path, target_min_side))
|
||||
|
||||
if not tasks:
|
||||
print(f"未在 {folder_path} 及其子文件夹中找到支持的图片文件。")
|
||||
return
|
||||
|
||||
print(f"--- 发现 {len(tasks)} 个图片文件,开始处理... ---")
|
||||
|
||||
resized_count = 0
|
||||
skipped_count = 0
|
||||
failed_count = 0
|
||||
|
||||
# 转换为相对路径,使输出更简洁
|
||||
base_folder_path = os.path.abspath(folder_path)
|
||||
|
||||
# --- 2. 使用 ProcessPoolExecutor 执行任务 (与上一版相同) ---
|
||||
with concurrent.futures.ProcessPoolExecutor(max_workers=max_workers) as executor:
|
||||
future_to_path = {
|
||||
executor.submit(_process_single_image, path, size): path
|
||||
for path, size in tasks
|
||||
}
|
||||
|
||||
# --- 3. 实时获取已完成的结果 (与上一版相同) ---
|
||||
for future in concurrent.futures.as_completed(future_to_path):
|
||||
try:
|
||||
# status, path_or_name, result
|
||||
status, image_path, result = future.result()
|
||||
|
||||
# 转换为相对路径
|
||||
try:
|
||||
display_path = os.path.relpath(image_path, base_folder_path)
|
||||
except ValueError:
|
||||
display_path = image_path # 如果不在同一驱动器(Windows),则显示完整路径
|
||||
|
||||
if status == "success":
|
||||
print(f" [成功] {display_path}: {result}")
|
||||
resized_count += 1
|
||||
elif status == "skipped":
|
||||
print(f" [跳过] {display_path}: {result}")
|
||||
skipped_count += 1
|
||||
elif status == "fail":
|
||||
print(f" [失败] {display_path}: {result}")
|
||||
failed_count += 1
|
||||
|
||||
except Exception as e:
|
||||
orig_path = future_to_path[future]
|
||||
print(f" [严重失败] 处理 {orig_path} 时进程出错: {e}")
|
||||
failed_count += 1
|
||||
|
||||
end_time = time.time()
|
||||
print("--- 处理完毕 ---")
|
||||
if resized_count > 0:
|
||||
print(f"成功缩放 {resized_count} 个文件。")
|
||||
if skipped_count > 0:
|
||||
print(f"跳过 {skipped_count} 个文件 (无需缩放)。")
|
||||
if failed_count > 0:
|
||||
print(f"失败 {failed_count} 个文件。")
|
||||
print(f"总耗时: {end_time - start_time:.2f} 秒。")
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="*** [已更新] *** 批量 **递归** 缩放文件夹及其子文件夹中的图片,使最小边不超过指定值。"
|
||||
)
|
||||
|
||||
# 位置参数:文件夹路径 (必需)
|
||||
parser.add_argument(
|
||||
"folder",
|
||||
type=str,
|
||||
help="包含图片的 **根** 文件夹路径 (例如 ./MyImages),将递归处理所有子文件夹。"
|
||||
)
|
||||
|
||||
# 选项参数:目标尺寸 (可选)
|
||||
parser.add_argument(
|
||||
"-s", "--size",
|
||||
type=int,
|
||||
default=1080,
|
||||
help="指定目标最小边长 (默认: 1080)。如果图片最小边已小于此值,则跳过。"
|
||||
)
|
||||
|
||||
# 选项参数:控制进程数 (可选)
|
||||
parser.add_argument(
|
||||
"-w", "--workers",
|
||||
type=int,
|
||||
default=DEFAULT_WORKERS,
|
||||
help=f"指定用于转换的工作进程数 (默认: {DEFAULT_WORKERS}, 即本机CPU核心数)"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# 将文件夹路径、"目标尺寸" 和 "工作进程数" 传递给函数
|
||||
resize_images_in_folder(args.folder, args.size, args.workers)
|
||||
69
Seg_All_In_One_YoloModel/Yolo数据集构建/Yolo数据集构建_使用手册
Normal file
69
Seg_All_In_One_YoloModel/Yolo数据集构建/Yolo数据集构建_使用手册
Normal file
@@ -0,0 +1,69 @@
|
||||
#### 0. 准备工作 ####
|
||||
# 清理旧文件
|
||||
rm -r ./Label/* ./ORI/*
|
||||
rm -r ./result_stack_*透明度 ./Data
|
||||
rm -r ./ORI_GT_label_fold ./ORI_pro_label_fold ./__pycache__
|
||||
conda activate SMP
|
||||
|
||||
# A. 图像移动
|
||||
cp 磁辅助分割图像-有效图/* Label/ # TODO TODO 将标注后图片放入Label中 # 保证Label中图片与ORI中图片一一对应,且命名相同
|
||||
cp 磁辅助分割-原图/* ORI/ # TODO TODO 将原始图片放入ORI中
|
||||
# B.1 修改类别颜色
|
||||
vim Tool_Classes_And_Palette.py # 修改 Annotate_CLASSES、Annotate_PALETTE 以匹配标注时的类别与颜色
|
||||
# B.2 将bmp、jpg图片转为png
|
||||
python Tool_convert_bmp_jpg_to_png.py ./Label --delete-source
|
||||
python Tool_convert_bmp_jpg_to_png.py ./ORI --delete-source
|
||||
# B.3 将图片转为最大边限制为1080
|
||||
python Tool_resize_pics.py ./Label # -s 1080
|
||||
python Tool_resize_pics.py ./ORI # -s 1080
|
||||
# C. 检测图片是否匹配
|
||||
python 0_1_check_picture_pair.py # -i "./ORI" -l "./Label" -p "" -s ""
|
||||
# python 0_1_check_picture_pair.py # -i "../../DataSet_Public/6_CWK_2_cfz/images/train" -l "../../DataSet_Public/6_CWK_2_cfz/labels_GT/train" -p "" -s ""
|
||||
# D. 生成堆叠图片(可视化标签效果)
|
||||
bash 0_2_TOOL_stack_pics.sh -i "./ORI" -l ./Label -r ./result_stack_0.3透明度 -a 0.3 -p "" -s "_label"
|
||||
# E. ※下载图片,查看匹配的是否有问题※
|
||||
|
||||
#### 1. 批量化生成训练、测试集图片 ####
|
||||
# A. 将图片转为GT图片
|
||||
python 1_deal_labels.py -src_fold ./Label
|
||||
# B. 新建数据最终存储文件夹
|
||||
rm -r ./Data
|
||||
mkdir -p Data/images/train Data/images/val
|
||||
cp ORI/* Data/images/train/
|
||||
cp ORI/* Data/images/val/ # 或根据需要分配训练集、验证集
|
||||
mkdir -p Data/labels/train Data/labels/val
|
||||
mkdir -p Data/labels_GT/train Data/labels_GT/val
|
||||
# C. 生成labels 、 labels_GT图片到 Data/labels_GT/train 中
|
||||
# Way 1:※推荐※(将类别压缩)
|
||||
python 2_Check_and_Gen_Txt_Label_sort_label.py
|
||||
# Way 2:(使用原始类别)
|
||||
# python 2_Check_and_Gen_Txt_Label_ori_label.py # Way 2
|
||||
# cp ORI_GT_label_fold/* Data/labels_GT/train/ # Way 2
|
||||
# cp ORI_GT_label_fold/* Data/labels_GT/val/ # Way 2
|
||||
|
||||
#### 2. 纳入到Yolo训练体系 ####
|
||||
# A. 移动数据
|
||||
cp -r ./Data ../../DataSet_Public/8_Haze_Baidu_Plus # TODO TODO 将 Data 文件夹改名后 放入 ../../DataSet_Public 下
|
||||
# B. 设置输出颜色参考
|
||||
将 程序输出的信息存入 ../../Seg_All_In_One_YoloModel/dataset.yaml 的 color中作为最终输出颜色的参考
|
||||
# C. 根据 Seg_All_In_One_YoloModel 中的使用手册新增数据集、进行配置、训练
|
||||
修改 dataset.yaml 中 训练数据集、修改 yolo_config.py 中 EPOCHS、PATIENCE
|
||||
|
||||
############################## 进行训练推理(整体流程) ############################################
|
||||
# 1. 批量化训练
|
||||
conda activate SMP
|
||||
cd ~/Desktop/Seg/Seg_All_In_One_YoloModel
|
||||
bash yolo_train.sh
|
||||
# 2. 复制最优模型到预测文件夹
|
||||
bash ./Tool_Yolo_Copy_Best_Model.sh --pt_name "best.pt" && bash ./Tool_Yolo_Copy_Best_Model.sh --pt_name "epoch100.pt" && bash ./Tool_Yolo_Copy_Best_Model.sh --pt_name "epoch50.pt" && bash ./Tool_Yolo_Copy_Best_Model.sh --pt_name "epoch150.pt"
|
||||
# 3. 批量化预测+热度图可视化
|
||||
bash yolo_predict.sh --conf 0.2 --pt_name "epoch100.pt" && bash yolo_predict.sh --conf 0.2 && bash yolo_predict.sh --conf 0.2 --pt_name "epoch50.pt" && bash yolo_predict.sh --conf 0.2 --pt_name "epoch150.pt"
|
||||
bash ./yolo_predict.sh --pt_name "best.pt" --heatmap_method "All" # && bash ./yolo_predict.sh --pt_name "epoch100.pt" --heatmap_method "All"
|
||||
# 4. 横向对比结果
|
||||
python yolo_predict_V2_compare_all.py
|
||||
# 5. 打包预测结果(不包含*.pt模型文件)
|
||||
cd /home/wkmgc/Desktop/Seg/BestMode_Predict_Results_DataSet_Public/
|
||||
zip -r /home/wkmgc/Desktop/8_Haze_Baidu_Plus-Yolo.zip 8_Haze_Baidu_Plus*-Yolo -x "*.pt" # TODO TODO
|
||||
# 6. 打包训练结果(只有.png、.jpg、.csv文件)
|
||||
cd /home/wkmgc/Desktop/Seg/Hardisk/
|
||||
zip -r /home/wkmgc/Desktop/8_Haze_Baidu_Plus-Yolo_train.zip 8_Haze_Baidu_Plus-Yolo -i \*.png \*.jpg \*.csv # TODO TODO
|
||||
318
Seg_All_In_One_YoloModel/dataset.yaml
Normal file
318
Seg_All_In_One_YoloModel/dataset.yaml
Normal file
@@ -0,0 +1,318 @@
|
||||
############ 第一部分:数据集跟路径 ############
|
||||
# # V1: 1_CholecSeg8k-13Type-1920x1080
|
||||
# path: ../DataSet_Public/1_CholecSeg8k-13Type-1920x1080
|
||||
# # V2:2_AutoLaparo-10Type-1920x1080
|
||||
# path: ../DataSet_Public/2_AutoLaparo-10Type-1920x1080
|
||||
# # V3:3_1_Endovis_2017-8Type-512x512
|
||||
# path: ../DataSet_Public/3_1_Endovis_2017-8Type-512x512
|
||||
# # V4:3_2_Endovis_2018-8Type-512x512
|
||||
# path: ../DataSet_Public/3_2_Endovis_2018-8Type-512x512
|
||||
# # V5:4_Dresden-11Type-512x512
|
||||
# path: ../DataSet_Public/4_Dresden-11Type-512x512
|
||||
|
||||
# # V6:5_LC_1_blood_verssel
|
||||
# path: ../DataSet_Public/5_LC_1_blood_verssel
|
||||
# # V7:5_LC_2_artery
|
||||
# path: ../DataSet_Public/5_LC_2_artery
|
||||
# # V8:5_LC_3_cystic_duct
|
||||
# path: ../DataSet_Public/5_LC_3_cystic_duct
|
||||
# # V9:5_LC_4_foreigner
|
||||
# path: ../DataSet_Public/5_LC_4_foreigner
|
||||
# # V10:5_LC_5_stop_bleed
|
||||
# path: ../DataSet_Public/5_LC_5_stop_bleed
|
||||
# # V11:6_CWK_1_yws
|
||||
# path: ../DataSet_Public/6_CWK_1_yws
|
||||
# # V12:6_CWK_2_cfz
|
||||
# path: ../DataSet_Public/6_CWK_2_cfz
|
||||
# # V13: 5_TQY # TODO
|
||||
# path: ../DataSet_Public/5_TQY
|
||||
# # V14: 5_Haze_ori、6_Haze_AOD_Net、7_Haze_Baidu、8_Haze_Baidu_Plus # TODO
|
||||
# path: ../DataSet_Public/8_Haze_Baidu_Plus
|
||||
|
||||
# 默认使用本机已存在的公开数据集,便于 yolo_config.py smoke test 直接通过。
|
||||
# 切换实验数据集时,请同步修改 path/test/train/val 和 names。
|
||||
path: ../DataSet_Public/3_1_Endovis_2017-8Type-512x512
|
||||
|
||||
# # # Test_V1:5_Predict_Video
|
||||
# path: ../DataSet_Public/5_Predict_Video/LC_Video_1
|
||||
|
||||
############ 第二部分:测试集相对路径 ############
|
||||
# 训练集和验证集图片路径 (相对于 'path')
|
||||
# # V1: 1_CholecSeg8k-13Type-1920x1080
|
||||
# test: images/val
|
||||
# # V2:2_AutoLaparo-10Type-1920x1080
|
||||
# test: images/val
|
||||
# # V3:3_1_Endovis_2017-8Type-512x512
|
||||
# test: images/val
|
||||
# # V4:3_2_Endovis_2018-8Type-512x512
|
||||
# test: images/val
|
||||
# # V5:4_Dresden-11Type-512x512
|
||||
# test: images/val # images/test
|
||||
|
||||
# # V6:5_LC_1_blood_verssel
|
||||
test: images/val
|
||||
# # V7:5_LC_2_artery
|
||||
# test: images/val
|
||||
# # V8:5_LC_3_cystic_duct
|
||||
# test: images/val
|
||||
# # V9:5_LC_4_foreigner
|
||||
# test: images/val
|
||||
# # V10:5_LC_5_stop_bleed
|
||||
# test: images/val
|
||||
# # V11:6_CWK_1_yws
|
||||
# test: images/val
|
||||
# # V12:6_CWK_2_cfz
|
||||
# test: images/val
|
||||
# # V13:5_TQY
|
||||
# test: images/val
|
||||
# # V14:5_Haze_ori、6_Haze_AOD_Net、7_Haze_Baidu、8_Haze_Baidu_Plus
|
||||
# test: images/val
|
||||
|
||||
# # Test_V1:5_Predict_Video
|
||||
# test: images/val
|
||||
|
||||
############ 第三部分:训练集、验证集相对路径 ############
|
||||
train: images/train
|
||||
val: images/val
|
||||
|
||||
############ 第四部分:类别名称【从0开始】 ############
|
||||
# # V1: 1_CholecSeg8k-13Type-1920x1080
|
||||
# names:
|
||||
# 0: background
|
||||
# 1: 1
|
||||
# 2: 2
|
||||
# 3: 3
|
||||
# 4: 4
|
||||
# 5: 5
|
||||
# 6: 6
|
||||
# 7: 7
|
||||
# 8: 8
|
||||
# 9: 9
|
||||
# 10: 10
|
||||
# 11: 11
|
||||
# 12: 12
|
||||
# V2:2_AutoLaparo-10Type-1920x1080
|
||||
# names:
|
||||
# 0: background
|
||||
# 1: 1
|
||||
# 2: 2
|
||||
# 3: 3
|
||||
# 4: 4
|
||||
# 5: 5
|
||||
# 6: 6
|
||||
# 7: 7
|
||||
# 8: 8
|
||||
# 9: 9
|
||||
# # V3:3_1_Endovis_2017-8Type-512x512
|
||||
# names:
|
||||
# 0: background
|
||||
# 1: 1
|
||||
# 2: 2
|
||||
# 3: 3
|
||||
# 4: 4
|
||||
# 5: 5
|
||||
# 6: 6
|
||||
# 7: 7
|
||||
# # V4:3_2_Endovis_2018-8Type-512x512
|
||||
# names:
|
||||
# 0: background
|
||||
# 1: 1
|
||||
# 2: 2
|
||||
# 3: 3
|
||||
# 4: 4
|
||||
# 5: 5
|
||||
# 6: 6
|
||||
# 7: 7
|
||||
# # V5:4_Dresden-11Type-512x512
|
||||
# names:
|
||||
# 0: background
|
||||
# 1: 1
|
||||
# 2: 2
|
||||
# 3: 3
|
||||
# 4: 4
|
||||
# 5: 5
|
||||
# 6: 6
|
||||
# 7: 7
|
||||
# 8: 8
|
||||
# 9: 9
|
||||
# 10: 10
|
||||
|
||||
# # V6:5_LC_1_blood_verssel # TODO
|
||||
# names:
|
||||
# 0: 0
|
||||
# 1: 1
|
||||
# 2: 2
|
||||
# 3: 3
|
||||
# 4: 4
|
||||
# 5: 5
|
||||
# # V7:5_LC_2_artery # TODO
|
||||
# names:
|
||||
# 0: 0
|
||||
# 1: 1
|
||||
# 2: 2
|
||||
# 3: 3
|
||||
# 4: 4
|
||||
# 5: 5
|
||||
# 6: 6
|
||||
# 7: 7
|
||||
# 8: 8
|
||||
# 9: 9
|
||||
# 10: 10
|
||||
# 11: 11
|
||||
# 12: 12
|
||||
# 13: 13
|
||||
# # V8:5_LC_3_cystic_duct # TODO
|
||||
# names:
|
||||
# 0: 0
|
||||
# 1: 1
|
||||
# 2: 2
|
||||
# 3: 3
|
||||
# 4: 4
|
||||
# 5: 5
|
||||
# 6: 6
|
||||
# 7: 7
|
||||
# 8: 8
|
||||
# 9: 9
|
||||
# # V9:5_LC_4_foreigner # TODO
|
||||
# names:
|
||||
# 0: 0
|
||||
# 1: 1
|
||||
# 2: 2
|
||||
# 3: 3
|
||||
# 4: 4
|
||||
# 5: 5
|
||||
# # V10:5_LC_5_stop_bleed # TODO
|
||||
# names:
|
||||
# 0: 0
|
||||
# 1: 1
|
||||
# 2: 2
|
||||
# 3: 3
|
||||
# 4: 4
|
||||
# 5: 5
|
||||
# 6: 6
|
||||
# 7: 7
|
||||
# 8: 8
|
||||
# 9: 9
|
||||
# 10: 10
|
||||
# 11: 11
|
||||
# # V11:6_CWK_1_yws # TODO
|
||||
# names:
|
||||
# 0: 0
|
||||
# 1: 1
|
||||
# 2: 2
|
||||
# 3: 3
|
||||
# 4: 4
|
||||
# 5: 5
|
||||
# # V12:6_CWK_2_cfz # TODO
|
||||
# names:
|
||||
# 0: 0
|
||||
# 1: 1
|
||||
# 2: 2
|
||||
# 3: 3
|
||||
# 4: 4
|
||||
# 5: 5
|
||||
# # V13:5_TQY # TODO
|
||||
# names:
|
||||
# 0: 0
|
||||
# # V14:5_Haze_ori、6_Haze_AOD_Net、7_Haze_Baidu、8_Haze_Baidu_Plus # TODO
|
||||
# names:
|
||||
# 0: 0
|
||||
# 1: 1
|
||||
# 2: 2
|
||||
# 3: 3
|
||||
|
||||
# 当前默认:V3 3_1_Endovis_2017-8Type-512x512
|
||||
names:
|
||||
0: background
|
||||
1: 1
|
||||
2: 2
|
||||
3: 3
|
||||
4: 4
|
||||
5: 5
|
||||
6: 6
|
||||
7: 7
|
||||
|
||||
############ 第五部分:最终上色 ############
|
||||
# # V6:5_LC_1_blood_verssel # TODO
|
||||
# colors:
|
||||
# 0: [255, 91, 0]
|
||||
# 1: [255, 234, 0]
|
||||
# 2: [167, 24, 233]
|
||||
# 3: [52, 184, 178]
|
||||
# 4: [255, 0, 0]
|
||||
# 5: [0, 155, 33]
|
||||
# # V7:5_LC_2_artery # TODO
|
||||
# colors:
|
||||
# 0: [255, 91, 0]
|
||||
# 1: [255, 234, 0]
|
||||
# 2: [255, 0, 0]
|
||||
# 3: [0, 160, 233]
|
||||
# 4: [0, 155, 33]
|
||||
# 5: [52, 184, 178]
|
||||
# 6: [167, 24, 233]
|
||||
# 7: [255, 255, 255]
|
||||
# 8: [117, 0, 0]
|
||||
# 9: [72, 0, 255]
|
||||
# 10: [85, 111, 181]
|
||||
# 11: [0, 255, 255]
|
||||
# 12: [42, 8, 66]
|
||||
# 13: [66, 115, 82]
|
||||
# # V8:5_LC_3_cystic_duct # TODO
|
||||
# colors:
|
||||
# 0: [255, 91, 0]
|
||||
# 1: [255, 234, 0]
|
||||
# 2: [255, 0, 0]
|
||||
# 3: [167, 24, 233]
|
||||
# 4: [0, 155, 33]
|
||||
# 5: [0, 160, 233]
|
||||
# 6: [72, 0, 255]
|
||||
# 7: [52, 184, 178]
|
||||
# 8: [255, 255, 255]
|
||||
# 9: [117, 0, 0]
|
||||
# # V9:5_LC_4_foreigner # TODO
|
||||
# colors:
|
||||
# 0: [255, 91, 0]
|
||||
# 1: [255, 234, 0]
|
||||
# 2: [255, 0, 0]
|
||||
# 3: [0, 155, 33]
|
||||
# 4: [52, 184, 178]
|
||||
# 5: [167, 24, 233]
|
||||
# # V10:5_LC_5_stop_bleed # TODO
|
||||
# colors:
|
||||
# 0: [255, 91, 0]
|
||||
# 1: [181, 227, 14]
|
||||
# 2: [160, 15, 95]
|
||||
# 3: [85, 111, 181]
|
||||
# 4: [52, 184, 178]
|
||||
# 5: [255, 0, 0]
|
||||
# 6: [255, 255, 255]
|
||||
# 7: [255, 0, 255]
|
||||
# 8: [167, 24, 233]
|
||||
# 9: [255, 234, 0]
|
||||
# 10: [0, 160, 233]
|
||||
# 11: [113, 102, 140]
|
||||
# # V11:6_CWK_1_yws # TODO
|
||||
# colors:
|
||||
# 0: [255, 91, 0]
|
||||
# 1: [155, 132, 0]
|
||||
# 2: [255, 234, 0]
|
||||
# 3: [255, 0, 255]
|
||||
# 4: [52, 184, 178]
|
||||
# 5: [85, 111, 181]
|
||||
# # # V12:6_CWK_2_cfz # TODO
|
||||
# colors:
|
||||
# 0: [255, 91, 0]
|
||||
# 1: [155, 132, 0]
|
||||
# 2: [255, 234, 0]
|
||||
# 3: [52, 184, 178]
|
||||
# 4: [255, 0, 255]
|
||||
# 5: [85, 111, 181]
|
||||
# # V13: 5_TQY # TODO
|
||||
# colors:
|
||||
# 0: [225, 182, 193]
|
||||
# V14:5_Haze_ori、6_Haze_AOD_Net、7_Haze_Baidu、8_Haze_Baidu_Plus
|
||||
colors:
|
||||
0: [167, 24, 233]
|
||||
1: [52, 184, 178]
|
||||
2: [255, 255, 255]
|
||||
3: [0, 160, 233]
|
||||
48
Seg_All_In_One_YoloModel/yolo12-seg.yaml
Normal file
48
Seg_All_In_One_YoloModel/yolo12-seg.yaml
Normal file
@@ -0,0 +1,48 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
# YOLO12-seg instance segmentation model with P3/8 - P5/32 outputs
|
||||
# Model docs: https://docs.ultralytics.com/models/yolo12
|
||||
# Task docs: https://docs.ultralytics.com/tasks/segment
|
||||
|
||||
# Parameters
|
||||
nc: 80 # number of classes
|
||||
scales: # model compound scaling constants, i.e. 'model=yolo12n-seg.yaml' will call yolo12-seg.yaml with scale 'n'
|
||||
# [depth, width, max_channels]
|
||||
n: [0.50, 0.25, 1024] # summary: 294 layers, 2,855,056 parameters, 2,855,040 gradients, 10.6 GFLOPs
|
||||
s: [0.50, 0.50, 1024] # summary: 294 layers, 9,938,592 parameters, 9,938,576 gradients, 35.7 GFLOPs
|
||||
m: [0.50, 1.00, 512] # summary: 314 layers, 22,505,376 parameters, 22,505,360 gradients, 123.5 GFLOPs
|
||||
l: [1.00, 1.00, 512] # summary: 510 layers, 28,756,992 parameters, 28,756,976 gradients, 145.1 GFLOPs
|
||||
x: [1.00, 1.50, 512] # summary: 510 layers, 64,387,264 parameters, 64,387,248 gradients, 324.6 GFLOPs
|
||||
|
||||
# YOLO12n backbone
|
||||
backbone:
|
||||
# [from, repeats, module, args]
|
||||
- [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
|
||||
- [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
|
||||
- [-1, 2, C3k2, [256, False, 0.25]]
|
||||
- [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
|
||||
- [-1, 2, C3k2, [512, False, 0.25]]
|
||||
- [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
|
||||
- [-1, 4, A2C2f, [512, True, 4]]
|
||||
- [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
|
||||
- [-1, 4, A2C2f, [1024, True, 1]] # 8
|
||||
|
||||
# YOLO12n head
|
||||
head:
|
||||
- [-1, 1, nn.Upsample, [None, 2, "nearest"]]
|
||||
- [[-1, 6], 1, Concat, [1]] # cat backbone P4
|
||||
- [-1, 2, A2C2f, [512, False, -1]] # 11
|
||||
|
||||
- [-1, 1, nn.Upsample, [None, 2, "nearest"]]
|
||||
- [[-1, 4], 1, Concat, [1]] # cat backbone P3
|
||||
- [-1, 2, A2C2f, [256, False, -1]] # 14
|
||||
|
||||
- [-1, 1, Conv, [256, 3, 2]]
|
||||
- [[-1, 11], 1, Concat, [1]] # cat head P4
|
||||
- [-1, 2, A2C2f, [512, False, -1]] # 17
|
||||
|
||||
- [-1, 1, Conv, [512, 3, 2]]
|
||||
- [[-1, 8], 1, Concat, [1]] # cat head P5
|
||||
- [-1, 2, C3k2, [1024, True]] # 20 (P5/32-large)
|
||||
|
||||
- [[14, 17, 20], 1, Segment, [nc, 32, 256]] # Detect(P3, P4, P5)
|
||||
123
Seg_All_In_One_YoloModel/yolo_config.py
Normal file
123
Seg_All_In_One_YoloModel/yolo_config.py
Normal file
@@ -0,0 +1,123 @@
|
||||
import yaml, sys
|
||||
import torch
|
||||
from pathlib import Path
|
||||
|
||||
# --- 工具函数:动态设备选择 (Dynamic Device Selection) ---
|
||||
def get_auto_device(verbose = False):
|
||||
"""自动检测并返回最合适的设备"""
|
||||
if torch.cuda.is_available():
|
||||
gpu_count = torch.cuda.device_count()
|
||||
if gpu_count > 1:
|
||||
# 如果有多个GPU,使用所有GPU
|
||||
device_str = ",".join(str(i) for i in range(gpu_count))
|
||||
if verbose:
|
||||
print(f"检测到 {gpu_count} 个可用的GPU。将使用所有GPU: {device_str}")
|
||||
return device_str
|
||||
else:
|
||||
# 如果只有1个GPU
|
||||
if verbose:
|
||||
print("检测到 1 个可用的GPU。将使用 GPU: 0")
|
||||
return '0'
|
||||
else:
|
||||
# 如果没有可用的GPU,使用CPU
|
||||
if verbose:
|
||||
print("未检测到可用的GPU。将使用CPU。")
|
||||
return 'cpu'
|
||||
# --- 1. 核心目录设置 (Core Directories) ---
|
||||
HARDISK_DIR = Path.home() / "Desktop" / "Seg" / "Hardisk" # 硬盘根目录
|
||||
BASE_DIR = Path(__file__).parent.parent # 当前脚本所在目录上级
|
||||
DATASET_YAML_PATH = Path(__file__).parent / "dataset.yaml" # 数据集的 YAML 配置文件路径
|
||||
OUTPUTS_DIR = BASE_DIR / "DataSet_Public_outputs" # 输出结果目录 # 最终为:OUTPUTS_DIR / dataset_name
|
||||
PREDICT_ALL_BEST_MODELS_DIR = BASE_DIR / 'BestMode_Predict_Results_DataSet_Public' # 所有最佳模型存放目录
|
||||
|
||||
TEST_IMAGE_DIR = None # 初始化 TEST_IMAGE_DIR # 测试文件地址 dataset.path / TODO "val" / "test" TODO
|
||||
# try:
|
||||
# 3. 读取并解析 YAML 文件
|
||||
with open(DATASET_YAML_PATH, 'r', encoding='utf-8') as f:
|
||||
yaml_data = yaml.safe_load(f)
|
||||
# 4. 从解析后的数据中获取 'path' 的值
|
||||
relative_path_from_yaml = yaml_data.get('path')
|
||||
test_path_from_yaml = yaml_data.get('test')
|
||||
dataset_name = Path(relative_path_from_yaml).name
|
||||
if relative_path_from_yaml:
|
||||
# 5. 【核心步骤】构建绝对路径
|
||||
# relative_path_from_yaml 是相对于 .yaml 文件本身的路径。
|
||||
# 所以,我们需要获取 .yaml 文件所在的目录,然后与这个相对路径拼接。
|
||||
yaml_file_directory = DATASET_YAML_PATH.parent
|
||||
TEST_IMAGE_DIR = (DATASET_YAML_PATH.parent / relative_path_from_yaml / test_path_from_yaml).resolve() # 这里val 或 test在dataset.yaml中定义
|
||||
# 使用 .exists() 方法来检查路径是否存在
|
||||
if not TEST_IMAGE_DIR.exists():
|
||||
# 如果路径不存在,则执行这里的代码
|
||||
print(f"警告: 测试图片目录不存在: {TEST_IMAGE_DIR}")
|
||||
sys.exit(1)
|
||||
# 6. 获取OUTPUTS_DIR
|
||||
if dataset_name and dataset_name not in {'.', '..'}:
|
||||
dataset_name_ = dataset_name + "-Yolo"
|
||||
# 设定输出路径
|
||||
OUTPUTS_DIR = OUTPUTS_DIR / dataset_name_
|
||||
# ../BestMode_Predict_Results_DataSet_Public/"dataset_name+"-Yolo"" # 需要用到 Tool_Yolo_Copy_Best_Model.sh
|
||||
PREDICT_BEST_MODEL_DIR = PREDICT_ALL_BEST_MODELS_DIR / dataset_name_ # 最优模型保存位置 PREDICT_ALL_BEST_MODELS_DIR / dataset_name_ / weights / best.pt
|
||||
else:
|
||||
print(f"警告: 提取的dataset_name: '{dataset_name}' 无效(为空、'.' 或 '..')。")
|
||||
sys.exit(1)
|
||||
else:
|
||||
print(f"警告: 在 '{DATASET_YAML_PATH}' 文件中没有找到 'path' 键。")
|
||||
sys.exit(1)
|
||||
# except FileNotFoundError:
|
||||
# print(f"错误: YAML 配置文件未找到: '{DATASET_YAML_PATH}'")
|
||||
# except Exception as e:
|
||||
# print(f"读取或解析 YAML 文件时发生错误: {e}")
|
||||
|
||||
|
||||
# --- 修改:模型、图像大小和批处理大小的集成配置 ---
|
||||
# 将所有模型的配置(权重、图像大小、批处理大小)集中管理
|
||||
MODEL_CONFIGS = {
|
||||
# YOLOv8
|
||||
'YOLOv8n-seg': {'weights': 'yolov8n-seg.pt', 'image_size': 640, 'batch_size': 16},
|
||||
'YOLOv8s-seg': {'weights': 'yolov8s-seg.pt', 'image_size': 640, 'batch_size': 16},
|
||||
'YOLOv8m-seg': {'weights': 'yolov8m-seg.pt', 'image_size': 640, 'batch_size': 16},
|
||||
'YOLOv8l-seg': {'weights': 'yolov8l-seg.pt', 'image_size': 640, 'batch_size': 16}, # 640,16,12.5GB
|
||||
'YOLOv8x-seg': {'weights': 'yolov8x-seg.pt', 'image_size': 640, 'batch_size': 16}, # 640,16,15.5GB # 示例:X模型使用1280分辨率
|
||||
# YOLOv9
|
||||
'YOLOv9c-seg': {'weights': 'yolov9c-seg.pt', 'image_size': 640, 'batch_size': 16}, # 640,16,13GB
|
||||
'YOLOv9e-seg': {'weights': 'yolov9e-seg.pt', 'image_size': 640, 'batch_size': 8}, # 640,16,内存超了 # 示例:E模型(最大)使用1280分辨率和更小的batch
|
||||
# YOLOv11 (假设)
|
||||
'YOLO11n-seg': {'weights': 'yolo11n-seg.pt', 'image_size': 640, 'batch_size': 16},
|
||||
'YOLO11s-seg': {'weights': 'yolo11s-seg.pt', 'image_size': 640, 'batch_size': 16},
|
||||
'YOLO11m-seg': {'weights': 'yolo11m-seg.pt', 'image_size': 640, 'batch_size': 16},
|
||||
'YOLO11l-seg': {'weights': 'yolo11l-seg.pt', 'image_size': 640, 'batch_size': 16}, # 640,16,12.5GB
|
||||
'YOLO11x-seg': {'weights': 'yolo11x-seg.pt', 'image_size': 640, 'batch_size': 16}, # 640,16,19.5GB
|
||||
# YOLOv12 (假设)
|
||||
'YOLO12-seg': {'weights': str(Path(__file__).parent / 'yolo12-seg.yaml'), 'image_size': 640, 'batch_size': 16}, # 640,16,3GB
|
||||
}
|
||||
|
||||
# --- 4. 训练超参数 (Training Hyperparameters) ---
|
||||
DEVICE = get_auto_device(verbose=False) # 调用函数来设置设备
|
||||
EPOCHS = 300 # 训练轮次 # TODO
|
||||
PATIENCE = 150 # 提前停止训练的轮数
|
||||
SAVE_PERIOD = 10 # 每隔多少轮保存一次模型 # TODO
|
||||
# BATCH_SIZE 已在上面根据模型自动设置
|
||||
LEARNING_RATE = 0.01 # 初始学习率('SGD:0.01', 'Adam:1e-4', 'AdamW:1e-4', 'auto:0.01')
|
||||
OPTIMIZER = 'auto' # 优化器 (TODO 'SGD', 'Adam', 'AdamW', 'auto' TODO)
|
||||
WORKERS = 4 # 数据加载的工作线程数
|
||||
|
||||
# --- 5. 预测设置 (Prediction Settings) ---
|
||||
SHOW_LABELS = True # 是否在预测结果上显示类别标签
|
||||
SHOW_CONF = True # 是否在预测结果上显示置信度和边界框
|
||||
SAVE_PREDICTIONS = True # 是否保存预测结果图像
|
||||
|
||||
# ==============================================================================
|
||||
# --- 主执行块:只在直接运行时才打印配置信息 ---
|
||||
# ==============================================================================
|
||||
def show_config_summary():
|
||||
"""打印所有配置摘要信息"""
|
||||
print(f"设定输出路径为: '{str(OUTPUTS_DIR)}'")
|
||||
print(f"最佳模型路径为: '{str(PREDICT_BEST_MODEL_DIR)}'")
|
||||
# 再次调用 get_auto_device 并设置 verbose=True 来打印设备信息
|
||||
get_auto_device(verbose=True)
|
||||
|
||||
|
||||
# 当这个脚本被直接执行时(python yolo_config.py),__name__ 的值是 '__main__'
|
||||
# 当它被其他脚本import时,__name__ 的值是 'yolo_config'
|
||||
if __name__ == '__main__':
|
||||
show_config_summary()
|
||||
167
Seg_All_In_One_YoloModel/yolo_predict.sh
Normal file
167
Seg_All_In_One_YoloModel/yolo_predict.sh
Normal file
@@ -0,0 +1,167 @@
|
||||
#!/bin/bash
|
||||
|
||||
# =================================================================
|
||||
# YOLO 模型批量并行预测脚本
|
||||
# =================================================================
|
||||
# - 此脚本会自动为每个模型架构查找其训练好的 'best.pt'
|
||||
# - 使用 'echo "1" |' 来自动选择找到的第一个训练版本
|
||||
# - 在不同的指定GPU上并行执行预测任务
|
||||
# =================================================================
|
||||
|
||||
# --- 1. Conda 环境设置 ---
|
||||
CONDA_BASE_PATH="/home/wkmgc/miniconda3" # <--- 在这里修改为您自己的 Conda 路径
|
||||
CONDA_ENV_NAME="${SEG_CONDA_ENV:-seg_smp}" # 可用 SEG_CONDA_ENV=SMP bash yolo_predict.sh 临时覆盖
|
||||
pt_name="best.pt" # <--- 在这里修改为您想使用的权重文件名,例如 "best.pt" 或 "epoch100.pt"
|
||||
conf_threshold=0.2 # <--- [新增] 默认的置信度阈值
|
||||
heatmap_method="None" # <--- [!! 新增 !!] 默认不运行热度图
|
||||
|
||||
# 循环解析参数
|
||||
while [[ $# -gt 0 ]]; do
|
||||
key="$1"
|
||||
case $key in
|
||||
--pt_name)
|
||||
if [ -n "$2" ] && [[ "$2" != -* ]]; then
|
||||
pt_name="$2"
|
||||
shift # 移过 --pt_name
|
||||
shift # 移过它的值
|
||||
else
|
||||
echo "错误: --pt_name 参数需要一个值。" >&2
|
||||
exit 1
|
||||
fi
|
||||
;;
|
||||
--conf)
|
||||
if [ -n "$2" ] && [[ "$2" != -* ]]; then
|
||||
conf_threshold="$2"
|
||||
shift # 移过 --conf
|
||||
shift # 移过它的值
|
||||
else
|
||||
echo "错误: --conf 参数需要一个值。" >&2
|
||||
exit 1
|
||||
fi
|
||||
;;
|
||||
--heatmap_method)
|
||||
if [ -n "$2" ] && [[ "$2" != -* ]]; then
|
||||
heatmap_method="$2"
|
||||
shift # 移过 --heatmap_method
|
||||
shift # 移过它的值
|
||||
else
|
||||
echo "错误: --heatmap_method 参数需要一个值 (例如 'GradCAM' 或 'All')。" >&2
|
||||
exit 1
|
||||
fi
|
||||
;;
|
||||
*)
|
||||
# 移过未知参数,不报错
|
||||
shift
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
# 初始化并激活 Conda 环境
|
||||
if [ -f "${CONDA_BASE_PATH}/etc/profile.d/conda.sh" ]; then
|
||||
source "${CONDA_BASE_PATH}/etc/profile.d/conda.sh"
|
||||
conda activate "${CONDA_ENV_NAME}"
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "错误: 激活 Conda 环境 '${CONDA_ENV_NAME}' 失败!"
|
||||
exit 1
|
||||
fi
|
||||
echo "Conda 环境 '${CONDA_ENV_NAME}' 已成功激活。"
|
||||
else
|
||||
echo "错误: 找不到 conda.sh 脚本。请检查您的 CONDA_BASE_PATH 设置是否正确。"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# --- 2. 模型与 GPU 配置 ---
|
||||
# 此处的分组应与 train.sh 保持一致,以确保能正确找到模型并分配资源
|
||||
GPUS_GROUP_0="0"
|
||||
GPUS_GROUP_1="1"
|
||||
GPUS_GROUP_2="2"
|
||||
GPUS_GROUP_3="3"
|
||||
# TODO #
|
||||
GPUS_GROUP_4="0"
|
||||
GPUS_GROUP_5="1"
|
||||
GPUS_GROUP_6="2"
|
||||
GPUS_GROUP_7="3"
|
||||
|
||||
# 从 yolo_config.py/train.sh 中选择的模型列表
|
||||
GROUP_0_MODELS=("YOLO11l-seg")
|
||||
GROUP_1_MODELS=("YOLOv8n-seg" "YOLOv8m-seg")
|
||||
GROUP_2_MODELS=("YOLO11n-seg" "YOLO11s-seg" "YOLO11m-seg")
|
||||
GROUP_3_MODELS=("YOLOv9e-seg")
|
||||
GROUP_4_MODELS=("YOLO11x-seg")
|
||||
GROUP_5_MODELS=("YOLOv9c-seg" "YOLOv8s-seg")
|
||||
GROUP_6_MODELS=("YOLOv8l-seg" "YOLO12-seg")
|
||||
GROUP_7_MODELS=("YOLOv8x-seg")
|
||||
|
||||
# 1. 从 config.py 中读取 PREDICT_BEST_MODEL_DIR 的值
|
||||
PREDICT_BEST_MODEL_DIR=$(python -c "from yolo_config import PREDICT_BEST_MODEL_DIR; print(PREDICT_BEST_MODEL_DIR)")
|
||||
# 检查是否成功获取了 PREDICT_BEST_MODEL_DIR
|
||||
if [ -z "$PREDICT_BEST_MODEL_DIR" ] || [ ! -e "$PREDICT_BEST_MODEL_DIR" ]; then
|
||||
echo "PREDICT_BEST_MODEL_DIR: $PREDICT_BEST_MODEL_DIR"
|
||||
echo "Error: Could not read PREDICT_BEST_MODEL_DIR from yolo_config.py. Exiting."
|
||||
echo "Error 2: Or the directory specified by PREDICT_BEST_MODEL_DIR does not exist. Please create it first."
|
||||
exit 1
|
||||
fi
|
||||
# 2. 定义带有时间戳的日志目录名
|
||||
LOG_DIR_NAME="yolo_predict_logs_parallel_$(date +%Y-%m-%d_%H-%M-%S)"
|
||||
# 3. 拼接成最终的完整路径
|
||||
LOG_DIR="$PREDICT_BEST_MODEL_DIR/$LOG_DIR_NAME"
|
||||
mkdir -p "${LOG_DIR}"
|
||||
echo "所有模型的预测日志将保存在 ./${LOG_DIR}/ 目录中。"
|
||||
echo "----------------------------------------------------"
|
||||
|
||||
|
||||
# --- 3. 预测执行函数 ---
|
||||
# 定义一个函数来启动一组预测,以避免代码重复
|
||||
start_prediction_group() {
|
||||
# 使用 nameref (引用) 来传递数组
|
||||
local -n models=$1
|
||||
local gpus=$2
|
||||
local group_name=$3
|
||||
|
||||
echo ">>> 准备启动 ${group_name} 的预测任务 (后台运行)..."
|
||||
# 遍历指定组中的所有模型
|
||||
for model_key in "${models[@]}"; do
|
||||
if [ "${heatmap_method}" == "None" ]; then
|
||||
# --- 模式 1: 运行标准预测 (原有逻辑) ---
|
||||
echo " -> 正在后台启动 [标准预测]: ${model_key} on GPUs: ${gpus}"
|
||||
# [注意] 我为您添加了 --conf 参数,您原有的脚本 没有传递它
|
||||
echo "1" | CUDA_VISIBLE_DEVICES=${gpus} python yolo_predict_V2.py --model "${model_key}" --pt_name "${pt_name}" --conf "${conf_threshold}" > "${LOG_DIR}/${model_key}_predict.log" 2>&1 &
|
||||
echo " - 模型 ${model_key} 的预测已在后台启动。日志文件: ${LOG_DIR}/${model_key}_predict.log"
|
||||
else
|
||||
# --- 模式 2: 运行热度图可视化 ---
|
||||
echo " -> 正在后台启动 [热度图可视化]: ${model_key} on GPUs: ${gpus} (Method: ${heatmap_method})"
|
||||
# [注意] 我们使用 yolo_predict_visualize_nn.py 并传递新参数
|
||||
echo "1" | CUDA_VISIBLE_DEVICES=${gpus} python yolo_predict_visualize_nn.py --model "${model_key}" --target_layers "default" --cam_method "${heatmap_method}" --pt_name "${pt_name}" > "${LOG_DIR}/${model_key}_heatmap.log" 2>&1 &
|
||||
echo " - 模型 ${model_key} 的热度图已在后台启动。日志文件: ${LOG_DIR}/${model_key}_heatmap.log"
|
||||
fi
|
||||
echo " - 等待 5 秒,确保 GPU 资源稳定分配..."
|
||||
sleep 5
|
||||
done
|
||||
echo ">>> ${group_name} 的所有模型均已启动。"
|
||||
echo "----------------------------------------------------"
|
||||
}
|
||||
|
||||
# --- 4. 依次启动所有预测任务 ---
|
||||
# 脚本将快速地按顺序启动每一组任务到后台
|
||||
start_prediction_group GROUP_0_MODELS "${GPUS_GROUP_0}" "第零组"
|
||||
start_prediction_group GROUP_1_MODELS "${GPUS_GROUP_1}" "第一组"
|
||||
start_prediction_group GROUP_2_MODELS "${GPUS_GROUP_2}" "第二组"
|
||||
start_prediction_group GROUP_3_MODELS "${GPUS_GROUP_3}" "第三组"
|
||||
start_prediction_group GROUP_4_MODELS "${GPUS_GROUP_4}" "第四组"
|
||||
start_prediction_group GROUP_5_MODELS "${GPUS_GROUP_5}" "第五组"
|
||||
start_prediction_group GROUP_6_MODELS "${GPUS_GROUP_6}" "第六组"
|
||||
start_prediction_group GROUP_7_MODELS "${GPUS_GROUP_7}" "第七组"
|
||||
|
||||
|
||||
# --- 5. 等待所有后台任务完成 ---
|
||||
echo ""
|
||||
echo "--- 所有模型均已在后台启动。现在等待所有预测任务完成... ---"
|
||||
# 'wait' 命令会暂停脚本,直到所有由此脚本启动的后台子进程全部执行完毕
|
||||
wait
|
||||
echo "--- 所有后台预测任务已全部完成! ---"
|
||||
|
||||
|
||||
# --- 6. 退出脚本 ---
|
||||
echo "预测流程结束。"
|
||||
conda deactivate
|
||||
echo "已取消激活 Conda 环境。"
|
||||
238
Seg_All_In_One_YoloModel/yolo_predict_V1.py
Normal file
238
Seg_All_In_One_YoloModel/yolo_predict_V1.py
Normal file
@@ -0,0 +1,238 @@
|
||||
import logging, sys, argparse, yaml
|
||||
from pathlib import Path
|
||||
from ultralytics import YOLO
|
||||
import yolo_config as config
|
||||
config.show_config_summary() # 显示配置信息
|
||||
from typing import List
|
||||
import cv2
|
||||
import numpy as np
|
||||
# +----------------------------------------------------------------------------+
|
||||
# ../BestMode_Predict_Results_DataSet_Public/
|
||||
# └── YOLOv8n-seg_2025-09-20_10-00-00/ # The folder for the trained model you selected
|
||||
# ├── weights/
|
||||
# │ └── best.pt
|
||||
# │
|
||||
# ├── prediction/ # Visualized results (image with colored mask overlay)
|
||||
# │ ├── image1.jpg
|
||||
# │ └── ...
|
||||
# │
|
||||
# └── prediction_masks_combined/ # <-- Your new unique folder for combined masks
|
||||
# ├── image1.png # Grayscale mask, original size, pixel values are class IDs
|
||||
# ├── image2.png # Grayscale mask
|
||||
# └── ...
|
||||
# +----------------------------------------------------------------------------+
|
||||
|
||||
# --- 日志设置 ---
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s [%(levelname)s] %(message)s",
|
||||
handlers=[logging.StreamHandler()]
|
||||
)
|
||||
|
||||
# --- 使用带 model_key 过滤的版本 ---
|
||||
def find_trained_models(outputs_dir: Path, model_key: str) -> List[str]:
|
||||
"""
|
||||
扫描输出目录,查找特定基础模型的所有有效的、已完成的训练项目。
|
||||
"""
|
||||
trained_models = []
|
||||
if not outputs_dir.is_dir():
|
||||
logging.warning(f"输出目录不存在: {outputs_dir}")
|
||||
return []
|
||||
|
||||
for project_folder in outputs_dir.iterdir():
|
||||
if project_folder.is_dir() and project_folder.name.startswith(model_key + '_'):
|
||||
if (project_folder / 'weights' / 'best.pt').exists():
|
||||
trained_models.append(project_folder.name)
|
||||
|
||||
trained_models.sort()
|
||||
return trained_models
|
||||
|
||||
# 老版predict函数,只有默认输出
|
||||
# def predict_old(model_path: str, source: str, project_name: str):
|
||||
# """
|
||||
# 使用指定的模型对图像源进行预测。
|
||||
# """
|
||||
# if not Path(source).exists():
|
||||
# logging.error(f"错误:预测源路径不存在: {source}")
|
||||
# return
|
||||
|
||||
# try:
|
||||
# logging.info(f"正在加载模型: {model_path}")
|
||||
# model = YOLO(model_path)
|
||||
# logging.info("模型加载成功。")
|
||||
|
||||
# logging.info(f"正在对源进行预测: {source}")
|
||||
# results = model.predict(
|
||||
# source=source,
|
||||
# save=config.SAVE_PREDICTIONS,
|
||||
# show_labels=config.SHOW_LABELS,
|
||||
# show_conf=config.SHOW_CONF,
|
||||
# project=project_name,
|
||||
# name="prediction",
|
||||
# exist_ok=True,
|
||||
# retina_masks=True, # <<< ADDED: 保存与原图一样大的分割图
|
||||
# )
|
||||
|
||||
# final_save_dir = Path(project_name) / "prediction"
|
||||
# logging.info(f"预测完成。结果保存在: {final_save_dir}")
|
||||
|
||||
# except Exception as e:
|
||||
# logging.error(f"预测过程中发生错误: {e}")
|
||||
|
||||
def predict(model_path: str, source: str, project_name: str):
|
||||
"""
|
||||
使用指定模型进行预测,并将所有类别的掩码合并到一个单通道灰度图中。
|
||||
在该图中,每个像素的值对应其类别ID (0=背景, 1=类别1, 2=类别2, ...)。
|
||||
"""
|
||||
# --- 1. 加载类别名称 (用于日志记录) ---
|
||||
try:
|
||||
with open(config.DATASET_YAML_PATH, 'r', encoding='utf-8') as f:
|
||||
class_names = yaml.safe_load(f)['names']
|
||||
logging.info(f"成功从 {config.DATASET_YAML_PATH} 加载类别名称: {class_names}")
|
||||
except Exception as e:
|
||||
logging.error(f"加载或解析 dataset.yaml 失败: {e}")
|
||||
return
|
||||
|
||||
if not Path(source).exists():
|
||||
logging.error(f"错误:预测源路径不存在: {source}")
|
||||
return
|
||||
|
||||
try:
|
||||
# --- 2. 加载模型并执行预测 ---
|
||||
logging.info(f"正在加载模型: {model_path}")
|
||||
model = YOLO(model_path)
|
||||
logging.info("模型加载成功。")
|
||||
|
||||
logging.info(f"正在对源进行预测: {source}")
|
||||
results = model.predict(
|
||||
source=source,
|
||||
stream=True, # 使用流式处理以节省内存,但是会减速
|
||||
save=config.SAVE_PREDICTIONS,
|
||||
show_labels=config.SHOW_LABELS,
|
||||
show_conf=config.SHOW_CONF,
|
||||
project=project_name,
|
||||
name="prediction",
|
||||
exist_ok=True,
|
||||
retina_masks=True,
|
||||
conf=0.1
|
||||
)
|
||||
|
||||
# --- 3. 创建新的输出目录并处理结果 ---
|
||||
mask_save_dir = Path(project_name) / "predicted_raw_masks"
|
||||
mask_save_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
logging.info(f"预测完成。正在处理结果并保存组合掩码图于: {mask_save_dir}")
|
||||
|
||||
for result in results:
|
||||
original_image_path = Path(result.path)
|
||||
logging.info(f"正在处理图片: {original_image_path.name}")
|
||||
|
||||
h, w = result.orig_shape
|
||||
# 创建一个空白画布,用于存储所有类别的组合掩码。0 代表背景。
|
||||
combined_mask = np.zeros((h, w), dtype=np.uint8)
|
||||
|
||||
if result.masks is None:
|
||||
logging.warning(f" -> 在图片 {original_image_path.name} 中未检测到任何物体,将创建一张全黑的掩码图。")
|
||||
# 如果没有检测到物体,直接保存空白掩码图并继续处理下一张图片
|
||||
mask_filename = mask_save_dir / f"{original_image_path.stem}.png"
|
||||
cv2.imwrite(str(mask_filename), combined_mask)
|
||||
continue
|
||||
|
||||
# --- 核心:创建组合掩码 ---
|
||||
masks_data = result.masks.data
|
||||
class_ids = result.boxes.cls.int().cpu().numpy()
|
||||
|
||||
if len(masks_data) != len(class_ids):
|
||||
logging.error(f" -> 掩码和类别ID数量不匹配,跳过图片 {original_image_path.name}")
|
||||
continue
|
||||
|
||||
# YOLO结果通常按置信度从高到低排序。
|
||||
# 我们反向迭代(从低置信度到高置信度),以确保在掩码重叠区域,
|
||||
# 置信度更高的物体类别能够覆盖置信度较低的。
|
||||
for i in reversed(range(len(masks_data))):
|
||||
instance_mask = masks_data[i].cpu().numpy().astype(bool)
|
||||
class_id = class_ids[i]
|
||||
|
||||
# 将掩码区域的像素值设置为其对应的类别ID。
|
||||
# 我们假设类别ID 0 是背景,所以即使模型预测了ID为0的物体,它也会被视为背景。
|
||||
if class_id != 0:
|
||||
combined_mask[instance_mask] = class_id
|
||||
|
||||
# --- 保存最终的组合掩码图 ---
|
||||
mask_filename = mask_save_dir / f"{original_image_path.stem}.png"
|
||||
cv2.imwrite(str(mask_filename), combined_mask)
|
||||
|
||||
logging.info(f"所有组合掩码图已成功保存。")
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"预测或手动保存过程中发生错误: {e}", exc_info=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 1. 创建解析器,现在只需要 --model 和 --source
|
||||
parser = argparse.ArgumentParser(description="使用已训练的YOLO模型进行预测。")
|
||||
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
required=True,
|
||||
choices=list(config.MODEL_CONFIGS.keys()),
|
||||
help="选择一个基础模型类型,以筛选其训练历史。"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--source",
|
||||
type=str,
|
||||
default=str(config.TEST_IMAGE_DIR),
|
||||
help="图片或图片文件夹的路径。"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# 2. 根据 --model 参数查找对应的训练历史
|
||||
available_runs = find_trained_models(config.PREDICT_BEST_MODEL_DIR, args.model)
|
||||
|
||||
run_to_use = None
|
||||
|
||||
# 3. 根据找到的结果数量,决定下一步操作
|
||||
if not available_runs:
|
||||
logging.error(f"错误:在 {config.PREDICT_BEST_MODEL_DIR} 中没有找到任何关于模型 '{args.model}' 的有效训练记录。")
|
||||
sys.exit(1)
|
||||
|
||||
elif len(available_runs) == 1:
|
||||
# 如果只有一个结果,自动选择
|
||||
run_to_use = available_runs[0]
|
||||
logging.info(f"只找到一个训练版本,已自动选择: {run_to_use}")
|
||||
|
||||
else:
|
||||
# 如果有多个结果,进入交互式选择模式
|
||||
print(f"\n为模型 '{args.model}' 找到多个训练版本:")
|
||||
for i, run_name in enumerate(available_runs, 1):
|
||||
print(f" [{i}] {run_name}")
|
||||
|
||||
while True:
|
||||
try:
|
||||
choice = input(f"请输入您想使用的版本序号 (1-{len(available_runs)}): ")
|
||||
choice_index = int(choice)
|
||||
if 1 <= choice_index <= len(available_runs):
|
||||
run_to_use = available_runs[choice_index - 1]
|
||||
break
|
||||
else:
|
||||
print("错误:输入无效,请输入列表中的序号。")
|
||||
except ValueError:
|
||||
print("错误:请输入一个数字。")
|
||||
except (KeyboardInterrupt, EOFError):
|
||||
print("\n操作已取消。")
|
||||
sys.exit(0)
|
||||
|
||||
# 4. 使用最终确定的 run_to_use 来执行预测
|
||||
if run_to_use:
|
||||
model_to_use = config.PREDICT_BEST_MODEL_DIR / run_to_use / 'weights' / 'best.pt'
|
||||
project_to_save_in = str(config.PREDICT_BEST_MODEL_DIR / run_to_use)
|
||||
|
||||
if not model_to_use.exists():
|
||||
logging.error(f"严重错误:找不到权重文件 {model_to_use}。")
|
||||
else:
|
||||
predict(
|
||||
model_path=str(model_to_use),
|
||||
source=args.source,
|
||||
project_name=project_to_save_in
|
||||
)
|
||||
538
Seg_All_In_One_YoloModel/yolo_predict_V2.py
Normal file
538
Seg_All_In_One_YoloModel/yolo_predict_V2.py
Normal file
@@ -0,0 +1,538 @@
|
||||
import logging, sys, argparse, yaml
|
||||
from pathlib import Path
|
||||
from ultralytics import YOLO
|
||||
import yolo_config as config
|
||||
config.show_config_summary() # 显示配置信息
|
||||
from typing import List
|
||||
import cv2
|
||||
import numpy as np
|
||||
# +----------------------------------------------------------------------------+
|
||||
# ../BestMode_Predict_Results_DataSet_Public/
|
||||
# └── YOLOv8n-seg_2025-09-20_10-00-00/ # The folder for the trained model you selected
|
||||
# ├── weights/
|
||||
# │ └── best.pt
|
||||
# │
|
||||
# ├── prediction/ # Visualized results (image with colored mask overlay)
|
||||
# │ ├── image1.jpg
|
||||
# │ └── ...
|
||||
# │
|
||||
# └── prediction_masks_combined/ # <-- Your new unique folder for combined masks
|
||||
# ├── image1.png # Grayscale mask, original size, pixel values are class IDs
|
||||
# ├── image2.png # Grayscale mask
|
||||
# └── ...
|
||||
# +----------------------------------------------------------------------------+
|
||||
|
||||
# --- 日志设置 ---
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s [%(levelname)s] %(message)s",
|
||||
handlers=[logging.StreamHandler()]
|
||||
)
|
||||
|
||||
# --- 使用带 model_key 过滤的版本 ---
|
||||
def find_trained_models(outputs_dir: Path, model_key: str, pt_name: str) -> List[str]:
|
||||
"""
|
||||
扫描输出目录,查找特定基础模型的所有有效的、已完成的训练项目。
|
||||
"""
|
||||
trained_models = []
|
||||
if not outputs_dir.is_dir():
|
||||
logging.warning(f"输出目录不存在: {outputs_dir}")
|
||||
return []
|
||||
|
||||
for project_folder in outputs_dir.iterdir():
|
||||
if project_folder.is_dir() and project_folder.name.startswith(model_key + '_'):
|
||||
if (project_folder / 'weights' / pt_name).exists():
|
||||
trained_models.append(project_folder.name)
|
||||
|
||||
trained_models.sort()
|
||||
return trained_models
|
||||
|
||||
# 老版predict函数,只有默认输出
|
||||
# def predict_old(model_path: str, source: str, project_name: str):
|
||||
# """
|
||||
# 使用指定模型进行预测。
|
||||
# 1. 保存原始灰度掩码 (predicted_raw_masks),像素值为类别ID。
|
||||
# 2. 保存彩色可视化掩码 (predicted_color_masks),根据 color_map.yaml 上色。
|
||||
# """
|
||||
|
||||
# # --- 1. 加载类别名称 (用于日志记录和确定类别数量) ---
|
||||
# try:
|
||||
# with open(config.DATASET_YAML_PATH, 'r', encoding='utf-8') as f:
|
||||
# class_names = yaml.safe_load(f)['names']
|
||||
# logging.info(f"成功从 {config.DATASET_YAML_PATH} 加载类别名称: {class_names}")
|
||||
# num_classes = len(class_names)
|
||||
# logging.info(f"检测到 {num_classes} 个类别 (ID 0 到 {num_classes-1})。")
|
||||
# except Exception as e:
|
||||
# logging.error(f"加载或解析 dataset.yaml 失败: {e}")
|
||||
# return
|
||||
|
||||
# # --- [新增] 2. 加载颜色查找表 (LUT) ---
|
||||
# color_lut = np.zeros((num_classes, 3), dtype=np.uint8) # 默认全黑
|
||||
# color_map_path = 'color_map.yaml'
|
||||
|
||||
# try:
|
||||
# with open(color_map_path, 'r', encoding='utf-8') as f:
|
||||
# color_data = yaml.safe_load(f)['colors']
|
||||
# logging.info(f"正在从 {color_map_path} 加载颜色...")
|
||||
|
||||
# for class_id, bgr_value in color_data.items():
|
||||
# if 0 <= class_id < num_classes:
|
||||
# color_lut[class_id] = bgr_value
|
||||
# else:
|
||||
# logging.warning(f"color_map.yaml 中的类别ID {class_id} 超出范围 (0-{num_classes-1}),已忽略。")
|
||||
# logging.info("颜色查找表 (LUT) 创建成功。")
|
||||
|
||||
# except FileNotFoundError:
|
||||
# logging.error(f"错误:未找到 {color_map_path}。")
|
||||
# logging.warning("将使用随机颜色作为备用方案(背景除外)。")
|
||||
# # 创建随机颜色作为备用
|
||||
# for i in range(1, num_classes): # 保持 0 (背景) 为黑色 [0,0,0]
|
||||
# color_lut[i] = np.random.randint(0, 255, 3, dtype=np.uint8)
|
||||
# except Exception as e:
|
||||
# logging.error(f"加载或解析 {color_map_path} 失败: {e}", exc_info=True)
|
||||
# return
|
||||
|
||||
# if not Path(source).exists():
|
||||
# logging.error(f"错误:预测源路径不存在: {source}")
|
||||
# return
|
||||
|
||||
# try:
|
||||
# # --- 3. 加载模型并执行预测 ---
|
||||
# logging.info(f"正在加载模型: {model_path}")
|
||||
# model = YOLO(model_path)
|
||||
# logging.info("模型加载成功。")
|
||||
|
||||
# logging.info(f"正在对源进行预测: {source}")
|
||||
# results = model.predict(
|
||||
# source=source,
|
||||
# stream=True,
|
||||
# save=config.SAVE_PREDICTIONS,
|
||||
# show_labels=config.SHOW_LABELS,
|
||||
# show_conf=config.SHOW_CONF,
|
||||
# project=project_name,
|
||||
# name="prediction",
|
||||
# exist_ok=True,
|
||||
# retina_masks=True,
|
||||
# conf=0.1
|
||||
# )
|
||||
|
||||
# # --- 4. 创建新的输出目录并处理结果 ---
|
||||
# # [修改] 定义两个保存目录
|
||||
# raw_mask_save_dir = Path(project_name) / "predicted_raw_masks"
|
||||
# color_mask_save_dir = Path(project_name) / "predicted_color_masks" # [新增] 彩色图目录
|
||||
|
||||
# raw_mask_save_dir.mkdir(parents=True, exist_ok=True)
|
||||
# color_mask_save_dir.mkdir(parents=True, exist_ok=True) # [新增]
|
||||
|
||||
# # [修改] 更新日志信息
|
||||
# logging.info(f"预测完成。正在处理结果...")
|
||||
# logging.info(f" -> 原始灰度掩码将保存于: {raw_mask_save_dir}")
|
||||
# logging.info(f" -> 彩色可视化掩码将保存于: {color_mask_save_dir}")
|
||||
|
||||
# for result in results:
|
||||
# original_image_path = Path(result.path)
|
||||
# logging.info(f"正在处理图片: {original_image_path.name}")
|
||||
|
||||
# h, w = result.orig_shape
|
||||
# # 创建一个空白画布,用于存储所有类别的组合掩码。0 代表背景。
|
||||
# combined_mask = np.zeros((h, w), dtype=np.uint8)
|
||||
|
||||
# if result.masks is None:
|
||||
# logging.warning(f" -> 在图片 {original_image_path.name} 中未检测到任何物体,将创建一张全黑的掩码图。")
|
||||
|
||||
# # --- 保存原始灰度掩码 (全黑) ---
|
||||
# raw_mask_filename = raw_mask_save_dir / f"{original_image_path.stem}.png"
|
||||
# cv2.imwrite(str(raw_mask_filename), combined_mask)
|
||||
|
||||
# # --- [新增] 保存彩色掩码 (全黑) ---
|
||||
# # 即使是全黑的,也应用LUT,结果还是全黑
|
||||
# colorized_mask = color_lut[combined_mask] # (H, W, 3)
|
||||
# color_mask_filename = color_mask_save_dir / f"{original_image_path.stem}.png"
|
||||
# cv2.imwrite(str(color_mask_filename), colorized_mask)
|
||||
|
||||
# continue # 继续处理下一张图片
|
||||
|
||||
# # --- 核心:创建组合掩码 ---
|
||||
# masks_data = result.masks.data
|
||||
# class_ids = result.boxes.cls.int().cpu().numpy()
|
||||
|
||||
# if len(masks_data) != len(class_ids):
|
||||
# logging.error(f" -> 掩码和类别ID数量不匹配,跳过图片 {original_image_path.name}")
|
||||
# continue
|
||||
|
||||
# # (核心逻辑保持不变)
|
||||
# # 反向迭代,确保高置信度覆盖低置信度
|
||||
# for i in reversed(range(len(masks_data))):
|
||||
# instance_mask = masks_data[i].cpu().numpy().astype(bool)
|
||||
# class_id = class_ids[i]
|
||||
|
||||
# if class_id != 0: # 假设 0 是背景
|
||||
# combined_mask[instance_mask] = class_id
|
||||
|
||||
# # --- 保存 1: 最终的组合掩码图 (灰度) ---
|
||||
# raw_mask_filename = raw_mask_save_dir / f"{original_image_path.stem}.png"
|
||||
# cv2.imwrite(str(raw_mask_filename), combined_mask)
|
||||
|
||||
# # --- [新增] 保存 2: 最终的彩色掩码图 (RGB) ---
|
||||
# # 这是最高效的上色方法:
|
||||
# # color_lut (14, 3)
|
||||
# # combined_mask (H, W),值 0-13
|
||||
# # NumPy 会自动将 (H, W) 中的每个值作为索引去 color_lut 中取 (B,G,R) 元组
|
||||
# colorized_mask = color_lut[combined_mask] # 结果维度 (H, W, 3)
|
||||
|
||||
# color_mask_filename = color_mask_save_dir / f"{original_image_path.stem}.png"
|
||||
# cv2.imwrite(str(color_mask_filename), colorized_mask)
|
||||
|
||||
# logging.info(f"所有掩码图已成功保存。")
|
||||
|
||||
# except Exception as e:
|
||||
# logging.error(f"预测或手动保存过程中发生错误: {e}", exc_info=True)
|
||||
|
||||
def predict(model_path: str, source: str, project_name: str, pt_name: str, conf_threshold: float):
|
||||
"""
|
||||
使用指定模型进行预测。
|
||||
1. 保存原始灰度掩码 (predicted_raw_masks),像素值为类别ID。
|
||||
2. 保存彩色可视化掩码 (predicted_color_masks)。
|
||||
3. 保存三合一对比图 (predicted_comparison) [Original | Ground Truth | Prediction]。
|
||||
"""
|
||||
#
|
||||
pt_name_raw = pt_name.replace('.pt','') # pt_name去掉.pt后缀,方便文件命名使用
|
||||
# [新增] 标志,用于决定是否将(类别0..N-1) 偏移到 (1..N),为 "真背景" 腾出 0
|
||||
perform_id_shift = False
|
||||
|
||||
# --- 1. 加载类别名称和颜色 ---
|
||||
try:
|
||||
with open(config.DATASET_YAML_PATH, 'r', encoding='utf-8') as f:
|
||||
yaml_data = yaml.safe_load(f)
|
||||
class_names = yaml_data.get('names')
|
||||
color_data_from_dataset = yaml_data.get('colors')
|
||||
|
||||
if not class_names:
|
||||
logging.error(f"在 {config.DATASET_YAML_PATH} 中未找到 'names' 键。")
|
||||
return
|
||||
|
||||
logging.info(f"成功从 {config.DATASET_YAML_PATH} 加载类别名称: {class_names}")
|
||||
num_classes = len(class_names)
|
||||
logging.info(f"检测到 {num_classes} 个类别 (ID 0 到 {num_classes-1})。")
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"加载或解析 dataset.yaml 失败: {e}")
|
||||
return
|
||||
|
||||
# --- 2. 创建颜色查找表 (LUT) ---
|
||||
# [调整] 先创建一个临时的 (N) 维 LUT
|
||||
temp_lut = np.zeros((num_classes, 3), dtype=np.uint8)
|
||||
|
||||
if color_data_from_dataset:
|
||||
logging.info(f"正在从 {config.DATASET_YAML_PATH} 加载颜色...")
|
||||
try:
|
||||
for class_id, rgb_value in color_data_from_dataset.items():
|
||||
if 0 <= class_id < num_classes:
|
||||
temp_lut[class_id] = rgb_value[::-1] # BGR
|
||||
logging.info("临时颜色查找表 (LUT) 已成功从 dataset.yaml 创建。")
|
||||
except Exception as e:
|
||||
logging.error(f"解析 'colors' 键失败: {e}。将尝试备用方案。")
|
||||
color_data_from_dataset = None # 强制使用备用方案
|
||||
|
||||
if not color_data_from_dataset:
|
||||
# 备用方案 1: 尝试 color_map.yaml
|
||||
color_map_path = 'color_map.yaml'
|
||||
logging.warning(f"未从 dataset.yaml 加载颜色,正在尝试备用方案: {color_map_path}")
|
||||
try:
|
||||
with open(color_map_path, 'r', encoding='utf-8') as f:
|
||||
color_data = yaml.safe_load(f)['colors']
|
||||
for class_id, bgr_value in color_data.items():
|
||||
if 0 <= class_id < num_classes:
|
||||
color_lut[class_id] = bgr_value
|
||||
logging.info(f"颜色查找表 (LUT) 已成功从 {color_map_path} 创建。")
|
||||
except Exception:
|
||||
# 备用方案 2: 随机颜色
|
||||
logging.warning(f"{color_map_path} 未找到或解析失败。将使用随机颜色。")
|
||||
for i in range(1, num_classes): # 保持 0 (背景) 为黑色 [0,0,0]
|
||||
color_lut[i] = np.random.randint(0, 255, 3, dtype=np.uint8)
|
||||
|
||||
# --- [新增] 核心条件判断 ---
|
||||
# 检查 类别0 (索引0) 的颜色是否为 BGR(0,0,0)
|
||||
if np.any(temp_lut[0] != [0, 0, 0]):
|
||||
# 颜色不是黑色,执行 "ID+1" 偏移
|
||||
logging.info("检测到 类别0 的颜色不是黑色。将执行 ID+1 偏移,使 0 成为专用背景。")
|
||||
perform_id_shift = True
|
||||
# [调整] 创建一个 (N+1) 维的最终 LUT
|
||||
color_lut = np.zeros((num_classes + 1, 3), dtype=np.uint8)
|
||||
# 将原 0..N-1 的颜色 复制到 1..N
|
||||
color_lut[1:] = temp_lut
|
||||
# 索引 0 保持为 [0, 0, 0] (真背景)
|
||||
else:
|
||||
# 颜色是黑色,不执行偏移
|
||||
logging.info("检测到 类别0 的颜色是黑色。将 类别0 视为背景。")
|
||||
perform_id_shift = False
|
||||
color_lut = temp_lut # [调整] 使用 (N) 维的原始 LUT
|
||||
|
||||
if not Path(source).exists():
|
||||
logging.error(f"错误:预测源路径不存在: {source}")
|
||||
return
|
||||
|
||||
try:
|
||||
# --- 3. 加载模型并执行预测 ---
|
||||
logging.info(f"正在加载模型: {model_path}")
|
||||
model = YOLO(model_path)
|
||||
logging.info("模型加载成功。")
|
||||
|
||||
logging.info(f"正在对源进行预测: {source}")
|
||||
results = model.predict(
|
||||
source=source,
|
||||
stream=True,
|
||||
save=False, # <-- 关闭保存,之后手动处理
|
||||
show_labels=config.SHOW_LABELS,
|
||||
show_conf=config.SHOW_CONF,
|
||||
# project=project_name, # <-- 不再需要,我们手动处理
|
||||
# name="prediction", # <-- 不再需要,我们手动处理
|
||||
exist_ok=True,
|
||||
retina_masks=True,
|
||||
conf=conf_threshold # TODO
|
||||
)
|
||||
|
||||
# --- 4. [扩展] 创建所有输出目录 ---
|
||||
yolo_prediction_dir = Path(project_name) / "prediction"
|
||||
raw_mask_save_dir = Path(project_name) / "predicted_raw_masks"
|
||||
color_mask_save_dir = Path(project_name) / "predicted_color_masks"
|
||||
comparison_save_dir = Path(project_name) / "predicted_comparison" # [新增]
|
||||
|
||||
logging.info(f"预测完成。正在处理结果...")
|
||||
logging.info(f" -> Yolo预测结果将保存于: {yolo_prediction_dir}")
|
||||
logging.info(f" -> 原始灰度掩码将保存于: {raw_mask_save_dir}")
|
||||
logging.info(f" -> 彩色预测掩码将保存于: {color_mask_save_dir}")
|
||||
logging.info(f" -> 三合一对比图将保存于: {comparison_save_dir}") # [新增]
|
||||
|
||||
yolo_prediction_dir.mkdir(parents=True, exist_ok=True)
|
||||
raw_mask_save_dir.mkdir(parents=True, exist_ok=True)
|
||||
color_mask_save_dir.mkdir(parents=True, exist_ok=True)
|
||||
comparison_save_dir.mkdir(parents=True, exist_ok=True) # [新增]
|
||||
|
||||
# --- 辅助函数:在图像上添加标签 ---
|
||||
def add_label_to_image(img, label):
|
||||
"""在图像左上角添加一个带背景的标签"""
|
||||
img_with_label = img.copy()
|
||||
h, w = img_with_label.shape[:2]
|
||||
# 动态调整字体大小和粗细
|
||||
font_scale = max(0.8, w // 1000)
|
||||
thickness = max(1, w // 500)
|
||||
|
||||
(text_w, text_h), baseline = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, font_scale, thickness)
|
||||
|
||||
# 绘制黑色背景矩形
|
||||
cv2.rectangle(img_with_label, (0, 0), (text_w + 20, text_h + 30), (0, 0, 0), -1)
|
||||
# 绘制白色文字
|
||||
cv2.putText(img_with_label, label, (10, text_h + 15), cv2.FONT_HERSHEY_SIMPLEX, font_scale, (255, 255, 255), thickness, cv2.LINE_AA)
|
||||
return img_with_label
|
||||
|
||||
# --- 5. [扩展] 处理结果循环 ---
|
||||
for result in results:
|
||||
original_image_path = Path(result.path)
|
||||
logging.info(f"正在处理图片: {original_image_path.name}")
|
||||
|
||||
h, w = result.orig_shape
|
||||
|
||||
# --- 5a. 保存 Yolo-predict 本身结果 ---
|
||||
if config.SAVE_PREDICTIONS == True:
|
||||
# a. 获取原始图片的文件名和后缀
|
||||
# result.path 是原始图片的完整路径
|
||||
original_path = Path(result.path)
|
||||
image_stem = original_path.stem # 这是 "图片名"
|
||||
image_suffix = original_path.suffix # 这是 ".后缀" (例如 .jpg)
|
||||
# b. 构建你的新文件名
|
||||
# 格式: 图片名_{pt_name_raw}.后缀
|
||||
new_filename = f"{image_stem}_{pt_name_raw}{image_suffix}"
|
||||
# c. 构建完整的保存路径
|
||||
save_path = yolo_prediction_dir / new_filename
|
||||
# d. 获取绘制了检测框的图像 (NumPy 数组)
|
||||
# result.plot() 会自动使用你在 predict 中设置的 show_labels, show_conf 等参数
|
||||
annotated_image = result.plot()
|
||||
# e. 使用 OpenCV (cv2) 保存图像
|
||||
cv2.imwrite(str(save_path), annotated_image)
|
||||
print(f"已保存: {save_path}")
|
||||
|
||||
# --- 5b. [新增] 查找并加载 Ground Truth 掩码 ---
|
||||
gt_mask = None
|
||||
try:
|
||||
# 假设的路径结构: .../DataSet_Public/5_.../images/val/image.png
|
||||
# 我们要找: .../DataSet_Public/5_.../labels_GT/val/image.png
|
||||
image_filename = original_image_path.name
|
||||
image_parent_dir_name = original_image_path.parent.name # 'val' or 'train'
|
||||
images_dir = original_image_path.parent.parent # '.../images'
|
||||
dataset_root_dir = images_dir.parent # '.../5_My_Gastric_2025_10_29'
|
||||
|
||||
gt_path = dataset_root_dir / "labels_GT" / image_parent_dir_name / image_filename
|
||||
|
||||
if not gt_path.exists():
|
||||
logging.warning(f" -> 未找到 Ground Truth 文件: {gt_path}")
|
||||
else:
|
||||
gt_mask = cv2.imread(str(gt_path), cv2.IMREAD_GRAYSCALE)
|
||||
if gt_mask is None:
|
||||
logging.warning(f" -> 读取 Ground Truth 文件失败: {gt_path}")
|
||||
elif gt_mask.shape != (h, w):
|
||||
logging.warning(f" -> GT 掩码形状 {gt_mask.shape} 与图像形状 {(h, w)} 不匹配。正在调整GT大小...")
|
||||
gt_mask = cv2.resize(gt_mask, (w, h), interpolation=cv2.INTER_NEAREST)
|
||||
except Exception as e:
|
||||
logging.error(f" -> 查找或加载GT文件时出错: {e}", exc_info=True)
|
||||
gt_mask = None
|
||||
|
||||
# --- 5c. 创建预测掩码 (与之前相同) ---
|
||||
combined_mask = np.zeros((h, w), dtype=np.uint8)
|
||||
|
||||
if result.masks is None:
|
||||
logging.warning(f" -> 在图片 {original_image_path.name} 中未检测到任何物体。")
|
||||
else:
|
||||
masks_data = result.masks.data
|
||||
class_ids = result.boxes.cls.int().cpu().numpy()
|
||||
|
||||
if len(masks_data) != len(class_ids):
|
||||
logging.error(f" -> 掩码和类别ID数量不匹配,跳过图片 {original_image_path.name}")
|
||||
continue
|
||||
|
||||
for i in reversed(range(len(masks_data))):
|
||||
instance_mask = masks_data[i].cpu().numpy().astype(bool)
|
||||
class_id = class_ids[i]
|
||||
# [调整] 根据标志执行不同逻辑
|
||||
if perform_id_shift:
|
||||
# 偏移模式:将所有 ID (包括0) 都 +1
|
||||
combined_mask[instance_mask] = class_id + 1
|
||||
else:
|
||||
# 原始模式:忽略 ID 0
|
||||
if class_id != 0:
|
||||
combined_mask[instance_mask] = class_id
|
||||
|
||||
# --- 5d. [扩展] 保存所有三种输出 ---
|
||||
|
||||
# --- 保存 1: 原始灰度掩码 (预测) ---
|
||||
raw_mask_filename = raw_mask_save_dir / f"{original_image_path.stem}_{pt_name_raw}.png"
|
||||
cv2.imwrite(str(raw_mask_filename), combined_mask)
|
||||
|
||||
# --- 保存 2: 彩色预测掩码 ---
|
||||
colorized_pred_mask = color_lut[combined_mask] # (H, W, 3)
|
||||
color_mask_filename = color_mask_save_dir / f"{original_image_path.stem}_{pt_name_raw}.png"
|
||||
cv2.imwrite(str(color_mask_filename), colorized_pred_mask)
|
||||
|
||||
# --- [新增] 保存 3: 三合一对比图 ---
|
||||
try:
|
||||
# 1. 获取原始图
|
||||
original_image_bgr = result.orig_img # YOLO 结果中自带 BGR 格式原图
|
||||
|
||||
# 2. 获取彩色 GT 图
|
||||
if gt_mask is not None:
|
||||
colorized_gt_mask = color_lut[gt_mask]
|
||||
else:
|
||||
# 如果GT不存在,创建一个黑色占位图
|
||||
logging.warning(f" -> 在对比图中将使用黑色图像作为 Ground Truth 占位符。")
|
||||
colorized_gt_mask = np.zeros_like(original_image_bgr) # (H, W, 3)
|
||||
|
||||
# 3. 获取彩色预测图 (已在上面生成: colorized_pred_mask)
|
||||
|
||||
# 确保所有图像形状一致 (理论上应该一致)
|
||||
if not (original_image_bgr.shape == colorized_gt_mask.shape == colorized_pred_mask.shape):
|
||||
logging.error(f" -> 形状不匹配! 原图: {original_image_bgr.shape}, GT: {colorized_gt_mask.shape}, 预测: {colorized_pred_mask.shape}。跳过对比图。")
|
||||
continue
|
||||
|
||||
# 为每张图添加标签
|
||||
img_labeled = add_label_to_image(original_image_bgr, "Original")
|
||||
gt_labeled = add_label_to_image(colorized_gt_mask, "Ground Truth")
|
||||
pred_labeled = add_label_to_image(colorized_pred_mask, "Prediction")
|
||||
|
||||
# 水平拼接
|
||||
comparison_image = cv2.hconcat([img_labeled, gt_labeled, pred_labeled])
|
||||
|
||||
# 保存 (使用 .jpg 格式以节省空间)
|
||||
comparison_filename = comparison_save_dir / f"{original_image_path.stem}_{pt_name_raw}.jpg"
|
||||
cv2.imwrite(str(comparison_filename), comparison_image)
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f" -> 创建或保存对比图失败: {e}", exc_info=True)
|
||||
|
||||
|
||||
logging.info(f"所有掩码图和对比图已成功保存。")
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"预测或手动保存过程中发生错误: {e}", exc_info=True)
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 1. 创建解析器,现在只需要 --model 和 --source
|
||||
parser = argparse.ArgumentParser(description="使用已训练的YOLO模型进行预测。")
|
||||
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
required=True,
|
||||
choices=list(config.MODEL_CONFIGS.keys()),
|
||||
help="选择一个基础模型类型,以筛选其训练历史。"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--source",
|
||||
type=str,
|
||||
default=str(config.TEST_IMAGE_DIR),
|
||||
help="图片或图片文件夹的路径。"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pt_name",
|
||||
type=str,
|
||||
default=str("best.pt"),
|
||||
help="图片或图片文件夹的路径。"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--conf",
|
||||
type=float, # 1. 类型应为浮点数
|
||||
default=0.2,
|
||||
help="设置预测的置信度阈值 (例如: 0.25)" # 2. 补充 help 信息
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# 2. 根据 --model 参数查找对应的训练历史
|
||||
available_runs = find_trained_models(config.PREDICT_BEST_MODEL_DIR, args.model, args.pt_name)
|
||||
|
||||
run_to_use = None
|
||||
|
||||
# 3. 根据找到的结果数量,决定下一步操作
|
||||
if not available_runs:
|
||||
logging.error(f"错误:在 {config.PREDICT_BEST_MODEL_DIR} 中没有找到任何关于模型 '{args.model}' 的有效训练记录。")
|
||||
sys.exit(1)
|
||||
|
||||
elif len(available_runs) == 1:
|
||||
# 如果只有一个结果,自动选择
|
||||
run_to_use = available_runs[0]
|
||||
logging.info(f"只找到一个训练版本,已自动选择: {run_to_use}")
|
||||
|
||||
else:
|
||||
# 如果有多个结果,进入交互式选择模式
|
||||
print(f"\n为模型 '{args.model}' 找到多个训练版本:")
|
||||
for i, run_name in enumerate(available_runs, 1):
|
||||
print(f" [{i}] {run_name}")
|
||||
|
||||
while True:
|
||||
try:
|
||||
choice = input(f"请输入您想使用的版本序号 (1-{len(available_runs)}): ")
|
||||
choice_index = int(choice)
|
||||
if 1 <= choice_index <= len(available_runs):
|
||||
run_to_use = available_runs[choice_index - 1]
|
||||
break
|
||||
else:
|
||||
print("错误:输入无效,请输入列表中的序号。")
|
||||
except ValueError:
|
||||
print("错误:请输入一个数字。")
|
||||
except (KeyboardInterrupt, EOFError):
|
||||
print("\n操作已取消。")
|
||||
sys.exit(0)
|
||||
|
||||
# 4. 使用最终确定的 run_to_use 来执行预测
|
||||
if run_to_use:
|
||||
model_to_use = config.PREDICT_BEST_MODEL_DIR / run_to_use / 'weights' / args.pt_name
|
||||
project_to_save_in = str(config.PREDICT_BEST_MODEL_DIR / run_to_use)
|
||||
|
||||
if not model_to_use.exists():
|
||||
logging.error(f"严重错误:找不到权重文件 {model_to_use}。")
|
||||
else:
|
||||
predict(
|
||||
model_path=str(model_to_use),
|
||||
source=args.source,
|
||||
project_name=project_to_save_in,
|
||||
pt_name = args.pt_name,
|
||||
conf_threshold=args.conf
|
||||
)
|
||||
383
Seg_All_In_One_YoloModel/yolo_predict_V2_compare_all.py
Normal file
383
Seg_All_In_One_YoloModel/yolo_predict_V2_compare_all.py
Normal file
@@ -0,0 +1,383 @@
|
||||
import cv2
|
||||
import numpy as np
|
||||
import yaml
|
||||
import logging
|
||||
import argparse
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from tqdm import tqdm
|
||||
|
||||
# --- 日志设置 ---
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s [%(levelname)s] %(message)s",
|
||||
handlers=[logging.StreamHandler()]
|
||||
)
|
||||
|
||||
try:
|
||||
import yolo_config as config
|
||||
logging.info(f"成功加载配置 yolo_config.py")
|
||||
logging.info(f" -> 最佳模型目录: {config.PREDICT_BEST_MODEL_DIR}")
|
||||
logging.info(f" -> 数据集 YAML: {config.DATASET_YAML_PATH}")
|
||||
logging.info(f" -> 测试图片目录: {config.TEST_IMAGE_DIR}")
|
||||
except ImportError:
|
||||
logging.error("错误: yolo_config.py 未找到。请确保它在同一目录下。")
|
||||
sys.exit(1)
|
||||
except AttributeError as e:
|
||||
logging.error(f"错误: yolo_config.py 中缺少必要的配置。 {e}")
|
||||
sys.exit(1)
|
||||
|
||||
# --- 辅助函数: 1. 添加标签 ---
|
||||
def add_label_to_image(img, label):
|
||||
"""在图像左上角添加一个带背景的标签"""
|
||||
img_with_label = img.copy()
|
||||
h, w = img_with_label.shape[:2]
|
||||
# 动态调整字体大小和粗细
|
||||
font_scale = max(0.8, w // 1000)
|
||||
thickness = max(1, w // 500)
|
||||
|
||||
(text_w, text_h), baseline = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, font_scale, thickness)
|
||||
|
||||
# 绘制黑色背景矩形
|
||||
cv2.rectangle(img_with_label, (0, 0), (text_w + 20, text_h + 30), (0, 0, 0), -1)
|
||||
# 绘制白色文字
|
||||
cv2.putText(img_with_label, label, (10, text_h + 15), cv2.FONT_HERSHEY_SIMPLEX, font_scale, (255, 255, 255), thickness, cv2.LINE_AA)
|
||||
return img_with_label
|
||||
|
||||
# --- 辅助函数: 2. 加载颜色LUT ---
|
||||
def load_color_lut(dataset_yaml_path):
|
||||
"""
|
||||
加载颜色查找表 (LUT),精确复制 yolo_predict_V2.py 的逻辑
|
||||
"""
|
||||
try:
|
||||
with open(dataset_yaml_path, 'r', encoding='utf-8') as f:
|
||||
yaml_data = yaml.safe_load(f)
|
||||
class_names = yaml_data.get('names')
|
||||
color_data_from_dataset = yaml_data.get('colors')
|
||||
|
||||
if not class_names:
|
||||
logging.error(f"在 {dataset_yaml_path} 中未找到 'names' 键。")
|
||||
return None
|
||||
|
||||
num_classes = len(class_names)
|
||||
logging.info(f"检测到 {num_classes} 个类别 (ID 0 到 {num_classes-1})。")
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"加载或解析 dataset.yaml 失败: {e}")
|
||||
return None
|
||||
|
||||
# 创建 (N) 维临时 LUT
|
||||
temp_lut = np.zeros((num_classes, 3), dtype=np.uint8)
|
||||
|
||||
if color_data_from_dataset:
|
||||
logging.info(f"正在从 {dataset_yaml_path} 加载颜色...")
|
||||
try:
|
||||
for class_id, rgb_value in color_data_from_dataset.items():
|
||||
if 0 <= class_id < num_classes:
|
||||
temp_lut[class_id] = rgb_value[::-1] # BGR
|
||||
logging.info("临时颜色查找表 (LUT) 已成功从 dataset.yaml 创建。")
|
||||
except Exception as e:
|
||||
logging.error(f"解析 'colors' 键失败: {e}。将尝试备用方案。")
|
||||
color_data_from_dataset = None
|
||||
|
||||
if not color_data_from_dataset:
|
||||
color_map_path = 'color_map.yaml'
|
||||
logging.warning(f"未从 dataset.yaml 加载颜色,正在尝试备用方案: {color_map_path}")
|
||||
try:
|
||||
with open(color_map_path, 'r', encoding='utf-8') as f:
|
||||
color_data = yaml.safe_load(f)['colors']
|
||||
for class_id, bgr_value in color_data.items():
|
||||
if 0 <= class_id < num_classes:
|
||||
temp_lut[class_id] = bgr_value # 假设 color_map.yaml 存的是 BGR
|
||||
logging.info(f"颜色查找表 (LUT) 已成功从 {color_map_path} 创建。")
|
||||
except Exception:
|
||||
logging.warning(f"{color_map_path} 未找到或解析失败。将使用随机颜色。")
|
||||
for i in range(1, num_classes):
|
||||
temp_lut[i] = np.random.randint(0, 255, 3, dtype=np.uint8)
|
||||
|
||||
# --- 核心条件判断 (与 yolo_predict_V2.py 完全一致) ---
|
||||
if np.any(temp_lut[0] != [0, 0, 0]):
|
||||
logging.info("检测到 类别0 的颜色不是黑色。将创建 (N+1) LUT,使 0 成为专用背景。")
|
||||
color_lut = np.zeros((num_classes + 1, 3), dtype=np.uint8)
|
||||
color_lut[1:] = temp_lut
|
||||
else:
|
||||
logging.info("检测到 类别0 的颜色是黑色。将使用 (N) LUT。")
|
||||
color_lut = temp_lut
|
||||
|
||||
logging.info(f"最终颜色LUT创建成功,维度: {color_lut.shape}")
|
||||
return color_lut
|
||||
|
||||
# --- 辅助函数: 3. 查找模型目录 ---
|
||||
def find_model_dirs(base_dir, required_subdir):
|
||||
"""查找所有包含 'required_subdir' 子目录的有效模型文件夹"""
|
||||
model_dirs = []
|
||||
if not base_dir.is_dir():
|
||||
logging.error(f"基础目录不存在: {base_dir}")
|
||||
return []
|
||||
|
||||
for d in base_dir.iterdir():
|
||||
# 必须是一个目录,并且包含所需的预测结果子目录
|
||||
if d.is_dir() and (d / required_subdir).exists():
|
||||
model_dirs.append(d)
|
||||
|
||||
model_dirs.sort()
|
||||
return model_dirs
|
||||
|
||||
# --- 辅助函数: 4. 查找图像名称 ---
|
||||
def find_image_names(model_dirs, sub_dir_name, file_glob_patterns):
|
||||
"""
|
||||
从第一个模型目录的指定子目录(sub_dir_name)中获取所有待处理的图像文件名
|
||||
使用 file_glob_patterns 列表 (例如 ['*.jpg', '*.png'])
|
||||
"""
|
||||
image_names = set() # 使用集合避免重复
|
||||
if not model_dirs:
|
||||
return []
|
||||
|
||||
first_model_target_dir = model_dirs[0] / sub_dir_name
|
||||
|
||||
for pattern in file_glob_patterns:
|
||||
for f in first_model_target_dir.glob(pattern):
|
||||
if f.is_file():
|
||||
image_names.add(f.name)
|
||||
|
||||
sorted_names = sorted(list(image_names))
|
||||
return sorted_names
|
||||
|
||||
# --- 辅助函数: 5. 获取原图和GT路径 ---
|
||||
def get_gt_and_original_paths(image_name):
|
||||
"""
|
||||
根据输出的掩码文件名,反向推导原始图像和GT图像的路径。
|
||||
假设文件名格式为: {original_stem}_{pt_name_raw}.png
|
||||
"""
|
||||
try:
|
||||
# 1. 从 config 中获取关键路径
|
||||
test_image_dir = config.TEST_IMAGE_DIR
|
||||
images_dir = test_image_dir.parent
|
||||
dataset_root_dir = images_dir.parent
|
||||
image_parent_dir_name = test_image_dir.name # e.g., 'val'
|
||||
|
||||
# 2. 推导原始图像的 stem
|
||||
if '_' not in image_name:
|
||||
logging.warning(f" -> 图像 {image_name} 似乎不含 '_{{pt_name}}' 后缀,将尝试使用完整文件名作为 stem。")
|
||||
original_stem = Path(image_name).stem
|
||||
else:
|
||||
# 从最后一个 '_' 拆分
|
||||
original_stem = image_name[:image_name.rfind('_')]
|
||||
|
||||
|
||||
# 3. 在 test_image_dir 中查找原始图像
|
||||
original_image_path = None
|
||||
# glob 查找, 匹配 .png, .jpg, .bmp 等
|
||||
for f in test_image_dir.glob(f"{original_stem}.*"):
|
||||
if f.is_file() and f.suffix.lower() in ['.jpg', '.jpeg', '.png', '.bmp', '.tif', '.tiff']:
|
||||
original_image_path = f
|
||||
break # 找到第一个匹配项
|
||||
|
||||
if not original_image_path:
|
||||
# 尝试在上一级目录查找 (兼容某些奇怪的 val/test 结构)
|
||||
for f in test_image_dir.parent.glob(f"{original_stem}.*"):
|
||||
if f.is_file() and f.suffix.lower() in ['.jpg', '.jpeg', '.png', '.bmp', '.tif', '.tiff']:
|
||||
original_image_path = f
|
||||
break
|
||||
|
||||
if not original_image_path:
|
||||
logging.warning(f" -> 在 {test_image_dir} (及其父目录) 中未找到原始图像 (如 {original_stem}.*)")
|
||||
return None, None
|
||||
|
||||
# 4. 推导GT路径 (逻辑复制自 yolo_predict_V2.py)
|
||||
image_filename = original_image_path.name
|
||||
|
||||
gt_path = dataset_root_dir / "labels_GT" / image_parent_dir_name / image_filename
|
||||
gt_path_png = gt_path.with_suffix('.png') # 备用
|
||||
|
||||
if gt_path.exists():
|
||||
return original_image_path, gt_path
|
||||
elif gt_path_png.exists():
|
||||
return original_image_path, gt_path_png
|
||||
else:
|
||||
# 返回预期的路径,主循环将处理 "文件未找到"
|
||||
return original_image_path, gt_path
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f" -> 推导路径时出错: {e}", exc_info=True)
|
||||
return None, None
|
||||
|
||||
|
||||
# --- [新增] 核心处理函数 ---
|
||||
def generate_comparison_images(base_dir, output_dir_name, source_subdir_name,
|
||||
gt_panel_label, use_gt_overlay,
|
||||
file_glob_patterns, color_lut):
|
||||
"""
|
||||
核心处理函数,用于生成一种类型的对比图。
|
||||
"""
|
||||
logging.info(f"--- 🚀 开始生成 '{output_dir_name}' ---")
|
||||
|
||||
output_dir = base_dir / output_dir_name
|
||||
output_dir.mkdir(exist_ok=True)
|
||||
|
||||
# 3. 查找模型目录
|
||||
model_dirs = find_model_dirs(base_dir, source_subdir_name)
|
||||
if not model_dirs:
|
||||
logging.error(f"在 {base_dir} 中未找到任何包含 '{source_subdir_name}' 的有效模型预测目录。")
|
||||
logging.warning(f"--- ⚠️ 跳过 '{output_dir_name}' 的生成 ---")
|
||||
return
|
||||
logging.info(f"找到 {len(model_dirs)} 个模型进行对比:")
|
||||
for d in model_dirs:
|
||||
logging.info(f" -> {d.name}")
|
||||
|
||||
# 4. 查找图像列表
|
||||
image_names = find_image_names(model_dirs, source_subdir_name, file_glob_patterns)
|
||||
if not image_names:
|
||||
logging.error(f"未找到任何匹配 {file_glob_patterns} 的图像文件。")
|
||||
logging.error(f"请检查 {model_dirs[0] / source_subdir_name} 目录。")
|
||||
logging.warning(f"--- ⚠️ 跳过 '{output_dir_name}' 的生成 ---")
|
||||
return
|
||||
logging.info(f"找到 {len(image_names)} 张图像进行处理。")
|
||||
|
||||
# 5. --- 遍历每张图像 ---
|
||||
for image_name in tqdm(image_names, desc=f"生成 {output_dir_name}"):
|
||||
|
||||
comparison_panels = []
|
||||
|
||||
# 5a. 获取路径
|
||||
orig_path, gt_path = get_gt_and_original_paths(image_name)
|
||||
|
||||
if not orig_path:
|
||||
logging.warning(f"\n跳过: 未能找到 {image_name} 的原始图像。")
|
||||
continue
|
||||
|
||||
# 5b. 加载原图
|
||||
orig_img = cv2.imread(str(orig_path))
|
||||
if orig_img is None:
|
||||
logging.warning(f"\n跳过: 无法读取原始图像 {orig_path}")
|
||||
continue
|
||||
|
||||
h, w = orig_img.shape[:2]
|
||||
|
||||
# 1. ORI_PIC
|
||||
comparison_panels.append(add_label_to_image(orig_img, "ORI_PIC"))
|
||||
|
||||
# 5c. 加载 Ground Truth
|
||||
gt_color = np.zeros_like(orig_img) # 默认黑色
|
||||
if gt_path and gt_path.exists():
|
||||
gt_gray = cv2.imread(str(gt_path), cv2.IMREAD_GRAYSCALE)
|
||||
if gt_gray is not None:
|
||||
if gt_gray.shape != (h, w):
|
||||
gt_gray = cv2.resize(gt_gray, (w, h), interpolation=cv2.INTER_NEAREST)
|
||||
gt_color = color_lut[gt_gray]
|
||||
else:
|
||||
logging.warning(f"\n无法读取GT图像 {gt_path},使用黑色占位符。")
|
||||
else:
|
||||
logging.warning(f"\nGT图像未找到 (预期路径: {gt_path}),使用黑色占位符。")
|
||||
|
||||
# 2. GT Panel (Mask or Overlay)
|
||||
if use_gt_overlay:
|
||||
gt_panel = cv2.addWeighted(orig_img, 0.6, gt_color, 0.4, 0)
|
||||
else:
|
||||
gt_panel = gt_color
|
||||
comparison_panels.append(add_label_to_image(gt_panel, gt_panel_label))
|
||||
|
||||
# 5d. 加载每个模型的预测结果
|
||||
for model_dir in model_dirs:
|
||||
pred_path = model_dir / source_subdir_name / image_name
|
||||
pred_img = np.zeros_like(orig_img) # 默认黑色
|
||||
|
||||
if pred_path.exists():
|
||||
img = cv2.imread(str(pred_path))
|
||||
if img is not None:
|
||||
if img.shape != (h, w, 3):
|
||||
logging.warning(f"\n预测图 {pred_path.name} 尺寸 {img.shape} 与原图 {(h,w,3)} 不匹配。正在调整大小...")
|
||||
img = cv2.resize(img, (w, h), interpolation=cv2.INTER_AREA)
|
||||
pred_img = img
|
||||
else:
|
||||
logging.warning(f"\n无法读取预测图 {pred_path},使用黑色占位符。")
|
||||
else:
|
||||
logging.warning(f"\n预测图未找到 {pred_path},使用黑色占位符。")
|
||||
|
||||
short_label = model_dir.name.split('_')[0]
|
||||
comparison_panels.append(add_label_to_image(pred_img, short_label))
|
||||
|
||||
# 5e. 拼接并保存
|
||||
try:
|
||||
final_image = cv2.hconcat(comparison_panels)
|
||||
save_stem = Path(image_name).stem
|
||||
save_path = output_dir / (save_stem + ".png")
|
||||
cv2.imwrite(str(save_path), final_image)
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"\n--- 拼接或保存 {image_name} 失败: {e} ---")
|
||||
logging.error(f"面板尺寸 (H, W, C): {[p.shape for p in comparison_panels]}")
|
||||
logging.warning("请检查所有图像是否具有相同的高度。")
|
||||
|
||||
logging.info(f"--- ✅ '{output_dir_name}' 生成完毕 ---")
|
||||
|
||||
|
||||
# --- [修改] 主函数 (重构) ---
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="将所有模型的预测结果 (Mask 和 Yolo) 与原图和GT图拼接成对比图。"
|
||||
)
|
||||
# 模式参数已移除
|
||||
parser.add_argument(
|
||||
"--pt_name",
|
||||
type=str,
|
||||
default="all",
|
||||
help="您想比较的权重文件名称 (例如 'best.pt' 或 'epoch100.pt','all'表示对全部权重进行处理)"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
base_dir = config.PREDICT_BEST_MODEL_DIR
|
||||
|
||||
# 1. 加载颜色LUT (只需一次)
|
||||
color_lut = load_color_lut(config.DATASET_YAML_PATH)
|
||||
if color_lut is None:
|
||||
logging.error("无法加载颜色LUT,脚本终止。")
|
||||
return
|
||||
|
||||
# 2. 定义通用文件模式
|
||||
pt_name_raw = args.pt_name.replace('.pt', '')
|
||||
all_image_types = ["*.jpg", "*.jpeg", "*.png", "*.bmp", "*.tif", "*.tiff"]
|
||||
|
||||
# 3. --- 执行 "Mask" 模式 ---
|
||||
mask_patterns = []
|
||||
if args.pt_name.lower() == "all":
|
||||
mask_patterns = ["*.png"]
|
||||
else:
|
||||
mask_patterns = [f"*_{pt_name_raw}.png"]
|
||||
|
||||
generate_comparison_images(
|
||||
base_dir=base_dir,
|
||||
output_dir_name="compare_all_masks",
|
||||
source_subdir_name="predicted_color_masks",
|
||||
gt_panel_label="Ground_Truth",
|
||||
use_gt_overlay=False,
|
||||
file_glob_patterns=mask_patterns,
|
||||
color_lut=color_lut
|
||||
)
|
||||
|
||||
# 4. --- 执行 "Yolo" 模式 ---
|
||||
yolo_patterns = []
|
||||
if args.pt_name.lower() == "all":
|
||||
yolo_patterns = all_image_types
|
||||
else:
|
||||
yolo_patterns = [
|
||||
f"*_{pt_name_raw}.jpg", f"*_{pt_name_raw}.jpeg",
|
||||
f"*_{pt_name_raw}.png", f"*_{pt_name_raw}.bmp",
|
||||
f"*_{pt_name_raw}.tif", f"*_{pt_name_raw}.tiff"
|
||||
]
|
||||
|
||||
generate_comparison_images(
|
||||
base_dir=base_dir,
|
||||
output_dir_name="compare_all_masks_Yolo",
|
||||
source_subdir_name="prediction",
|
||||
gt_panel_label="GT_Overlay",
|
||||
use_gt_overlay=True,
|
||||
file_glob_patterns=yolo_patterns,
|
||||
color_lut=color_lut
|
||||
)
|
||||
|
||||
logging.info("--- 🏁 所有对比图任务执行完毕 ---")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
233
Seg_All_In_One_YoloModel/yolo_predict_raw_masks_check.py
Normal file
233
Seg_All_In_One_YoloModel/yolo_predict_raw_masks_check.py
Normal file
@@ -0,0 +1,233 @@
|
||||
import sys
|
||||
import yaml
|
||||
from pathlib import Path
|
||||
from typing import Set, Tuple
|
||||
import argparse # <-- [新增]
|
||||
import logging # <-- [新增]
|
||||
|
||||
try:
|
||||
# Import config to get base directories and dataset configuration path
|
||||
import yolo_config as config
|
||||
except ImportError:
|
||||
print("错误:无法导入 'yolo_config.py'。请确保此脚本与 config.py 在同一目录下。")
|
||||
sys.exit(1)
|
||||
|
||||
def get_file_info(directory: Path) -> Tuple[Set[str], Set[str]]:
|
||||
"""
|
||||
安全地获取目录中所有文件的基本名(stem)和后缀(suffix)。
|
||||
|
||||
返回:
|
||||
一个包含 stems 集合和 suffixes 集合的元组。
|
||||
"""
|
||||
if not directory.is_dir():
|
||||
return set(), set()
|
||||
|
||||
stems = set()
|
||||
suffixes = set()
|
||||
for item in directory.iterdir():
|
||||
if item.is_file():
|
||||
stems.add(item.stem)
|
||||
suffixes.add(item.suffix)
|
||||
return stems, suffixes
|
||||
|
||||
def get_source_labels_dir() -> Path:
|
||||
"""
|
||||
从 dataset.yaml 解析并构建源标签目录的绝对路径。
|
||||
路径的最后一部分 (例如 'train', 'val' 或 'test') 将根据 dataset.yaml 中的 'test' 键动态确定。
|
||||
"""
|
||||
try:
|
||||
with open(config.DATASET_YAML_PATH, 'r', encoding='utf-8') as f:
|
||||
yaml_data = yaml.safe_load(f)
|
||||
|
||||
# 1. 获取数据集的根目录相对路径
|
||||
relative_path_from_yaml = yaml_data.get('path')
|
||||
if not relative_path_from_yaml:
|
||||
print(f"错误: 在 '{config.DATASET_YAML_PATH}' 文件中没有找到 'path' 键。")
|
||||
sys.exit(1)
|
||||
|
||||
# 2. 【新增】获取 'test' 键的值 (例如 'images/val')
|
||||
test_path_from_yaml = yaml_data.get('test')
|
||||
if not test_path_from_yaml:
|
||||
print(f"错误: 在 '{config.DATASET_YAML_PATH}' 文件中没有找到 'test' 键。")
|
||||
sys.exit(1)
|
||||
|
||||
# 3. 【新增】从 'images/val' 中提取最后一部分 'val'
|
||||
# Path(...).name 可以轻松实现这个功能
|
||||
source_split = Path(test_path_from_yaml).name
|
||||
|
||||
# 4. 构建数据集的绝对路径
|
||||
dataset_dir = (config.DATASET_YAML_PATH.parent / relative_path_from_yaml).resolve()
|
||||
|
||||
# 5. 【已修改】使用动态提取的 source_split 构建并返回标签目录路径
|
||||
labels_dir = dataset_dir / "labels" / source_split
|
||||
return labels_dir
|
||||
|
||||
except FileNotFoundError:
|
||||
print(f"错误: YAML 配置文件未找到: '{config.DATASET_YAML_PATH}'")
|
||||
sys.exit(1)
|
||||
except Exception as e:
|
||||
print(f"读取或解析 YAML 文件时发生错误: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def main():
|
||||
"""
|
||||
主函数,用于详细检查源训练标签集和所有预测输出的原始掩码
|
||||
目录之间的文件数量和文件名一致性。
|
||||
"""
|
||||
# --- [新增] ---
|
||||
parser = argparse.ArgumentParser(description="检查预测掩码与源标签的一致性。")
|
||||
parser.add_argument(
|
||||
"--pt_name",
|
||||
type=str,
|
||||
default="best.pt",
|
||||
help="要检查的权重文件名称 (例如 'best.pt' 或 'epoch100.pt')。脚本将查找以此为后缀的掩码文件。"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# 从 'best.pt' 提取 'best'
|
||||
pt_name_raw = Path(args.pt_name).stem
|
||||
# 构建预期的后缀, e.g., '_best'
|
||||
expected_suffix = f"_{pt_name_raw}"
|
||||
|
||||
# [核心修改] 只有 'best.pt' 会兼容检查无后缀的传统文件
|
||||
check_traditional = (args.pt_name == "best.pt")
|
||||
|
||||
print(f"--- 开始详细检查预测原始掩码 (predicted_raw_masks) ---")
|
||||
if check_traditional:
|
||||
print(f"[i] 本次检查目标: --pt_name='{args.pt_name}' (检查后缀 '{expected_suffix}' 或 无后缀的传统文件)")
|
||||
else:
|
||||
print(f"[i] 本次检查目标: --pt_name='{args.pt_name}' (仅检查后缀 '{expected_suffix}' 的文件)")
|
||||
|
||||
# 1. 获取源标签目录并解析文件信息
|
||||
source_dir = get_source_labels_dir()
|
||||
if not source_dir.exists() or not source_dir.is_dir():
|
||||
print(f"错误:在 '{source_dir}' 找不到源训练标签目录。请检查 dataset.yaml 中的 'path' 设置。")
|
||||
sys.exit(1)
|
||||
|
||||
source_stems, _ = get_file_info(source_dir)
|
||||
source_file_count = len(source_stems)
|
||||
|
||||
if source_file_count == 0:
|
||||
print(f"警告:源标签目录 '{source_dir}' 为空或不包含任何文件。检查中止。")
|
||||
sys.exit(0)
|
||||
|
||||
# 在你的日志中,路径是 '1_CholecSeg8k-13Type-1920x1080',但配置文件是 '4_Dresden-11Type-512x512'
|
||||
# 这里我们以代码逻辑为准,它会动态读取配置文件
|
||||
print(f"源标签集: 在 '{source_dir}' 中找到 {source_file_count} 个标签文件 (.txt)。")
|
||||
print("-" * 60)
|
||||
|
||||
# 2. 获取预测结果的基础目录
|
||||
predictions_base_dir = config.PREDICT_BEST_MODEL_DIR
|
||||
if not predictions_base_dir.exists() or not predictions_base_dir.is_dir():
|
||||
print(f"错误:在 '{predictions_base_dir}' 找不到预测结果的基础目录。")
|
||||
sys.exit(1)
|
||||
|
||||
# 3. 遍历每个模型的运行目录并进行详细检查
|
||||
found_predictions = False
|
||||
global_mismatch_found = False
|
||||
|
||||
# 【已修正】直接查找所有模型的运行目录,例如 'YOLOv8n-seg_2025-09-20_10-00-00'
|
||||
run_dirs = sorted([d for d in predictions_base_dir.iterdir() if d.is_dir()])
|
||||
|
||||
if not run_dirs:
|
||||
print(f"信息:在基础目录 '{predictions_base_dir}' 中没有找到任何模型的预测结果文件夹。")
|
||||
sys.exit(0)
|
||||
|
||||
# 【已修正】移除外层循环,直接遍历 run_dirs
|
||||
for run_dir in run_dirs:
|
||||
target_dir = run_dir / "predicted_raw_masks"
|
||||
|
||||
# 如果目标目录不存在,则跳过
|
||||
if not (target_dir.exists() and target_dir.is_dir()):
|
||||
continue
|
||||
|
||||
print(f"正在检查: '{target_dir}'...")
|
||||
found_predictions = True
|
||||
is_current_dir_ok = True
|
||||
|
||||
# 获取目录中所有的stems
|
||||
all_predicted_stems, _ = get_file_info(target_dir)
|
||||
|
||||
# --- [核心修改] ---
|
||||
# 1. 提取我们关心的stems
|
||||
processed_stems = set() # 存储处理后 (去掉后缀) 的stems, e.g., {'fileA', 'fileB'}
|
||||
stems_we_checked = set() # 存储我们检查过的原始stems, e.g., {'fileA_best', 'fileA'}
|
||||
|
||||
for stem in all_predicted_stems:
|
||||
# 检查 1: 是否为带后缀的新模式 (e.g., "fileA_best" 或 "fileA_epoch100")
|
||||
if stem.endswith(expected_suffix):
|
||||
# 去掉后缀, e.g., 'fileA_best' -> 'fileA'
|
||||
unstripped_stem = stem[:-len(expected_suffix)]
|
||||
|
||||
# 仅当去掉后缀后的部分在源文件中时才添加
|
||||
if unstripped_stem in source_stems:
|
||||
processed_stems.add(unstripped_stem)
|
||||
stems_we_checked.add(stem)
|
||||
|
||||
# 检查 2: [条件] 是否允许检查传统文件 (check_traditional == True)
|
||||
# [条件] 且该文件是否为无后缀的传统文件 (e.g., "fileA")
|
||||
elif check_traditional and (stem in source_stems):
|
||||
# 兼容 'best.pt' 检查传统文件
|
||||
processed_stems.add(stem)
|
||||
stems_we_checked.add(stem)
|
||||
|
||||
# 2. 现在,'processed_stems' 只包含与 'source_stems' 对应的基础文件名
|
||||
predicted_file_count = len(processed_stems)
|
||||
|
||||
# --- 检查 1: 文件数量 ---
|
||||
if source_file_count != predicted_file_count:
|
||||
print(f" [✗ 数量不匹配] 源目录有 {source_file_count} 个文件,但此目录中(匹配检查规则的)有 {predicted_file_count} 个。")
|
||||
is_current_dir_ok = False
|
||||
|
||||
# --- 检查 2: 文件名 (基于处理后的stems) ---
|
||||
missing_files = source_stems - processed_stems
|
||||
if missing_files:
|
||||
examples = list(missing_files)[:3]
|
||||
print(f" [✗ 文件名缺失] 预测结果中缺少 {len(missing_files)} 个文件 (匹配检查规则)。例如: {examples}...")
|
||||
is_current_dir_ok = False
|
||||
|
||||
extra_files = processed_stems - source_stems
|
||||
if extra_files:
|
||||
examples = list(extra_files)[:3]
|
||||
print(f" [✗ 文件名多余] (逻辑异常) 预测结果中多出 {len(extra_files)} 个文件。例如: {examples}...")
|
||||
is_current_dir_ok = False
|
||||
|
||||
# --- [新增] 检查 3: 提示其他被忽略的文件 ---
|
||||
other_files = all_predicted_stems - stems_we_checked
|
||||
if other_files:
|
||||
examples = list(other_files)[:3]
|
||||
print(f" [i 信息] 目录中还包含 {len(other_files)} 个其他文件 (例如: {examples}...)。")
|
||||
if check_traditional:
|
||||
print(f" [i 信息] (这些文件不匹配后缀 '{expected_suffix}' 且不属于无后缀传统文件,已忽略)")
|
||||
else:
|
||||
print(f" [i 信息] (这些文件不匹配后缀 '{expected_suffix}',已忽略)")
|
||||
|
||||
# --- 当前目录的检查总结 ---
|
||||
if is_current_dir_ok:
|
||||
# 再次检查目录为空的特殊情况
|
||||
if predicted_file_count == 0 and source_file_count > 0:
|
||||
print(f" [✗ 目录为空] 预测目录中没有找到任何匹配本次检查规则的文件。")
|
||||
global_mismatch_found = True
|
||||
else:
|
||||
print(f" [✓ OK] 所有检查通过 (文件数量: {predicted_file_count},匹配检查规则)")
|
||||
else:
|
||||
global_mismatch_found = True
|
||||
|
||||
print("-" * 30)
|
||||
|
||||
# 4. 输出最终的全局检查摘要
|
||||
print("\n--- 检查摘要 ---")
|
||||
if not found_predictions:
|
||||
print("结果: 没有找到任何 'predicted_raw_masks' 目录进行检查。")
|
||||
elif global_mismatch_found:
|
||||
print("结论: 检查不通过。发现至少一个预测目录未能与源标签集完全匹配,请查看上方日志。")
|
||||
sys.exit(1)
|
||||
else:
|
||||
print("结论: 检查通过!所有已找到的 'predicted_raw_masks' 目录均通过了与源标签集的文件名和数量一致性检查。")
|
||||
|
||||
print("--- 检查完成 ---")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
538
Seg_All_In_One_YoloModel/yolo_predict_visualize_nn.py
Normal file
538
Seg_All_In_One_YoloModel/yolo_predict_visualize_nn.py
Normal file
@@ -0,0 +1,538 @@
|
||||
import logging
|
||||
import sys
|
||||
import argparse
|
||||
import shutil
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Any
|
||||
from ultralytics import YOLO
|
||||
|
||||
# --- [!! 关键 !!] 动态添加项目根目录 (来自之前的讨论) ---
|
||||
script_path = Path(__file__).resolve()
|
||||
project_root = script_path.parent.parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
# ----------------------------------------------------
|
||||
import yolo_config as config
|
||||
|
||||
# --- [新增] 导入 V2 所需的库 ---
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from pytorch_grad_cam import (
|
||||
GradCAM, GradCAMPlusPlus, XGradCAM, EigenCAM,
|
||||
HiResCAM, LayerCAM, RandomCAM, EigenGradCAM
|
||||
)
|
||||
from pytorch_grad_cam.utils.image import show_cam_on_image
|
||||
from pytorch_grad_cam.base_cam import BaseCAM
|
||||
|
||||
# --- 日志设置 ---
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s [%(levelname)s] %(message)s",
|
||||
handlers=[logging.StreamHandler()]
|
||||
)
|
||||
|
||||
# --- [保留] V2 中的 CAM 方法字典 ---
|
||||
CAM_METHODS = {
|
||||
"GradCAM": GradCAM, # ok # 速度快
|
||||
"GradCAMPlusPlus": GradCAMPlusPlus, # ok # 速度快
|
||||
# "XGradCAM": XGradCAM, # ok # 速度快
|
||||
"EigenCAM": EigenCAM,
|
||||
"HiResCAM": HiResCAM, # ok # 速度快
|
||||
# "LayerCAM": LayerCAM, # ok # 速度快
|
||||
"RandomCAM": RandomCAM, # ok # 速度快
|
||||
# "EigenGradCAM": EigenGradCAM, # 这个耗时长
|
||||
}
|
||||
|
||||
# --- [保留] V2 中的 ActivationMaximizationTarget ---
|
||||
class ActivationMaximizationTarget:
|
||||
def __init__(self, channel=0):
|
||||
self.channel = channel
|
||||
def __call__(self, model_output):
|
||||
if model_output.ndim == 4:
|
||||
return model_output[:, self.channel, :, :].mean()
|
||||
elif model_output.ndim == 3:
|
||||
return model_output[self.channel, :, :].mean()
|
||||
else:
|
||||
raise ValueError(f"Unsupported model_output shape: {model_output.shape}")
|
||||
|
||||
# --- [保留] 图像预处理函数 ---
|
||||
def preprocess_image(bgr_img: np.ndarray, imgsz: int, device: str) -> (torch.Tensor, Dict[str, Any]):
|
||||
# (函数内容与原版相同,为节省篇幅已折叠)
|
||||
orig_h, orig_w = bgr_img.shape[:2]
|
||||
rgb_img = bgr_img[:, :, ::-1]
|
||||
scale = min(imgsz / orig_w, imgsz / orig_h)
|
||||
new_w, new_h = int(orig_w * scale), int(orig_h * scale)
|
||||
resized_image = cv2.resize(rgb_img, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
|
||||
pad_w, pad_h = imgsz - new_w, imgsz - new_h
|
||||
pad_left, pad_right = pad_w // 2, pad_w - pad_w // 2
|
||||
pad_top, pad_bottom = pad_h // 2, pad_h - pad_h // 2
|
||||
padded_image = cv2.copyMakeBorder(resized_image, pad_top, pad_bottom, pad_left, pad_right,
|
||||
cv2.BORDER_CONSTANT, value=(114, 114, 114))
|
||||
img_float = np.float32(padded_image) / 255.0
|
||||
input_tensor = torch.from_numpy(img_float).permute(2, 0, 1).unsqueeze(0).to(device)
|
||||
padding_info = {
|
||||
"pad_top": pad_top, "new_h": new_h, "pad_left": pad_left,
|
||||
"new_w": new_w, "orig_h": orig_h, "orig_w": orig_w
|
||||
}
|
||||
return input_tensor, padding_info
|
||||
|
||||
|
||||
# --- [保留] V1 中的关键修复函数 ---
|
||||
def reshape_transform(outputs):
|
||||
if isinstance(outputs, tuple):
|
||||
return outputs[0]
|
||||
return outputs
|
||||
|
||||
# --- [保留] V1 中的模型查找函数 ---
|
||||
def find_trained_models(outputs_dir: Path, model_key: str, pt_name: str) -> List[str]:
|
||||
# (函数内容与原版相同,为节省篇幅已折叠)
|
||||
trained_models = []
|
||||
if not outputs_dir.is_dir():
|
||||
logging.warning(f"输出目录不存在: {outputs_dir}")
|
||||
return []
|
||||
for project_folder in outputs_dir.iterdir():
|
||||
if project_folder.is_dir() and project_folder.name.startswith(model_key + '_'):
|
||||
if (project_folder / 'weights' / pt_name).exists():
|
||||
trained_models.append(project_folder.name)
|
||||
trained_models.sort()
|
||||
return trained_models
|
||||
|
||||
|
||||
# --- [!! 核心重构 !!] 融合 V1 和 V2 的可视化函数 ---
|
||||
def visualize_nn_comprehensive(
|
||||
model_path: str,
|
||||
source_dir: str,
|
||||
base_save_dir: Path,
|
||||
pt_name: str,
|
||||
cam_method_name: str,
|
||||
target_layer_str: str,
|
||||
model_key: str # [!! 新增 !!] 接收模型 Key (e.g., "YOLO11m-seg")
|
||||
):
|
||||
"""
|
||||
使用 V2 的多方法和 V1 的框架进行 CAM 可视化。
|
||||
"""
|
||||
|
||||
# --- 1. 加载模型 (V1 逻辑) ---
|
||||
try:
|
||||
logging.info(f"正在加载模型: {model_path}")
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
model = YOLO(model_path).to(device)
|
||||
model.model.eval()
|
||||
imgsz = model.model.args['imgsz']
|
||||
logging.info(f"模型加载成功。设备: {device}, 图像尺寸: {imgsz}")
|
||||
except Exception as e:
|
||||
logging.error(f"无法加载模型 {model_path}: {e}", exc_info=True)
|
||||
return
|
||||
|
||||
# --- 2. [!! 关键修改 !!] 解析目标层 ---
|
||||
target_layers_to_process = [] # 存储 (name, layer_module) 元组
|
||||
|
||||
if target_layer_str.strip().lower() == "default":
|
||||
logging.info("检测到 'default' 目标层,正在解析...")
|
||||
|
||||
# 对 YOLOv8 系列模型的 'default' 逻辑 # TODO # 对模型进行修改
|
||||
if model_key.lower().startswith("yolov8"):
|
||||
logging.info(f"检测到 YOLOv8 系列模型 ({model_key})。将使用精选的8个默认层。")
|
||||
default_yolo8_layers = [
|
||||
(1, "model.model.model[1]"), # Layer_01_Conv
|
||||
(4, "model.model.model[4]"), # Layer_04_C2f
|
||||
(9, "model.model.model[9]"), # Layer_09_SPPF
|
||||
(11, "model.model.model[11]"), # Layer_11_Concat
|
||||
(14, "model.model.model[14]"), # Layer_14_Concat
|
||||
(19, "model.model.model[19]"), # Layer_19_Conv
|
||||
(20, "model.model.model[20]"), # Layer_20_Concat
|
||||
(21, "model.model.model[21]") # Layer_21_C2f
|
||||
]
|
||||
|
||||
# 如果使用 EigenCAM 或 EigenGradCAM,则只分析最后两个层
|
||||
if cam_method_name in ["EigenCAM", "EigenGradCAM"]:
|
||||
logging.info(f" -> 使用 {cam_method_name}。将自动裁剪为最后2个默认层。")
|
||||
default_yolo8_layers = default_yolo8_layers[-3:] # (只保留最后两个)
|
||||
|
||||
for idx, path_str in default_yolo8_layers:
|
||||
try:
|
||||
layer = model.model.model[idx]
|
||||
target_layers_to_process.append((path_str, layer))
|
||||
logging.info(f" -> 已添加 YOLOv8 默认层: {path_str} (索引 {idx})")
|
||||
except Exception as e:
|
||||
logging.warning(f" -> [!] 无法添加 YOLOv8 默认层 {path_str}: {e}")
|
||||
|
||||
# 对 YOLOv9 系列模型的 'default' 逻辑 # TODO # 对模型进行修改
|
||||
elif model_key.lower().startswith("yolov9"):
|
||||
logging.info(f"检测到 YOLOv9 系列模型 ({model_key})。将使用精选的8个默认层。")
|
||||
# (这是您新指定的列表)
|
||||
default_yolo9_layers = [
|
||||
(2, "model.model.model[2]"), # Layer_02_Conv
|
||||
(7, "model.model.model[7]"), # Layer_07_RepNCSPELAN4
|
||||
(29, "model.model.model[29]"), # Layer_29_SPPELAN
|
||||
(31, "model.model.model[31]"), # Layer_31_Concat
|
||||
(34, "model.model.model[34]"), # Layer_34_Concat
|
||||
(37, "model.model.model[37]"), # Layer_37_Concat
|
||||
(40, "model.model.model[40]"), # Layer_40_Concat
|
||||
(41, "model.model.model[41]") # Layer_41_RepNCSPELAN4
|
||||
]
|
||||
|
||||
# 如果使用 EigenCAM 或 EigenGradCAM,则只分析最后两个层
|
||||
if cam_method_name in ["EigenCAM", "EigenGradCAM"]:
|
||||
logging.info(f" -> 使用 {cam_method_name}。将自动裁剪为最后2个默认层。")
|
||||
default_yolo9_layers = default_yolo9_layers[-3:] # (只保留最后两个)
|
||||
|
||||
for idx, path_str in default_yolo9_layers:
|
||||
try:
|
||||
layer = model.model.model[idx]
|
||||
target_layers_to_process.append((path_str, layer))
|
||||
logging.info(f" -> 已添加 YOLOv9 默认层: {path_str} (索引 {idx})")
|
||||
except Exception as e:
|
||||
logging.warning(f" -> [!] 无法添加 YOLOv9 默认层 {path_str}: {e}")
|
||||
|
||||
# 对 YOLO11 系列模型的 'default' 逻辑 # TODO # 对模型进行修改
|
||||
elif model_key.lower().startswith("yolo11"):
|
||||
logging.info(f"检测到 YOLO11 系列模型 ({model_key})。将使用精选的8个默认层。")
|
||||
# (索引, 路径字符串)
|
||||
default_yolo11_layers = [
|
||||
(0, "model.model.model[0]"), # Layer_00_Conv
|
||||
(4, "model.model.model[4]"), # Layer_04_C3k2
|
||||
(9, "model.model.model[9]"), # Layer_09_SPPF
|
||||
(12, "model.model.model[12]"), # Layer_12_Concat
|
||||
(15, "model.model.model[15]"), # Layer_15_Concat
|
||||
(20, "model.model.model[20]"), # Layer_20_Conv
|
||||
(21, "model.model.model[21]"), # Layer_21_Concat
|
||||
(22, "model.model.model[22]") # Layer_22_C3k2
|
||||
]
|
||||
|
||||
# 如果使用 EigenCAM 或 EigenGradCAM,则只分析最后两个层
|
||||
if cam_method_name in ["EigenCAM", "EigenGradCAM"]:
|
||||
logging.info(f" -> 使用 {cam_method_name}。将自动裁剪为最后2个默认层。")
|
||||
default_yolo11_layers = default_yolo11_layers[-3:] # (只保留最后两个)
|
||||
|
||||
for idx, path_str in default_yolo11_layers:
|
||||
try:
|
||||
# 假设 'model.model.model' 是一个列表
|
||||
layer = model.model.model[idx]
|
||||
target_layers_to_process.append((path_str, layer))
|
||||
logging.info(f" -> 已添加 YOLO11 默认层: {path_str} (索引 {idx})")
|
||||
except Exception as e:
|
||||
logging.warning(f" -> [!] 无法添加 YOLO11 默认层 {path_str}: {e}")
|
||||
|
||||
# 对 YOLO12 系列模型的 'default' 逻辑 # TODO # 对模型进行修改
|
||||
elif model_key.lower().startswith("yolo12"):
|
||||
logging.info(f"检测到 YOLOv12 系列模型 ({model_key})。将使用您精选的8个默认层。")
|
||||
# (这是您新指定的列表)
|
||||
default_yolo12_layers = [
|
||||
(2, "model.model.model[2]"), # Layer_02_C3k2
|
||||
(4, "model.model.model[4]"), # Layer_04_C3k2
|
||||
(8, "model.model.model[8]"), # Layer_08_A2C2f
|
||||
(10, "model.model.model[10]"), # Layer_10_Concat
|
||||
(13, "model.model.model[13]"), # Layer_13_Concat
|
||||
(17, "model.model.model[17]"), # Layer_17_A2C2f
|
||||
(19, "model.model.model[19]"), # Layer_19_Concat
|
||||
(20, "model.model.model[20]") # Layer_20_C3k2
|
||||
]
|
||||
|
||||
# 如果使用 EigenCAM 或 EigenGradCAM,则只分析最后两个层
|
||||
if cam_method_name in ["EigenCAM", "EigenGradCAM"]:
|
||||
logging.info(f" -> 使用 {cam_method_name}。将自动裁剪为最后2个默认层。")
|
||||
default_yolo12_layers = default_yolo12_layers[-3:] # (只保留最后两个)
|
||||
|
||||
for idx, path_str in default_yolo12_layers:
|
||||
try:
|
||||
layer = model.model.model[idx]
|
||||
target_layers_to_process.append((path_str, layer))
|
||||
logging.info(f" -> 已添加 YOLOv12 默认层: {path_str} (索引 {idx})")
|
||||
except Exception as e:
|
||||
logging.warning(f" -> [!] 无法添加 YOLOv12 默认层 {path_str}: {e}")
|
||||
|
||||
# [!! 保留 !!] 非 YOLOv8、YOLOv9、YOLO11、YOLO12 模型的 'default' 逻辑
|
||||
else:
|
||||
logging.info(f"非 YOLO11 模型 ({model_key})。将使用标准 'default' 逻辑。")
|
||||
try:
|
||||
layer = model.model.model[8].proto.cv3 # (来自您 V2 文件的 `[8]`)
|
||||
name = "model.model.model[8].proto.cv3"
|
||||
target_layers_to_process.append((name, layer))
|
||||
logging.info(f" -> 已添加默认层: {name}")
|
||||
except Exception as e:
|
||||
logging.warning(f" -> 无法定位 'model[8].proto.cv3'。回退到 'model[15]'。错误: {e}")
|
||||
try:
|
||||
layer = model.model.model[15]
|
||||
name = "model.model.model[15]"
|
||||
target_layers_to_process.append((name, layer))
|
||||
logging.info(f" -> 已添加回退层: {name}")
|
||||
except Exception as e_fallback:
|
||||
logging.error(f" -> [!] 无法定位 'model[15]'。跳过 'default'。错误: {e_fallback}")
|
||||
|
||||
# [!! 保留 !!] 处理非 'default' 的自定义层路径
|
||||
else:
|
||||
logging.info(f"正在解析自定义目标层: {target_layer_str}")
|
||||
layer_paths = target_layer_str.split(',')
|
||||
for path in layer_paths:
|
||||
path_cleaned = path.strip()
|
||||
if not path_cleaned:
|
||||
continue
|
||||
try:
|
||||
layer = eval(path_cleaned, {"model": model})
|
||||
if isinstance(layer, torch.nn.Module):
|
||||
target_layers_to_process.append((path_cleaned, layer))
|
||||
logging.info(f" -> 已添加指定层: {path_cleaned}")
|
||||
else:
|
||||
logging.warning(f" -> 路径 '{path_cleaned}' 不是一个 nn.Module,已跳过。")
|
||||
except Exception as e:
|
||||
logging.error(f" -> [!] 无法解析目标层路径 '{path_cleaned}': {e}。已跳过。")
|
||||
|
||||
if not target_layers_to_process:
|
||||
logging.error("未找到有效的目标层进行可视化。程序退出。")
|
||||
return
|
||||
|
||||
# --- 3. 查找源图片 (V1 逻辑) ---
|
||||
source_path = Path(source_dir)
|
||||
image_paths = []
|
||||
if source_path.is_file():
|
||||
image_paths = [source_path]
|
||||
elif source_path.is_dir():
|
||||
image_paths = sorted(
|
||||
list(source_path.glob("*.jpg")) +
|
||||
list(source_path.glob("*.jpeg")) +
|
||||
list(source_path.glob("*.png"))
|
||||
)
|
||||
if not image_paths:
|
||||
logging.error(f"在源路径中未找到任何图片: {source_dir}")
|
||||
return
|
||||
|
||||
# --- 4. 获取 CAM 方法 (V2 逻辑) ---
|
||||
cam_class = CAM_METHODS.get(cam_method_name)
|
||||
if not cam_class:
|
||||
logging.error(f"无效的 CAM 方法: {cam_method_name}。")
|
||||
return
|
||||
logging.info(f"--- 将使用 CAM 方法: {cam_method_name} ---")
|
||||
|
||||
# --- 5. [!! 循环顺序保留 !!] ---
|
||||
|
||||
targets = None
|
||||
if cam_method_name != "EigenCAM":
|
||||
targets = [ActivationMaximizationTarget(channel=0)]
|
||||
logging.info("已为非 EigenCAM 方法设置 ActivationMaximizationTarget(channel=0)。")
|
||||
|
||||
pt_name_raw = pt_name.replace('.pt', '') # 移到循环外
|
||||
|
||||
# 外循环:遍历图片
|
||||
for img_path in image_paths:
|
||||
logging.info(f"\n--- 正在处理图片: {img_path.name} ---")
|
||||
|
||||
try:
|
||||
# --- a. 加载和预处理 ---
|
||||
orig_img_bgr = cv2.imread(str(img_path))
|
||||
if orig_img_bgr is None:
|
||||
logging.warning(f" -> 无法读取图片: {img_path.name},已跳过。")
|
||||
continue
|
||||
|
||||
orig_img_rgb_float = np.float32(orig_img_bgr[:, :, ::-1]) / 255
|
||||
input_tensor, pad_info = preprocess_image(orig_img_bgr, imgsz, device)
|
||||
input_tensor.requires_grad_()
|
||||
|
||||
# --- b. [!! 关键修改 !!] 创建新的三层输出目录 ---
|
||||
image_stem = img_path.stem
|
||||
|
||||
# 1. 创建算法根目录 (e.g., .../train6/HeartMap_Visual/GradCAM/)
|
||||
cam_root = base_save_dir / "HeartMap_Visual" / f"{cam_method_name}_{pt_name_raw}"
|
||||
|
||||
# 2. 创建该图片专用的子目录 (e.g., .../HeartMap_Visual/GradCAM/image01_GradCAM/)
|
||||
nn_visual_root = cam_root / f"{image_stem}_{cam_method_name}"
|
||||
nn_visual_root.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 3. 创建该图片专用的 _ORI 子目录 (e.g., .../HeartMap_Visual/GradCAM/image01_GradCAM_ORI/)
|
||||
nn_visual_ori_root = cam_root / f"{image_stem}_{cam_method_name}_ORI"
|
||||
nn_visual_ori_root.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# [!! 关键修改 !!] 更新日志打印路径,使其相对于 base_save_dir
|
||||
logging.info(f" -> 覆盖图将保存到: {nn_visual_root.relative_to(base_save_dir)}")
|
||||
logging.info(f" -> 纯热度图将保存到: {nn_visual_ori_root.relative_to(base_save_dir)}")
|
||||
|
||||
|
||||
# --- c. 内循环:遍历层 ---
|
||||
for layer_name, layer_module in target_layers_to_process:
|
||||
safe_layer_name = layer_name.replace('.', '_').replace('[', '').replace(']', '')
|
||||
logging.info(f" -> 正在处理层: {layer_name}")
|
||||
|
||||
try:
|
||||
# --- d. 初始化 CAM ---
|
||||
with cam_class(model=model.model,
|
||||
target_layers=[layer_module],
|
||||
reshape_transform=reshape_transform) as cam:
|
||||
|
||||
# --- e. 运行 CAM ---
|
||||
cam_output = cam(input_tensor=input_tensor, targets=targets)
|
||||
|
||||
# --- f. CAM 输出后处理 ---
|
||||
# (此部分逻辑与 V2 相同)
|
||||
grayscale_cam = None
|
||||
if isinstance(cam_output, (list, tuple)):
|
||||
grayscale_cam = cam_output[0]
|
||||
else:
|
||||
grayscale_cam = cam_output
|
||||
if isinstance(grayscale_cam, torch.Tensor):
|
||||
grayscale_cam = grayscale_cam.detach().cpu().numpy()
|
||||
if grayscale_cam.ndim == 4:
|
||||
grayscale_cam = grayscale_cam[0].mean(axis=0)
|
||||
elif grayscale_cam.ndim == 3:
|
||||
grayscale_cam = grayscale_cam[0, :]
|
||||
elif grayscale_cam.ndim == 2:
|
||||
pass # OK
|
||||
else:
|
||||
logging.warning(f" -> [!] 层 {layer_name} 的 CAM 结果维度异常: {grayscale_cam.shape},已跳过。")
|
||||
continue
|
||||
|
||||
# --- g. 裁剪 Padding 并 Resize ---
|
||||
cam_cropped = grayscale_cam[pad_info['pad_top'] : pad_info['pad_top'] + pad_info['new_h'],
|
||||
pad_info['pad_left'] : pad_info['pad_left'] + pad_info['new_w']]
|
||||
cam_upscaled = cv2.resize(cam_cropped,
|
||||
(pad_info['orig_w'], pad_info['orig_h']),
|
||||
interpolation=cv2.INTER_LINEAR)
|
||||
cam_upscaled = (cam_upscaled - cam_upscaled.min()) / (cam_upscaled.max() + 1e-8)
|
||||
|
||||
# --- h. 定义文件名和路径 ---
|
||||
save_filename_overlay = f"Layer_{safe_layer_name}_overlay.jpg"
|
||||
save_filename_ori = f"Layer_{safe_layer_name}_heatmap.jpg"
|
||||
|
||||
save_path_overlay = nn_visual_root / save_filename_overlay
|
||||
save_path_ori = nn_visual_ori_root / save_filename_ori
|
||||
|
||||
# --- i. 保存覆盖图 ---
|
||||
visualization = show_cam_on_image(orig_img_rgb_float,
|
||||
cam_upscaled,
|
||||
use_rgb=True,
|
||||
image_weight=0.5)
|
||||
viz_bgr = visualization[:, :, ::-1]
|
||||
cv2.imwrite(str(save_path_overlay), viz_bgr)
|
||||
|
||||
# --- j. 保存纯热度图 ---
|
||||
heatmap_uint8 = np.uint8(255 * cam_upscaled)
|
||||
heatmap_color = cv2.applyColorMap(heatmap_uint8, cv2.COLORMAP_JET)
|
||||
cv2.imwrite(str(save_path_ori), heatmap_color)
|
||||
|
||||
logging.info(f" -> (√) {layer_name} 可视化结果已保存")
|
||||
|
||||
except Exception as e_layer:
|
||||
logging.error(f" -> [!] 处理层 {layer_name} 时发生错误: {e_layer}", exc_info=False)
|
||||
|
||||
except Exception as e_img:
|
||||
logging.error(f" -> [!] 处理图片 {img_path.name} 时发生严重错误: {e_img}", exc_info=True)
|
||||
|
||||
logging.info("\n--- 所有图片已处理完毕 ---")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
config.show_config_summary()
|
||||
|
||||
parser = argparse.ArgumentParser(description="使用 Grad-CAM (EigenCAM) 生成 YOLO 神经网络可视化。")
|
||||
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
required=True,
|
||||
choices=list(config.MODEL_CONFIGS.keys()),
|
||||
help="选择一个基础模型类型来筛选其训练历史。"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--source",
|
||||
type=str,
|
||||
default=str(config.TEST_IMAGE_DIR),
|
||||
help="图片或图片文件夹的路径。"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pt_name",
|
||||
type=str,
|
||||
default="best.pt",
|
||||
help="要使用的权重文件名 (例如 'best.pt' 或 'epoch100.pt')。"
|
||||
)
|
||||
all_cam_choices = list(CAM_METHODS.keys()) + ["All"]
|
||||
parser.add_argument(
|
||||
"--cam_method",
|
||||
type=str,
|
||||
default="All",
|
||||
choices=all_cam_choices,
|
||||
help=f"选择要使用的 CAM 可视化方法。默认为: All。"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--target_layers",
|
||||
type=str,
|
||||
default="default",
|
||||
help="指定要可视化的目标层。 "
|
||||
"'default': (YOLOv8/YOLO11系列使用精选8层,其他使用标准默认层); "
|
||||
"'model.model.model[15]': (指定单个层); "
|
||||
"'model.model.model[10],model.model.model[15]': (指定多个层)"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# --- [!! 关键修改 !!] 移除路径校准 ---
|
||||
# 假设 yolo_config.py 已被修复,config.PREDICT_BEST_MODEL_DIR 是绝对路径
|
||||
logging.info(f"正在搜索模型目录: {config.PREDICT_BEST_MODEL_DIR}")
|
||||
available_runs = find_trained_models(config.PREDICT_BEST_MODEL_DIR, args.model, args.pt_name)
|
||||
|
||||
run_to_use = None
|
||||
|
||||
if not available_runs:
|
||||
logging.error(f"错误:在 {config.PREDICT_BEST_MODEL_DIR} 中未找到模型 '{args.model}' (使用 '{args.pt_name}') 的有效训练记录。")
|
||||
sys.exit(1)
|
||||
|
||||
elif len(available_runs) == 1:
|
||||
run_to_use = available_runs[0]
|
||||
logging.info(f"只找到一个训练版本,已自动选择: {run_to_use}")
|
||||
|
||||
else:
|
||||
# (V1 的模型选择菜单)
|
||||
print(f"\n为模型 '{args.model}' (使用 '{args.pt_name}') 找到多个训练版本:")
|
||||
for i, run_name in enumerate(available_runs, 1):
|
||||
print(f" [{i}] {run_name}")
|
||||
while True:
|
||||
try:
|
||||
choice = input(f"请输入您想使用的版本序号 (1-{len(available_runs)}): ")
|
||||
choice_index = int(choice)
|
||||
if 1 <= choice_index <= len(available_runs):
|
||||
run_to_use = available_runs[choice_index - 1]
|
||||
break
|
||||
else:
|
||||
print("错误:输入无效,请输入列表中的序号。")
|
||||
except ValueError:
|
||||
print("错误:请输入一个数字。")
|
||||
except (KeyboardInterrupt, EOFError):
|
||||
print("\n操作已取消。")
|
||||
sys.exit(0)
|
||||
|
||||
# --- 4. 调用 [!! 新的 !!] 可视化函数 ---
|
||||
if run_to_use:
|
||||
model_to_use = config.PREDICT_BEST_MODEL_DIR / run_to_use / 'weights' / args.pt_name
|
||||
project_to_save_in = config.PREDICT_BEST_MODEL_DIR / run_to_use
|
||||
|
||||
if not model_to_use.exists():
|
||||
logging.error(f"严重错误:找不到权重文件 {model_to_use}。")
|
||||
else:
|
||||
methods_to_run = []
|
||||
if args.cam_method == "All":
|
||||
methods_to_run = list(CAM_METHODS.keys()) #
|
||||
logging.info(f"--- [!!] 检测到 'All' 选项,将运行全部 {len(methods_to_run)} 种 CAM 方法 ---")
|
||||
else:
|
||||
methods_to_run = [args.cam_method] # 列表中只有一项
|
||||
|
||||
for method_name in methods_to_run:
|
||||
# (可选) 检查 Eigen* 方法是否与 'default' 一起使用
|
||||
is_slow_method = method_name in ["EigenCAM", "EigenGradCAM"] #
|
||||
is_default_layers = args.target_layers == "default" #
|
||||
|
||||
if is_slow_method and not is_default_layers:
|
||||
logging.warning(f"--- [警告] 您正在对自定义层运行 {method_name}。")
|
||||
logging.warning(f"--- 如果目标层是浅层,可能会非常缓慢或卡住。---")
|
||||
|
||||
logging.info(f"\n--- 正在启动 CAM 方法: [ {method_name} ] ---")
|
||||
|
||||
# 将原始调用放入循环中
|
||||
visualize_nn_comprehensive(
|
||||
model_path=str(model_to_use),
|
||||
source_dir=args.source,
|
||||
base_save_dir=project_to_save_in,
|
||||
pt_name=args.pt_name,
|
||||
cam_method_name=method_name, # [!! 修改 !!] 使用循环中的变量
|
||||
target_layer_str=args.target_layers,
|
||||
model_key=args.model
|
||||
)
|
||||
|
||||
logging.info("\n--- [!!] 所有指定的 CAM 方法均已执行完毕 ---")
|
||||
158
Seg_All_In_One_YoloModel/yolo_train.py
Normal file
158
Seg_All_In_One_YoloModel/yolo_train.py
Normal file
@@ -0,0 +1,158 @@
|
||||
import argparse, shutil, glob
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from ultralytics import YOLO
|
||||
from datetime import datetime
|
||||
import yolo_config as config
|
||||
import gc, torch
|
||||
|
||||
# --- 日志设置 ---
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s [%(levelname)s] %(message)s",
|
||||
handlers=[logging.StreamHandler()]
|
||||
)
|
||||
|
||||
# 寻找最匹配的 last.pt文件 和 其所在文件夹
|
||||
def find_latest_last_pt(model_key: str, outputs_dir: Path):
|
||||
"""根据模型名查找最新的 last.pt 路径,并返回 (last_pt路径, 文件夹名)"""
|
||||
pattern = str(outputs_dir / f"{model_key}_*" / "weights" / "last.pt")
|
||||
candidates = glob.glob(pattern)
|
||||
if not candidates:
|
||||
return None, None
|
||||
# 按修改时间排序,取最新的
|
||||
latest_pt = max(candidates, key=lambda p: Path(p).stat().st_mtime)
|
||||
# 提取文件夹名
|
||||
save_folder_name = Path(latest_pt).parent.parent.name # weights 的上级就是训练目录
|
||||
return latest_pt, save_folder_name
|
||||
|
||||
def train_model(model_key: str):
|
||||
"""
|
||||
根据给定的模型密钥训练YOLO分割模型。
|
||||
|
||||
Args:
|
||||
model_key (str): 在 yolo_config.MODEL_CONFIGS 中定义的模型密钥。
|
||||
"""
|
||||
# 生成时间戳以区分不同训练运行
|
||||
timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
||||
project_folder = f"{model_key}_{timestamp}"
|
||||
|
||||
if model_key not in config.MODEL_CONFIGS:
|
||||
logging.error(f"错误:模型 '{model_key}' 不在 yolo_config.py 中定义。")
|
||||
logging.info(f"可用模型: {list(config.MODEL_CONFIGS.keys())}")
|
||||
return
|
||||
|
||||
model_config = config.MODEL_CONFIGS[model_key]
|
||||
model_file = model_config['weights']
|
||||
model_batch_size = model_config['batch_size']
|
||||
model_image_size = model_config['image_size']
|
||||
logging.info(f"开始训练模型: {model_key} ({model_file})")
|
||||
|
||||
try:
|
||||
# # V1:--- 1. 加载模型 ---
|
||||
# model = YOLO(model_file) # 如果没有模型会自动下载
|
||||
# logging.info("模型加载成功。")
|
||||
|
||||
# V2:--- 1. 加载模型(支持断点续训) ---
|
||||
last_pt, save_folder_name = find_latest_last_pt(model_key, config.OUTPUTS_DIR)
|
||||
if last_pt and Path(last_pt).exists():
|
||||
model = YOLO(last_pt)
|
||||
logging.info(f"断点续训: {last_pt}")
|
||||
logging.info(f"续训目录名为: {save_folder_name}")
|
||||
|
||||
model.train(
|
||||
resume=True, # ✅ 关键参数
|
||||
data=str(config.DATASET_YAML_PATH),
|
||||
epochs=config.EPOCHS,
|
||||
imgsz=model_image_size,
|
||||
batch=model_batch_size,
|
||||
optimizer=config.OPTIMIZER,
|
||||
lr0=config.LEARNING_RATE,
|
||||
device=config.DEVICE,
|
||||
project=str(config.OUTPUTS_DIR), # 指定输出的根目录
|
||||
name=save_folder_name, # 指定本次训练的项目名
|
||||
workers=config.WORKERS,
|
||||
exist_ok=True, # 如果项目已存在,则覆盖
|
||||
patience=config.PATIENCE, # 提前停止训练的轮数
|
||||
save_period=config.SAVE_PERIOD # 每隔多少轮保存一次模型
|
||||
)
|
||||
else:
|
||||
model = YOLO(model_file)
|
||||
logging.info(f"从头训练: {model_file}")
|
||||
save_folder_name = project_folder
|
||||
|
||||
# --- 2. 模型训练 ---
|
||||
logging.info(f"数据集配置文件: {config.DATASET_YAML_PATH}")
|
||||
logging.info(f"训练参数: Epochs={config.EPOCHS}, Batch Size={model_batch_size}, Img Size={model_image_size}")
|
||||
|
||||
model.train(
|
||||
data=str(config.DATASET_YAML_PATH),
|
||||
epochs=config.EPOCHS,
|
||||
imgsz=model_image_size,
|
||||
batch=model_batch_size,
|
||||
optimizer=config.OPTIMIZER,
|
||||
lr0=config.LEARNING_RATE,
|
||||
device=config.DEVICE,
|
||||
project=str(config.OUTPUTS_DIR), # 指定输出的根目录
|
||||
name=save_folder_name, # 指定本次训练的项目名
|
||||
workers=config.WORKERS,
|
||||
exist_ok=True, # 如果项目已存在,则覆盖
|
||||
patience=config.PATIENCE, # 提前停止训练的轮数
|
||||
save_period=config.SAVE_PERIOD # 每隔多少轮保存一次模型
|
||||
)
|
||||
|
||||
logging.info("模型训练完成。")
|
||||
logging.info(f"训练结果保存在: {config.OUTPUTS_DIR / save_folder_name}")
|
||||
|
||||
return save_folder_name
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"训练过程中发生错误: {e}")
|
||||
return None
|
||||
|
||||
# --- 3. 复制结果到硬盘并删除原文件夹 (新增逻辑) ---
|
||||
def move_results_to_hardisk(project_folder: str):
|
||||
source_dir = config.OUTPUTS_DIR / project_folder
|
||||
# 确保目标硬盘目录存在
|
||||
outputs_folder_name = config.OUTPUTS_DIR.name
|
||||
destination_dir = config.HARDISK_DIR / outputs_folder_name / project_folder
|
||||
destination_dir.parent.mkdir(parents=True, exist_ok=True)
|
||||
logging.info(f"准备将结果从 {source_dir} 移动到 {destination_dir}...")
|
||||
try:
|
||||
# 步骤 1: 复制文件夹
|
||||
shutil.copytree(source_dir, destination_dir)
|
||||
logging.info(f"成功复制结果到: {destination_dir}")
|
||||
|
||||
# 步骤 2: 复制成功后,删除原文件夹
|
||||
logging.info(f"正在删除原文件夹: {source_dir}")
|
||||
shutil.rmtree(source_dir)
|
||||
logging.info("成功删除原文件夹。")
|
||||
logging.info(f"任务完成,最终结果已保存至: {destination_dir}")
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"移动文件夹时发生错误: {e}")
|
||||
logging.warning(f"原始训练结果仍保留在: {source_dir}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 创建命令行参数解析器
|
||||
parser = argparse.ArgumentParser(description="训练 YOLO 分割模型。")
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
required=True,
|
||||
choices=list(config.MODEL_CONFIGS.keys()),
|
||||
help=f"选择要训练的模型。可选: {list(config.MODEL_CONFIGS.keys())}"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# 1. 开始训练
|
||||
save_folder_name = train_model(args.model)
|
||||
# 2. 训练完成后,移动结果到硬盘
|
||||
if save_folder_name:
|
||||
move_results_to_hardisk(save_folder_name)
|
||||
else:
|
||||
logging.error("由于训练失败,未执行结果移动操作。")
|
||||
# 3. 释放资源
|
||||
torch.cuda.empty_cache() # 释放 CUDA 显存
|
||||
gc.collect() # 释放系统内存
|
||||
logging.info("已清理 CUDA 显存与系统内存。")
|
||||
137
Seg_All_In_One_YoloModel/yolo_train.sh
Normal file
137
Seg_All_In_One_YoloModel/yolo_train.sh
Normal file
@@ -0,0 +1,137 @@
|
||||
#!/bin/bash
|
||||
|
||||
# =================================================================
|
||||
# YOLO 模型批量并行训练脚本
|
||||
# =================================================================
|
||||
|
||||
# --- 1. Conda 环境设置 ---
|
||||
CONDA_BASE_PATH="/home/wkmgc/miniconda3" # <--- 在这里修改为您自己的 Conda 路径
|
||||
CONDA_ENV_NAME="${SEG_CONDA_ENV:-seg_smp}" # 可用 SEG_CONDA_ENV=SMP bash yolo_train.sh 临时覆盖
|
||||
|
||||
# 初始化并激活 Conda 环境
|
||||
if [ -f "${CONDA_BASE_PATH}/etc/profile.d/conda.sh" ]; then
|
||||
source "${CONDA_BASE_PATH}/etc/profile.d/conda.sh"
|
||||
conda activate "${CONDA_ENV_NAME}"
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "错误: 激活 Conda 环境 '${CONDA_ENV_NAME}' 失败!"
|
||||
exit 1
|
||||
fi
|
||||
echo "Conda 环境 '${CONDA_ENV_NAME}' 已成功激活。"
|
||||
else
|
||||
echo "错误: 找不到 conda.sh 脚本。请检查您的 CONDA_BASE_PATH 设置是否正确。"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "正在为当前终端会话设置 HTTP/HTTPS 代理..."
|
||||
export https_proxy=http://127.0.0.1:1080
|
||||
export http_proxy=http://127.0.0.1:1080
|
||||
echo "代理已设置为 http://127.0.0.1:1080"
|
||||
echo "如果模型权重不存在,将通过此代理进行下载。"
|
||||
|
||||
# --- 2. 模型与 GPU 配置 ---
|
||||
# 根据您的 GPU 硬件和训练计划进行分组
|
||||
# yolo_train.py 会自动使用分配给它的所有可见 GPU
|
||||
GPUS_GROUP_0="0"
|
||||
GPUS_GROUP_1="1"
|
||||
GPUS_GROUP_2="2"
|
||||
GPUS_GROUP_3="3"
|
||||
# TODO
|
||||
GPUS_GROUP_4="0"
|
||||
GPUS_GROUP_5="1"
|
||||
GPUS_GROUP_6="2"
|
||||
GPUS_GROUP_7="3"
|
||||
|
||||
# GPUS_GROUP_4="4"
|
||||
# GPUS_GROUP_5="5"
|
||||
# GPUS_GROUP_6="6"
|
||||
# GPUS_GROUP_7="7"
|
||||
|
||||
# 从 yolo_config.py 中选择要训练的模型,并分配到不同的组
|
||||
# 建议将计算量/模型大小相似的模型放在同一组
|
||||
GROUP_0_MODELS=("YOLO11l-seg" "YOLOv8m-seg")
|
||||
GROUP_1_MODELS=("YOLO11x-seg")
|
||||
GROUP_2_MODELS=("YOLO11n-seg" "YOLO11s-seg" "YOLO11m-seg")
|
||||
GROUP_3_MODELS=("YOLOv9e-seg")
|
||||
GROUP_4_MODELS=("YOLOv9c-seg")
|
||||
GROUP_5_MODELS=("YOLOv8x-seg")
|
||||
GROUP_6_MODELS=("YOLOv8l-seg" "YOLO12-seg") # YOLOv9e-seg # 模型太大
|
||||
GROUP_7_MODELS=("YOLOv8n-seg" "YOLOv8s-seg" ) # 训练完毕
|
||||
|
||||
# 1. 从 config.py 中读取 OUTPUTS_DIR 的值
|
||||
OUTPUTS_DIR=$(python -c "from yolo_config import OUTPUTS_DIR; print(OUTPUTS_DIR)")
|
||||
# 检查是否成功获取了 OUTPUTS_DIR
|
||||
if [ -z "$OUTPUTS_DIR" ]; then
|
||||
echo "OUTPUTS_DIR: $OUTPUTS_DIR"
|
||||
echo "Error 1: Could not read OUTPUTS_DIR from yolo_config.py. Exiting."
|
||||
echo "Error 2: Or the directory specified by OUTPUTS_DIR does not exist. Please create it first."
|
||||
# exit 1
|
||||
fi
|
||||
# 2. 定义带有时间戳的日志目录名
|
||||
LOG_DIR_NAME="yolo_train_logs_parallel_$(date +%Y-%m-%d_%H-%M-%S)"
|
||||
# 3. 拼接成最终的完整路径
|
||||
LOG_DIR="$OUTPUTS_DIR/$LOG_DIR_NAME"
|
||||
mkdir -p "${LOG_DIR}"
|
||||
echo "所有模型的日志将保存在 ./${LOG_DIR}/ 目录中。"
|
||||
echo "----------------------------------------------------"
|
||||
|
||||
|
||||
# --- 3. 训练执行函数 ---
|
||||
# 定义一个函数来启动一组训练,以避免代码重复
|
||||
start_training_group() {
|
||||
# 使用 nameref (引用) 来传递数组
|
||||
local -n models=$1
|
||||
local gpus=$2
|
||||
local group_name=$3
|
||||
|
||||
echo ">>> 准备启动 ${group_name} 的训练任务 (后台运行)..."
|
||||
# 遍历指定组中的所有模型
|
||||
for model_key in "${models[@]}"; do
|
||||
echo " -> 正在后台启动模型: ${model_key} on GPUs: ${gpus}"
|
||||
# 使用 '&' 将命令放入后台运行
|
||||
# 通过 --model 参数将模型名称传递给 yolo_train.py
|
||||
CUDA_VISIBLE_DEVICES=${gpus} python yolo_train.py --model "${model_key}" > "${LOG_DIR}/${model_key}.log" 2>&1 &
|
||||
echo " - 模型 ${model_key} 已在后台启动。日志文件: ${LOG_DIR}/${model_key}.log"
|
||||
echo " - 等待 30 秒,确保 GPU 资源稳定分配..."
|
||||
sleep 30
|
||||
done
|
||||
echo ">>> ${group_name} 的所有模型均已启动。"
|
||||
echo "----------------------------------------------------"
|
||||
}
|
||||
|
||||
# --- 4. 依次启动所有训练任务 ---
|
||||
# 脚本将快速地按顺序启动每一组任务到后台
|
||||
start_training_group GROUP_0_MODELS "${GPUS_GROUP_0}" "第零组" # TODO
|
||||
start_training_group GROUP_1_MODELS "${GPUS_GROUP_1}" "第一组"
|
||||
start_training_group GROUP_2_MODELS "${GPUS_GROUP_2}" "第二组"
|
||||
start_training_group GROUP_3_MODELS "${GPUS_GROUP_3}" "第三组"
|
||||
|
||||
|
||||
# --- 5. 等待所有后台任务完成 ---
|
||||
echo ""
|
||||
echo "--- 所有模型均已在后台启动。现在等待所有训练任务完成... ---"
|
||||
# 'wait' 命令会暂停脚本,直到所有由此脚本启动的后台子进程全部执行完毕
|
||||
wait
|
||||
echo "--- 所有后台训练任务已全部完成! ---"
|
||||
|
||||
# 适用于4卡版本
|
||||
start_training_group GROUP_4_MODELS "${GPUS_GROUP_4}" "第四组" # TODO
|
||||
start_training_group GROUP_5_MODELS "${GPUS_GROUP_5}" "第五组"
|
||||
start_training_group GROUP_6_MODELS "${GPUS_GROUP_6}" "第六组"
|
||||
start_training_group GROUP_7_MODELS "${GPUS_GROUP_7}" "第七组"
|
||||
# --- 5. 等待所有后台任务完成 ---
|
||||
echo ""
|
||||
echo "--- 所有模型均已在后台启动。现在等待所有训练任务完成... ---"
|
||||
# 'wait' 命令会暂停脚本,直到所有由此脚本启动的后台子进程全部执行完毕
|
||||
wait
|
||||
echo "--- 所有后台训练任务已全部完成! ---"
|
||||
|
||||
# --- 6. 退出脚本 ---
|
||||
# 训练完成,取消激活 Conda 环境
|
||||
echo "训练流程结束。"
|
||||
# --- 取消网络代理设置 ---
|
||||
echo "正在取消当前终端会话的 HTTP/HTTPS 代理设置..."
|
||||
unset https_proxy
|
||||
unset http_proxy
|
||||
echo "代理已取消。"
|
||||
conda deactivate
|
||||
echo "取消激活 Conda 环境。"
|
||||
100
Seg_All_In_One_YoloModel/※2025_9_21_使用手册
Normal file
100
Seg_All_In_One_YoloModel/※2025_9_21_使用手册
Normal file
@@ -0,0 +1,100 @@
|
||||
############## A. 配置环境 ##############
|
||||
pip install ultralytics grad-cam
|
||||
修改 /home/wkmgc/miniconda3/envs/SMP/lib/python3.9/site-packages/pytorch_grad_cam/base_cam.py
|
||||
将
|
||||
if targets is None:
|
||||
target_categories = np.argmax(outputs.cpu().data.numpy(), axis=-1)
|
||||
targets = [ClassifierOutputTarget(category) for category in target_categories]
|
||||
变为
|
||||
if targets is None:
|
||||
# 循环解包,直到找到第一个非元组的元素,这可以处理嵌套元组,例如 ((Tensor, ...), ...)
|
||||
processed_outputs = outputs
|
||||
print(f"DEBUG: 原始 outputs 类型: {type(processed_outputs)}")
|
||||
while isinstance(processed_outputs, tuple):
|
||||
if len(processed_outputs) == 0:
|
||||
# 如果遇到一个空元组,我们就无法继续,打印错误并退出循环
|
||||
print(f"ERROR: 在解包时遇到空元组。原始 outputs: {outputs}")
|
||||
break
|
||||
print(f"DEBUG: 正在解包元组... 选择第 0 个元素。")
|
||||
processed_outputs = processed_outputs[0] # 逐次选择第一个元素
|
||||
print(f"DEBUG: 最终选定的 outputs 类型: {type(processed_outputs)}")
|
||||
|
||||
# 现在,processed_outputs 应该是我们需要的张量 (Tensor)
|
||||
try:
|
||||
target_categories = np.argmax(processed_outputs.cpu().data.numpy(), axis=-1)
|
||||
targets = [ClassifierOutputTarget(category) for category in target_categories]
|
||||
except AttributeError as e:
|
||||
print(f"ERROR: 最终选定的元素 (类型: {type(processed_outputs)}) 无法处理。")
|
||||
print(f" 它没有 .cpu() 属性。原始错误: {e}")
|
||||
print(f" 请检查您的模型输出结构。原始 outputs: {outputs}")
|
||||
raise e # 重新抛出异常,以便程序停止
|
||||
|
||||
############## B. 数据集构建(可选) ##############
|
||||
参考:./Yolo数据集构建
|
||||
|
||||
############## C. Train 训练程序 ##############
|
||||
# A. 第一次训练开一个梯子【需要下载内容】
|
||||
export https_proxy=http://127.0.0.1:1080 http_proxy=http://127.0.0.1:1080
|
||||
CUDA_VISIBLE_DEVICES=0 python train.py --model {YOLOv8n-seg, YOLOv8s-seg, YOLOv8m-seg, YOLOv8l-seg, YOLOv8x-seg, YOLOv9e-seg, YOLOv9c-seg, YOLO11l-seg, YOLO11n-seg, YOLO11s-seg, YOLO11m-seg, YOLO11x-seg, YOLO12-seg, YOLO12-seg}
|
||||
|
||||
# B. 运行单个训练程序
|
||||
1. 在 dataset.yaml 中修改 训练数据集;CUDA_VISIBLE_DEVICES修改使用显卡;--model 修改算法;
|
||||
2. CUDA_VISIBLE_DEVICES=0 python yolo_train.py --model "YOLOv8n-seg"
|
||||
|
||||
# C. 批量运行训练程序
|
||||
1. train.sh 修改想让它使用的显卡、算法;
|
||||
2. bash yolo_train.sh
|
||||
3. 会生成 ./yolo_logs_parallel_DATE 的终端记录文件;结果先存储在 ../DataSet_Public_outputs/DATASET-Yolo 后自动移动到 ../Hardisk/DATASET_outputs-SegModel 中
|
||||
|
||||
############## D. Predict 推理程序 ##############
|
||||
# A. 预先步骤:将模型同步到 Nas_BackUp_Seg 或 ./Hardisk 文件夹中【cd .. && bash ./Back_Up.sh】
|
||||
# B1. 将最优模型文件从 ./Nas_BackUp_Seg 移动到 ./BestMode_Predict_Results_DataSet_Public 指定文件夹中
|
||||
bash ./Tool_Yolo_Copy_Best_Model.sh --pt_name "best.pt" # "epoch100.pt" # 修改里面的路径
|
||||
# B2. 如果需要处理自定义数据集,请将模型文件夹手动复制为 ./BestMode_Predict_Results_DataSet_Public/DATASET-Yolo 中
|
||||
可以先对原有 ./BestMode_Predict_Results_DataSet_Public/ORI_DATASET-Yolo 临时改名
|
||||
运行 bash ./Tool_Yolo_Copy_Best_Model.sh --pt_name "best.pt" # "epoch100.pt" # 修改里面的路径
|
||||
在将生成的 ./BestMode_Predict_Results_DataSet_Public/ORI_DATASET-Yolo 改为 ./BestMode_Predict_Results_DataSet_Public/DATASET-Yolo
|
||||
# C. 运行单个推理程序
|
||||
1. 在 dataset.yaml 中修改 预测数据集 及 val/test; yolo_config.py 中需修改模型保存路径 PREDICT_ALL_BEST_MODELS_DIR/PREDICT_BEST_MODEL_DIR;CUDA_VISIBLE_DEVICES修改使用显卡;--model 修改算法;
|
||||
2. CUDA_VISIBLE_DEVICES=0 python yolo_predict_V1_NoColor.py --model "YOLOv8n-seg"
|
||||
CUDA_VISIBLE_DEVICES=0 python yolo_predict_V2.py --model "YOLOv8n-seg" --conf 0.2 --pt_name "epoch100.pt" # "epoch100.pt" # "best.pt"
|
||||
# D. 批量运行推理程序(分割图可视化 或 热图可视化)
|
||||
1. yolo_predict.sh 修改想让它使用的显卡、算法;
|
||||
2. bash yolo_predict.sh --conf 0.2 --pt_name "epoch100.pt" # "best.pt"(默认)# 分割图可视化
|
||||
3. bash yolo_predict.sh --conf 0.2 --heatmap_method "All" --pt_name "epoch100.pt" # "best.pt"(默认)# 热图可视化
|
||||
4. 会生成 ./yolo_predict_logs_parallel_DATE 的终端记录文件;结果存储在 ../BestMode_Predict_Results_DataSet_Public/DATASET-Yolo 中
|
||||
# E. 横向对比训练结果
|
||||
1. 将模型结果进行横向对比(生成color_mask、Yolo_result的横向对比)
|
||||
python yolo_predict_V2_compare_all.py --pt_name "all" # "epoch100.pt" # "best.pt" # "all"(默认)
|
||||
|
||||
############## E. Predict_raw_img_Check 检测推理输出图片是否齐全 ##############
|
||||
# 检测 dataset.yaml path/"labels"/test.name 和 yolo_config.PREDICT_BEST_MODEL_DIR/****/predicted_raw_masks 中图片是否匹配
|
||||
1. 在 dataset.yaml 中修改 path、test;在yolo_config.PREDICT_ALL_BEST_MODELS_DIR;CUDA_VISIBLE_DEVICES修改使用显卡;
|
||||
2. python yolo_predict_raw_masks_check.py --pt_name "best.pt" # "epoch100.pt" # "best.pt"(默认)
|
||||
|
||||
############## F. 神经网络热图可视化,目前只支持EigenCAM(无指定类别的方法) ##############
|
||||
# A. 预先步骤:cd ./Yolo可视化测试;
|
||||
python yolo_layer_tester.py --model "YOLO12-seg" --cam_method "GradCAM" --pt_name "best.pt" # 测试各个层是否能成功生成热图,选择合适的层
|
||||
vim ../yolo_predict_visualize_nn.py # 修改 visualize_nn_comprehensive 中 model_key.lower().startswith("yoloXXX") 的部分,添加对应模型的默认层
|
||||
# B. 运行单个热图可视化程序
|
||||
# YOLOv8n-seg,YOLOv8s-seg,YOLOv8m-seg,YOLOv8l-seg,YOLOv8x-seg,YOLOv9c-seg,YOLOv9e-seg,YOLO11n-seg,YOLO11s-seg,YOLO11m-seg,YOLO11l-seg,YOLO11x-seg,YOLO12-seg
|
||||
python yolo_predict_visualize_nn.py --model "YOLO11m-seg" --target_layers "default" --cam_method "All" --pt_name "best.pt" # "epoch100.pt" # "best.pt"
|
||||
|
||||
############## G.※ 快速进行分割实验 ※ ##############
|
||||
# 1. 修改 dataset.yaml 中 训练数据集、修改 yolo_config.py 中 EPOCHS、PATIENCE
|
||||
cd ~/Desktop/Seg/Seg_All_In_One_YoloModel
|
||||
# 2. 批量化训练
|
||||
bash yolo_train.sh
|
||||
# 3. 复制最优模型到预测文件夹
|
||||
bash ./Tool_Yolo_Copy_Best_Model.sh --pt_name "best.pt" && bash ./Tool_Yolo_Copy_Best_Model.sh --pt_name "epoch100.pt"
|
||||
# 4. 批量化预测+热度图可视化
|
||||
bash yolo_predict.sh --conf 0.2 --pt_name "epoch100.pt" && bash yolo_predict.sh --conf 0.2
|
||||
bash ./yolo_predict.sh --heatmap_method "All" && bash ./yolo_predict.sh --pt_name "epoch100.pt" --heatmap_method "All"
|
||||
# 5. 横向对比结果
|
||||
python yolo_predict_V2_compare_all.py
|
||||
# 6. 打包预测结果(不包含*.pt模型文件)
|
||||
cd /home/wkmgc/Desktop/Seg/BestMode_Predict_Results_DataSet_Public/
|
||||
zip -r /home/wkmgc/Desktop/5_My_Gastric_2025_10_29_938-Yolo.zip 5_My_Gastric_2025_10_29_938*-Yolo -x "*.pt"
|
||||
# 7. 打包训练结果(只有.png、.jpg、.csv文件)
|
||||
cd /home/wkmgc/Desktop/Seg/Hardisk/
|
||||
zip -r /home/wkmgc/Desktop/5_My_Gastric_2025_10_29_938-Yolo_train.zip 5_My_Gastric_2025_10_29_938-Yolo -i \*.png \*.jpg \*.csv
|
||||
Reference in New Issue
Block a user