import torch
import numpy as np
import cv2
import matplotlib.pyplot as plt
from torchvision.models import vgg16, VGG16_Weights
from torchvision import transforms
from pytorch_grad_cam import GradCAM, GradCAMPlusPlus, ScoreCAM, HiResCAM, XGradCAM, EigenCAM, AblationCAM, FullGrad
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image
# --- モデルのロード ---
model = vgg16(weights=VGG16_Weights.IMAGENET1K_V1)
model.eval()
# --- 対象の層(最後の畳み込み層) ---
target_layer = model.features[-1]
# --- 画像の読み込み ---
img_path = "examples/both.png"
rgb_img = cv2.imread(img_path, 1) # BGRで読み込み
rgb_img = cv2.cvtColor(rgb_img, cv2.COLOR_BGR2RGB) # RGB変換
rgb_img = np.float32(rgb_img) / 255.0
# --- 画像前処理 ---
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
input_tensor = transform(rgb_img).unsqueeze(0)
# --- Grad-CAM初期化 ---
cam = GradCAM(model=model, target_layers=[target_layer])
# --- ターゲットクラス(例:ImageNetクラス281 = タビ猫) ---
targets = [ClassifierOutputTarget(281)]
# --- CAM生成 ---
grayscale_cam = cam(input_tensor=input_tensor, targets=targets)
grayscale_cam = grayscale_cam[0, :] # バッチの先頭を取り出す
# --- CAMをRGB画像に重ねる ---
visualization = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True)
# --- 可視化 ---
plt.figure(figsize=(6, 6))
plt.imshow(visualization)
plt.title("Grad-CAM Visualization (VGG16, class=281)")
plt.axis("off")
plt.show()