Skip to content

Commit

Permalink
[metrics] speed up SSIM tests (#1467)
Browse files Browse the repository at this point in the history
* Update setup.cfg

* [metrics] update ssim

* use np.allclose instead of torch.allclose

* Apply suggestions from code review

* extract into _test_ssim
  • Loading branch information
Jeff Yang authored Nov 19, 2020
1 parent 6b2f235 commit 3bde732
Showing 1 changed file with 71 additions and 26 deletions.
97 changes: 71 additions & 26 deletions tests/ignite/metrics/test_ssim.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os

import numpy as np
import pytest
import torch

Expand All @@ -20,14 +21,14 @@ def test_zero_div():


def test_invalid_ssim():
y_pred = torch.rand(16, 1, 32, 32)
y_pred = torch.rand(1, 1, 4, 4)
y = y_pred + 0.125
with pytest.raises(ValueError, match=r"Expected kernel_size to have odd positive number. Got 10."):
ssim = SSIM(data_range=1.0, kernel_size=10)
with pytest.raises(ValueError, match=r"Expected kernel_size to have odd positive number."):
ssim = SSIM(data_range=1.0, kernel_size=2)
ssim.update((y_pred, y))
ssim.compute()

with pytest.raises(ValueError, match=r"Expected kernel_size to have odd positive number. Got -1."):
with pytest.raises(ValueError, match=r"Expected kernel_size to have odd positive number."):
ssim = SSIM(data_range=1.0, kernel_size=-1)
ssim.update((y_pred, y))
ssim.compute()
Expand All @@ -42,38 +43,73 @@ def test_invalid_ssim():
ssim.update((y_pred, y))
ssim.compute()

with pytest.raises(ValueError, match=r"Expected sigma to have positive number."):
ssim = SSIM(data_range=1.0, sigma=(-1, -1))
ssim.update((y_pred, y))
ssim.compute()

with pytest.raises(ValueError, match=r"Argument sigma should be either float or a sequence of float."):
ssim = SSIM(data_range=1.0, sigma=1)
ssim.update((y_pred, y))
ssim.compute()

with pytest.raises(ValueError, match=r"Expected y_pred and y to have the same shape."):
y = y.squeeze(dim=0)
ssim = SSIM(data_range=1.0)
ssim.update((y_pred, y))
ssim.compute()

def test_ssim():
device = "cuda" if torch.cuda.is_available() else "cpu"
ssim = SSIM(data_range=1.0, device=device)
y_pred = torch.rand(16, 3, 64, 64, device=device)
y = y_pred * 0.65
ssim.update((y_pred, y))
with pytest.raises(ValueError, match=r"Expected y_pred and y to have BxCxHxW shape."):
y = y.squeeze(dim=0)
ssim = SSIM(data_range=1.0)
ssim.update((y, y))
ssim.compute()

np_pred = y_pred.permute(0, 2, 3, 1).cpu().numpy()
np_y = np_pred * 0.65
np_ssim = ski_ssim(np_pred, np_y, win_size=11, multichannel=True, gaussian_weights=True, data_range=1.0)
with pytest.raises(TypeError, match=r"Expected y_pred and y to have the same data type."):
y = y.double()
ssim = SSIM(data_range=1.0)
ssim.update((y_pred, y))
ssim.compute()

assert isinstance(ssim.compute(), torch.Tensor)
assert torch.allclose(ssim.compute(), torch.tensor(np_ssim, dtype=torch.float64, device=device), atol=1e-4)

device = "cuda" if torch.cuda.is_available() else "cpu"
ssim = SSIM(data_range=1.0, gaussian=False, kernel_size=7, device=device)
y_pred = torch.rand(16, 3, 227, 227, device=device)
y = y_pred * 0.65
def _test_ssim(y_pred, y, data_range, kernel_size, sigma, gaussian, use_sample_covariance, device):
atol = 7e-5
ssim = SSIM(data_range=data_range, sigma=sigma, device=device)
ssim.update((y_pred, y))
ignite_ssim = ssim.compute()

skimg_pred = y_pred.permute(0, 2, 3, 1).cpu().numpy()
skimg_y = skimg_pred * 0.8
skimg_ssim = ski_ssim(
skimg_pred,
skimg_y,
win_size=kernel_size,
sigma=sigma,
multichannel=True,
gaussian_weights=gaussian,
data_range=data_range,
use_sample_covariance=use_sample_covariance,
)

assert isinstance(ignite_ssim, torch.Tensor)
assert ignite_ssim.dtype == torch.float64
assert ignite_ssim.device == torch.device(device)
assert np.allclose(ignite_ssim.numpy(), skimg_ssim, atol=atol)

np_pred = y_pred.permute(0, 2, 3, 1).cpu().numpy()
np_y = np_pred * 0.65
np_ssim = ski_ssim(np_pred, np_y, win_size=7, multichannel=True, gaussian_weights=False, data_range=1.0)

assert isinstance(ssim.compute(), torch.Tensor)
assert torch.allclose(ssim.compute(), torch.tensor(np_ssim, dtype=torch.float64, device=device), atol=1e-4)
def test_ssim():
device = "cuda" if torch.cuda.is_available() else "cpu"
y_pred = torch.rand(8, 3, 224, 224, device=device)
y = y_pred * 0.8
_test_ssim(
y_pred, y, data_range=1.0, kernel_size=7, sigma=1.5, gaussian=False, use_sample_covariance=True, device=device
)

y_pred = torch.rand(12, 3, 28, 28, device=device)
y = y_pred * 0.8
_test_ssim(
y_pred, y, data_range=1.0, kernel_size=11, sigma=1.5, gaussian=True, use_sample_covariance=False, device=device
)


def _test_distrib_integration(device, tol=1e-4):
Expand Down Expand Up @@ -105,7 +141,16 @@ def update(engine, i):

np_pred = y_pred.permute(0, 2, 3, 1).cpu().numpy()
np_true = np_pred * 0.65
true_res = ski_ssim(np_pred, np_true, win_size=11, multichannel=True, gaussian_weights=True, data_range=1.0)
true_res = ski_ssim(
np_pred,
np_true,
win_size=11,
sigma=1.5,
multichannel=True,
gaussian_weights=True,
data_range=1.0,
use_sample_covariance=False,
)

assert pytest.approx(res, abs=tol) == true_res

Expand Down Expand Up @@ -142,7 +187,7 @@ def _test_distrib_accumulator_device(device):
type(ssim._kernel.device), ssim._kernel.device, type(metric_device), metric_device
)

y_pred = torch.rand(4, 3, 28, 28, dtype=torch.float, device=device)
y_pred = torch.rand(2, 3, 28, 28, dtype=torch.float, device=device)
y = y_pred * 0.65
ssim.update((y_pred, y))

Expand Down

0 comments on commit 3bde732

Please sign in to comment.