Skip to content

Commit

Permalink
Fixes pytorch#1289
Browse files Browse the repository at this point in the history
- Promoted _required_output_keys to be public as user would like to override it.
  • Loading branch information
vfdev-5 committed Sep 14, 2020
1 parent 002b595 commit 7e70949
Show file tree
Hide file tree
Showing 5 changed files with 112 additions and 11 deletions.
2 changes: 1 addition & 1 deletion ignite/metrics/accumulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class VariableAccumulation(Metric):
"""

_required_output_keys = None
required_output_keys = None

def __init__(
self,
Expand Down
2 changes: 1 addition & 1 deletion ignite/metrics/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class Loss(Metric):
"""

_required_output_keys = None
required_output_keys = None

def __init__(
self,
Expand Down
66 changes: 61 additions & 5 deletions ignite/metrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,9 +125,65 @@ class Metric(metaclass=ABCMeta):
device (str or torch.device): specifies which device updates are accumulated on. Setting the
metric's device to be the same as your ``update`` arguments ensures the ``update`` method is
non-blocking. By default, CPU.
Class Attributes:
required_output_keys (dict): dictionary defines required keys to be found in ``engine.state.output`` if the
latter is a dictionary. This is useful with custom metrics that can require other arguments than
predictions ``y_pred`` and targets ``y``. See notes below for an example.
Note:
.. code-block:: python
# https://discuss.pytorch.org/t/how-access-inputs-in-custom-ignite-metric/91221/5
# Let's implement a custom metric that requires ``y_pred``, ``y`` and ``x``
import torch
import torch.nn as nn
from ignite.metrics import Metric, Accuracy
from ignite.engine import create_supervised_evaluator
class CustomMetric(Metric):
required_output_keys = ("y_pred", "y", "x")
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def update(self, output):
y_pred, y, x = output
# ...
def reset(self):
# ...
pass
def compute(self):
# ...
pass
model = ...
metrics = {
"Accuracy": Accuracy(),
"CustomMetric": CustomMetric()
}
evaluator = create_supervised_evaluator(
model,
metrics=metrics,
output_transform=lambda x, y, y_pred: {"x": x, "y": y, "y_pred": y_pred}
)
res = evaluator.run(data)
"""

_required_output_keys = ("y_pred", "y")
# public class attribute
required_output_keys = ("y_pred", "y")
# for backward compatibility
_required_output_keys = required_output_keys

def __init__(
self, output_transform: Callable = lambda x: x, device: Union[str, torch.device] = torch.device("cpu"),
Expand Down Expand Up @@ -211,18 +267,18 @@ def iteration_completed(self, engine: Engine) -> None:

output = self._output_transform(engine.state.output)
if isinstance(output, Mapping):
if self._required_output_keys is None:
if self.required_output_keys is None:
raise TypeError(
"Transformed engine output for {} metric should be a tuple/list, but given {}".format(
self.__class__.__name__, type(output)
)
)
if not all([k in output for k in self._required_output_keys]):
if not all([k in output for k in self.required_output_keys]):
raise ValueError(
"When transformed engine's output is a mapping, "
"it should contain {} keys, but given {}".format(self._required_output_keys, list(output.keys()))
"it should contain {} keys, but given {}".format(self.required_output_keys, list(output.keys()))
)
output = tuple(output[k] for k in self._required_output_keys)
output = tuple(output[k] for k in self.required_output_keys)
self.update(output)

def completed(self, engine: Engine, name: str) -> None:
Expand Down
2 changes: 1 addition & 1 deletion ignite/metrics/running_average.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def log_running_avg_metrics(engine):
"""

_required_output_keys = None
required_output_keys = None

def __init__(
self,
Expand Down
51 changes: 48 additions & 3 deletions tests/ignite/metrics/test_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def test_output_as_mapping_wrong_keys():

def test_output_as_mapping_keys_is_none():
class DummyMetric(Metric):
_required_output_keys = None
required_output_keys = None

def reset(self):
pass
Expand All @@ -79,7 +79,7 @@ def update(self, output):
pass

metric = DummyMetric()
assert metric._required_output_keys is None
assert metric.required_output_keys is None
state = State(output=({"y1": 0, "y2": 1}))
engine = MagicMock(state=state)

Expand Down Expand Up @@ -318,7 +318,7 @@ def process_function(*args, **kwargs):

def test_detach():
class DummyMetric(Metric):
_required_output_keys = None
required_output_keys = None

def reset(self):
pass
Expand Down Expand Up @@ -794,3 +794,48 @@ def _():
assert bfm[0] == 1

engine.run([0, 1, 2, 3], max_epochs=10)


def test_override_required_output_keys():
# https://discuss.pytorch.org/t/how-access-inputs-in-custom-ignite-metric/91221/5
import torch.nn as nn

from ignite.engine import create_supervised_evaluator

counter = [0]

class CustomMetric(Metric):
required_output_keys = ("y_pred", "y", "x")

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def update(self, output):
y_pred, y, x = output
assert y_pred.shape == (4, 3)
assert y.shape == (4,)
assert x.shape == (4, 10)
assert x.equal(data[counter[0]][0])
assert y.equal(data[counter[0]][1])
counter[0] += 1

def reset(self):
pass

def compute(self):
pass

model = nn.Linear(10, 3)

metrics = {"Precision": Precision(), "CustomMetric": CustomMetric()}

evaluator = create_supervised_evaluator(
model, metrics=metrics, output_transform=lambda x, y, y_pred: {"x": x, "y": y, "y_pred": y_pred}
)

data = [
(torch.rand(4, 10), torch.randint(0, 3, size=(4,))),
(torch.rand(4, 10), torch.randint(0, 3, size=(4,))),
(torch.rand(4, 10), torch.randint(0, 3, size=(4,))),
]
evaluator.run(data)

0 comments on commit 7e70949

Please sign in to comment.