first commit
This commit is contained in:
2
Seg_All_In_One_YoloModel/Yolo可视化测试/Yolo可视化测试-使用手册
Normal file
2
Seg_All_In_One_YoloModel/Yolo可视化测试/Yolo可视化测试-使用手册
Normal file
@@ -0,0 +1,2 @@
|
||||
# 本手册主要目的在于判断给Yolo模型的各个层是否能成功生成热图
|
||||
使用方法: python yolo_layer_tester.py --model "YOLOv8m-seg" --cam_method "GradCAM" --pt_name "best.pt"
|
||||
359
Seg_All_In_One_YoloModel/Yolo可视化测试/yolo_layer_tester.py
Normal file
359
Seg_All_In_One_YoloModel/Yolo可视化测试/yolo_layer_tester.py
Normal file
@@ -0,0 +1,359 @@
|
||||
import logging
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from ultralytics import YOLO
|
||||
import random
|
||||
import sys
|
||||
import os
|
||||
from typing import Dict, Any, Tuple, List # [!!] 确保导入 List
|
||||
|
||||
# --- [!! 关键修改 !!] 动态添加项目根目录到 sys.path ---
|
||||
# (来自您之前的目录结构)
|
||||
script_path = Path(__file__).resolve()
|
||||
project_root = script_path.parent.parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
# ----------------------------------------------------
|
||||
|
||||
# 现在下面这行导入就可以正常工作了
|
||||
import yolo_config as config
|
||||
|
||||
# --- [导入 V3 的 CAM 库] ---
|
||||
from pytorch_grad_cam import (
|
||||
GradCAM, GradCAMPlusPlus, XGradCAM, EigenCAM,
|
||||
HiResCAM, LayerCAM, RandomCAM, EigenGradCAM
|
||||
)
|
||||
from pytorch_grad_cam.utils.image import show_cam_on_image
|
||||
from pytorch_grad_cam.base_cam import BaseCAM
|
||||
|
||||
# --- 日志设置 ---
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s [%(levelname)s] %(message)s",
|
||||
handlers=[logging.StreamHandler()]
|
||||
)
|
||||
logging.getLogger("pytorch_grad_cam").setLevel(logging.WARNING)
|
||||
|
||||
# --- [!! 新增 !!] 从 V1 移植的函数 ---
|
||||
def find_trained_models(outputs_dir: Path, model_key: str, pt_name: str) -> List[str]:
|
||||
"""
|
||||
扫描输出目录,查找特定基础模型的所有有效的、已完成的训练项目。
|
||||
(函数来自 yolo_predict_visualize_nn_V1.py)
|
||||
"""
|
||||
trained_models = []
|
||||
if not outputs_dir.is_dir():
|
||||
logging.warning(f"输出目录不存在: {outputs_dir}")
|
||||
return []
|
||||
|
||||
for project_folder in outputs_dir.iterdir():
|
||||
if project_folder.is_dir() and project_folder.name.startswith(model_key + '_'):
|
||||
if (project_folder / 'weights' / pt_name).exists():
|
||||
trained_models.append(project_folder.name)
|
||||
|
||||
trained_models.sort()
|
||||
return trained_models
|
||||
# -------------------------------------
|
||||
|
||||
# --- [保留] V3 的 CAM 方法字典 ---
|
||||
CAM_METHODS = {
|
||||
"GradCAM": GradCAM,
|
||||
"GradCAMPlusPlus": GradCAMPlusPlus,
|
||||
"XGradCAM": XGradCAM,
|
||||
"EigenCAM": EigenCAM,
|
||||
"HiResCAM": HiResCAM,
|
||||
"LayerCAM": LayerCAM,
|
||||
"RandomCAM": RandomCAM,
|
||||
"EigenGradCAM": EigenGradCAM,
|
||||
}
|
||||
|
||||
# --- [保留] V3 的 Target ---
|
||||
class ActivationMaximizationTarget:
|
||||
def __init__(self, channel=0):
|
||||
self.channel = channel
|
||||
def __call__(self, model_output):
|
||||
if model_output.ndim == 4:
|
||||
return model_output[:, self.channel, :, :].mean()
|
||||
elif model_output.ndim == 3:
|
||||
return model_output[self.channel, :, :].mean()
|
||||
else:
|
||||
raise ValueError(f"Unsupported model_output shape: {model_output.shape}")
|
||||
|
||||
# --- [保留] V3 的预处理 ---
|
||||
def preprocess_image(bgr_img: np.ndarray, imgsz: int, device: str) -> (torch.Tensor, Dict[str, Any]):
|
||||
# ... (此函数内容与上一版完全相同,为节省篇幅已折叠) ...
|
||||
orig_h, orig_w = bgr_img.shape[:2]
|
||||
rgb_img = bgr_img[:, :, ::-1]
|
||||
scale = min(imgsz / orig_w, imgsz / orig_h)
|
||||
new_w, new_h = int(orig_w * scale), int(orig_h * scale)
|
||||
resized_image = cv2.resize(rgb_img, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
|
||||
pad_w, pad_h = imgsz - new_w, imgsz - new_h
|
||||
pad_left, pad_right = pad_w // 2, pad_w - pad_w // 2
|
||||
pad_top, pad_bottom = pad_h // 2, pad_h - pad_h // 2
|
||||
padded_image = cv2.copyMakeBorder(resized_image, pad_top, pad_bottom, pad_left, pad_right,
|
||||
cv2.BORDER_CONSTANT, value=(114, 114, 114))
|
||||
img_float = np.float32(padded_image) / 255.0
|
||||
input_tensor = torch.from_numpy(img_float).permute(2, 0, 1).unsqueeze(0).to(device)
|
||||
padding_info = {
|
||||
"pad_top": pad_top, "new_h": new_h,
|
||||
"pad_left": pad_left, "new_w": new_w,
|
||||
"orig_h": orig_h, "orig_w": orig_w
|
||||
}
|
||||
return input_tensor, padding_info
|
||||
|
||||
|
||||
# --- [保留] V1/V3 的关键修复 ---
|
||||
def reshape_transform(outputs):
|
||||
if isinstance(outputs, tuple):
|
||||
return outputs[0]
|
||||
return outputs
|
||||
|
||||
# --- [!! 核心功能 !!] (此函数本身无需修改) ---
|
||||
def generate_heatmaps_for_all_layers(model_path: str, image_path: str, cam_method_name: str, model_name: str):
|
||||
"""
|
||||
加载模型和图像,遍历所有层,并为每个成功的层生成热图。
|
||||
(此函数内容与上一版完全相同,为节省篇幅已折叠)
|
||||
"""
|
||||
|
||||
# --- 1. 获取 CAM 方法 ---
|
||||
cam_class = CAM_METHODS.get(cam_method_name)
|
||||
if not cam_class:
|
||||
logging.error(f"无效的 CAM 方法: {cam_method_name}")
|
||||
return
|
||||
# --- 2. 定义 Targets ---
|
||||
targets = None
|
||||
if cam_method_name != "EigenCAM":
|
||||
targets = [ActivationMaximizationTarget(channel=0)]
|
||||
logging.info(f"已为 {cam_method_name} 设置 ActivationMaximizationTarget(channel=0)。")
|
||||
else:
|
||||
logging.info(f"将使用 {cam_method_name} (无需 Targets)。")
|
||||
# --- 3. 加载模型 ---
|
||||
try:
|
||||
logging.info(f"正在加载模型: {model_path}")
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
model = YOLO(model_path).to(device)
|
||||
model.model.eval()
|
||||
imgsz = model.model.args['imgsz']
|
||||
logging.info(f"模型加载成功。设备: {device}, 图像尺寸: {imgsz}")
|
||||
except Exception as e:
|
||||
logging.error(f"无法加载模型 {model_path}: {e}", exc_info=True)
|
||||
return
|
||||
# --- 4. 加载和预处理图像 ---
|
||||
try:
|
||||
logging.info(f"正在加载测试图像: {image_path}")
|
||||
orig_img_bgr = cv2.imread(image_path)
|
||||
if orig_img_bgr is None:
|
||||
logging.error(f"无法读取图像: {image_path}")
|
||||
return
|
||||
orig_img_rgb_float = np.float32(orig_img_bgr[:, :, ::-1]) / 255.0
|
||||
input_tensor, pad_info = preprocess_image(orig_img_bgr, imgsz, device)
|
||||
input_tensor.requires_grad_()
|
||||
logging.info(f"图像加载并预处理完成。Tensor shape: {input_tensor.shape}")
|
||||
except Exception as e:
|
||||
logging.error(f"处理图像时出错: {e}", exc_info=True)
|
||||
return
|
||||
# --- 5. 创建输出目录 (您的自定义逻辑) ---
|
||||
image_stem = Path(image_path).stem
|
||||
output_dir_name = f"./{cam_method_name}_{model_name}_{image_stem}_All_HeartMap"
|
||||
output_dir = Path(output_dir_name)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
logging.info(f"所有热图将保存到: {output_dir.resolve()}")
|
||||
|
||||
# --- 6. 遍历并测试所有层 ---
|
||||
base_model_layers = model.model.model
|
||||
successful_layers, failed_layers = [], []
|
||||
logging.info(f"\n" + "="*50)
|
||||
logging.info(f"开始遍历 {len(base_model_layers)} 个层 (使用 {cam_method_name})")
|
||||
logging.info("\n" + "="*50 + "\n")
|
||||
|
||||
for i, layer in enumerate(base_model_layers):
|
||||
layer_name = f"model.model.model[{i}]"
|
||||
layer_type = layer.__class__.__name__
|
||||
print(f"--- [ {i:02d} / {len(base_model_layers)-1} ] 尝试层: {layer_name:<25} (类型: {layer_type}) ---")
|
||||
try:
|
||||
with cam_class(model=model.model, target_layers=[layer], reshape_transform=reshape_transform) as cam:
|
||||
grayscale_cam = cam(input_tensor=input_tensor, targets=targets)
|
||||
if grayscale_cam is None: raise ValueError("CAM-Method 返回了 None")
|
||||
|
||||
# ... (结果提取、裁剪、保存逻辑) ...
|
||||
if isinstance(grayscale_cam, (list, tuple)): result = grayscale_cam[0]
|
||||
else: result = grayscale_cam
|
||||
if isinstance(result, torch.Tensor): result = result.detach().cpu().numpy()
|
||||
if result.ndim == 4: result = result[0].mean(axis=0)
|
||||
elif result.ndim == 3: result = result[0, :]
|
||||
elif result.ndim != 2: raise TypeError(f"CAM 输出维度异常 {result.shape}")
|
||||
|
||||
cam_cropped = result[pad_info['pad_top'] : pad_info['pad_top'] + pad_info['new_h'],
|
||||
pad_info['pad_left'] : pad_info['pad_left'] + pad_info['new_w']]
|
||||
cam_upscaled = cv2.resize(cam_cropped, (pad_info['orig_w'], pad_info['orig_h']), interpolation=cv2.INTER_LINEAR)
|
||||
cam_normalized = (cam_upscaled - cam_upscaled.min()) / (cam_upscaled.max() + 1e-8)
|
||||
|
||||
visualization = show_cam_on_image(orig_img_rgb_float, cam_normalized, use_rgb=True, image_weight=0.5)
|
||||
viz_bgr = visualization[:, :, ::-1]
|
||||
heatmap_uint8 = np.uint8(255 * cam_normalized)
|
||||
heatmap_color = cv2.applyColorMap(heatmap_uint8, cv2.COLORMAP_JET)
|
||||
|
||||
base_filename = f"Layer_{i:02d}_{layer_type}"
|
||||
cv2.imwrite(str(output_dir / f"{base_filename}_overlay.jpg"), viz_bgr)
|
||||
# cv2.imwrite(str(output_dir / f"{base_filename}_heatmap.jpg"), heatmap_color)
|
||||
|
||||
print(f" [√] 成功! 已保存 {base_filename}_overlay.jpg 和 _heatmap.jpg")
|
||||
successful_layers.append((i, layer_name, layer_type))
|
||||
|
||||
except Exception as e:
|
||||
error_msg = str(e).split('\n')[0]
|
||||
print(f" [X] 失败. 错误: {error_msg}")
|
||||
failed_layers.append((i, layer_name, layer_type, error_msg))
|
||||
|
||||
# --- 7. 打印总结报告 ---
|
||||
# ... (报告逻辑与上一版相同) ...
|
||||
print("\n" + "="*60)
|
||||
print(" 全层热图生成报告")
|
||||
print("="*60)
|
||||
print(f"模型: {Path(model_path).name} (来自: {Path(model_path).parent.parent})")
|
||||
print(f"测试图像: {Path(image_path).name}")
|
||||
print(f"测试方法: {cam_method_name}")
|
||||
print(f"输出目录: {output_dir.resolve()}")
|
||||
print(f"总计: {len(successful_layers)} 个成功 (已生成), {len(failed_layers)} 个失败/跳过")
|
||||
print("\n" + "--- [√] 成功生成的层 ---")
|
||||
if not successful_layers: print(" (无)")
|
||||
else:
|
||||
for i, name, type in successful_layers: print(f" - [{i:02d}] {name:<25} (类型: {type})")
|
||||
print("\n" + "--- [X] 失败或跳过的层 ---")
|
||||
if not failed_layers: print(" (无)")
|
||||
else:
|
||||
for i, name, type, error in failed_layers:
|
||||
print(f" - [{i:02d}] {name:<25} (类型: {type})")
|
||||
print(f" └> 失败原因: {error[:100]}...")
|
||||
print("\n" + "="*60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
# --- [!! 重构 !!] (以下逻辑来自 yolo_predict_visualize_nn_V1.py) ---
|
||||
|
||||
config.show_config_summary()
|
||||
|
||||
parser = argparse.ArgumentParser(description="全层热图生成器 (V1 模型选择模式)")
|
||||
|
||||
# 1. (来自 V1) --model 参数
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
required=True,
|
||||
choices=list(config.MODEL_CONFIGS.keys()),
|
||||
help="选择一个基础模型类型来筛选其训练历史。"
|
||||
)
|
||||
# 2. (来自 V1) --pt_name 参数
|
||||
parser.add_argument(
|
||||
"--pt_name",
|
||||
type=str,
|
||||
default="best.pt",
|
||||
help="要使用的权重文件名 (例如 'best.pt' 或 'epoch100.pt')。"
|
||||
)
|
||||
# 3. (来自您) --source 参数
|
||||
parser.add_argument(
|
||||
"--source",
|
||||
type=str,
|
||||
default="random",
|
||||
help="测试图像源: 'random' (从 config.TEST_IMAGE_DIR 随机选), "
|
||||
"'path/to/image.jpg' (指定文件), "
|
||||
"或 'path/to/dir/' (从该目录随机选)"
|
||||
)
|
||||
# 4. (来自您) --cam_method 参数
|
||||
parser.add_argument(
|
||||
"--cam_method",
|
||||
type=str,
|
||||
default="GradCAM",
|
||||
choices=list(CAM_METHODS.keys()),
|
||||
help=f"选择要使用的 CAM 可视化方法。默认为: GradCAM。"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# --- (来自 V1) 模型搜索和选择逻辑 ---
|
||||
available_runs = find_trained_models(config.PREDICT_BEST_MODEL_DIR, args.model, args.pt_name)
|
||||
|
||||
run_to_use = None
|
||||
|
||||
if not available_runs:
|
||||
logging.error(f"错误:在 {config.PREDICT_BEST_MODEL_DIR} 中未找到模型 '{args.model}' (使用 '{args.pt_name}') 的有效训练记录。")
|
||||
sys.exit(1)
|
||||
|
||||
elif len(available_runs) == 1:
|
||||
run_to_use = available_runs[0]
|
||||
logging.info(f"只找到一个训练版本,已自动选择: {run_to_use}")
|
||||
|
||||
else:
|
||||
print(f"\n为模型 '{args.model}' (使用 '{args.pt_name}') 找到多个训练版本:")
|
||||
for i, run_name in enumerate(available_runs, 1):
|
||||
print(f" [{i}] {run_name}")
|
||||
|
||||
while True:
|
||||
try:
|
||||
choice = input(f"请输入您想使用的版本序号 (1-{len(available_runs)}): ")
|
||||
choice_index = int(choice)
|
||||
if 1 <= choice_index <= len(available_runs):
|
||||
run_to_use = available_runs[choice_index - 1]
|
||||
break
|
||||
else:
|
||||
print("错误:输入无效,请输入列表中的序号。")
|
||||
except ValueError:
|
||||
print("错误:请输入一个数字。")
|
||||
except (KeyboardInterrupt, EOFError):
|
||||
print("\n操作已取消。")
|
||||
sys.exit(0)
|
||||
|
||||
# --- (!! 融合 !!) 执行逻辑 ---
|
||||
if run_to_use:
|
||||
# 1. (来自 V1) 确定最终模型路径
|
||||
model_to_use_path = config.PREDICT_BEST_MODEL_DIR / run_to_use / 'weights' / args.pt_name
|
||||
|
||||
if not model_to_use_path.exists():
|
||||
logging.error(f"严重错误:找不到权重文件 {model_to_use_path}。")
|
||||
sys.exit(1)
|
||||
|
||||
# 2. (来自您) 确定最终图像路径
|
||||
source_input = args.source
|
||||
image_to_test = None
|
||||
source_dir_for_random = None
|
||||
|
||||
if source_input.lower() == "random":
|
||||
source_dir_for_random = config.TEST_IMAGE_DIR
|
||||
logging.info(f"Source='random'。将从 config.TEST_IMAGE_DIR 随机选择图像: {source_dir_for_random}")
|
||||
elif Path(source_input).is_file():
|
||||
image_to_test = source_input
|
||||
logging.info(f"Source 是一个特定文件: {image_to_test}")
|
||||
elif Path(source_input).is_dir():
|
||||
source_dir_for_random = Path(source_input)
|
||||
logging.info(f"Source 是一个目录。将从该目录随机选择图像: {source_dir_for_random}")
|
||||
else:
|
||||
logging.error(f"错误: 指定的源路径无效: {source_input}")
|
||||
sys.exit(1)
|
||||
|
||||
# 如果需要随机选择...
|
||||
if image_to_test is None:
|
||||
if not source_dir_for_random or not source_dir_for_random.is_dir():
|
||||
logging.error(f"错误: 用于随机选择的目录无效: {source_dir_for_random}")
|
||||
sys.exit(1)
|
||||
|
||||
image_paths = sorted(
|
||||
list(source_dir_for_random.glob("*.jpg")) +
|
||||
list(source_dir_for_random.glob("*.jpeg")) +
|
||||
list(source_dir_for_random.glob("*.png"))
|
||||
)
|
||||
if not image_paths:
|
||||
logging.error(f"错误: 在目录 {source_dir_for_random} 中未找到任何图像 (.jpg, .jpeg, .png)。")
|
||||
sys.exit(1)
|
||||
|
||||
image_to_test = str(random.choice(image_paths))
|
||||
logging.info(f"已随机选择图像: {Path(image_to_test).name}")
|
||||
|
||||
# 3. (融合) 调用核心函数
|
||||
if image_to_test:
|
||||
generate_heatmaps_for_all_layers(
|
||||
model_path=str(model_to_use_path), # <-- 使用 V1 找到的路径
|
||||
image_path=image_to_test, # <-- 使用您的逻辑找到的路径
|
||||
cam_method_name=args.cam_method, # <-- 使用您的参数
|
||||
model_name = args.model
|
||||
)
|
||||
Reference in New Issue
Block a user