Skip to content

Commit

Permalink
metrics: add SSIM (#2671)
Browse files Browse the repository at this point in the history
* metrics: add SSIM

* Update CHANGELOG.md

fix codefactor issue

fix doctest

fix doctest

fix test

* added test for raise Error
  • Loading branch information
Jeff Yang authored Jul 23, 2020
1 parent d0b8e85 commit bda7cf1
Show file tree
Hide file tree
Showing 8 changed files with 255 additions and 3 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added SSIM metrics ([#2671](https://github.com/PyTorchLightning/pytorch-lightning/pull/2671))
- Added BLEU metrics ([#2535](https://github.com/PyTorchLightning/pytorch-lightning/pull/2535))

- Added support for `Trainer(num_sanity_val_steps=-1)` to check all validation data before training ([#2246](https://github.com/PyTorchLightning/pytorch-lightning/pull/2246))
Expand Down
12 changes: 12 additions & 0 deletions docs/source/metrics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,12 @@ RMSLE
.. autoclass:: pytorch_lightning.metrics.regression.RMSLE
:noindex:

SSIM
^^^^

.. autoclass:: pytorch_lightning.metrics.regression.SSIM
:noindex:

----------------

Functional Metrics
Expand Down Expand Up @@ -403,6 +409,12 @@ psnr (F)
.. autofunction:: pytorch_lightning.metrics.functional.psnr
:noindex:

ssim (F)
^^^^^^^^

.. autofunction:: pytorch_lightning.metrics.functional.ssim
:noindex:

stat_scores_multiple_classes (F)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Expand Down
2 changes: 2 additions & 0 deletions pytorch_lightning/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
PSNR,
RMSE,
RMSLE,
SSIM
)
from pytorch_lightning.metrics.classification import (
Accuracy,
Expand Down Expand Up @@ -54,6 +55,7 @@
"PSNR",
"RMSE",
"RMSLE",
"SSIM"
]
__sequence_metrics = ["BLEUScore"]
__all__ = __regression_metrics + __classification_metrics + ["SklearnMetric"] + __sequence_metrics
1 change: 1 addition & 0 deletions pytorch_lightning/metrics/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,5 +26,6 @@
psnr,
rmse,
rmsle,
ssim
)
from pytorch_lightning.metrics.functional.nlp import bleu_score
115 changes: 115 additions & 0 deletions pytorch_lightning/metrics/functional/regression.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Sequence

import torch
from torch.nn import functional as F

Expand Down Expand Up @@ -182,3 +184,116 @@ def psnr(
psnr_base_e = 2 * torch.log(data_range) - torch.log(mse_score)
psnr = psnr_base_e * (10 / torch.log(torch.tensor(base)))
return psnr


def _gaussian_kernel(channel, kernel_size, sigma, device):
def gaussian(kernel_size, sigma, device):
gauss = torch.arange(
start=(1 - kernel_size) / 2, end=(1 + kernel_size) / 2, step=1, dtype=torch.float32, device=device
)
gauss = torch.exp(-gauss.pow(2) / (2 * pow(sigma, 2)))
return (gauss / gauss.sum()).unsqueeze(dim=0) # (1, kernel_size)

gaussian_kernel_x = gaussian(kernel_size[0], sigma[0], device)
gaussian_kernel_y = gaussian(kernel_size[1], sigma[1], device)
kernel = torch.matmul(gaussian_kernel_x.t(), gaussian_kernel_y) # (kernel_size, 1) * (1, kernel_size)

return kernel.expand(channel, 1, kernel_size[0], kernel_size[1])


def ssim(
pred: torch.Tensor,
target: torch.Tensor,
kernel_size: Sequence[int] = (11, 11),
sigma: Sequence[float] = (1.5, 1.5),
reduction: str = "elementwise_mean",
data_range: float = None,
k1: float = 0.01,
k2: float = 0.03
) -> torch.Tensor:
"""
Computes Structual Similarity Index Measure
Args:
pred: Estimated image
target: Ground truth image
kernel_size: Size of the gaussian kernel. Default: (11, 11)
sigma: Standard deviation of the gaussian kernel. Default: (1.5, 1.5)
reduction: A method for reducing ssim over all elements in the ``pred`` tensor. Default: ``elementwise_mean``
Available reduction methods:
- elementwise_mean: takes the mean
- none: pass away
- sum: add elements
data_range: Range of the image. If ``None``, it is determined from the image (max - min)
k1: Parameter of SSIM. Default: 0.01
k2: Parameter of SSIM. Default: 0.03
Returns:
A Tensor with SSIM
Example:
>>> pred = torch.rand([16, 1, 16, 16])
>>> target = pred * 1.25
>>> ssim(pred, target)
tensor(0.9520)
"""

if pred.dtype != target.dtype:
raise TypeError(
"Expected `pred` and `target` to have the same data type."
f" Got pred: {pred.dtype} and target: {target.dtype}."
)

if pred.shape != target.shape:
raise ValueError(
"Expected `pred` and `target` to have the same shape."
f" Got pred: {pred.shape} and target: {target.shape}."
)

if len(pred.shape) != 4 or len(target.shape) != 4:
raise ValueError(
"Expected `pred` and `target` to have BxCxHxW shape."
f" Got pred: {pred.shape} and target: {target.shape}."
)

if len(kernel_size) != 2 or len(sigma) != 2:
raise ValueError(
"Expected `kernel_size` and `sigma` to have the length of two."
f" Got kernel_size: {len(kernel_size)} and sigma: {len(sigma)}."
)

if any(x % 2 == 0 or x <= 0 for x in kernel_size):
raise ValueError(f"Expected `kernel_size` to have odd positive number. Got {kernel_size}.")

if any(y <= 0 for y in sigma):
raise ValueError(f"Expected `sigma` to have positive number. Got {sigma}.")

if data_range is None:
data_range = max(pred.max() - pred.min(), target.max() - target.min())

C1 = pow(k1 * data_range, 2)
C2 = pow(k2 * data_range, 2)
device = pred.device

channel = pred.size(1)
kernel = _gaussian_kernel(channel, kernel_size, sigma, device)
mu_pred = F.conv2d(pred, kernel, groups=channel)
mu_target = F.conv2d(target, kernel, groups=channel)

mu_pred_sq = mu_pred.pow(2)
mu_target_sq = mu_target.pow(2)
mu_pred_target = mu_pred * mu_target

sigma_pred_sq = F.conv2d(pred * pred, kernel, groups=channel) - mu_pred_sq
sigma_target_sq = F.conv2d(target * target, kernel, groups=channel) - mu_target_sq
sigma_pred_target = F.conv2d(pred * target, kernel, groups=channel) - mu_pred_target

UPPER = 2 * sigma_pred_target + C2
LOWER = sigma_pred_sq + sigma_target_sq + C2

ssim_idx = ((2 * mu_pred_target + C1) * UPPER) / ((mu_pred_sq + mu_target_sq + C1) * LOWER)

return reduce(ssim_idx, reduction)
64 changes: 63 additions & 1 deletion pytorch_lightning/metrics/regression.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from typing import Sequence

import torch

from pytorch_lightning.metrics.functional.regression import (
mae,
mse,
psnr,
rmse,
rmsle
rmsle,
ssim
)
from pytorch_lightning.metrics.metric import Metric

Expand Down Expand Up @@ -229,3 +232,62 @@ def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
A Tensor with psnr score.
"""
return psnr(pred, target, self.data_range, self.base, self.reduction)


class SSIM(Metric):
"""
Computes Structual Similarity Index Measure
Example:
>>> pred = torch.rand([16, 1, 16, 16])
>>> target = pred * 1.25
>>> metric = SSIM()
>>> metric(pred, target)
tensor(0.9520)
"""

def __init__(
self,
kernel_size: Sequence[int] = (11, 11),
sigma: Sequence[float] = (1.5, 1.5),
reduction: str = "elementwise_mean",
data_range: float = None,
k1: float = 0.01,
k2: float = 0.03
):
"""
Args:
kernel_size: Size of the gaussian kernel. Default: (11, 11)
sigma: Standard deviation of the gaussian kernel. Default: (1.5, 1.5)
reduction: A method for reducing ssim. Default: ``elementwise_mean``
Available reduction methods:
- elementwise_mean: takes the mean
- none: pass away
- sum: add elements
data_range: Range of the image. If ``None``, it is determined from the image (max - min)
k1: Parameter of SSIM. Default: 0.01
k2: Parameter of SSIM. Default: 0.03
"""
super().__init__(name="ssim")
self.kernel_size = kernel_size
self.sigma = sigma
self.reduction = reduction
self.data_range = data_range
self.k1 = k1
self.k2 = k2

def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Actual metric computation
Args:
pred: Estimated image
target: Ground truth image
Return:
torch.Tensor: SSIM Score
"""
return ssim(pred, target, self.kernel_size, self.sigma, self.reduction, self.data_range, self.k1, self.k2)
51 changes: 50 additions & 1 deletion tests/metrics/functional/test_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@
import torch
import numpy as np
from skimage.metrics import peak_signal_noise_ratio as ski_psnr
from skimage.metrics import structural_similarity as ski_ssim

from pytorch_lightning.metrics.functional import (
mae,
mse,
psnr,
rmse,
rmsle
rmsle,
ssim
)


Expand Down Expand Up @@ -93,3 +95,50 @@ def test_psnr_against_sklearn(sklearn_metric, torch_metric):
sk_score = torch.tensor(sk_score, dtype=torch.float, device=device)
pl_score = torch_metric(pred, target, data_range=n_cls_target)
assert torch.allclose(sk_score, pl_score)


@pytest.mark.parametrize(['size', 'channel', 'plus', 'multichannel'], [
pytest.param(16, 1, 0.125, False),
pytest.param(32, 1, 0.25, False),
pytest.param(48, 3, 0.5, True),
pytest.param(64, 4, 0.75, True),
pytest.param(128, 5, 1, True)
])
def test_ssim(size, channel, plus, multichannel):
device = "cuda" if torch.cuda.is_available() else "cpu"
pred = torch.rand(1, channel, size, size, device=device)
target = pred + plus
ssim_idx = ssim(pred, target)
np_pred = np.random.rand(size, size, channel)
if multichannel is False:
np_pred = np_pred[:, :, 0]
np_target = np.add(np_pred, plus)
sk_ssim_idx = ski_ssim(np_pred, np_target, win_size=11, multichannel=multichannel, gaussian_weights=True)
assert torch.allclose(ssim_idx, torch.tensor(sk_ssim_idx, dtype=torch.float, device=device), atol=1e-2, rtol=1e-2)

ssim_idx = ssim(pred, pred)
assert torch.allclose(ssim_idx, torch.tensor(1.0, device=device))


@pytest.mark.parametrize(['pred', 'target', 'kernel', 'sigma'], [
pytest.param([1, 1, 16, 16], [1, 16, 16], [11, 11], [1.5, 1.5]), # shape
pytest.param([1, 16, 16], [1, 16, 16], [11, 11], [1.5, 1.5]), # len(shape)
pytest.param([1, 1, 16, 16], [1, 1, 16, 16], [11, 11], [1.5]), # len(kernel), len(sigma)
pytest.param([1, 1, 16, 16], [1, 1, 16, 16], [11], [1.5, 1.5]), # len(kernel), len(sigma)
pytest.param([1, 1, 16, 16], [1, 1, 16, 16], [11], [1.5]), # len(kernel), len(sigma)
pytest.param([1, 1, 16, 16], [1, 1, 16, 16], [11, 0], [1.5, 1.5]), # invalid kernel input
pytest.param([1, 1, 16, 16], [1, 1, 16, 16], [11, 10], [1.5, 1.5]), # invalid kernel input
pytest.param([1, 1, 16, 16], [1, 1, 16, 16], [11, -11], [1.5, 1.5]), # invalid kernel input
pytest.param([1, 1, 16, 16], [1, 1, 16, 16], [11, 11], [1.5, 0]), # invalid sigma input
pytest.param([1, 1, 16, 16], [1, 1, 16, 16], [11, 0], [1.5, -1.5]), # invalid sigma input
])
def test_ssim_invalid_inputs(pred, target, kernel, sigma):
pred_t = torch.rand(pred)
target_t = torch.rand(target, dtype=torch.float64)
with pytest.raises(TypeError):
ssim(pred_t, target_t)

pred = torch.rand(pred)
target = torch.rand(target)
with pytest.raises(ValueError):
ssim(pred, target, kernel, sigma)
12 changes: 11 additions & 1 deletion tests/metrics/test_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from skimage.metrics import peak_signal_noise_ratio as ski_psnr

from pytorch_lightning.metrics.regression import (
MAE, MSE, RMSE, RMSLE, PSNR
MAE, MSE, RMSE, RMSLE, PSNR, SSIM
)


Expand Down Expand Up @@ -58,3 +58,13 @@ def test_psnr():
target = torch.tensor([0., 1, 2, 2])
score = psnr(pred, target)
assert isinstance(score, torch.Tensor)


def test_ssim():
ssim = SSIM()
assert ssim.name == 'ssim'

pred = torch.rand([16, 1, 16, 16])
target = pred * 1.25
score = ssim(pred, target)
assert isinstance(score, torch.Tensor)

0 comments on commit bda7cf1

Please sign in to comment.