Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add rank correlation metrics #3276

Merged
merged 7 commits into from
Sep 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/metrics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,8 @@ Complete list of metrics
regression.MedianAbsolutePercentageError
regression.MedianRelativeAbsoluteError
regression.PearsonCorrelation
regression.SpearmanRankCorrelation
regression.KendallRankCorrelation
regression.R2Score
regression.WaveHedgesDistance

Expand Down
2 changes: 2 additions & 0 deletions ignite/metrics/regression/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from ignite.metrics.regression.fractional_bias import FractionalBias
from ignite.metrics.regression.geometric_mean_absolute_error import GeometricMeanAbsoluteError
from ignite.metrics.regression.geometric_mean_relative_absolute_error import GeometricMeanRelativeAbsoluteError
from ignite.metrics.regression.kendall_correlation import KendallRankCorrelation
from ignite.metrics.regression.manhattan_distance import ManhattanDistance
from ignite.metrics.regression.maximum_absolute_error import MaximumAbsoluteError
from ignite.metrics.regression.mean_absolute_relative_error import MeanAbsoluteRelativeError
Expand All @@ -13,4 +14,5 @@
from ignite.metrics.regression.median_relative_absolute_error import MedianRelativeAbsoluteError
from ignite.metrics.regression.pearson_correlation import PearsonCorrelation
from ignite.metrics.regression.r2_score import R2Score
from ignite.metrics.regression.spearman_correlation import SpearmanRankCorrelation
from ignite.metrics.regression.wave_hedges_distance import WaveHedgesDistance
115 changes: 115 additions & 0 deletions ignite/metrics/regression/kendall_correlation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
from typing import Any, Callable, Tuple, Union

import torch

from torch import Tensor

from ignite.exceptions import NotComputableError
from ignite.metrics.epoch_metric import EpochMetric
from ignite.metrics.regression._base import _check_output_shapes, _check_output_types


def _get_kendall_tau(variant: str = "b") -> Callable[[Tensor, Tensor], float]:
from scipy.stats import kendalltau

if variant not in ("b", "c"):
raise ValueError(f"variant accepts 'b' or 'c', got {variant!r}.")

def _tau(predictions: Tensor, targets: Tensor) -> float:
np_preds = predictions.flatten().numpy()
np_targets = targets.flatten().numpy()
r = kendalltau(np_preds, np_targets, variant=variant).statistic
return r

return _tau


class KendallRankCorrelation(EpochMetric):
r"""Calculates the
`Kendall rank correlation coefficient <https://en.wikipedia.org/wiki/Kendall_rank_correlation_coefficient>`_.

.. math::
\tau = 1-\frac{2(\text{number of discordant pairs})}{\left( \begin{array}{c}n\\2\end{array} \right)}

Two prediction-target pairs :math:`(P_i, A_i)` and :math:`(P_j, A_j)`, where :math:`i<j`,
are said to be concordant when both :math:`P_i<P_j` and :math:`A_i<A_j` holds
or both :math:`P_i>P_j` and :math:`A_i>A_j`.

The `number of discordant pairs` counts the number of pairs that are not concordant.

The computation of this metric is implemented with
`scipy.stats.kendalltau <https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.kendalltau.html>`_.

- ``update`` must receive output of the form ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y}``.
- `y` and `y_pred` must be of same shape `(N, )` or `(N, 1)`.

Parameters are inherited from ``Metric.__init__``.

Args:
variant: variant of kendall rank correlation. ``b`` or ``c`` is accepted.
Details can be found
`here <https://en.wikipedia.org/wiki/Kendall_rank_correlation_coefficient#Accounting_for_ties>`_.
Default: ``b``
output_transform: a callable that is used to transform the
:class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the
form expected by the metric. This can be useful if, for example, you have a multi-output model and
you want to compute the metric with respect to one of the outputs.
By default, metrics require the output as ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y}``.
device: specifies which device updates are accumulated on. Setting the
metric's device to be the same as your ``update`` arguments ensures the ``update`` method is
non-blocking. By default, CPU.

Examples:
To use with ``Engine`` and ``process_function``, simply attach the metric instance to the engine.
The output of the engine's ``process_function`` needs to be in format of
``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y, ...}``.

.. include:: defaults.rst
:start-after: :orphan:

.. testcode::

metric = KendallRankCorrelation()
metric.attach(default_evaluator, 'kendall_tau')
y_true = torch.tensor([0., 1., 2., 3., 4., 5.])
y_pred = torch.tensor([0.5, 2.8, 1.9, 1.3, 6.0, 4.1])
state = default_evaluator.run([[y_pred, y_true]])
print(state.metrics['kendall_tau'])

.. testoutput::

0.4666666666666666
"""

