-
Notifications
You must be signed in to change notification settings - Fork 0
/
heat.py
104 lines (88 loc) · 3.18 KB
/
heat.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import os
from PIL import Image
import torch
import numpy as np
import cv2
import matplotlib.pyplot as plt
import argparse
import torch
import torch.nn as nn
import torch.optim as optim
from model import CNN, VGG
import configparser
from dataset import Dataset
import os
from tensorboardX import SummaryWriter
import pickle
from train import test
def draw_CAM(model, img_path, save_path, transform=None, visual_heatmap=False):
'''
绘制 Class Activation Map
:param model: 加载好权重的Pytorch model
:param img_path: 测试图片路径
:param save_path: CAM结果保存路径
:param transform: 输入图像预处理方法
:param visual_heatmap: 是否可视化原始heatmap(调用matplotlib)
:return:
'''
# 图像加载&预处理
img = Image.open(img_path).convert('L')
img = img.resize((224,224))
img = np.array(img)
img = torch.from_numpy(img)
if transform:
img = transform(img)
img = img.unsqueeze(0)
# 获取模型输出的feature/score
model.eval()
features = model.features(img)
output = model.classifier(features)
# 为了能读取到中间梯度定义的辅助函数
def extract(g):
global features_grad
features_grad = g
# 预测得分最高的那一类对应的输出score
pred = torch.argmax(output).item()
pred_class = output[:, pred]
features.register_hook(extract)
pred_class.backward() # 计算梯度
grads = features_grad # 获取梯度
pooled_grads = torch.nn.functional.adaptive_avg_pool2d(grads, (1, 1))
# 此处batch size默认为1,所以去掉了第0维(batch size维)
pooled_grads = pooled_grads[0]
features = features[0]
# 512是最后一层feature的通道数
for i in range(512):
features[i, ...] *= pooled_grads[i, ...]
# 以下部分同Keras版实现
heatmap = features.detach().numpy()
heatmap = np.mean(heatmap, axis=0)
heatmap = np.maximum(heatmap, 0)
heatmap /= np.max(heatmap)
# 可视化原始热力图
if visual_heatmap:
plt.matshow(heatmap)
plt.show()
img = cv2.imread(img_path) # 用cv2加载原始图像
heatmap = cv2.resize(heatmap, (img.shape[1], img.shape[0])) # 将热力图的大小调整为与原始图像相同
heatmap = np.uint8(255 * heatmap) # 将热力图转换为RGB格式
heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET) # 将热力图应用于原始图像
superimposed_img = heatmap * 0.4 + img # 这里的0.4是热力图强度因子
cv2.imwrite(save_path, superimposed_img) # 将图像保存到硬盘
if __name__ == '__main__':
# parser
# setting
torch.manual_seed(11)
device = torch.device(0)
# data loader
data_path = './data'
lr = 0.001
model = VGG().to(device)
ckpt = torch.load('./model/vgg.pt')
model.load_state_dict(ckpt['model_state_dict'])
draw_CAM(model=model, img_path='./data/COVID/Covid (1007).png', save_path='heat.png', transform=None, visual_heatmap=True)
#model.eval()
#optimizer = optim.Adam(model.parameters(), lr=lr)
#metric = nn.CrossEntropyLoss().to(device)
#acc = test(model, device, test_loader, metric)
#print(acc)