190 lines
8.5 KiB
Python
190 lines
8.5 KiB
Python
import os
|
||
import cv2
|
||
import numpy as np
|
||
import torch
|
||
from ultralytics import YOLO
|
||
from pytorch_grad_cam import (
|
||
GradCAM, GradCAMPlusPlus, XGradCAM, EigenCAM,
|
||
HiResCAM, LayerCAM, RandomCAM, EigenGradCAM, KPCA_CAM
|
||
)
|
||
from pytorch_grad_cam.utils.image import show_cam_on_image
|
||
from pytorch_grad_cam.base_cam import BaseCAM
|
||
|
||
|
||
class ActivationMaximizationTarget:
|
||
def __init__(self, channel=0):
|
||
self.channel = channel
|
||
def __call__(self, model_output):
|
||
if model_output.ndim == 4:
|
||
# [B, C, H, W]
|
||
return model_output[:, self.channel, :, :].mean()
|
||
elif model_output.ndim == 3:
|
||
# [C, H, W]
|
||
return model_output[self.channel, :, :].mean()
|
||
else:
|
||
raise ValueError(f"Unsupported model_output shape: {model_output.shape}")
|
||
|
||
|
||
class YoloFeatureExtractor(torch.nn.Module):
|
||
def __init__(self, model):
|
||
super().__init__()
|
||
self.model = model
|
||
|
||
def forward(self, x):
|
||
out = self.model(x)
|
||
if isinstance(out, (list, tuple)) and len(out) > 1:
|
||
if isinstance(out, (list, tuple)):
|
||
feat_candidates = out[1] if len(out) > 1 else out[0]
|
||
if isinstance(feat_candidates, (list, tuple)):
|
||
for i, feat in enumerate(feat_candidates):
|
||
if isinstance(feat, torch.Tensor) and feat.ndim == 4:
|
||
feat.requires_grad_()
|
||
return feat
|
||
raise RuntimeError("未找到可用于 CAM 的特征图")
|
||
|
||
|
||
# ------------------------------ 设置路径 -------------------------------
|
||
# model_path = "yolo11n-seg.pt" # ← 替换为你的模型路径
|
||
model_path = r"runs\segment\train6\weights\best.pt"
|
||
model_name = os.path.splitext(os.path.basename(model_path))[0]
|
||
output_root = f"result_CAM_Method_{model_name}"
|
||
input_dir = r"Data\images\train"
|
||
|
||
# ------------------------------ 加载模型 -------------------------------
|
||
model = YOLO(model_path)
|
||
model.model.eval()
|
||
device = next(model.model.parameters()).device
|
||
|
||
# ------------------------------ 提取所有卷积层 -------------------------------
|
||
conv_layers = []
|
||
for idx, layer in enumerate(model.model.modules()):
|
||
if isinstance(layer, torch.nn.Conv2d):
|
||
conv_layers.append((idx, layer))
|
||
# conv_layers.reverse()
|
||
if not conv_layers:
|
||
raise RuntimeError("未找到卷积层")
|
||
|
||
# ------------------------------ 设置 CAM 方法 -------------------------------
|
||
cam_methods = {
|
||
"GradCAM": GradCAM,
|
||
"GradCAMPlusPlus": GradCAMPlusPlus,
|
||
"XGradCAM": XGradCAM,
|
||
# "EigenCAM": EigenCAM, # 风险较高,容易炸内存
|
||
"HiResCAM": HiResCAM,
|
||
"LayerCAM": LayerCAM,
|
||
"RandomCAM": RandomCAM,
|
||
"EigenGradCAM": EigenGradCAM,
|
||
# "KPCA_CAM": KPCA_CAM # 风险较高,容易炸内存
|
||
}
|
||
|
||
# 创建目录
|
||
for method in cam_methods:
|
||
os.makedirs(os.path.join(output_root, method), exist_ok=True)
|
||
|
||
# ------------------------------ 处理图像 -------------------------------
|
||
for img_name in os.listdir(input_dir):
|
||
if not img_name.lower().endswith((".jpg", ".jpeg", ".png", ".bmp")):
|
||
continue
|
||
|
||
img_path = os.path.join(input_dir, img_name)
|
||
orig_image = cv2.imread(img_path)
|
||
if orig_image is None:
|
||
continue
|
||
orig_image = cv2.cvtColor(orig_image, cv2.COLOR_BGR2RGB)
|
||
orig_h, orig_w = orig_image.shape[:2]
|
||
|
||
# Resize + padding
|
||
target_size = 640
|
||
scale = min(target_size / orig_w, target_size / orig_h)
|
||
new_w, new_h = int(orig_w * scale), int(orig_h * scale)
|
||
resized_image = cv2.resize(orig_image, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
|
||
pad_w, pad_h = target_size - new_w, target_size - new_h
|
||
pad_left, pad_right = pad_w // 2, pad_w - pad_w // 2
|
||
pad_top, pad_bottom = pad_h // 2, pad_h - pad_h // 2
|
||
padded_image = cv2.copyMakeBorder(resized_image, pad_top, pad_bottom, pad_left, pad_right,
|
||
cv2.BORDER_CONSTANT, value=(114, 114, 114))
|
||
padded_image_float = padded_image.astype(np.float32) / 255.0
|
||
input_tensor = torch.from_numpy(padded_image_float.transpose(2, 0, 1))[None].to(device)
|
||
input_tensor.requires_grad_() # ★★★ 加这行!
|
||
|
||
wrapped_model = YoloFeatureExtractor(model.model)
|
||
|
||
# -------------------- 遍历每一层 + 每个方法 ---------------------
|
||
for layer_idx, layer in conv_layers:
|
||
print(f"\nProcessing Layer {layer_idx}: {layer.__class__.__name__}")
|
||
for method_name, cam_class in cam_methods.items():
|
||
try:
|
||
# 方法执行前预检查特征图尺寸
|
||
if cam_class in [EigenCAM, KPCA_CAM]:
|
||
# 特征图太大提前跳过
|
||
try:
|
||
# 临时 forward 一次拿特征图大小
|
||
with torch.no_grad():
|
||
feat = layer(input_tensor)
|
||
feat_shape = feat.shape # [B, C, H, W]
|
||
numel = feat.numel()
|
||
if numel > 4096 ** 2: # 超过 16M 元素就跳过
|
||
print(f"[SKIP] {method_name} on Layer {layer_idx}: 特征图过大 shape={feat_shape}")
|
||
continue
|
||
except Exception as e:
|
||
print(f"[SKIP] {method_name} on Layer {layer_idx}: 特征图检查失败: {e}")
|
||
continue
|
||
print(f" Using {method_name}...")
|
||
|
||
# Way 1: 使用 wrapper 包装模型
|
||
# with cam_class(model=wrapped_model, target_layers=[layer]) as cam:
|
||
# targets = [ActivationMaximizationTarget(channel=0)]
|
||
|
||
# Way 2:
|
||
with cam_class(model=model.model, target_layers=[layer]) as cam:
|
||
targets = [ActivationMaximizationTarget(channel=0)]
|
||
|
||
cam_output = cam(input_tensor=input_tensor, targets=targets)
|
||
|
||
# 正确顺序:先判断类型,再使用变量
|
||
if isinstance(cam_output, (list, tuple)):
|
||
cam_result = cam_output[0]
|
||
else:
|
||
cam_result = cam_output
|
||
|
||
# Tensor → Numpy
|
||
if isinstance(cam_result, torch.Tensor):
|
||
cam_result = cam_result.detach().cpu().numpy()
|
||
|
||
# 处理不同维度
|
||
if cam_result.ndim == 4:
|
||
cam_result = cam_result[0].mean(axis=0)
|
||
elif cam_result.ndim == 3:
|
||
cam_result = cam_result.mean(axis=0)
|
||
elif cam_result.ndim == 2:
|
||
pass # OK
|
||
else:
|
||
print(f"[SKIP] {method_name} on Layer {layer_idx}:CAM 结果维度异常 {cam_result.shape}")
|
||
continue
|
||
|
||
# EigenCAM 特征图过大保护
|
||
if cam_class in [EigenCAM, KPCA_CAM] and cam_result.size > 4096**2:
|
||
print(f"[SKIP] {method_name} on Layer {layer_idx} 特征图太大,跳过")
|
||
continue
|
||
|
||
# CAM 后处理
|
||
cam_cropped = cam_result[pad_top:pad_top + new_h, pad_left:pad_left + new_w]
|
||
cam_resized = cv2.resize(cam_cropped, (orig_w, orig_h), interpolation=cv2.INTER_LINEAR)
|
||
cam_resized = (cam_resized - cam_resized.min()) / (cam_resized.max() + 1e-8)
|
||
|
||
overlay_image = show_cam_on_image(orig_image.astype(np.float32) / 255.0,
|
||
cam_resized, use_rgb=True)
|
||
heatmap = cv2.applyColorMap(np.uint8(255 * cam_resized), cv2.COLORMAP_JET)
|
||
heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
|
||
|
||
layer_name = layer.__class__.__name__
|
||
base = os.path.splitext(img_name)[0]
|
||
fname = f"{layer_idx}_{layer_name}_{base}"
|
||
overlay_path = os.path.join(output_root, method_name, f"{fname}_overlay.jpg")
|
||
heatmap_path = os.path.join(output_root, method_name, f"{fname}_heatmap.jpg")
|
||
cv2.imwrite(overlay_path, cv2.cvtColor(overlay_image, cv2.COLOR_RGB2BGR))
|
||
cv2.imwrite(heatmap_path, cv2.cvtColor(heatmap, cv2.COLOR_RGB2BGR))
|
||
|
||
except Exception as e:
|
||
print(f"[ERROR] {method_name} on Layer {layer_idx} failed: {e}")
|