-
Notifications
You must be signed in to change notification settings - Fork 0
/
OSMNet_main.py
165 lines (122 loc) · 5.12 KB
/
OSMNet_main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
#!/usr/bin/python2 -utt
# -*- coding: utf-8 -*-
"""
If you use this code, please cite
@article{
author = {Han Zhang, Lin Lei, Weiping Ni, Tao Tang, Junzheng Wu, Deliang Xiang, Gangyao Kuang},
title = "{Explore Better Network Framework for High Resolution Optical and SAR Image Matching}",
year = 2021}
(c) 2021 by Han Zhang
"""
from __future__ import division, print_function
import argparse
import torch
import torch.nn.init
import torchvision.transforms as transforms
import torch.backends.cudnn as cudnn
import os
import numpy as np
import cv2
import matplotlib.pyplot as plt
# Training settings
parser = argparse.ArgumentParser(description='PyTorch HardNet_Dense')
# Model options
# parser.add_argument('--use-attention', type=bool, default=False,
# help='use the MultFca Attention layer or not')
parser.add_argument('--resume', default='/models/checkpoint_osmnet_wo.pth', type=str, metavar='PATH',
help='path to trained model without the attention module(default: none)')
parser.add_argument('--resume-att', default='/models/checkpoint_osmnet.pth', type=str, metavar='PATH',
help='path to trained model with MultFca attention module (default: none)')
parser.add_argument('--extent-pos', default=1, type=int,
help='Extent of positive samples on the ground truth map')
parser.add_argument('--search-rad', default=32, type=int,
help='Search radius for fft match')
parser.add_argument('--num-workers', default=0, type=int,
help='Number of workers to be created')
parser.add_argument('--pin-memory', type=bool, default=True,
help='')
parser.add_argument('--mean-image', type=float, default=0.4309,
help='mean of train dataset for normalization')
parser.add_argument('--std-image', type=float, default=0.2236,
help='std of train dataset for normalization')
# Device options
parser.add_argument('--use-cuda', action='store_true', default=True,
help='enables CUDA training')
parser.add_argument('--gpu-id', default='0', type=str,
help='id(s) for CUDA_VISIBLE_DEVICES')
args = parser.parse_args()
# set the device to use by setting CUDA_VISIBLE_DEVICES env variable in
# order to prevent any memory allocation on unused GPUs
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_id
args.cuda = args.use_cuda and torch.cuda.is_available()
print(("NOT " if not args.cuda else "") + "Using cuda")
if args.cuda:
cudnn.benchmark = True
def fft_match_batch(feature_sar, feature_opt, search_rad):
search_rad = int(search_rad)
b, c, w, h = np.shape(feature_sar)
nt = search_rad
T = torch.zeros(np.shape(feature_sar))
T[:, :, 0:h - 2 * nt, 0:w - 2 * nt] = 1
if args.use_cuda:
T = T.cuda()
sen_x = feature_sar ** 2
tmp1 = torch.fft.fft2(sen_x)
tmp2 = torch.fft.fft2(T)
tmp_sum = torch.sum(tmp1 * torch.conj(tmp2), 1)
ssd_f_1 = torch.fft.ifft2(tmp_sum)
ssd_fr_1 = torch.real(ssd_f_1)
ssd_fr_1 = ssd_fr_1[:, 0:2 * nt + 1, 0:2 * nt + 1]
ref_T = feature_opt[:, :, nt:w - nt, nt:h - nt]
ref_Tx = torch.zeros(np.shape(feature_opt))
ref_Tx[:, :, 0:w - 2 * nt, 0:h - 2 * nt] = ref_T
if args.use_cuda:
ref_Tx = ref_Tx.cuda()
tmp1 = torch.fft.fft2(feature_sar)
tmp2 = torch.fft.fft2(ref_Tx)
tmp_sum = torch.sum(tmp1 * torch.conj(tmp2), 1)
ssd_f_2 = torch.fft.ifft2(tmp_sum)
ssd_fr_2 = torch.real(ssd_f_2)
ssd_fr_2 = ssd_fr_2[:, 0:2 * nt + 1, 0:2 * nt + 1]
ssd_batch = (ssd_fr_1 - 2 * ssd_fr_2) / w / h
return ssd_batch
if __name__ == '__main__':
# from ResNetDenseP import ResNetDenseP
# model = ResNetDenseP()
from OSMNet import SSLCNetPseudo # SSLCNetPseudo_Att
model = SSLCNetPseudo()
path_model = args.resume
if os.path.isfile(path_model):
print('=> loading checkpoint {}'.format(path_model))
checkpoint = torch.load(path_model)
args.start_epoch = checkpoint['epoch']
checkpoint = torch.load(path_model)
model.load_state_dict(checkpoint['state_dict'])
else:
print('=> no checkpoint found at {}'.format(args.resume))
# switch to evaluate mode
model.eval()
img_sar = cv2.imread('/DataS/zhanghan_data/OSdataset/256/test/sar50.png', 0)
img_opt = cv2.imread('/DataS/zhanghan_data/OSdataset/256/test/opt50.png', 0)
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((args.mean_image,), (args.std_image,))])
data_sar = transform(img_sar)
data_opt = transform(img_opt)
data_sar = torch.reshape(data_sar, (1, 1, 256, 256))
data_opt = torch.reshape(data_opt, (1, 1, 256, 256))
if args.cuda:
data_sar, data_opt = data_sar.cuda(), data_opt.cuda()
model.cuda()
out_sar, out_opt = model(data_sar, data_opt)
out = fft_match_batch(out_sar, out_opt, args.search_rad)
out_s = torch.squeeze(out)
out_s = out_s.cpu().detach().numpy()
plt.figure()
plt.subplot(2, 2, 1)
plt.imshow(img_opt)
plt.subplot(2, 2, 2)
plt.imshow(img_sar)
plt.subplot(2, 2, 3)
plt.imshow(out_s)
plt.show()