diff --git a/docs/source/metrics.rst b/docs/source/metrics.rst index d67c90c2248..29943e98343 100644 --- a/docs/source/metrics.rst +++ b/docs/source/metrics.rst @@ -357,6 +357,7 @@ Complete list of metrics KLDivergence JSDivergence MaximumMeanDiscrepancy + HSIC AveragePrecision CohenKappa GpuInfo diff --git a/ignite/metrics/__init__.py b/ignite/metrics/__init__.py index 142a13e5934..9f2c2303bc8 100644 --- a/ignite/metrics/__init__.py +++ b/ignite/metrics/__init__.py @@ -14,6 +14,7 @@ from ignite.metrics.gan.fid import FID from ignite.metrics.gan.inception_score import InceptionScore from ignite.metrics.gpu_info import GpuInfo +from ignite.metrics.hsic import HSIC from ignite.metrics.js_divergence import JSDivergence from ignite.metrics.kl_divergence import KLDivergence from ignite.metrics.loss import Loss @@ -64,6 +65,7 @@ "JaccardIndex", "JSDivergence", "KLDivergence", + "HSIC", "MaximumMeanDiscrepancy", "MultiLabelConfusionMatrix", "MutualInformation", diff --git a/ignite/metrics/hsic.py b/ignite/metrics/hsic.py new file mode 100644 index 00000000000..a35d47f258b --- /dev/null +++ b/ignite/metrics/hsic.py @@ -0,0 +1,170 @@ +from typing import Callable, Sequence, Union + +import torch +from torch import Tensor + +from ignite.exceptions import NotComputableError +from ignite.metrics.metric import Metric, reinit__is_reduced, sync_all_reduce + +__all__ = ["HSIC"] + + +class HSIC(Metric): + r"""Calculates the `Hilbert-Schmidt Independence Criterion (HSIC) + `_. + + .. math:: + \text{HSIC}(X,Y) = \frac{1}{B(B-3)}\left[ \text{tr}(\tilde{\mathbf{K}}\tilde{\mathbf{L}}) + + \frac{\mathbf{1}^\top \tilde{\mathbf{K}} \mathbf{11}^\top \tilde{\mathbf{L}} \mathbf{1}}{(B-1)(B-2)} + -\frac{2}{B-2}\mathbf{1}^\top \tilde{\mathbf{K}}\tilde{\mathbf{L}} \mathbf{1} \right] + + where :math:`B` is the batch size, and :math:`\tilde{\mathbf{K}}` + and :math:`\tilde{\mathbf{L}}` are the Gram matrices of + the Gaussian RBF kernel with their diagonal entries being set to zero. + + HSIC measures non-linear statistical independence between features :math:`X` and :math:`Y`. + HSIC becomes zero if and only if :math:`X` and :math:`Y` are independent. + + This metric computes the unbiased estimator of HSIC proposed in + `Song et al. (2012) `_. + The HSIC is estimated using Eq. (5) of the paper for each batch and the average is accumulated. + + Each batch must contain at least four samples. + + - ``update`` must receive output of the form ``(y_pred, y)``. + + Args: + sigma_x: bandwidth of the kernel for :math:`X`. + If negative, a heuristic value determined by the median of the distances between + the samples is used. Default: -1 + sigma_y: bandwidth of the kernel for :math:`Y`. + If negative, a heuristic value determined by the median of the distances + between the samples is used. Default: -1 + ignore_invalid_batch: If ``True``, computation for a batch with less than four samples is skipped. + If ``False``, ``ValueError`` is raised when received such a batch. + output_transform: a callable that is used to transform the + :class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the + form expected by the metric. This can be useful if, for example, you have a multi-output model and + you want to compute the metric with respect to one of the outputs. + By default, metrics require the output as ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y}``. + device: specifies which device updates are accumulated on. Setting the + metric's device to be the same as your ``update`` arguments ensures the ``update`` method is + non-blocking. By default, CPU. + skip_unrolling: specifies whether output should be unrolled before being fed to update method. Should be + true for multi-output model, for example, if ``y_pred`` contains multi-ouput as ``(y_pred_a, y_pred_b)`` + Alternatively, ``output_transform`` can be used to handle this. + + Examples: + To use with ``Engine`` and ``process_function``, simply attach the metric instance to the engine. + The output of the engine's ``process_function`` needs to be in the format of + ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y, ...}``. If not, ``output_tranform`` can be added + to the metric to transform the output into the form expected by the metric. + + ``y_pred`` and ``y`` should have the same shape. + + For more information on how metric works with :class:`~ignite.engine.engine.Engine`, visit :ref:`attach-engine`. + + .. include:: defaults.rst + :start-after: :orphan: + + .. testcode:: + + metric = HSIC() + metric.attach(default_evaluator, "hsic") + X = torch.tensor([[0., 1., 2., 3., 4.], + [5., 6., 7., 8., 9.], + [10., 11., 12., 13., 14.], + [15., 16., 17., 18., 19.], + [20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.], + [40., 41., 42., 43., 44.], + [45., 46., 47., 48., 49.]]) + Y = torch.sin(X * torch.pi * 2 / 50) + state = default_evaluator.run([[X, Y]]) + print(state.metrics["hsic"]) + + .. testoutput:: + + 0.09226646274328232 + + .. versionadded:: 0.5.2 + """ + + def __init__( + self, + sigma_x: float = -1, + sigma_y: float = -1, + ignore_invalid_batch: bool = True, + output_transform: Callable = lambda x: x, + device: Union[str, torch.device] = torch.device("cpu"), + skip_unrolling: bool = False, + ): + super().__init__(output_transform, device, skip_unrolling=skip_unrolling) + + self.sigma_x = sigma_x + self.sigma_y = sigma_y + self.ignore_invalid_batch = ignore_invalid_batch + + _state_dict_all_req_keys = ("_sum_of_hsic", "_num_batches") + + @reinit__is_reduced + def reset(self) -> None: + self._sum_of_hsic = torch.tensor(0.0, device=self._device) + self._num_batches = 0 + + @reinit__is_reduced + def update(self, output: Sequence[Tensor]) -> None: + X = output[0].detach().flatten(start_dim=1) + Y = output[1].detach().flatten(start_dim=1) + b = X.shape[0] + + if b <= 3: + if self.ignore_invalid_batch: + return + else: + raise ValueError(f"A batch must contain more than four samples, got only {b} samples.") + + mask = 1.0 - torch.eye(b, device=X.device) + + xx = X @ X.T + rx = xx.diag().unsqueeze(0).expand_as(xx) + dxx = rx.T + rx - xx * 2 + + vx: Union[Tensor, float] + if self.sigma_x < 0: + vx = torch.quantile(dxx, 0.5) + else: + vx = self.sigma_x**2 + K = torch.exp(-0.5 * dxx / vx) * mask + + yy = Y @ Y.T + ry = yy.diag().unsqueeze(0).expand_as(yy) + dyy = ry.T + ry - yy * 2 + + vy: Union[Tensor, float] + if self.sigma_y < 0: + vy = torch.quantile(dyy, 0.5) + else: + vy = self.sigma_y**2 + L = torch.exp(-0.5 * dyy / vy) * mask + + KL = K @ L + trace = KL.trace() + second_term = K.sum() * L.sum() / ((b - 1) * (b - 2)) + third_term = KL.sum() / (b - 2) + + hsic = trace + second_term - third_term * 2.0 + hsic /= b * (b - 3) + hsic = torch.clamp(hsic, min=0.0) # HSIC must not be negative + self._sum_of_hsic += hsic.to(self._device) + + self._num_batches += 1 + + @sync_all_reduce("_sum_of_hsic", "_num_batches") + def compute(self) -> float: + if self._num_batches == 0: + raise NotComputableError("HSIC must have at least one batch before it can be computed.") + + return self._sum_of_hsic.item() / self._num_batches diff --git a/tests/ignite/metrics/test_hsic.py b/tests/ignite/metrics/test_hsic.py new file mode 100644 index 00000000000..6ee4237b96e --- /dev/null +++ b/tests/ignite/metrics/test_hsic.py @@ -0,0 +1,188 @@ +from typing import Tuple + +import numpy as np +import pytest + +import torch +from torch import nn, Tensor + +import ignite.distributed as idist +from ignite.engine import Engine +from ignite.exceptions import NotComputableError +from ignite.metrics import HSIC + + +def np_hsic(x: Tensor, y: Tensor, sigma_x: float = -1, sigma_y: float = -1) -> float: + x_np = x.detach().cpu().numpy() + y_np = y.detach().cpu().numpy() + b = x_np.shape[0] + + ii, jj = np.meshgrid(np.arange(b), np.arange(b), indexing="ij") + mask = 1.0 - np.eye(b) + + dxx = np.square(x_np[ii] - x_np[jj]).sum(axis=2) + if sigma_x < 0: + vx = np.median(dxx) + else: + vx = sigma_x * sigma_x + K = np.exp(-0.5 * dxx / vx) * mask + + dyy = np.square(y_np[ii] - y_np[jj]).sum(axis=2) + if sigma_y < 0: + vy = np.median(dyy) + else: + vy = sigma_y * sigma_y + L = np.exp(-0.5 * dyy / vy) * mask + + KL = K @ L + ones = np.ones(b) + hsic = np.trace(KL) + (ones @ K @ ones) * (ones @ L @ ones) / ((b - 1) * (b - 2)) - ones @ KL @ ones * 2 / (b - 2) + hsic /= b * (b - 3) + hsic = np.clip(hsic, 0.0, None) + return hsic + + +def test_zero_batch(): + hsic = HSIC() + with pytest.raises(NotComputableError, match=r"HSIC must have at least one batch before it can be computed"): + hsic.compute() + + +def test_invalid_batch(): + hsic = HSIC(ignore_invalid_batch=False) + X = torch.tensor([[1, 2, 3]]).float() + Y = torch.tensor([[4, 5, 6]]).float() + with pytest.raises(ValueError, match=r"A batch must contain more than four samples, got only"): + hsic.update((X, Y)) + + +@pytest.fixture(params=[0, 1, 2]) +def test_case(request) -> Tuple[Tensor, Tensor, int]: + if request.param == 0: + # independent + N = 100 + b = 10 + x, y = torch.randn((N, 50)), torch.randn((N, 30)) + elif request.param == 1: + # linearly dependent + N = 100 + b = 10 + x = torch.normal(1.0, 2.0, size=(N, 10)) + y = x @ torch.rand(10, 15) * 3 + torch.randn(N, 15) * 1e-4 + else: + # non-linearly dependent + N = 200 + b = 20 + x = torch.randn(N, 5) + y = x @ torch.normal(0.0, torch.pi, size=(5, 3)) + y = ( + torch.stack([torch.sin(y[:, 0]), torch.cos(y[:, 1]), torch.exp(y[:, 2])], dim=1) + + torch.randn_like(y) * 1e-4 + ) + + return x, y, b + + +@pytest.mark.parametrize("n_times", range(3)) +@pytest.mark.parametrize("sigma_x", [-1.0, 1.0]) +@pytest.mark.parametrize("sigma_y", [-1.0, 1.0]) +def test_compute(n_times, sigma_x: float, sigma_y: float, test_case: Tuple[Tensor, Tensor, int]): + x, y, batch_size = test_case + + hsic = HSIC(sigma_x=sigma_x, sigma_y=sigma_y) + + hsic.reset() + + np_hsic_sum = 0.0 + n_iters = y.shape[0] // batch_size + for i in range(n_iters): + idx = i * batch_size + x_batch = x[idx : idx + batch_size] + y_batch = y[idx : idx + batch_size] + + hsic.update((x_batch, y_batch)) + np_hsic_sum += np_hsic(x_batch, y_batch, sigma_x, sigma_y) + expected_hsic = np_hsic_sum / n_iters + + assert isinstance(hsic.compute(), float) + assert pytest.approx(expected_hsic, abs=2e-5) == hsic.compute() + + +def test_accumulator_detached(): + hsic = HSIC() + + x = torch.rand(10, 10, dtype=torch.float) + y = torch.rand(10, 10, dtype=torch.float) + hsic.update((x, y)) + + assert not hsic._sum_of_hsic.requires_grad + + +@pytest.mark.usefixtures("distributed") +class TestDistributed: + @pytest.mark.parametrize("sigma_x", [-1.0, 1.0]) + @pytest.mark.parametrize("sigma_y", [-1.0, 1.0]) + def test_integration(self, sigma_x: float, sigma_y: float): + tol = 2e-5 + n_iters = 100 + batch_size = 20 + n_dims_x = 100 + n_dims_y = 50 + + rank = idist.get_rank() + torch.manual_seed(12 + rank) + + device = idist.device() + metric_devices = [torch.device("cpu")] + if device.type != "xla": + metric_devices.append(device) + + for metric_device in metric_devices: + x = torch.randn((n_iters * batch_size, n_dims_x)).float().to(device) + + lin = nn.Linear(n_dims_x, n_dims_y).to(device) + y = torch.sin(lin(x) * 100) + torch.randn(n_iters * batch_size, n_dims_y) * 1e-4 + + def data_loader(i, input_x, input_y): + return input_x[i * batch_size : (i + 1) * batch_size], input_y[i * batch_size : (i + 1) * batch_size] + + engine = Engine(lambda e, i: data_loader(i, x, y)) + + m = HSIC(sigma_x=sigma_x, sigma_y=sigma_y, device=metric_device) + m.attach(engine, "hsic") + + data = list(range(n_iters)) + engine.run(data=data, max_epochs=1) + + assert "hsic" in engine.state.metrics + res = engine.state.metrics["hsic"] + + x = idist.all_gather(x) + y = idist.all_gather(y) + total_n_iters = idist.all_reduce(n_iters) + + np_res = 0.0 + for i in range(total_n_iters): + x_batch, y_batch = data_loader(i, x, y) + np_res += np_hsic(x_batch, y_batch, sigma_x, sigma_y) + + expected_hsic = np_res / total_n_iters + assert pytest.approx(expected_hsic, abs=tol) == res + + def test_accumulator_device(self): + device = idist.device() + metric_devices = [torch.device("cpu")] + if device.type != "xla": + metric_devices.append(device) + for metric_device in metric_devices: + hsic = HSIC(device=metric_device) + + for dev in (hsic._device, hsic._sum_of_hsic.device): + assert dev == metric_device, f"{type(dev)}:{dev} vs {type(metric_device)}:{metric_device}" + + x = torch.zeros(10, 10).float() + y = torch.ones(10, 10).float() + hsic.update((x, y)) + + for dev in (hsic._device, hsic._sum_of_hsic.device): + assert dev == metric_device, f"{type(dev)}:{dev} vs {type(metric_device)}:{metric_device}"