From 8ec312c466c0f7456ac18746859d25d6232c12f1 Mon Sep 17 00:00:00 2001 From: Marc Bresson <50196352+MarcBresson@users.noreply.github.com> Date: Wed, 13 Sep 2023 09:46:39 +0200 Subject: [PATCH] Added compatibility with uint8 to SSIM metric (#3045) * feat: add compatibility with uint8 * style: format using the run_code_style script * refactor: delete warning and independantly convert y and y_pred to fp * feat: remove uint8 warning test --------- Co-authored-by: vfdev --- ignite/metrics/ssim.py | 7 +++++++ tests/ignite/metrics/test_ssim.py | 31 +++++++++++++++++++++++++++++++ 2 files changed, 38 insertions(+) diff --git a/ignite/metrics/ssim.py b/ignite/metrics/ssim.py index 25757a94f7c..6824c0b3f37 100644 --- a/ignite/metrics/ssim.py +++ b/ignite/metrics/ssim.py @@ -99,6 +99,7 @@ def __init__( super(SSIM, self).__init__(output_transform=output_transform, device=device) self.gaussian = gaussian + self.data_range = data_range self.c1 = (k1 * data_range) ** 2 self.c2 = (k2 * data_range) ** 2 self.pad_h = (self.kernel_size[0] - 1) // 2 @@ -157,6 +158,12 @@ def update(self, output: Sequence[torch.Tensor]) -> None: f"Expected y_pred and y to have BxCxHxW shape. Got y_pred: {y_pred.shape} and y: {y.shape}." ) + # converts potential integer tensor to fp + if not y.is_floating_point(): + y = y.float() + if not y_pred.is_floating_point(): + y_pred = y_pred.float() + nb_channel = y_pred.size(1) if self._kernel is None or self._kernel.shape[0] != nb_channel: self._kernel = self._kernel_2d.expand(nb_channel, 1, -1, -1) diff --git a/tests/ignite/metrics/test_ssim.py b/tests/ignite/metrics/test_ssim.py index cc43addb0d2..e81d9abf962 100644 --- a/tests/ignite/metrics/test_ssim.py +++ b/tests/ignite/metrics/test_ssim.py @@ -222,6 +222,37 @@ def test_cuda_ssim_dtypes(available_device, dtype, precision): compare_ssim_ignite_skiimg(y_pred, y, available_device, precision) +@pytest.mark.parametrize( + "shape, kernel_size, gaussian, use_sample_covariance", + [[(8, 3, 224, 224), 7, False, True], [(12, 3, 28, 28), 11, True, False]], +) +def test_ssim_uint8(available_device, shape, kernel_size, gaussian, use_sample_covariance): + y_pred = torch.randint(0, 255, shape, device=available_device, dtype=torch.uint8) + y = (y_pred * 0.8).to(dtype=torch.uint8) + + sigma = 1.5 + data_range = 255 + ssim = SSIM(data_range=data_range, sigma=sigma, device=available_device) + ssim.update((y_pred, y)) + ignite_ssim = ssim.compute() + + skimg_pred = y_pred.cpu().numpy() + skimg_y = (skimg_pred * 0.8).astype(np.uint8) + skimg_ssim = ski_ssim( + skimg_pred, + skimg_y, + win_size=kernel_size, + sigma=sigma, + channel_axis=1, + gaussian_weights=gaussian, + data_range=data_range, + use_sample_covariance=use_sample_covariance, + ) + + assert isinstance(ignite_ssim, float) + assert np.allclose(ignite_ssim, skimg_ssim, atol=1e-5) + + @pytest.mark.parametrize("metric_device", ["cpu", "process_device"]) def test_distrib_integration(distributed, metric_device): from ignite.engine import Engine