def __init__(
self,
variant: str = "b",
output_transform: Callable[..., Any] = lambda x: x,
check_compute_fn: bool = True,
device: Union[str, torch.device] = torch.device("cpu"),
skip_unrolling: bool = False,
) -> None:
try:
from scipy.stats import kendalltau # noqa: F401
except ImportError:
raise ModuleNotFoundError("This module requires scipy to be installed.")

super().__init__(_get_kendall_tau(variant), output_transform, check_compute_fn, device, skip_unrolling)

def update(self, output: Tuple[torch.Tensor, torch.Tensor]) -> None:
y_pred, y = output[0].detach(), output[1].detach()
if y_pred.ndim == 1:
y_pred = y_pred.unsqueeze(1)
if y.ndim == 1:
y = y.unsqueeze(1)

_check_output_shapes(output)
_check_output_types(output)

super().update(output)

def compute(self) -> float:
if len(self._predictions) < 1 or len(self._targets) < 1:
raise NotComputableError("KendallRankCorrelation must have at least one example before it can be computed.")

return super().compute()
107 changes: 107 additions & 0 deletions ignite/metrics/regression/spearman_correlation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
from typing import Any, Callable, Tuple, Union

import torch

from torch import Tensor

from ignite.exceptions import NotComputableError
from ignite.metrics.epoch_metric import EpochMetric
from ignite.metrics.regression._base import _check_output_shapes, _check_output_types


def _get_spearman_r() -> Callable[[Tensor, Tensor], float]:
from scipy.stats import spearmanr

def _compute_spearman_r(predictions: Tensor, targets: Tensor) -> float:
np_preds = predictions.flatten().numpy()
np_targets = targets.flatten().numpy()
r = spearmanr(np_preds, np_targets).statistic
return r

return _compute_spearman_r


class SpearmanRankCorrelation(EpochMetric):
r"""Calculates the
`Spearman's rank correlation coefficient
<https://en.wikipedia.org/wiki/Spearman%27s_rank_correlation_coefficient>`_.

.. math::
r_\text{s} = \text{Corr}[R[P], R[A]] = \frac{\text{Cov}[R[P], R[A]]}{\sigma_{R[P]} \sigma_{R[A]}}

where :math:`A` and :math:`P` are the ground truth and predicted value,
and :math:`R[X]` is the ranking value of :math:`X`.

The computation of this metric is implemented with
`scipy.stats.spearmanr <https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.spearmanr.html>`_.

- ``update`` must receive output of the form ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y}``.
- `y` and `y_pred` must be of same shape `(N, )` or `(N, 1)`.

Parameters are inherited from ``Metric.__init__``.

Args:
output_transform: a callable that is used to transform the
:class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the
form expected by the metric. This can be useful if, for example, you have a multi-output model and
you want to compute the metric with respect to one of the outputs.
By default, metrics require the output as ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y}``.
device: specifies which device updates are accumulated on. Setting the
metric's device to be the same as your ``update`` arguments ensures the ``update`` method is
non-blocking. By default, CPU.

Examples:
To use with ``Engine`` and ``process_function``, simply attach the metric instance to the engine.
The output of the engine's ``process_function`` needs to be in format of
``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y, ...}``.

.. include:: defaults.rst
:start-after: :orphan:

.. testcode::

metric = SpearmanRankCorrelation()
metric.attach(default_evaluator, 'spearman_corr')
y_true = torch.tensor([0., 1., 2., 3., 4., 5.])
y_pred = torch.tensor([0.5, 2.8, 1.9, 1.3, 6.0, 4.1])
state = default_evaluator.run([[y_pred, y_true]])
print(state.metrics['spearman_corr'])

.. testoutput::

0.7142857142857143
"""

def __init__(
self,
output_transform: Callable[..., Any] = lambda x: x,
check_compute_fn: bool = True,
device: Union[str, torch.device] = torch.device("cpu"),
skip_unrolling: bool = False,
) -> None:
try:
from scipy.stats import spearmanr # noqa: F401
except ImportError:
raise ModuleNotFoundError("This module requires scipy to be installed.")

super().__init__(_get_spearman_r(), output_transform, check_compute_fn, device, skip_unrolling)

def update(self, output: Tuple[torch.Tensor, torch.Tensor]) -> None:
y_pred, y = output[0].detach(), output[1].detach()
if y_pred.ndim == 1:
y_pred = y_pred.unsqueeze(1)
if y.ndim == 1:
y = y.unsqueeze(1)

_check_output_shapes(output)
_check_output_types(output)

super().update(output)

def compute(self) -> float:
if len(self._predictions) < 1 or len(self._targets) < 1:
raise NotComputableError(
"SpearmanRankCorrelation must have at least one example before it can be computed."
)

return super().compute()
Loading
Loading