Skip to content

Commit

Permalink
[BC-breaking] Make Metrics accumulate values on device specified by u…
Browse files Browse the repository at this point in the history
…ser (pytorch#1232) (pytorch#1238)

* Make Metrics accumulate values on device specified by user (pytorch#1232)

* update accuracy to accumulate _num_correct in a tensor on the right device

* update loss metric to accumulate _sum in a tensor on the right device

* update mae metric to accumulate in a tensor on the right device

* update mpd metric to accumulate in a tensor on the right device

* update mse metric to accumulate in a tensor on the right device

* update top k accuracy  metric to accumulate in a tensor on the right device

* update precision and recall metrics to accumulate in tensors on the right device

* .....

* black formatting

* reverted run*.sh

* change all metrics default device to cpu except running_average

* Update ignite/metrics/precision.py

Co-authored-by: vfdev <[email protected]>

* remove Optional type from metric devices since default is cpu

* add comment explaining lack of detach in accuracy metrics

Co-authored-by: vfdev <[email protected]>

* Improved and fixed accuracy tests

* autopep8 fix

* update docs and docstrings for updated metrics (pytorch#1239)

* update accuracy to accumulate _num_correct in a tensor on the right device

* update loss metric to accumulate _sum in a tensor on the right device

* update mae metric to accumulate in a tensor on the right device

* update mpd metric to accumulate in a tensor on the right device

* update mse metric to accumulate in a tensor on the right device

* update top k accuracy  metric to accumulate in a tensor on the right device

* update precision and recall metrics to accumulate in tensors on the right device

* .....

* black formatting

* reverted run*.sh

* change all metrics default device to cpu except running_average

* Update ignite/metrics/precision.py

Co-authored-by: vfdev <[email protected]>

* remove Optional type from metric devices since default is cpu

* add comment explaining lack of detach in accuracy metrics

* update docstrings and docs

* Update ignite/metrics/accumulation.py

Co-authored-by: vfdev <[email protected]>

* Update ignite/metrics/accumulation.py

Co-authored-by: vfdev <[email protected]>

* Update ignite/metrics/accumulation.py

Co-authored-by: vfdev <[email protected]>

* Update ignite/metrics/accuracy.py

Co-authored-by: vfdev <[email protected]>

* Update ignite/metrics/fbeta.py

Co-authored-by: vfdev <[email protected]>

* Update ignite/metrics/loss.py

Co-authored-by: vfdev <[email protected]>

* Update ignite/metrics/metric.py

Co-authored-by: vfdev <[email protected]>

* Update ignite/metrics/precision.py

Co-authored-by: vfdev <[email protected]>

* Update ignite/metrics/recall.py

Co-authored-by: vfdev <[email protected]>

* add comment explaining lack of detach in metrics docs

* support device argument for running_average

* update support for device argumenet for accumulation

* fix and improve device tests for metrics

* fix and improve device tests for metrics

* fix TPU tests

* Apply suggestions from code review

* Apply suggestions from code review

Co-authored-by: vfdev <[email protected]>

* Updates to metrics_impl (pytorch#1266)

* update accuracy to accumulate _num_correct in a tensor on the right device

* update loss metric to accumulate _sum in a tensor on the right device

* update mae metric to accumulate in a tensor on the right device

* update mpd metric to accumulate in a tensor on the right device

* update mse metric to accumulate in a tensor on the right device

* update top k accuracy  metric to accumulate in a tensor on the right device

* update precision and recall metrics to accumulate in tensors on the right device

* .....

* black formatting

* reverted run*.sh

* change all metrics default device to cpu except running_average

* Update ignite/metrics/precision.py

Co-authored-by: vfdev <[email protected]>

* remove Optional type from metric devices since default is cpu

* add comment explaining lack of detach in accuracy metrics

* update docstrings and docs

* Update ignite/metrics/accumulation.py

Co-authored-by: vfdev <[email protected]>

* Update ignite/metrics/accumulation.py

Co-authored-by: vfdev <[email protected]>

* Update ignite/metrics/accumulation.py

Co-authored-by: vfdev <[email protected]>

* Update ignite/metrics/accuracy.py

Co-authored-by: vfdev <[email protected]>

* Update ignite/metrics/fbeta.py

Co-authored-by: vfdev <[email protected]>

* Update ignite/metrics/loss.py

Co-authored-by: vfdev <[email protected]>

* Update ignite/metrics/metric.py

Co-authored-by: vfdev <[email protected]>

* Update ignite/metrics/precision.py

Co-authored-by: vfdev <[email protected]>

* Update ignite/metrics/recall.py

Co-authored-by: vfdev <[email protected]>

* add comment explaining lack of detach in metrics docs

* support device argument for running_average

* update support for device argumenet for accumulation

* fix and improve device tests for metrics

* fix and improve device tests for metrics

* fix TPU tests

* Apply suggestions from code review

* Apply suggestions from code review

* detach tensors earlier in update

* remove redundant to() call

* ensure metrics aren't created on XLA devices

* Fixed isort

* move xla check to Metric.__init__ instead of individual metrics

* update xla tests

* replace deleted callable check

* remove redundant precision and recall __init__

* replace precision/recall __init__ for docs rendering

* add support for metrics_lambda with components on diff devices

Co-authored-by: vfdev <[email protected]>
Co-authored-by: n2cholas <[email protected]>

* Update metrics.rst

* Update metrics.rst

* Fix TPU tests for metrics_impl branch (pytorch#1277)

* update accuracy to accumulate _num_correct in a tensor on the right device

* update loss metric to accumulate _sum in a tensor on the right device

* update mae metric to accumulate in a tensor on the right device

* update mpd metric to accumulate in a tensor on the right device

* update mse metric to accumulate in a tensor on the right device

* update top k accuracy  metric to accumulate in a tensor on the right device

* update precision and recall metrics to accumulate in tensors on the right device

* .....

* black formatting

* reverted run*.sh

* change all metrics default device to cpu except running_average

* Update ignite/metrics/precision.py

Co-authored-by: vfdev <[email protected]>

* remove Optional type from metric devices since default is cpu

* add comment explaining lack of detach in accuracy metrics

* update docstrings and docs

* Update ignite/metrics/accumulation.py

Co-authored-by: vfdev <[email protected]>

* Update ignite/metrics/accumulation.py

Co-authored-by: vfdev <[email protected]>

* Update ignite/metrics/accumulation.py

Co-authored-by: vfdev <[email protected]>

* Update ignite/metrics/accuracy.py

Co-authored-by: vfdev <[email protected]>

* Update ignite/metrics/fbeta.py

Co-authored-by: vfdev <[email protected]>

* Update ignite/metrics/loss.py

Co-authored-by: vfdev <[email protected]>

* Update ignite/metrics/metric.py

Co-authored-by: vfdev <[email protected]>

* Update ignite/metrics/precision.py

Co-authored-by: vfdev <[email protected]>

* Update ignite/metrics/recall.py

Co-authored-by: vfdev <[email protected]>

* add comment explaining lack of detach in metrics docs

* support device argument for running_average

* update support for device argumenet for accumulation

* fix and improve device tests for metrics

* fix and improve device tests for metrics

* fix TPU tests

* Apply suggestions from code review

* Apply suggestions from code review

* detach tensors earlier in update

* remove redundant to() call

* ensure metrics aren't created on XLA devices

* Fixed isort

* move xla check to Metric.__init__ instead of individual metrics

* update xla tests

* replace deleted callable check

* remove redundant precision and recall __init__

* replace precision/recall __init__ for docs rendering

* add support for metrics_lambda with components on diff devices

* fix epoch_metric xla test

Co-authored-by: vfdev <[email protected]>
Co-authored-by: n2cholas <[email protected]>

* metrics_impl fix 2 gpu hvd tests and ensure consistent detaching (pytorch#1280)

* update accuracy to accumulate _num_correct in a tensor on the right device

* update loss metric to accumulate _sum in a tensor on the right device

* update mae metric to accumulate in a tensor on the right device

* update mpd metric to accumulate in a tensor on the right device

* update mse metric to accumulate in a tensor on the right device

* update top k accuracy  metric to accumulate in a tensor on the right device

* update precision and recall metrics to accumulate in tensors on the right device

* .....

* black formatting

* reverted run*.sh

* change all metrics default device to cpu except running_average

* Update ignite/metrics/precision.py

Co-authored-by: vfdev <[email protected]>

* remove Optional type from metric devices since default is cpu

* add comment explaining lack of detach in accuracy metrics

* update docstrings and docs

* Update ignite/metrics/accumulation.py

Co-authored-by: vfdev <[email protected]>

* Update ignite/metrics/accumulation.py

Co-authored-by: vfdev <[email protected]>

* Update ignite/metrics/accumulation.py

Co-authored-by: vfdev <[email protected]>

* Update ignite/metrics/accuracy.py

Co-authored-by: vfdev <[email protected]>

* Update ignite/metrics/fbeta.py

Co-authored-by: vfdev <[email protected]>

* Update ignite/metrics/loss.py

Co-authored-by: vfdev <[email protected]>

* Update ignite/metrics/metric.py

Co-authored-by: vfdev <[email protected]>

* Update ignite/metrics/precision.py

Co-authored-by: vfdev <[email protected]>

* Update ignite/metrics/recall.py

Co-authored-by: vfdev <[email protected]>

* add comment explaining lack of detach in metrics docs

* support device argument for running_average

* update support for device argumenet for accumulation

* fix and improve device tests for metrics

* fix and improve device tests for metrics

* fix TPU tests

* Apply suggestions from code review

* Apply suggestions from code review

* detach tensors earlier in update

* remove redundant to() call

* ensure metrics aren't created on XLA devices

* Fixed isort

* move xla check to Metric.__init__ instead of individual metrics

* update xla tests

* replace deleted callable check

* remove redundant precision and recall __init__

* replace precision/recall __init__ for docs rendering

* add support for metrics_lambda with components on diff devices

* fix epoch_metric xla test

* detach output consistently for all metrics

* fix horovod two gpu tests

* make confusion matrix detaches like other metrics

Co-authored-by: vfdev <[email protected]>
Co-authored-by: n2cholas <[email protected]>

* Fixes failing test on TPUs

Co-authored-by: Nicholas Vadivelu <[email protected]>
Co-authored-by: AutoPEP8 <>
Co-authored-by: Sylvain Desroziers <[email protected]>
Co-authored-by: n2cholas <[email protected]>
  • Loading branch information
4 people authored Sep 11, 2020
1 parent d92f1c6 commit 002b595
Show file tree
Hide file tree
Showing 33 changed files with 1,134 additions and 458 deletions.
16 changes: 10 additions & 6 deletions docs/source/metrics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -120,21 +120,21 @@ specific condition (e.g. ignore user-defined classes):
class CustomAccuracy(Metric):
def __init__(self, ignored_class, output_transform=lambda x: x):
def __init__(self, ignored_class, output_transform=lambda x: x, device="cpu"):
self.ignored_class = ignored_class
self._num_correct = None
self._num_examples = None
super(CustomAccuracy, self).__init__(output_transform=output_transform)
super(CustomAccuracy, self).__init__(output_transform=output_transform, device=device)
@reinit__is_reduced
def reset(self):
self._num_correct = 0
self._num_correct = torch.tensor(0, device=self._device)
self._num_examples = 0
super(CustomAccuracy, self).reset()
@reinit__is_reduced
def update(self, output):
y_pred, y = output
y_pred, y = output[0].detach(), output[1].detach()
indices = torch.argmax(y_pred, dim=1)
Expand All @@ -144,21 +144,25 @@ specific condition (e.g. ignore user-defined classes):
indices = indices[mask]
correct = torch.eq(indices, y).view(-1)
self._num_correct += torch.sum(correct).item()
self._num_correct += torch.sum(correct).to(self._device)
self._num_examples += correct.shape[0]
@sync_all_reduce("_num_examples", "_num_correct")
def compute(self):
if self._num_examples == 0:
raise NotComputableError('CustomAccuracy must have at least one example before it can be computed.')
return self._num_correct / self._num_examples
return self._num_correct.item() / self._num_examples
We imported necessary classes as :class:`~ignite.metrics.Metric`, :class:`~ignite.exceptions.NotComputableError` and
decorators to adapt the metric for distributed setting. In ``reset`` method, we reset internal variables ``_num_correct``
and ``_num_examples`` which are used to compute the custom metric. In ``updated`` method we define how to update
the internal variables. And finally in ``compute`` method, we compute metric value.

Notice that ``_num_correct`` is a tensor, since in ``update`` we accumulate tensor values. ``_num_examples`` is a python
scalar since we accumulate normal integers. For differentiable metrics, you must detach the accumulated values before
adding them to the internal variables.

We can check this implementation in a simple case:

.. code-block:: python
Expand Down
35 changes: 24 additions & 11 deletions ignite/metrics/accumulation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import numbers
from typing import Any, Callable, Optional, Union
from typing import Any, Callable, Union

import torch

Expand Down Expand Up @@ -31,14 +31,19 @@ class VariableAccumulation(Metric):
:class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the
form expected by the metric. This can be useful if, for example, you have a multi-output model and
you want to compute the metric with respect to one of the outputs.
device (str of torch.device, optional): optional device specification for internal storage.
device (str or torch.device): specifies which device updates are accumulated on. Setting the metric's
device to be the same as your ``update`` arguments ensures the ``update`` method is non-blocking. By
default, CPU.
"""

_required_output_keys = None

def __init__(
self, op: Callable, output_transform: Callable = lambda x: x, device: Optional[Union[str, torch.device]] = None
self,
op: Callable,
output_transform: Callable = lambda x: x,
device: Union[str, torch.device] = torch.device("cpu"),
):
if not callable(op):
raise TypeError("Argument op should be a callable, but given {}".format(type(op)))
Expand All @@ -61,12 +66,13 @@ def _check_output_type(self, output: Union[Any, torch.Tensor, numbers.Number]) -
def update(self, output: Union[Any, torch.Tensor, numbers.Number]) -> None:
self._check_output_type(output)

if self._device is not None:
# Put output to the metric's device
if isinstance(output, torch.Tensor) and (output.device != self._device):
if isinstance(output, torch.Tensor):
output = output.detach()
if output.device != self._device:
output = output.to(self._device)

self.accumulator = self._op(self.accumulator, output)

if hasattr(output, "shape"):
self.num_examples += output.shape[0] if len(output.shape) > 1 else 1
else:
Expand Down Expand Up @@ -111,11 +117,14 @@ class Average(VariableAccumulation):
:class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the
form expected by the metric. This can be useful if, for example, you have a multi-output model and
you want to compute the metric with respect to one of the outputs.
device (str of torch.device, optional): optional device specification for internal storage.
device (str or torch.device): specifies which device updates are accumulated on. Setting the metric's
device to be the same as your ``update`` arguments ensures the ``update`` method is non-blocking. By
default, CPU.
"""

def __init__(self, output_transform: Callable = lambda x: x, device: Optional[Union[str, torch.device]] = None):
def __init__(
self, output_transform: Callable = lambda x: x, device: Union[str, torch.device] = torch.device("cpu")
):
def _mean_op(a, x):
if isinstance(x, torch.Tensor) and x.ndim > 1:
x = x.sum(dim=0)
Expand Down Expand Up @@ -155,11 +164,15 @@ class GeometricAverage(VariableAccumulation):
:class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the
form expected by the metric. This can be useful if, for example, you have a multi-output model and
you want to compute the metric with respect to one of the outputs.
device (str of torch.device, optional): optional device specification for internal storage.
device (str or torch.device): specifies which device updates are accumulated on. Setting the metric's
device to be the same as your ``update`` arguments ensures the ``update`` method is non-blocking. By
default, CPU.
"""

def __init__(self, output_transform: Callable = lambda x: x, device: Optional[Union[str, torch.device]] = None):
def __init__(
self, output_transform: Callable = lambda x: x, device: Union[str, torch.device] = torch.device("cpu")
):
def _geom_op(a: torch.Tensor, x: Union[Any, numbers.Number, torch.Tensor]) -> torch.Tensor:
if not isinstance(x, torch.Tensor):
x = torch.tensor(x)
Expand Down
22 changes: 12 additions & 10 deletions ignite/metrics/accuracy.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Callable, Optional, Sequence, Union
from typing import Callable, Sequence, Union

import torch

Expand All @@ -13,7 +13,7 @@ def __init__(
self,
output_transform: Callable = lambda x: x,
is_multilabel: bool = False,
device: Optional[Union[str, torch.device]] = None,
device: Union[str, torch.device] = torch.device("cpu"),
):
self._is_multilabel = is_multilabel
self._type = None
Expand Down Expand Up @@ -122,31 +122,33 @@ def thresholded_output_transform(output):
form expected by the metric. This can be useful if, for example, you have a multi-output model and
you want to compute the metric with respect to one of the outputs.
is_multilabel (bool, optional): flag to use in multilabel case. By default, False.
device (str of torch.device, optional): unused argument.
device (str or torch.device): specifies which device updates are accumulated on. Setting the metric's
device to be the same as your ``update`` arguments ensures the ``update`` method is non-blocking. By
default, CPU.
"""

def __init__(
self,
output_transform: Callable = lambda x: x,
is_multilabel: bool = False,
device: Optional[Union[str, torch.device]] = None,
device: Union[str, torch.device] = torch.device("cpu"),
):
self._num_correct = None
self._num_examples = None
super(Accuracy, self).__init__(output_transform=output_transform, is_multilabel=is_multilabel, device=device)

@reinit__is_reduced
def reset(self) -> None:
self._num_correct = 0
self._num_correct = torch.tensor(0, device=self._device)
self._num_examples = 0
super(Accuracy, self).reset()

@reinit__is_reduced
def update(self, output: Sequence[torch.Tensor]) -> None:
y_pred, y = output
self._check_shape((y_pred, y))
self._check_type((y_pred, y))
self._check_shape(output)
self._check_type(output)
y_pred, y = output[0].detach(), output[1].detach()

if self._type == "binary":
correct = torch.eq(y_pred.view(-1).to(y), y.view(-1))
Expand All @@ -161,11 +163,11 @@ def update(self, output: Sequence[torch.Tensor]) -> None:
y = torch.transpose(y, 1, last_dim - 1).reshape(-1, num_classes)
correct = torch.all(y == y_pred.type_as(y), dim=-1)

self._num_correct += torch.sum(correct).item()
self._num_correct += torch.sum(correct).to(self._device)
self._num_examples += correct.shape[0]

@sync_all_reduce("_num_examples", "_num_correct")
def compute(self) -> torch.Tensor:
if self._num_examples == 0:
raise NotComputableError("Accuracy must have at least one example before it can be computed.")
return self._num_correct / self._num_examples
return self._num_correct.item() / self._num_examples
10 changes: 6 additions & 4 deletions ignite/metrics/confusion_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@ class ConfusionMatrix(Metric):
:class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the
form expected by the metric. This can be useful if, for example, you have a multi-output model and
you want to compute the metric with respect to one of the outputs.
device (str of torch.device, optional): optional device specification for internal storage.
device (str or torch.device): specifies which device updates are accumulated on. Setting the metric's
device to be the same as your ``update`` arguments ensures the ``update`` method is non-blocking. By
default, CPU.
Note:
In case of the targets `y` in `(batch_size, ...)` format, target indices between 0 and `num_classes` only
Expand All @@ -44,7 +46,7 @@ def __init__(
num_classes: int,
average: Optional[str] = None,
output_transform: Callable = lambda x: x,
device: Optional[Union[str, torch.device]] = None,
device: Union[str, torch.device] = torch.device("cpu"),
):
if average is not None and average not in ("samples", "recall", "precision"):
raise ValueError("Argument average can None or one of 'samples', 'recall', 'precision'")
Expand All @@ -61,7 +63,7 @@ def reset(self) -> None:
self._num_examples = 0

def _check_shape(self, output: Sequence[torch.Tensor]) -> None:
y_pred, y = output
y_pred, y = output[0].detach(), output[1].detach()

if y_pred.ndimension() < 2:
raise ValueError(
Expand Down Expand Up @@ -92,7 +94,7 @@ def _check_shape(self, output: Sequence[torch.Tensor]) -> None:
@reinit__is_reduced
def update(self, output: Sequence[torch.Tensor]) -> None:
self._check_shape(output)
y_pred, y = output
y_pred, y = output[0].detach(), output[1].detach()

self._num_examples += y_pred.shape[0]

Expand Down
6 changes: 4 additions & 2 deletions ignite/metrics/fbeta.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def Fbeta(
precision: Optional[Precision] = None,
recall: Optional[Recall] = None,
output_transform: Optional[Callable] = None,
device: Optional[Union[str, torch.device]] = None,
device: Union[str, torch.device] = torch.device("cpu"),
) -> MetricsLambda:
"""Calculates F-beta score
Expand All @@ -28,7 +28,9 @@ def Fbeta(
output_transform (callable, optional): a callable that is used to transform the
:class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the
form expected by the metric. It is used only if precision or recall are not provided.
device (str of torch.device, optional): optional device specification for internal storage.
device (str or torch.device): specifies which device updates are accumulated on. Setting the metric's
device to be the same as your ``update`` arguments ensures the ``update`` method is non-blocking. By
default, CPU.
Returns:
MetricsLambda, F-beta metric
Expand Down
6 changes: 5 additions & 1 deletion ignite/metrics/frequency.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Callable, Optional, Union

import torch

import ignite.distributed as idist
Expand Down Expand Up @@ -35,7 +37,9 @@ class Frequency(Metric):
# Epoch [2/10]: [50/100] 50%|█████ , wps=400 [00:17<00:35]
"""

def __init__(self, output_transform=lambda x: x, device=None):
def __init__(
self, output_transform: Callable = lambda x: x, device: Union[str, torch.device] = torch.device("cpu")
):
self._timer = None
self._acc = None
self._n = None
Expand Down
16 changes: 9 additions & 7 deletions ignite/metrics/loss.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Callable, Optional, Sequence, Union
from typing import Callable, Sequence, Union

import torch

Expand Down Expand Up @@ -26,7 +26,9 @@ class Loss(Metric):
keywords arguments. If extra keywords arguments are provided they are passed to `loss_fn`.
batch_size (callable): a callable taking a target tensor that returns the
first dimension size (usually the batch size).
device (str of torch.device, optional): unused argument.
device (str or torch.device): specifies which device updates are accumulated on. Setting the
metric's device to be the same as your ``update`` arguments ensures the ``update`` method is
non-blocking. By default, CPU.
"""

Expand All @@ -37,15 +39,15 @@ def __init__(
loss_fn: Callable,
output_transform: Callable = lambda x: x,
batch_size: Callable = lambda x: len(x),
device: Optional[Union[str, torch.device]] = None,
device: Union[str, torch.device] = torch.device("cpu"),
):
super(Loss, self).__init__(output_transform, device=device)
self._loss_fn = loss_fn
self._batch_size = batch_size

@reinit__is_reduced
def reset(self) -> None:
self._sum = 0
self._sum = torch.tensor(0.0, device=self._device)
self._num_examples = 0

@reinit__is_reduced
Expand All @@ -55,17 +57,17 @@ def update(self, output: Sequence[Union[torch.Tensor, dict]]) -> None:
kwargs = {}
else:
y_pred, y, kwargs = output
average_loss = self._loss_fn(y_pred, y, **kwargs)
average_loss = self._loss_fn(y_pred.detach(), y.detach(), **kwargs)

if len(average_loss.shape) != 0:
raise ValueError("loss_fn did not return the average loss.")

n = self._batch_size(y)
self._sum += average_loss.item() * n
self._sum += average_loss.to(self._device) * n
self._num_examples += n

@sync_all_reduce("_sum", "_num_examples")
def compute(self) -> None:
if self._num_examples == 0:
raise NotComputableError("Loss must have at least one example before it can be computed.")
return self._sum / self._num_examples
return self._sum.item() / self._num_examples
8 changes: 4 additions & 4 deletions ignite/metrics/mean_absolute_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,18 @@ class MeanAbsoluteError(Metric):

@reinit__is_reduced
def reset(self) -> None:
self._sum_of_absolute_errors = 0.0
self._sum_of_absolute_errors = torch.tensor(0.0, device=self._device)
self._num_examples = 0

@reinit__is_reduced
def update(self, output: Sequence[torch.Tensor]) -> None:
y_pred, y = output
y_pred, y = output[0].detach(), output[1].detach()
absolute_errors = torch.abs(y_pred - y.view_as(y_pred))
self._sum_of_absolute_errors += torch.sum(absolute_errors).item()
self._sum_of_absolute_errors += torch.sum(absolute_errors).to(self._device)
self._num_examples += y.shape[0]

@sync_all_reduce("_sum_of_absolute_errors", "_num_examples")
def compute(self) -> Union[float, torch.Tensor]:
if self._num_examples == 0:
raise NotComputableError("MeanAbsoluteError must have at least one example before it can be computed.")
return self._sum_of_absolute_errors / self._num_examples
return self._sum_of_absolute_errors.item() / self._num_examples
12 changes: 6 additions & 6 deletions ignite/metrics/mean_pairwise_distance.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Callable, Optional, Sequence, Union
from typing import Callable, Sequence, Union

import torch
from torch.nn.functional import pairwise_distance
Expand All @@ -21,26 +21,26 @@ def __init__(
p: int = 2,
eps: float = 1e-6,
output_transform: Callable = lambda x: x,
device: Optional[Union[str, torch.device]] = None,
device: Union[str, torch.device] = torch.device("cpu"),
):
super(MeanPairwiseDistance, self).__init__(output_transform, device=device)
self._p = p
self._eps = eps

@reinit__is_reduced
def reset(self):
self._sum_of_distances = 0.0
self._sum_of_distances = torch.tensor(0.0, device=self._device)
self._num_examples = 0

@reinit__is_reduced
def update(self, output: Sequence[torch.Tensor]) -> None:
y_pred, y = output
y_pred, y = output[0].detach(), output[1].detach()
distances = pairwise_distance(y_pred, y, p=self._p, eps=self._eps)
self._sum_of_distances += torch.sum(distances).item()
self._sum_of_distances += torch.sum(distances).to(self._device)
self._num_examples += y.shape[0]

@sync_all_reduce("_sum_of_distances", "_num_examples")
def compute(self) -> Union[float, torch.Tensor]:
if self._num_examples == 0:
raise NotComputableError("MeanAbsoluteError must have at least one example before it can be computed.")
return self._sum_of_distances / self._num_examples
return self._sum_of_distances.item() / self._num_examples
Loading

0 comments on commit 002b595

Please sign in to comment.