Files
Seg_Data_Server/Tool-可视化/yolov11_heatmap_V1.py
2026-05-20 15:05:35 +08:00

140 lines
5.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import os
import cv2
import numpy as np
import torch
from ultralytics import YOLO
from pytorch_grad_cam import GradCAM, GradCAMPlusPlus, XGradCAM, EigenCAM, HiResCAM, LayerCAM, RandomCAM, EigenGradCAM, KPCA_CAM
from pytorch_grad_cam.utils.image import show_cam_on_image
from pytorch_grad_cam.base_cam import BaseCAM
class ActivationMaximizationTarget:
def __init__(self, channel=0):
self.channel = channel
def __call__(self, model_output):
# model_output: [B, C, H, W]
return model_output[:, self.channel, :, :].mean()
# 1. 中间封装类:只返回 [B, C, H, W] 特征图
class YoloFeatureExtractor(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(self, x):
out = self.model(x)
if isinstance(out, (list, tuple)) and len(out) > 1:
feat_candidates = out[1]
if isinstance(feat_candidates, (list, tuple)):
for i, feat in enumerate(feat_candidates):
if isinstance(feat, torch.Tensor) and feat.ndim == 4:
print(f"[DEBUG] 找到特征图: item[1][{i}] -> shape={feat.shape}")
feat.requires_grad_() # 🔥 关键修复点
return feat
else:
print(f"[DEBUG] item[1][{i}] 类型: {type(feat)}")
raise RuntimeError("未找到可用于 CAM 的特征图")
# 2. 加载 YOLOv11 模型
# model_path = r"runs\segment\train2\weights\best.pt"
model_path = "yolo11n-seg.pt" # 替换为你的模型路径
model = YOLO(model_path)
model.model.eval()
# 3. 提取 CAM hook 层(最后一个 Conv2d
target_layers = []
for module in model.model.modules():
if isinstance(module, torch.nn.Conv2d):
target_layers.append(module)
if not target_layers:
raise RuntimeError("未找到卷积层")
target_layers = [target_layers[-1]] # 使用最后一层 Conv2d
# 4. 输出文件夹初始化
output_root = "result_CAM_Method"
os.makedirs(output_root, exist_ok=True)
cam_methods = {
"GradCAM": GradCAM,
"GradCAMPlusPlus": GradCAMPlusPlus,
"XGradCAM": XGradCAM,
"EigenCAM": EigenCAM,
"HiResCAM": HiResCAM,
"LayerCAM": LayerCAM,
"RandomCAM": RandomCAM,
"EigenGradCAM": EigenGradCAM,
"KPCA_CAM": KPCA_CAM
}
for method_name in cam_methods:
os.makedirs(os.path.join(output_root, method_name), exist_ok=True)
# 5. 遍历图像
input_dir = "Data/img_dir"
for img_name in os.listdir(input_dir):
if img_name.lower().endswith((".jpg", ".jpeg", ".png", ".bmp")):
img_path = os.path.join(input_dir, img_name)
orig_image = cv2.imread(img_path)
if orig_image is None:
continue
orig_image = cv2.cvtColor(orig_image, cv2.COLOR_BGR2RGB)
orig_h, orig_w = orig_image.shape[:2]
# letterbox resize + padding to 640x640
target_size = 640
scale = min(target_size / orig_w, target_size / orig_h)
new_w = int(orig_w * scale)
new_h = int(orig_h * scale)
resized_image = cv2.resize(orig_image, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
pad_w = target_size - new_w
pad_h = target_size - new_h
pad_left = pad_w // 2
pad_right = pad_w - pad_left
pad_top = pad_h // 2
pad_bottom = pad_h - pad_top
pad_color = (114, 114, 114)
padded_image = cv2.copyMakeBorder(resized_image, pad_top, pad_bottom, pad_left, pad_right,
cv2.BORDER_CONSTANT, value=pad_color)
# 图像归一化 + tensor 转换
padded_image_float = padded_image.astype(np.float32) / 255.0
device = next(model.model.parameters()).device
input_tensor = torch.from_numpy(padded_image_float.transpose(2, 0, 1))[None].to(device)
# 6. 用 wrapper 包装模型以兼容 CAM
wrapped_model = YoloFeatureExtractor(model.model)
# 遍历每种 CAM 方法
for method_name, cam_class in cam_methods.items():
with cam_class(model=wrapped_model, target_layers=target_layers) as cam:
target = [ActivationMaximizationTarget(channel=0)]
cam_result = cam(input_tensor=input_tensor, targets=target)[0]
# 如果输出为 3 维(如 [C, H, W]),取通道平均为 [H, W]
if isinstance(cam_result, torch.Tensor):
cam_result = cam_result.detach().cpu().numpy()
if cam_result.ndim == 3:
cam_result = cam_result.mean(axis=0)
elif cam_result.ndim != 2:
raise ValueError(f"[CAM ERROR] Unexpected CAM shape: {cam_result.shape}")
grayscale_cam = cam_result # 安全赋值
cam_cropped = grayscale_cam[pad_top:pad_top + new_h, pad_left:pad_left + new_w]
if cam_cropped.size == 0:
continue
cam_resized = cv2.resize(cam_cropped, (orig_w, orig_h), interpolation=cv2.INTER_LINEAR)
cam_resized = cam_resized - cam_resized.min()
cam_resized = cam_resized / (cam_resized.max() + 1e-8)
orig_image_float = orig_image.astype(np.float32) / 255.0
overlay_image = show_cam_on_image(orig_image_float, cam_resized, use_rgb=True)
heatmap = cv2.applyColorMap(np.uint8(255 * cam_resized), cv2.COLORMAP_JET)
heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
base_name = os.path.splitext(img_name)[0]
overlay_path = os.path.join(output_root, method_name, f"{base_name}_overlay.jpg")
heatmap_path = os.path.join(output_root, method_name, f"{base_name}_heatmap.jpg")
cv2.imwrite(overlay_path, cv2.cvtColor(overlay_image, cv2.COLOR_RGB2BGR))
cv2.imwrite(heatmap_path, cv2.cvtColor(heatmap, cv2.COLOR_RGB2BGR))