Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix incorrect design of handlers (Frequency, RunningAverage, etc.) #1567

Closed
wants to merge 10 commits into from
2 changes: 1 addition & 1 deletion ignite/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from ignite.metrics.confusion_matrix import ConfusionMatrix, DiceCoefficient, IoU, mIoU
from ignite.metrics.epoch_metric import EpochMetric
from ignite.metrics.fbeta import Fbeta
from ignite.metrics.frequency import Frequency
from ignite.metrics.frequency import Frequency, FrequencyWise
from ignite.metrics.loss import Loss
from ignite.metrics.mean_absolute_error import MeanAbsoluteError
from ignite.metrics.mean_pairwise_distance import MeanPairwiseDistance
Expand Down
38 changes: 18 additions & 20 deletions ignite/metrics/frequency.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,16 @@
import torch

import ignite.distributed as idist
from ignite.engine import Engine, Events
from ignite.engine import CallableEventWithFilter, Engine, Events
from ignite.handlers.timing import Timer
from ignite.metrics.metric import EpochWise, Metric, MetricUsage, reinit__is_reduced, sync_all_reduce
from ignite.metrics.metric import Metric, MetricUsage, reinit__is_reduced, sync_all_reduce


class FrequencyWise(MetricUsage):
def __init__(self, event: CallableEventWithFilter = Events.ITERATION_COMPLETED) -> None:
super(FrequencyWise, self).__init__(
started=Events.EPOCH_STARTED, completed=event, iteration_completed=Events.ITERATION_COMPLETED,
)


class Frequency(Metric):
Expand All @@ -30,7 +37,7 @@ class Frequency(Metric):

# Compute number of tokens processed
wps_metric = Frequency(output_transform=lambda x: x['ntokens'])
wps_metric.attach(trainer, name='wps', event_name=Events.ITERATION_COMPLETED(every=50))
wps_metric.attach(trainer, name='wps', usage=FrequencyWise(Events.ITERATION_COMPLETED(every=50)))
# Logging with TQDM
ProgressBar(persist=True).attach(trainer, metric_names=['wps'])
# Progress bar will look like
Expand All @@ -41,14 +48,17 @@ def __init__(
self, output_transform: Callable = lambda x: x, device: Union[str, torch.device] = torch.device("cpu")
) -> None:
super(Frequency, self).__init__(output_transform=output_transform, device=device)
self._timer = Timer()
self._acc = 0
self._n = 0
self._elapsed = 0.0

@reinit__is_reduced
def reset(self) -> None:
self._timer = Timer()
self._acc = 0
self._n = 0
self._elapsed = 0.0
super(Frequency, self).reset()

@reinit__is_reduced
def update(self, output: int) -> None:
Expand All @@ -58,21 +68,9 @@ def update(self, output: int) -> None:

@sync_all_reduce("_n", "_elapsed")
def compute(self) -> float:
time_divisor = 1.0

if idist.get_world_size() > 1:
time_divisor *= idist.get_world_size()

# Returns the average processed objects per second across all workers
return self._n / self._elapsed * time_divisor
return int(self._n / self._elapsed * idist.get_world_size())

def completed(self, engine: Engine, name: str) -> None:
engine.state.metrics[name] = int(self.compute())

# TODO: see issue https://github.com/pytorch/ignite/issues/1405
def attach( # type: ignore
self, engine: Engine, name: str, event_name: Events = Events.ITERATION_COMPLETED
) -> None:
engine.add_event_handler(Events.EPOCH_STARTED, self.started)
engine.add_event_handler(Events.ITERATION_COMPLETED, self.iteration_completed)
engine.add_event_handler(event_name, self.completed, name)
# override the method attach() of Metrics to define a different default value for usage
def attach(self, engine: Engine, name: str, usage: Union[str, MetricUsage] = FrequencyWise()) -> None:
sdesrozis marked this conversation as resolved.
Show resolved Hide resolved
super(Frequency, self).attach(engine, name, usage)
11 changes: 8 additions & 3 deletions ignite/metrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,17 +31,22 @@ class MetricUsage:
:meth:`~ignite.metrics.Metric.iteration_completed`.
"""

def __init__(self, started: Events, completed: Events, iteration_completed: CallableEventWithFilter) -> None:
def __init__(
self,
started: CallableEventWithFilter,
completed: CallableEventWithFilter,
iteration_completed: CallableEventWithFilter,
) -> None:
self.__started = started
self.__completed = completed
self.__iteration_completed = iteration_completed

@property
def STARTED(self) -> Events:
def STARTED(self) -> CallableEventWithFilter:
return self.__started

@property
def COMPLETED(self) -> Events:
def COMPLETED(self) -> CallableEventWithFilter:
return self.__completed

@property
Expand Down
4 changes: 2 additions & 2 deletions tests/ignite/metrics/test_frequency.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import ignite.distributed as idist
from ignite.engine import Engine, Events
from ignite.metrics import Frequency
from ignite.metrics import Frequency, FrequencyWise

if sys.platform.startswith("darwin"):
pytest.skip("Skip if on MacOS", allow_module_level=True)
Expand Down Expand Up @@ -45,7 +45,7 @@ def update_fn(engine, batch):
engine = Engine(update_fn)
wps_metric = Frequency(output_transform=lambda x: x["ntokens"])
event = Events.ITERATION_COMPLETED(every=every)
wps_metric.attach(engine, "wps", event_name=event)
wps_metric.attach(engine, "wps", usage=FrequencyWise(event))

@engine.on(event)
def assert_wps(e):
Expand Down