-
Notifications
You must be signed in to change notification settings - Fork 0
/
eval.py
124 lines (102 loc) · 4.14 KB
/
eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import numpy as np
import lpips
import json
import cv2
import torch
import scipy
import scipy.signal
from glob import glob
from os.path import join as pjoin
from skimage import io as imageio
from skimage.metrics import peak_signal_noise_ratio
import click
def glob_images(image_dir):
ret = []
for suff in ['*.jpg', '*.JPG', '*.png', '*.PNG']:
ret += glob(pjoin(image_dir, suff))
return sorted(ret)
def to_torch_image(img):
return ((torch.from_numpy(img).to(torch.float32) / 255.0) * 2. - 1.).to(torch.device('cuda')).permute(2, 0, 1)[None]
def rgb_ssim(img0, img1, max_val,
filter_size=11,
filter_sigma=1.5,
k1=0.01,
k2=0.03,
return_map=False):
# Modified from https://github.com/google/mipnerf/blob/16e73dfdb52044dcceb47cda5243a686391a6e0f/internal/math.py#L58
assert len(img0.shape) == 3
assert img0.shape[-1] == 3
assert img0.shape == img1.shape
# Construct a 1D Gaussian blur filter.
hw = filter_size // 2
shift = (2 * hw - filter_size + 1) / 2
f_i = ((np.arange(filter_size) - hw + shift) / filter_sigma)**2
filt = np.exp(-0.5 * f_i)
filt /= np.sum(filt)
# Blur in x and y (faster than the 2D convolution).
def convolve2d(z, f):
return scipy.signal.convolve2d(z, f, mode='valid')
filt_fn = lambda z: np.stack([
convolve2d(convolve2d(z[...,i], filt[:, None]), filt[None, :])
for i in range(z.shape[-1])], -1)
mu0 = filt_fn(img0)
mu1 = filt_fn(img1)
mu00 = mu0 * mu0
mu11 = mu1 * mu1
mu01 = mu0 * mu1
sigma00 = filt_fn(img0**2) - mu00
sigma11 = filt_fn(img1**2) - mu11
sigma01 = filt_fn(img0 * img1) - mu01
# Clip the variances and covariances to valid values.
# Variance must be non-negative:
sigma00 = np.maximum(0., sigma00)
sigma11 = np.maximum(0., sigma11)
sigma01 = np.sign(sigma01) * np.minimum(
np.sqrt(sigma00 * sigma11), np.abs(sigma01))
c1 = (k1 * max_val)**2
c2 = (k2 * max_val)**2
numer = (2 * mu01 + c1) * (2 * sigma01 + c2)
denom = (mu00 + mu11 + c1) * (sigma00 + sigma11 + c2)
ssim_map = numer / denom
ssim = np.mean(ssim_map)
return ssim_map if return_map else ssim
def run_eval(pred_dir, target_dir, info_json, vis=False):
loss_fn_vgg = lpips.LPIPS(net='vgg').to(torch.device('cuda'))
tg_image_paths = sorted(glob_images(target_dir))
pd_image_paths = sorted(glob_images(pred_dir))
assert len(tg_image_paths) == len(pd_image_paths)
psnr_tot, ssim_tot, lpips_tot = 0., 0., 0.
info_data = { 'psnr': dict(), 'ssim': dict(), 'lpips': dict() }
for i, (tg_path, pd_path) in enumerate(zip(tg_image_paths, pd_image_paths)):
tg_image = imageio.imread(tg_path)[:,:,:3]
pd_image = imageio.imread(pd_path)[:,:,:3]
assert np.array(tg_image.shape == pd_image.shape).all()
if vis:
res = np.abs(tg_image.astype(np.float16) - pd_image.astype(np.float16))
cat_ret = np.concatenate([pd_image,tg_image,res.astype(np.uint8)],axis=0)
cv2.imwrite(f"debug_{i}.png", cv2.cvtColor(cat_ret, cv2.COLOR_RGB2BGR))
psnr = peak_signal_noise_ratio(tg_image, pd_image)
ssim = rgb_ssim(tg_image / 255., pd_image / 255., max_val=1)
lpip = loss_fn_vgg(to_torch_image(tg_image), to_torch_image(pd_image)).cpu().item()
psnr_tot += psnr
ssim_tot += ssim
lpips_tot += lpip
info_data['psnr'][str(i)] = psnr
info_data['ssim'][str(i)] = ssim
info_data['lpips'][str(i)] = lpip
n_images = len(tg_image_paths)
info_data['psnr']['mean'] = psnr_tot / n_images
info_data['ssim']['mean'] = ssim_tot / n_images
info_data['lpips']['mean'] = lpips_tot / n_images
with open(info_json, 'w') as f:
json.dump(info_data, f, indent=2)
@click.command()
@click.option('--scan', type=str)
@click.option('--vis', type=bool, default=False)
def eval_main(scan, vis=False):
pred_dir = f"data/eval/{scan}/pred"
target_dir = f"data/eval/{scan}/target"
info_json = f"data/eval/{scan}-ret.json"
run_eval(pred_dir, target_dir, info_json, vis)
if __name__ == '__main__':
eval_main()