-
Notifications
You must be signed in to change notification settings - Fork 13
/
heatmap.py
94 lines (83 loc) · 3.06 KB
/
heatmap.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import os
import matplotlib
matplotlib.use('agg')
import matplotlib.pyplot as plt
import numpy as np
from model import three_view_net
from utils import load_network
import yaml
import argparse
import torch
import torch.nn as nn
import torchvision
from torchvision import datasets, models, transforms
from PIL import Image
os.environ["CUDA_VISIBLE_DEVICES"] = '2'
parser = argparse.ArgumentParser(description='Training')
parser.add_argument('--data_dir',default='/home/wangtyu/datasets/University-Release/test',type=str, help='./test_data')
parser.add_argument('--name', default='three_view_long_share_d0.75_256_s1_google', type=str, help='save model path')
parser.add_argument('--batchsize', default=1, type=int, help='batchsize')
opt = parser.parse_args()
config_path = os.path.join('./model',opt.name,'opts.yaml')
with open(config_path, 'r') as stream:
config = yaml.load(stream)
opt.fp16 = config['fp16']
opt.PCB = config['PCB']
opt.stride = config['stride']
opt.block = config['block']
opt.views = config['views']
if 'h' in config:
opt.h = config['h']
opt.w = config['w']
if 'nclasses' in config: # tp compatible with old config files
opt.nclasses = config['nclasses']
else:
opt.nclasses = 751
def heatmap2d(img, arr):
# fig = plt.figure()
# ax0 = fig.add_subplot(121, title="Image")
# ax1 = fig.add_subplot(122, title="Heatmap")
# fig, ax = plt.subplots()
# ax[0].imshow(Image.open(img))
plt.figure()
heatmap = plt.imshow(arr, cmap='viridis')
plt.axis('off')
# fig.colorbar(heatmap, fraction=0.046, pad=0.04)
#plt.show()
plt.savefig('heatmap_dbase')
data_transforms = transforms.Compose([
transforms.Resize((opt.h, opt.w), interpolation=3),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# image_datasets = {x: datasets.ImageFolder( os.path.join(opt.data_dir,x) ,data_transforms) for x in ['satellite']}
# dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=opt.batchsize,
# shuffle=False, num_workers=1) for x in ['satellite']}
# imgpath = image_datasets['satellite'].imgs
# print(imgpath)
imgname = 'gallery_drone/0721/image-28.jpeg'
# imgname = 'query_satellite/0721/0721.jpg'
imgpath = os.path.join(opt.data_dir,imgname)
img = Image.open(imgpath)
img = data_transforms(img)
img = torch.unsqueeze(img,0)
print(img.shape)
model, _, epoch = load_network(opt.name, opt)
model = model.eval().cuda()
# data = next(iter(dataloaders['satellite']))
# img, label = data
with torch.no_grad():
x = model.model_3.model.conv1(img.cuda())
x = model.model_3.model.bn1(x)
x = model.model_3.model.relu(x)
x = model.model_3.model.maxpool(x)
x = model.model_3.model.layer1(x)
x = model.model_3.model.layer2(x)
x = model.model_3.model.layer3(x)
output = model.model_3.model.layer4(x)
print(output.shape)
heatmap = output.squeeze().sum(dim=0).cpu().numpy()
print(heatmap.shape)
#test_array = np.arange(100 * 100).reshape(100, 100)
# Result is saved tas `heatmap.png`
heatmap2d(imgpath,heatmap)