From e2f9ac07eef9acf908b973c4c881556ef749ff5d Mon Sep 17 00:00:00 2001 From: Kazuki Adachi Date: Wed, 4 Sep 2024 04:44:43 +0900 Subject: [PATCH] Add rank correlation metrics (#3276) * add SpearmanRankCorrelation metric * add KendallRankCorrelation metric * add import check of scipy * fix type hints * fix formatting error * minor modification to docstring --------- Co-authored-by: vfdev --- docs/source/metrics.rst | 2 + ignite/metrics/regression/__init__.py | 2 + .../metrics/regression/kendall_correlation.py | 115 ++++++++++ .../regression/spearman_correlation.py | 107 ++++++++++ .../regression/test_kendall_correlation.py | 200 ++++++++++++++++++ .../regression/test_spearman_correlation.py | 191 +++++++++++++++++ 6 files changed, 617 insertions(+) create mode 100644 ignite/metrics/regression/kendall_correlation.py create mode 100644 ignite/metrics/regression/spearman_correlation.py create mode 100644 tests/ignite/metrics/regression/test_kendall_correlation.py create mode 100644 tests/ignite/metrics/regression/test_spearman_correlation.py diff --git a/docs/source/metrics.rst b/docs/source/metrics.rst index 0e4979f82a1..d67c90c2248 100644 --- a/docs/source/metrics.rst +++ b/docs/source/metrics.rst @@ -377,6 +377,8 @@ Complete list of metrics regression.MedianAbsolutePercentageError regression.MedianRelativeAbsoluteError regression.PearsonCorrelation + regression.SpearmanRankCorrelation + regression.KendallRankCorrelation regression.R2Score regression.WaveHedgesDistance diff --git a/ignite/metrics/regression/__init__.py b/ignite/metrics/regression/__init__.py index 7be1f18d0f3..4be1abddb11 100644 --- a/ignite/metrics/regression/__init__.py +++ b/ignite/metrics/regression/__init__.py @@ -3,6 +3,7 @@ from ignite.metrics.regression.fractional_bias import FractionalBias from ignite.metrics.regression.geometric_mean_absolute_error import GeometricMeanAbsoluteError from ignite.metrics.regression.geometric_mean_relative_absolute_error import GeometricMeanRelativeAbsoluteError +from ignite.metrics.regression.kendall_correlation import KendallRankCorrelation from ignite.metrics.regression.manhattan_distance import ManhattanDistance from ignite.metrics.regression.maximum_absolute_error import MaximumAbsoluteError from ignite.metrics.regression.mean_absolute_relative_error import MeanAbsoluteRelativeError @@ -13,4 +14,5 @@ from ignite.metrics.regression.median_relative_absolute_error import MedianRelativeAbsoluteError from ignite.metrics.regression.pearson_correlation import PearsonCorrelation from ignite.metrics.regression.r2_score import R2Score +from ignite.metrics.regression.spearman_correlation import SpearmanRankCorrelation from ignite.metrics.regression.wave_hedges_distance import WaveHedgesDistance diff --git a/ignite/metrics/regression/kendall_correlation.py b/ignite/metrics/regression/kendall_correlation.py new file mode 100644 index 00000000000..826ca350ac2 --- /dev/null +++ b/ignite/metrics/regression/kendall_correlation.py @@ -0,0 +1,115 @@ +from typing import Any, Callable, Tuple, Union + +import torch + +from torch import Tensor + +from ignite.exceptions import NotComputableError +from ignite.metrics.epoch_metric import EpochMetric +from ignite.metrics.regression._base import _check_output_shapes, _check_output_types + + +def _get_kendall_tau(variant: str = "b") -> Callable[[Tensor, Tensor], float]: + from scipy.stats import kendalltau + + if variant not in ("b", "c"): + raise ValueError(f"variant accepts 'b' or 'c', got {variant!r}.") + + def _tau(predictions: Tensor, targets: Tensor) -> float: + np_preds = predictions.flatten().numpy() + np_targets = targets.flatten().numpy() + r = kendalltau(np_preds, np_targets, variant=variant).statistic + return r + + return _tau + + +class KendallRankCorrelation(EpochMetric): + r"""Calculates the + `Kendall rank correlation coefficient `_. + + .. math:: + \tau = 1-\frac{2(\text{number of discordant pairs})}{\left( \begin{array}{c}n\\2\end{array} \right)} + + Two prediction-target pairs :math:`(P_i, A_i)` and :math:`(P_j, A_j)`, where :math:`iP_j` and :math:`A_i>A_j`. + + The `number of discordant pairs` counts the number of pairs that are not concordant. + + The computation of this metric is implemented with + `scipy.stats.kendalltau `_. + + - ``update`` must receive output of the form ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y}``. + - `y` and `y_pred` must be of same shape `(N, )` or `(N, 1)`. + + Parameters are inherited from ``Metric.__init__``. + + Args: + variant: variant of kendall rank correlation. ``b`` or ``c`` is accepted. + Details can be found + `here `_. + Default: ``b`` + output_transform: a callable that is used to transform the + :class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the + form expected by the metric. This can be useful if, for example, you have a multi-output model and + you want to compute the metric with respect to one of the outputs. + By default, metrics require the output as ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y}``. + device: specifies which device updates are accumulated on. Setting the + metric's device to be the same as your ``update`` arguments ensures the ``update`` method is + non-blocking. By default, CPU. + + Examples: + To use with ``Engine`` and ``process_function``, simply attach the metric instance to the engine. + The output of the engine's ``process_function`` needs to be in format of + ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y, ...}``. + + .. include:: defaults.rst + :start-after: :orphan: + + .. testcode:: + + metric = KendallRankCorrelation() + metric.attach(default_evaluator, 'kendall_tau') + y_true = torch.tensor([0., 1., 2., 3., 4., 5.]) + y_pred = torch.tensor([0.5, 2.8, 1.9, 1.3, 6.0, 4.1]) + state = default_evaluator.run([[y_pred, y_true]]) + print(state.metrics['kendall_tau']) + + .. testoutput:: + + 0.4666666666666666 + """ + + def __init__( + self, + variant: str = "b", + output_transform: Callable[..., Any] = lambda x: x, + check_compute_fn: bool = True, + device: Union[str, torch.device] = torch.device("cpu"), + skip_unrolling: bool = False, + ) -> None: + try: + from scipy.stats import kendalltau # noqa: F401 + except ImportError: + raise ModuleNotFoundError("This module requires scipy to be installed.") + + super().__init__(_get_kendall_tau(variant), output_transform, check_compute_fn, device, skip_unrolling) + + def update(self, output: Tuple[torch.Tensor, torch.Tensor]) -> None: + y_pred, y = output[0].detach(), output[1].detach() + if y_pred.ndim == 1: + y_pred = y_pred.unsqueeze(1) + if y.ndim == 1: + y = y.unsqueeze(1) + + _check_output_shapes(output) + _check_output_types(output) + + super().update(output) + + def compute(self) -> float: + if len(self._predictions) < 1 or len(self._targets) < 1: + raise NotComputableError("KendallRankCorrelation must have at least one example before it can be computed.") + + return super().compute() diff --git a/ignite/metrics/regression/spearman_correlation.py b/ignite/metrics/regression/spearman_correlation.py new file mode 100644 index 00000000000..7c5d586b152 --- /dev/null +++ b/ignite/metrics/regression/spearman_correlation.py @@ -0,0 +1,107 @@ +from typing import Any, Callable, Tuple, Union + +import torch + +from torch import Tensor + +from ignite.exceptions import NotComputableError +from ignite.metrics.epoch_metric import EpochMetric +from ignite.metrics.regression._base import _check_output_shapes, _check_output_types + + +def _get_spearman_r() -> Callable[[Tensor, Tensor], float]: + from scipy.stats import spearmanr + + def _compute_spearman_r(predictions: Tensor, targets: Tensor) -> float: + np_preds = predictions.flatten().numpy() + np_targets = targets.flatten().numpy() + r = spearmanr(np_preds, np_targets).statistic + return r + + return _compute_spearman_r + + +class SpearmanRankCorrelation(EpochMetric): + r"""Calculates the + `Spearman's rank correlation coefficient + `_. + + .. math:: + r_\text{s} = \text{Corr}[R[P], R[A]] = \frac{\text{Cov}[R[P], R[A]]}{\sigma_{R[P]} \sigma_{R[A]}} + + where :math:`A` and :math:`P` are the ground truth and predicted value, + and :math:`R[X]` is the ranking value of :math:`X`. + + The computation of this metric is implemented with + `scipy.stats.spearmanr `_. + + - ``update`` must receive output of the form ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y}``. + - `y` and `y_pred` must be of same shape `(N, )` or `(N, 1)`. + + Parameters are inherited from ``Metric.__init__``. + + Args: + output_transform: a callable that is used to transform the + :class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the + form expected by the metric. This can be useful if, for example, you have a multi-output model and + you want to compute the metric with respect to one of the outputs. + By default, metrics require the output as ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y}``. + device: specifies which device updates are accumulated on. Setting the + metric's device to be the same as your ``update`` arguments ensures the ``update`` method is + non-blocking. By default, CPU. + + Examples: + To use with ``Engine`` and ``process_function``, simply attach the metric instance to the engine. + The output of the engine's ``process_function`` needs to be in format of + ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y, ...}``. + + .. include:: defaults.rst + :start-after: :orphan: + + .. testcode:: + + metric = SpearmanRankCorrelation() + metric.attach(default_evaluator, 'spearman_corr') + y_true = torch.tensor([0., 1., 2., 3., 4., 5.]) + y_pred = torch.tensor([0.5, 2.8, 1.9, 1.3, 6.0, 4.1]) + state = default_evaluator.run([[y_pred, y_true]]) + print(state.metrics['spearman_corr']) + + .. testoutput:: + + 0.7142857142857143 + """ + + def __init__( + self, + output_transform: Callable[..., Any] = lambda x: x, + check_compute_fn: bool = True, + device: Union[str, torch.device] = torch.device("cpu"), + skip_unrolling: bool = False, + ) -> None: + try: + from scipy.stats import spearmanr # noqa: F401 + except ImportError: + raise ModuleNotFoundError("This module requires scipy to be installed.") + + super().__init__(_get_spearman_r(), output_transform, check_compute_fn, device, skip_unrolling) + + def update(self, output: Tuple[torch.Tensor, torch.Tensor]) -> None: + y_pred, y = output[0].detach(), output[1].detach() + if y_pred.ndim == 1: + y_pred = y_pred.unsqueeze(1) + if y.ndim == 1: + y = y.unsqueeze(1) + + _check_output_shapes(output) + _check_output_types(output) + + super().update(output) + + def compute(self) -> float: + if len(self._predictions) < 1 or len(self._targets) < 1: + raise NotComputableError( + "SpearmanRankCorrelation must have at least one example before it can be computed." + ) + + return super().compute() diff --git a/tests/ignite/metrics/regression/test_kendall_correlation.py b/tests/ignite/metrics/regression/test_kendall_correlation.py new file mode 100644 index 00000000000..5dd55b0691b --- /dev/null +++ b/tests/ignite/metrics/regression/test_kendall_correlation.py @@ -0,0 +1,200 @@ +from typing import Tuple + +import numpy as np +import pytest + +import torch +from scipy.stats import kendalltau +from torch import Tensor + +from ignite import distributed as idist +from ignite.engine import Engine +from ignite.exceptions import NotComputableError +from ignite.metrics.regression import KendallRankCorrelation + + +def test_zero_sample(): + with pytest.raises( + NotComputableError, match="KendallRankCorrelation must have at least one example before it can be computed" + ): + metric = KendallRankCorrelation() + metric.compute() + + +def test_wrong_y_pred_shape(): + with pytest.raises(ValueError, match=r"Input y_pred should have shape \(N,\) or \(N, 1\), but given"): + metric = KendallRankCorrelation() + y_pred = torch.arange(9).reshape(3, 3).float() + y = torch.arange(3).unsqueeze(1).float() + metric.update((y_pred, y)) + + +def test_wrong_y_shape(): + with pytest.raises(ValueError, match=r"Input y should have shape \(N,\) or \(N, 1\), but given"): + metric = KendallRankCorrelation() + y_pred = torch.arange(3).unsqueeze(1).float() + y = torch.arange(9).reshape(3, 3).float() + metric.update((y_pred, y)) + + +def test_wrong_y_pred_dtype(): + with pytest.raises(TypeError, match="Input y_pred dtype should be float 16, 32 or 64, but given"): + metric = KendallRankCorrelation() + y_pred = torch.arange(3).unsqueeze(1).long() + y = torch.arange(3).unsqueeze(1).float() + metric.update((y_pred, y)) + + +def test_wrong_y_dtype(): + with pytest.raises(TypeError, match="Input y dtype should be float 16, 32 or 64, but given"): + metric = KendallRankCorrelation() + y_pred = torch.arange(3).unsqueeze(1).float() + y = torch.arange(3).unsqueeze(1).long() + metric.update((y_pred, y)) + + +def test_wrong_variant(): + with pytest.raises(ValueError, match="variant accepts 'b' or 'c', got"): + KendallRankCorrelation(variant="x") + + +@pytest.mark.parametrize("variant", ["b", "c"]) +def test_kendall_correlation(variant: str): + a = np.random.randn(4).astype(np.float32) + b = np.random.randn(4).astype(np.float32) + c = np.random.randn(4).astype(np.float32) + d = np.random.randn(4).astype(np.float32) + ground_truth = np.random.randn(4).astype(np.float32) + + m = KendallRankCorrelation(variant=variant) + + m.update((torch.from_numpy(a), torch.from_numpy(ground_truth))) + np_ans = kendalltau(a, ground_truth, variant=variant).statistic + assert m.compute() == pytest.approx(np_ans, rel=1e-4) + + m.update((torch.from_numpy(b), torch.from_numpy(ground_truth))) + np_ans = kendalltau(np.concatenate([a, b]), np.concatenate([ground_truth] * 2), variant=variant).statistic + assert m.compute() == pytest.approx(np_ans, rel=1e-4) + + m.update((torch.from_numpy(c), torch.from_numpy(ground_truth))) + np_ans = kendalltau(np.concatenate([a, b, c]), np.concatenate([ground_truth] * 3), variant=variant).statistic + assert m.compute() == pytest.approx(np_ans, rel=1e-4) + + m.update((torch.from_numpy(d), torch.from_numpy(ground_truth))) + np_ans = kendalltau(np.concatenate([a, b, c, d]), np.concatenate([ground_truth] * 4), variant=variant).statistic + assert m.compute() == pytest.approx(np_ans, rel=1e-4) + + +@pytest.fixture(params=list(range(2))) +def test_case(request): + # correlated sample + x = torch.randn(size=[50]).float() + y = x + torch.randn_like(x) * 0.1 + + return [ + (x, y, 1), + (torch.rand(size=(50, 1)).float(), torch.rand(size=(50, 1)).float(), 10), + ][request.param] + + +@pytest.mark.parametrize("n_times", range(5)) +@pytest.mark.parametrize("variant", ["b", "c"]) +def test_integration(n_times: int, variant: str, test_case: Tuple[Tensor, Tensor, int]): + y_pred, y, batch_size = test_case + + np_y = y.numpy().ravel() + np_y_pred = y_pred.numpy().ravel() + + def update_fn(engine: Engine, batch): + idx = (engine.state.iteration - 1) * batch_size + y_true_batch = np_y[idx : idx + batch_size] + y_pred_batch = np_y_pred[idx : idx + batch_size] + return torch.from_numpy(y_pred_batch), torch.from_numpy(y_true_batch) + + engine = Engine(update_fn) + + m = KendallRankCorrelation(variant=variant) + m.attach(engine, "kendall_tau") + + data = list(range(y_pred.shape[0] // batch_size)) + corr = engine.run(data, max_epochs=1).metrics["kendall_tau"] + + np_ans = kendalltau(np_y_pred, np_y, variant=variant).statistic + + assert pytest.approx(np_ans, rel=2e-4) == corr + + +@pytest.mark.usefixtures("distributed") +class TestDistributed: + @pytest.mark.parametrize("variant", ["b", "c"]) + def test_compute(self, variant: str): + rank = idist.get_rank() + device = idist.device() + metric_devices = [torch.device("cpu")] + if device.type != "xla": + metric_devices.append(device) + + torch.manual_seed(10 + rank) + for metric_device in metric_devices: + m = KendallRankCorrelation(device=metric_device, variant=variant) + + y_pred = torch.rand(size=[100], device=device) + y = torch.rand(size=[100], device=device) + + m.update((y_pred, y)) + + y_pred = idist.all_gather(y_pred) + y = idist.all_gather(y) + + np_y = y.cpu().numpy() + np_y_pred = y_pred.cpu().numpy() + + np_ans = kendalltau(np_y_pred, np_y, variant=variant).statistic + + assert pytest.approx(np_ans, rel=2e-4) == m.compute() + + @pytest.mark.parametrize("n_epochs", [1, 2]) + @pytest.mark.parametrize("variant", ["b", "c"]) + def test_integration(self, n_epochs: int, variant: str): + tol = 2e-4 + rank = idist.get_rank() + device = idist.device() + metric_devices = [torch.device("cpu")] + if device.type != "xla": + metric_devices.append(device) + + n_iters = 80 + batch_size = 16 + + for metric_device in metric_devices: + torch.manual_seed(12 + rank) + + y_true = torch.rand(size=(n_iters * batch_size,)).to(device) + y_preds = torch.rand(size=(n_iters * batch_size,)).to(device) + + engine = Engine( + lambda e, i: ( + y_preds[i * batch_size : (i + 1) * batch_size], + y_true[i * batch_size : (i + 1) * batch_size], + ) + ) + + corr = KendallRankCorrelation(variant=variant, device=metric_device) + corr.attach(engine, "kendall_tau") + + data = list(range(n_iters)) + engine.run(data=data, max_epochs=n_epochs) + + y_preds = idist.all_gather(y_preds) + y_true = idist.all_gather(y_true) + + assert "kendall_tau" in engine.state.metrics + + res = engine.state.metrics["kendall_tau"] + + np_y = y_true.cpu().numpy() + np_y_pred = y_preds.cpu().numpy() + + np_ans = kendalltau(np_y_pred, np_y, variant=variant).statistic + + assert pytest.approx(np_ans, rel=tol) == res diff --git a/tests/ignite/metrics/regression/test_spearman_correlation.py b/tests/ignite/metrics/regression/test_spearman_correlation.py new file mode 100644 index 00000000000..4aac6221f62 --- /dev/null +++ b/tests/ignite/metrics/regression/test_spearman_correlation.py @@ -0,0 +1,191 @@ +from typing import Tuple + +import numpy as np +import pytest + +import torch +from scipy.stats import spearmanr +from torch import Tensor + +from ignite import distributed as idist +from ignite.engine import Engine +from ignite.exceptions import NotComputableError +from ignite.metrics.regression import SpearmanRankCorrelation + + +def test_zero_sample(): + with pytest.raises( + NotComputableError, match="SpearmanRankCorrelation must have at least one example before it can be computed" + ): + metric = SpearmanRankCorrelation() + metric.compute() + + +def test_wrong_y_pred_shape(): + with pytest.raises(ValueError, match=r"Input y_pred should have shape \(N,\) or \(N, 1\), but given"): + metric = SpearmanRankCorrelation() + y_pred = torch.arange(9).reshape(3, 3).float() + y = torch.arange(3).unsqueeze(1).float() + metric.update((y_pred, y)) + + +def test_wrong_y_shape(): + with pytest.raises(ValueError, match=r"Input y should have shape \(N,\) or \(N, 1\), but given"): + metric = SpearmanRankCorrelation() + y_pred = torch.arange(3).unsqueeze(1).float() + y = torch.arange(9).reshape(3, 3).float() + metric.update((y_pred, y)) + + +def test_wrong_y_pred_dtype(): + with pytest.raises(TypeError, match="Input y_pred dtype should be float 16, 32 or 64, but given"): + metric = SpearmanRankCorrelation() + y_pred = torch.arange(3).unsqueeze(1).long() + y = torch.arange(3).unsqueeze(1).float() + metric.update((y_pred, y)) + + +def test_wrong_y_dtype(): + with pytest.raises(TypeError, match="Input y dtype should be float 16, 32 or 64, but given"): + metric = SpearmanRankCorrelation() + y_pred = torch.arange(3).unsqueeze(1).float() + y = torch.arange(3).unsqueeze(1).long() + metric.update((y_pred, y)) + + +def test_spearman_correlation(): + a = np.random.randn(4).astype(np.float32) + b = np.random.randn(4).astype(np.float32) + c = np.random.randn(4).astype(np.float32) + d = np.random.randn(4).astype(np.float32) + ground_truth = np.random.randn(4).astype(np.float32) + + m = SpearmanRankCorrelation() + + m.update((torch.from_numpy(a), torch.from_numpy(ground_truth))) + np_ans = spearmanr(a, ground_truth).statistic + assert m.compute() == pytest.approx(np_ans, rel=1e-4) + + m.update((torch.from_numpy(b), torch.from_numpy(ground_truth))) + np_ans = spearmanr(np.concatenate([a, b]), np.concatenate([ground_truth] * 2)).statistic + assert m.compute() == pytest.approx(np_ans, rel=1e-4) + + m.update((torch.from_numpy(c), torch.from_numpy(ground_truth))) + np_ans = spearmanr(np.concatenate([a, b, c]), np.concatenate([ground_truth] * 3)).statistic + assert m.compute() == pytest.approx(np_ans, rel=1e-4) + + m.update((torch.from_numpy(d), torch.from_numpy(ground_truth))) + np_ans = spearmanr(np.concatenate([a, b, c, d]), np.concatenate([ground_truth] * 4)).statistic + assert m.compute() == pytest.approx(np_ans, rel=1e-4) + + +@pytest.fixture(params=list(range(2))) +def test_case(request): + # correlated sample + x = torch.randn(size=[50]).float() + y = x + torch.randn_like(x) * 0.1 + + return [ + (x, y, 1), + (torch.rand(size=(50, 1)).float(), torch.rand(size=(50, 1)).float(), 10), + ][request.param] + + +@pytest.mark.parametrize("n_times", range(5)) +def test_integration(n_times, test_case: Tuple[Tensor, Tensor, int]): + y_pred, y, batch_size = test_case + + np_y = y.numpy().ravel() + np_y_pred = y_pred.numpy().ravel() + + def update_fn(engine: Engine, batch): + idx = (engine.state.iteration - 1) * batch_size + y_true_batch = np_y[idx : idx + batch_size] + y_pred_batch = np_y_pred[idx : idx + batch_size] + return torch.from_numpy(y_pred_batch), torch.from_numpy(y_true_batch) + + engine = Engine(update_fn) + + m = SpearmanRankCorrelation() + m.attach(engine, "spearman_corr") + + data = list(range(y_pred.shape[0] // batch_size)) + corr = engine.run(data, max_epochs=1).metrics["spearman_corr"] + + np_ans = spearmanr(np_y_pred, np_y).statistic + + assert pytest.approx(np_ans, rel=2e-4) == corr + + +@pytest.mark.usefixtures("distributed") +class TestDistributed: + def test_compute(self): + rank = idist.get_rank() + device = idist.device() + metric_devices = [torch.device("cpu")] + if device.type != "xla": + metric_devices.append(device) + + torch.manual_seed(10 + rank) + for metric_device in metric_devices: + m = SpearmanRankCorrelation(device=metric_device) + + y_pred = torch.rand(size=[100], device=device) + y = torch.rand(size=[100], device=device) + + m.update((y_pred, y)) + + y_pred = idist.all_gather(y_pred) + y = idist.all_gather(y) + + np_y = y.cpu().numpy() + np_y_pred = y_pred.cpu().numpy() + + np_ans = spearmanr(np_y_pred, np_y).statistic + + assert pytest.approx(np_ans, rel=2e-4) == m.compute() + + @pytest.mark.parametrize("n_epochs", [1, 2]) + def test_integration(self, n_epochs: int): + tol = 2e-4 + rank = idist.get_rank() + device = idist.device() + metric_devices = [torch.device("cpu")] + if device.type != "xla": + metric_devices.append(device) + + n_iters = 80 + batch_size = 16 + + for metric_device in metric_devices: + torch.manual_seed(12 + rank) + + y_true = torch.rand(size=(n_iters * batch_size,)).to(device) + y_preds = torch.rand(size=(n_iters * batch_size,)).to(device) + + engine = Engine( + lambda e, i: ( + y_preds[i * batch_size : (i + 1) * batch_size], + y_true[i * batch_size : (i + 1) * batch_size], + ) + ) + + corr = SpearmanRankCorrelation(device=metric_device) + corr.attach(engine, "spearman_corr") + + data = list(range(n_iters)) + engine.run(data=data, max_epochs=n_epochs) + + y_preds = idist.all_gather(y_preds) + y_true = idist.all_gather(y_true) + + assert "spearman_corr" in engine.state.metrics + + res = engine.state.metrics["spearman_corr"] + + np_y = y_true.cpu().numpy() + np_y_pred = y_preds.cpu().numpy() + + np_ans = spearmanr(np_y_pred, np_y).statistic + + assert pytest.approx(np_ans, rel=tol) == res