diff --git a/ignite/metrics/accumulation.py b/ignite/metrics/accumulation.py index 926e7816bae..62708ec53a7 100644 --- a/ignite/metrics/accumulation.py +++ b/ignite/metrics/accumulation.py @@ -37,7 +37,7 @@ class VariableAccumulation(Metric): """ - _required_output_keys = None + required_output_keys = None def __init__( self, diff --git a/ignite/metrics/loss.py b/ignite/metrics/loss.py index c667cf3ade5..8fc3aaba300 100644 --- a/ignite/metrics/loss.py +++ b/ignite/metrics/loss.py @@ -32,7 +32,7 @@ class Loss(Metric): """ - _required_output_keys = None + required_output_keys = None def __init__( self, diff --git a/ignite/metrics/metric.py b/ignite/metrics/metric.py index 228a89d2d1b..272c2184e72 100644 --- a/ignite/metrics/metric.py +++ b/ignite/metrics/metric.py @@ -125,9 +125,65 @@ class Metric(metaclass=ABCMeta): device (str or torch.device): specifies which device updates are accumulated on. Setting the metric's device to be the same as your ``update`` arguments ensures the ``update`` method is non-blocking. By default, CPU. + + Class Attributes: + required_output_keys (dict): dictionary defines required keys to be found in ``engine.state.output`` if the + latter is a dictionary. This is useful with custom metrics that can require other arguments than + predictions ``y_pred`` and targets ``y``. See notes below for an example. + + Note: + + .. code-block:: python + + # https://discuss.pytorch.org/t/how-access-inputs-in-custom-ignite-metric/91221/5 + # Let's implement a custom metric that requires ``y_pred``, ``y`` and ``x`` + + import torch + import torch.nn as nn + + from ignite.metrics import Metric, Accuracy + from ignite.engine import create_supervised_evaluator + + class CustomMetric(Metric): + + required_output_keys = ("y_pred", "y", "x") + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def update(self, output): + y_pred, y, x = output + # ... + + def reset(self): + # ... + pass + + def compute(self): + # ... + pass + + model = ... + + metrics = { + "Accuracy": Accuracy(), + "CustomMetric": CustomMetric() + } + + evaluator = create_supervised_evaluator( + model, + metrics=metrics, + output_transform=lambda x, y, y_pred: {"x": x, "y": y, "y_pred": y_pred} + ) + + res = evaluator.run(data) + """ - _required_output_keys = ("y_pred", "y") + # public class attribute + required_output_keys = ("y_pred", "y") + # for backward compatibility + _required_output_keys = required_output_keys def __init__( self, output_transform: Callable = lambda x: x, device: Union[str, torch.device] = torch.device("cpu"), @@ -211,18 +267,18 @@ def iteration_completed(self, engine: Engine) -> None: output = self._output_transform(engine.state.output) if isinstance(output, Mapping): - if self._required_output_keys is None: + if self.required_output_keys is None: raise TypeError( "Transformed engine output for {} metric should be a tuple/list, but given {}".format( self.__class__.__name__, type(output) ) ) - if not all([k in output for k in self._required_output_keys]): + if not all([k in output for k in self.required_output_keys]): raise ValueError( "When transformed engine's output is a mapping, " - "it should contain {} keys, but given {}".format(self._required_output_keys, list(output.keys())) + "it should contain {} keys, but given {}".format(self.required_output_keys, list(output.keys())) ) - output = tuple(output[k] for k in self._required_output_keys) + output = tuple(output[k] for k in self.required_output_keys) self.update(output) def completed(self, engine: Engine, name: str) -> None: diff --git a/ignite/metrics/running_average.py b/ignite/metrics/running_average.py index 0fa1216c794..419743ec15a 100644 --- a/ignite/metrics/running_average.py +++ b/ignite/metrics/running_average.py @@ -44,7 +44,7 @@ def log_running_avg_metrics(engine): """ - _required_output_keys = None + required_output_keys = None def __init__( self, diff --git a/tests/ignite/metrics/test_metric.py b/tests/ignite/metrics/test_metric.py index 31395b99ec9..99317b4addf 100644 --- a/tests/ignite/metrics/test_metric.py +++ b/tests/ignite/metrics/test_metric.py @@ -67,7 +67,7 @@ def test_output_as_mapping_wrong_keys(): def test_output_as_mapping_keys_is_none(): class DummyMetric(Metric): - _required_output_keys = None + required_output_keys = None def reset(self): pass @@ -79,7 +79,7 @@ def update(self, output): pass metric = DummyMetric() - assert metric._required_output_keys is None + assert metric.required_output_keys is None state = State(output=({"y1": 0, "y2": 1})) engine = MagicMock(state=state) @@ -318,7 +318,7 @@ def process_function(*args, **kwargs): def test_detach(): class DummyMetric(Metric): - _required_output_keys = None + required_output_keys = None def reset(self): pass @@ -794,3 +794,48 @@ def _(): assert bfm[0] == 1 engine.run([0, 1, 2, 3], max_epochs=10) + + +def test_override_required_output_keys(): + # https://discuss.pytorch.org/t/how-access-inputs-in-custom-ignite-metric/91221/5 + import torch.nn as nn + + from ignite.engine import create_supervised_evaluator + + counter = [0] + + class CustomMetric(Metric): + required_output_keys = ("y_pred", "y", "x") + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def update(self, output): + y_pred, y, x = output + assert y_pred.shape == (4, 3) + assert y.shape == (4,) + assert x.shape == (4, 10) + assert x.equal(data[counter[0]][0]) + assert y.equal(data[counter[0]][1]) + counter[0] += 1 + + def reset(self): + pass + + def compute(self): + pass + + model = nn.Linear(10, 3) + + metrics = {"Precision": Precision(), "CustomMetric": CustomMetric()} + + evaluator = create_supervised_evaluator( + model, metrics=metrics, output_transform=lambda x, y, y_pred: {"x": x, "y": y, "y_pred": y_pred} + ) + + data = [ + (torch.rand(4, 10), torch.randint(0, 3, size=(4,))), + (torch.rand(4, 10), torch.randint(0, 3, size=(4,))), + (torch.rand(4, 10), torch.randint(0, 3, size=(4,))), + ] + evaluator.run(data)