Files
Seg_Data_Server/Tool-可视化/yolov11_heatmap_V2.py
2026-05-20 15:05:35 +08:00

190 lines
8.5 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import os
import cv2
import numpy as np
import torch
from ultralytics import YOLO
from pytorch_grad_cam import (
GradCAM, GradCAMPlusPlus, XGradCAM, EigenCAM,
HiResCAM, LayerCAM, RandomCAM, EigenGradCAM, KPCA_CAM
)
from pytorch_grad_cam.utils.image import show_cam_on_image
from pytorch_grad_cam.base_cam import BaseCAM
class ActivationMaximizationTarget:
def __init__(self, channel=0):
self.channel = channel
def __call__(self, model_output):
if model_output.ndim == 4:
# [B, C, H, W]
return model_output[:, self.channel, :, :].mean()
elif model_output.ndim == 3:
# [C, H, W]
return model_output[self.channel, :, :].mean()
else:
raise ValueError(f"Unsupported model_output shape: {model_output.shape}")
class YoloFeatureExtractor(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(self, x):
out = self.model(x)
if isinstance(out, (list, tuple)) and len(out) > 1:
if isinstance(out, (list, tuple)):
feat_candidates = out[1] if len(out) > 1 else out[0]
if isinstance(feat_candidates, (list, tuple)):
for i, feat in enumerate(feat_candidates):
if isinstance(feat, torch.Tensor) and feat.ndim == 4:
feat.requires_grad_()
return feat
raise RuntimeError("未找到可用于 CAM 的特征图")
# ------------------------------ 设置路径 -------------------------------
# model_path = "yolo11n-seg.pt" # ← 替换为你的模型路径
model_path = r"runs\segment\train6\weights\best.pt"
model_name = os.path.splitext(os.path.basename(model_path))[0]
output_root = f"result_CAM_Method_{model_name}"
input_dir = r"Data\images\train"
# ------------------------------ 加载模型 -------------------------------
model = YOLO(model_path)
model.model.eval()
device = next(model.model.parameters()).device
# ------------------------------ 提取所有卷积层 -------------------------------
conv_layers = []
for idx, layer in enumerate(model.model.modules()):
if isinstance(layer, torch.nn.Conv2d):
conv_layers.append((idx, layer))
# conv_layers.reverse()
if not conv_layers:
raise RuntimeError("未找到卷积层")
# ------------------------------ 设置 CAM 方法 -------------------------------
cam_methods = {
"GradCAM": GradCAM,
"GradCAMPlusPlus": GradCAMPlusPlus,
"XGradCAM": XGradCAM,
# "EigenCAM": EigenCAM, # 风险较高,容易炸内存
"HiResCAM": HiResCAM,
"LayerCAM": LayerCAM,
"RandomCAM": RandomCAM,
"EigenGradCAM": EigenGradCAM,
# "KPCA_CAM": KPCA_CAM # 风险较高,容易炸内存
}
# 创建目录
for method in cam_methods:
os.makedirs(os.path.join(output_root, method), exist_ok=True)
# ------------------------------ 处理图像 -------------------------------
for img_name in os.listdir(input_dir):
if not img_name.lower().endswith((".jpg", ".jpeg", ".png", ".bmp")):
continue
img_path = os.path.join(input_dir, img_name)
orig_image = cv2.imread(img_path)
if orig_image is None:
continue
orig_image = cv2.cvtColor(orig_image, cv2.COLOR_BGR2RGB)
orig_h, orig_w = orig_image.shape[:2]
# Resize + padding
target_size = 640
scale = min(target_size / orig_w, target_size / orig_h)
new_w, new_h = int(orig_w * scale), int(orig_h * scale)
resized_image = cv2.resize(orig_image, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
pad_w, pad_h = target_size - new_w, target_size - new_h
pad_left, pad_right = pad_w // 2, pad_w - pad_w // 2
pad_top, pad_bottom = pad_h // 2, pad_h - pad_h // 2
padded_image = cv2.copyMakeBorder(resized_image, pad_top, pad_bottom, pad_left, pad_right,
cv2.BORDER_CONSTANT, value=(114, 114, 114))
padded_image_float = padded_image.astype(np.float32) / 255.0
input_tensor = torch.from_numpy(padded_image_float.transpose(2, 0, 1))[None].to(device)
input_tensor.requires_grad_() # ★★★ 加这行!
wrapped_model = YoloFeatureExtractor(model.model)
# -------------------- 遍历每一层 + 每个方法 ---------------------
for layer_idx, layer in conv_layers:
print(f"\nProcessing Layer {layer_idx}: {layer.__class__.__name__}")
for method_name, cam_class in cam_methods.items():
try:
# 方法执行前预检查特征图尺寸
if cam_class in [EigenCAM, KPCA_CAM]:
# 特征图太大提前跳过
try:
# 临时 forward 一次拿特征图大小
with torch.no_grad():
feat = layer(input_tensor)
feat_shape = feat.shape # [B, C, H, W]
numel = feat.numel()
if numel > 4096 ** 2: # 超过 16M 元素就跳过
print(f"[SKIP] {method_name} on Layer {layer_idx}: 特征图过大 shape={feat_shape}")
continue
except Exception as e:
print(f"[SKIP] {method_name} on Layer {layer_idx}: 特征图检查失败: {e}")
continue
print(f" Using {method_name}...")
# Way 1: 使用 wrapper 包装模型
# with cam_class(model=wrapped_model, target_layers=[layer]) as cam:
# targets = [ActivationMaximizationTarget(channel=0)]
# Way 2:
with cam_class(model=model.model, target_layers=[layer]) as cam:
targets = [ActivationMaximizationTarget(channel=0)]
cam_output = cam(input_tensor=input_tensor, targets=targets)
# 正确顺序:先判断类型,再使用变量
if isinstance(cam_output, (list, tuple)):
cam_result = cam_output[0]
else:
cam_result = cam_output
# Tensor → Numpy
if isinstance(cam_result, torch.Tensor):
cam_result = cam_result.detach().cpu().numpy()
# 处理不同维度
if cam_result.ndim == 4:
cam_result = cam_result[0].mean(axis=0)
elif cam_result.ndim == 3:
cam_result = cam_result.mean(axis=0)
elif cam_result.ndim == 2:
pass # OK
else:
print(f"[SKIP] {method_name} on Layer {layer_idx}CAM 结果维度异常 {cam_result.shape}")
continue
# EigenCAM 特征图过大保护
if cam_class in [EigenCAM, KPCA_CAM] and cam_result.size > 4096**2:
print(f"[SKIP] {method_name} on Layer {layer_idx} 特征图太大,跳过")
continue
# CAM 后处理
cam_cropped = cam_result[pad_top:pad_top + new_h, pad_left:pad_left + new_w]
cam_resized = cv2.resize(cam_cropped, (orig_w, orig_h), interpolation=cv2.INTER_LINEAR)
cam_resized = (cam_resized - cam_resized.min()) / (cam_resized.max() + 1e-8)
overlay_image = show_cam_on_image(orig_image.astype(np.float32) / 255.0,
cam_resized, use_rgb=True)
heatmap = cv2.applyColorMap(np.uint8(255 * cam_resized), cv2.COLORMAP_JET)
heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
layer_name = layer.__class__.__name__
base = os.path.splitext(img_name)[0]
fname = f"{layer_idx}_{layer_name}_{base}"
overlay_path = os.path.join(output_root, method_name, f"{fname}_overlay.jpg")
heatmap_path = os.path.join(output_root, method_name, f"{fname}_heatmap.jpg")
cv2.imwrite(overlay_path, cv2.cvtColor(overlay_image, cv2.COLOR_RGB2BGR))
cv2.imwrite(heatmap_path, cv2.cvtColor(heatmap, cv2.COLOR_RGB2BGR))
except Exception as e:
print(f"[ERROR] {method_name} on Layer {layer_idx} failed: {e}")