140 lines
5.9 KiB
Python
140 lines
5.9 KiB
Python
import os
|
||
import cv2
|
||
import numpy as np
|
||
import torch
|
||
from ultralytics import YOLO
|
||
from pytorch_grad_cam import GradCAM, GradCAMPlusPlus, XGradCAM, EigenCAM, HiResCAM, LayerCAM, RandomCAM, EigenGradCAM, KPCA_CAM
|
||
from pytorch_grad_cam.utils.image import show_cam_on_image
|
||
from pytorch_grad_cam.base_cam import BaseCAM
|
||
|
||
class ActivationMaximizationTarget:
|
||
def __init__(self, channel=0):
|
||
self.channel = channel
|
||
|
||
def __call__(self, model_output):
|
||
# model_output: [B, C, H, W]
|
||
return model_output[:, self.channel, :, :].mean()
|
||
|
||
# 1. 中间封装类:只返回 [B, C, H, W] 特征图
|
||
class YoloFeatureExtractor(torch.nn.Module):
|
||
def __init__(self, model):
|
||
super().__init__()
|
||
self.model = model
|
||
|
||
def forward(self, x):
|
||
out = self.model(x)
|
||
|
||
if isinstance(out, (list, tuple)) and len(out) > 1:
|
||
feat_candidates = out[1]
|
||
|
||
if isinstance(feat_candidates, (list, tuple)):
|
||
for i, feat in enumerate(feat_candidates):
|
||
if isinstance(feat, torch.Tensor) and feat.ndim == 4:
|
||
print(f"[DEBUG] 找到特征图: item[1][{i}] -> shape={feat.shape}")
|
||
feat.requires_grad_() # 🔥 关键修复点
|
||
return feat
|
||
else:
|
||
print(f"[DEBUG] item[1][{i}] 类型: {type(feat)}")
|
||
|
||
raise RuntimeError("未找到可用于 CAM 的特征图")
|
||
|
||
# 2. 加载 YOLOv11 模型
|
||
# model_path = r"runs\segment\train2\weights\best.pt"
|
||
model_path = "yolo11n-seg.pt" # 替换为你的模型路径
|
||
model = YOLO(model_path)
|
||
model.model.eval()
|
||
|
||
# 3. 提取 CAM hook 层(最后一个 Conv2d)
|
||
target_layers = []
|
||
for module in model.model.modules():
|
||
if isinstance(module, torch.nn.Conv2d):
|
||
target_layers.append(module)
|
||
if not target_layers:
|
||
raise RuntimeError("未找到卷积层")
|
||
target_layers = [target_layers[-1]] # 使用最后一层 Conv2d
|
||
|
||
# 4. 输出文件夹初始化
|
||
output_root = "result_CAM_Method"
|
||
os.makedirs(output_root, exist_ok=True)
|
||
cam_methods = {
|
||
"GradCAM": GradCAM,
|
||
"GradCAMPlusPlus": GradCAMPlusPlus,
|
||
"XGradCAM": XGradCAM,
|
||
"EigenCAM": EigenCAM,
|
||
"HiResCAM": HiResCAM,
|
||
"LayerCAM": LayerCAM,
|
||
"RandomCAM": RandomCAM,
|
||
"EigenGradCAM": EigenGradCAM,
|
||
"KPCA_CAM": KPCA_CAM
|
||
}
|
||
for method_name in cam_methods:
|
||
os.makedirs(os.path.join(output_root, method_name), exist_ok=True)
|
||
|
||
# 5. 遍历图像
|
||
input_dir = "Data/img_dir"
|
||
for img_name in os.listdir(input_dir):
|
||
if img_name.lower().endswith((".jpg", ".jpeg", ".png", ".bmp")):
|
||
img_path = os.path.join(input_dir, img_name)
|
||
orig_image = cv2.imread(img_path)
|
||
if orig_image is None:
|
||
continue
|
||
orig_image = cv2.cvtColor(orig_image, cv2.COLOR_BGR2RGB)
|
||
orig_h, orig_w = orig_image.shape[:2]
|
||
|
||
# letterbox resize + padding to 640x640
|
||
target_size = 640
|
||
scale = min(target_size / orig_w, target_size / orig_h)
|
||
new_w = int(orig_w * scale)
|
||
new_h = int(orig_h * scale)
|
||
resized_image = cv2.resize(orig_image, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
|
||
pad_w = target_size - new_w
|
||
pad_h = target_size - new_h
|
||
pad_left = pad_w // 2
|
||
pad_right = pad_w - pad_left
|
||
pad_top = pad_h // 2
|
||
pad_bottom = pad_h - pad_top
|
||
pad_color = (114, 114, 114)
|
||
padded_image = cv2.copyMakeBorder(resized_image, pad_top, pad_bottom, pad_left, pad_right,
|
||
cv2.BORDER_CONSTANT, value=pad_color)
|
||
|
||
# 图像归一化 + tensor 转换
|
||
padded_image_float = padded_image.astype(np.float32) / 255.0
|
||
device = next(model.model.parameters()).device
|
||
input_tensor = torch.from_numpy(padded_image_float.transpose(2, 0, 1))[None].to(device)
|
||
|
||
# 6. 用 wrapper 包装模型以兼容 CAM
|
||
wrapped_model = YoloFeatureExtractor(model.model)
|
||
|
||
# 遍历每种 CAM 方法
|
||
for method_name, cam_class in cam_methods.items():
|
||
with cam_class(model=wrapped_model, target_layers=target_layers) as cam:
|
||
target = [ActivationMaximizationTarget(channel=0)]
|
||
cam_result = cam(input_tensor=input_tensor, targets=target)[0]
|
||
|
||
# 如果输出为 3 维(如 [C, H, W]),取通道平均为 [H, W]
|
||
if isinstance(cam_result, torch.Tensor):
|
||
cam_result = cam_result.detach().cpu().numpy()
|
||
if cam_result.ndim == 3:
|
||
cam_result = cam_result.mean(axis=0)
|
||
elif cam_result.ndim != 2:
|
||
raise ValueError(f"[CAM ERROR] Unexpected CAM shape: {cam_result.shape}")
|
||
grayscale_cam = cam_result # 安全赋值
|
||
|
||
cam_cropped = grayscale_cam[pad_top:pad_top + new_h, pad_left:pad_left + new_w]
|
||
if cam_cropped.size == 0:
|
||
continue
|
||
cam_resized = cv2.resize(cam_cropped, (orig_w, orig_h), interpolation=cv2.INTER_LINEAR)
|
||
cam_resized = cam_resized - cam_resized.min()
|
||
cam_resized = cam_resized / (cam_resized.max() + 1e-8)
|
||
|
||
orig_image_float = orig_image.astype(np.float32) / 255.0
|
||
overlay_image = show_cam_on_image(orig_image_float, cam_resized, use_rgb=True)
|
||
heatmap = cv2.applyColorMap(np.uint8(255 * cam_resized), cv2.COLORMAP_JET)
|
||
heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
|
||
|
||
base_name = os.path.splitext(img_name)[0]
|
||
overlay_path = os.path.join(output_root, method_name, f"{base_name}_overlay.jpg")
|
||
heatmap_path = os.path.join(output_root, method_name, f"{base_name}_heatmap.jpg")
|
||
cv2.imwrite(overlay_path, cv2.cvtColor(overlay_image, cv2.COLOR_RGB2BGR))
|
||
cv2.imwrite(heatmap_path, cv2.cvtColor(heatmap, cv2.COLOR_RGB2BGR))
|