-
Notifications
You must be signed in to change notification settings - Fork 2
/
predict_with_metrics.py
180 lines (138 loc) · 7.21 KB
/
predict_with_metrics.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
import os
os.environ['CUDA_VISIBLE_DEVICES'] = "0"
import argparse
import json
import torch
from tqdm import tqdm
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.nn.functional as F
from utils import write_img, chw_to_hwc
from datasets.Rain_Dataloader import TestData_for_Rain100L_pad, TestData_for_Rain100H_pad, TestData_for_Rain100L, TestData_for_Rain100H
from datasets.Internet_Dataloader import TestData_for_Internet
from datasets.DDN_Dataloader import TestData_for_DDN, TestData_for_DDN_pad
from datasets.DID_Dataloader import TestData_for_DID, TestData_for_DID_pad
from datasets.SPA_Dataloader import TestData_for_SPA_pad, TestData_for_SPA
from pytorch_msssim import ssim
from utils import *
from utils.utils import *
from skimage.metrics import structural_similarity as compare_ssim
from numpy import *
from models import *
parser = argparse.ArgumentParser()
parser.add_argument('--model', default='FADformer', type=str, help='model name')
parser.add_argument('--num_workers', default=8, type=int, help='number of workers')
parser.add_argument('--data_dir', default='./data/', type=str, help='path to dataset')
parser.add_argument('--save_dir', default='./saved_models/', type=str, help='path to models saving')
parser.add_argument('--result_dir', default='./results/', type=str, help='path to results saving')
parser.add_argument('--exp', default='rain200', type=str, help='experiment setting')
args = parser.parse_args()
def test(val_loader_full, network, result_dir):
PSNR_full = AverageMeter()
SSIM_full = AverageMeter()
torch.cuda.empty_cache()
network.eval()
os.makedirs(result_dir, exist_ok=True)
for batch in val_loader_full:
source_img = batch['source'].cuda()
target_img = batch['target'].cuda()
file_name = batch['filename'][0]
h, w = source_img.shape[2], source_img.shape[3]
# Pad the input if not_multiple_of 4
img_multiple_of = 4
height, width = source_img.shape[2], source_img.shape[3]
H, W = ((height + img_multiple_of) // img_multiple_of) * img_multiple_of, (
(width + img_multiple_of) // img_multiple_of) * img_multiple_of
padh = H - height if height % img_multiple_of != 0 else 0
padw = W - width if width % img_multiple_of != 0 else 0
source_img = F.pad(source_img, (0, padw, 0, padh), mode='reflect')
with torch.no_grad():
output = network(source_img).clamp_(0, 1)
# Unpad the output
output = output[:, :, :height, :width]
psnr_full, sim = calculate_psnr_torch(target_img, output)
PSNR_full.update(psnr_full.item(), source_img.size(0))
ssim_full = sim
SSIM_full.update(ssim_full.item(), source_img.size(0))
# if you dont't need to save output, please comment out
out_img = chw_to_hwc(output.detach().cpu().squeeze(0).numpy())
write_img(os.path.join(result_dir, file_name.split('.')[0] + '.png'), out_img)
# os.rename(os.path.join(result_dir, file_name), os.path.join(result_dir, file_name).split('.')[0] + '_' + str(float(psnr_full)) + '_' + str(float(ssim_full)) + '.png')
return PSNR_full.avg, SSIM_full.avg
def test_merge(val_loader_full, network, result_dir):
PSNR_full = AverageMeter()
SSIM_full = AverageMeter()
torch.cuda.empty_cache()
network.eval()
os.makedirs(result_dir, exist_ok=True)
for batch in val_loader_full:
source_img = batch['source'].cuda()
target_img = batch['target'].cuda()
file_name = batch['filename'][0]
B, C, H, W = source_img.shape
# print(H, W)
crop_H, crop_W = H - H % 4, W - W % 4
source1 = source_img[:, :, 0:crop_H, 0:crop_W]
source2 = source_img[:, :, H - crop_H:H, 0:crop_W]
source3 = source_img[:, :, H - crop_H:H, W - crop_W:W]
source4 = source_img[:, :, 0:crop_H, W - crop_W:W]
map1 = torch.zeros([B, C, H, W]).cuda()
map2 = torch.zeros([B, C, H, W]).cuda()
map3 = torch.zeros([B, C, H, W]).cuda()
map4 = torch.zeros([B, C, H, W]).cuda()
map1[:, :, 0:crop_H, 0:crop_W] = 1.
map2[:, :, H - crop_H:H, 0:crop_W] = 1.
map3[:, :, H - crop_H:H, W - crop_W:W] = 1.
map4[:, :, 0:crop_H, W - crop_W:W] = 1.
map = map1 + map2 + map3 + map4
with torch.no_grad():
output1 = network(source1).clamp_(0, 1)
output2 = network(source2).clamp_(0, 1)
output3 = network(source3).clamp_(0, 1)
output4 = network(source4).clamp_(0, 1)
output = torch.zeros([B, C, H, W]).cuda()
output[:, :, 0:crop_H, 0:crop_W] += output1
output[:, :, H - crop_H:H, 0:crop_W] += output2
output[:, :, H - crop_H:H, W - crop_W:W] += output3
output[:, :, 0:crop_H, W - crop_W:W] += output4
output = (output / map).clamp_(0, 1)
psnr_full, sim = calculate_psnr_torch(target_img, output)
PSNR_full.update(psnr_full.item(), source_img.size(0))
ssim_full = sim
SSIM_full.update(ssim_full.item(), source_img.size(0))
# if you dont't need to save output, please comment out
out_img = chw_to_hwc(output.detach().cpu().squeeze(0).numpy())
write_img(os.path.join(result_dir, file_name.split('.')[0] + '.png'), out_img)
# os.rename(os.path.join(result_dir, file_name), os.path.join(result_dir, file_name).split('.')[0] + '_' + str(float(psnr_full)) + '_' + str(float(ssim_full)) + '.jpg')
return PSNR_full.avg, SSIM_full.avg
if __name__ == '__main__':
device_index = [0]
network = eval(args.model)()
network = nn.DataParallel(network, device_ids=device_index).cuda()
network.load_state_dict(torch.load('./pretrain_weights/ddn/FADformer_Rain200H.pth')['state_dict'])
# Rain200H, Rain200L and SPA-Data can use this code for saving images and testing psnr and ssim
# DID and DDN should use the Matlib Code to calculate psnr and ssim, so you can output images by this code, and test images in Matlab
test_dir = '/home/jxy/projects_dir/datasets/Rain100/rain_heavy_test'
# './datasets/Rain100/rain_data_test_Light'
# './datasets/Rain100/rain_heavy_test'
result_dir = './saved_images/Rain200H'
# option1: test with clip image for faster pred speed
'''
test_dataset = TestData_for_Rain100H(4, test_dir_haze)
test_loader = DataLoader(test_dataset,
batch_size=1,
shuffle=True,
num_workers=args.num_workers,
pin_memory=True)
psnr, ssim = test(test_loader, network, result_dir)
print(psnr, ssim)
'''
# option2: test with full image by merge to reproduce the performance table in our paper
test_dataset = TestData_for_Rain100H_pad(test_dir)
test_loader = DataLoader(test_dataset,
batch_size=1,
shuffle=True,
num_workers=args.num_workers,
pin_memory=True)
psnr, ssim = test_merge(test_loader, network, result_dir)
print(psnr, ssim)