From a9e49ffbe388cea7373f072a8b2bbdca809fea53 Mon Sep 17 00:00:00 2001 From: Desroziers Date: Fri, 22 Jan 2021 09:51:30 +0100 Subject: [PATCH 1/5] fix Frequency metric - introduce FrequencyWise usage --- ignite/metrics/__init__.py | 2 +- ignite/metrics/frequency.py | 41 +++++++++++++------------- ignite/metrics/metric.py | 7 +++-- tests/ignite/metrics/test_frequency.py | 4 +-- 4 files changed, 28 insertions(+), 26 deletions(-) diff --git a/ignite/metrics/__init__.py b/ignite/metrics/__init__.py index de0415bbe6c..96824764e0c 100644 --- a/ignite/metrics/__init__.py +++ b/ignite/metrics/__init__.py @@ -3,7 +3,7 @@ from ignite.metrics.confusion_matrix import ConfusionMatrix, DiceCoefficient, IoU, mIoU from ignite.metrics.epoch_metric import EpochMetric from ignite.metrics.fbeta import Fbeta -from ignite.metrics.frequency import Frequency +from ignite.metrics.frequency import Frequency, FrequencyWise from ignite.metrics.loss import Loss from ignite.metrics.mean_absolute_error import MeanAbsoluteError from ignite.metrics.mean_pairwise_distance import MeanPairwiseDistance diff --git a/ignite/metrics/frequency.py b/ignite/metrics/frequency.py index b34c3371064..0f515ede92e 100644 --- a/ignite/metrics/frequency.py +++ b/ignite/metrics/frequency.py @@ -3,9 +3,19 @@ import torch import ignite.distributed as idist -from ignite.engine import Engine, Events +from ignite.engine import Engine, Events, CallableEventWithFilter from ignite.handlers.timing import Timer -from ignite.metrics.metric import EpochWise, Metric, MetricUsage, reinit__is_reduced, sync_all_reduce +from ignite.metrics.metric import Metric, MetricUsage, reinit__is_reduced, sync_all_reduce + + +class FrequencyWise(MetricUsage): + + def __init__(self, event: CallableEventWithFilter = Events.ITERATION_COMPLETED) -> None: + super(FrequencyWise, self).__init__( + started=Events.EPOCH_STARTED, + completed=event, + iteration_completed=Events.ITERATION_COMPLETED, + ) class Frequency(Metric): @@ -30,7 +40,7 @@ class Frequency(Metric): # Compute number of tokens processed wps_metric = Frequency(output_transform=lambda x: x['ntokens']) - wps_metric.attach(trainer, name='wps', event_name=Events.ITERATION_COMPLETED(every=50)) + wps_metric.attach(trainer, name='wps', usage=FrequencyWise(Events.ITERATION_COMPLETED(every=50))) # Logging with TQDM ProgressBar(persist=True).attach(trainer, metric_names=['wps']) # Progress bar will look like @@ -41,6 +51,10 @@ def __init__( self, output_transform: Callable = lambda x: x, device: Union[str, torch.device] = torch.device("cpu") ) -> None: super(Frequency, self).__init__(output_transform=output_transform, device=device) + self._timer = Timer() + self._acc = 0 + self._n = 0 + self._elapsed = 0.0 @reinit__is_reduced def reset(self) -> None: @@ -48,7 +62,6 @@ def reset(self) -> None: self._acc = 0 self._n = 0 self._elapsed = 0.0 - super(Frequency, self).reset() @reinit__is_reduced def update(self, output: int) -> None: @@ -58,21 +71,9 @@ def update(self, output: int) -> None: @sync_all_reduce("_n", "_elapsed") def compute(self) -> float: - time_divisor = 1.0 - - if idist.get_world_size() > 1: - time_divisor *= idist.get_world_size() - # Returns the average processed objects per second across all workers - return self._n / self._elapsed * time_divisor + return int(self._n / self._elapsed * idist.get_world_size()) - def completed(self, engine: Engine, name: str) -> None: - engine.state.metrics[name] = int(self.compute()) - - # TODO: see issue https://github.com/pytorch/ignite/issues/1405 - def attach( # type: ignore - self, engine: Engine, name: str, event_name: Events = Events.ITERATION_COMPLETED - ) -> None: - engine.add_event_handler(Events.EPOCH_STARTED, self.started) - engine.add_event_handler(Events.ITERATION_COMPLETED, self.iteration_completed) - engine.add_event_handler(event_name, self.completed, name) + # override the method attach() of Metrics to define a different default value for usage + def attach(self, engine: Engine, name: str, usage: Union[str, MetricUsage] = FrequencyWise()) -> None: + super(Frequency, self).attach(engine, name, usage) diff --git a/ignite/metrics/metric.py b/ignite/metrics/metric.py index f36dcc2ebe1..a013d5ef414 100644 --- a/ignite/metrics/metric.py +++ b/ignite/metrics/metric.py @@ -31,17 +31,18 @@ class MetricUsage: :meth:`~ignite.metrics.Metric.iteration_completed`. """ - def __init__(self, started: Events, completed: Events, iteration_completed: CallableEventWithFilter) -> None: + def __init__(self, started: CallableEventWithFilter, completed: CallableEventWithFilter, + iteration_completed: CallableEventWithFilter) -> None: self.__started = started self.__completed = completed self.__iteration_completed = iteration_completed @property - def STARTED(self) -> Events: + def STARTED(self) -> CallableEventWithFilter: return self.__started @property - def COMPLETED(self) -> Events: + def COMPLETED(self) -> CallableEventWithFilter: return self.__completed @property diff --git a/tests/ignite/metrics/test_frequency.py b/tests/ignite/metrics/test_frequency.py index 98286d03216..b32372352b4 100644 --- a/tests/ignite/metrics/test_frequency.py +++ b/tests/ignite/metrics/test_frequency.py @@ -7,7 +7,7 @@ import ignite.distributed as idist from ignite.engine import Engine, Events -from ignite.metrics import Frequency +from ignite.metrics import Frequency, FrequencyWise if sys.platform.startswith("darwin"): pytest.skip("Skip if on MacOS", allow_module_level=True) @@ -45,7 +45,7 @@ def update_fn(engine, batch): engine = Engine(update_fn) wps_metric = Frequency(output_transform=lambda x: x["ntokens"]) event = Events.ITERATION_COMPLETED(every=every) - wps_metric.attach(engine, "wps", event_name=event) + wps_metric.attach(engine, "wps", usage=FrequencyWise(event)) @engine.on(event) def assert_wps(e): From 1157d907f614f04de5518203c82dc33759e2322b Mon Sep 17 00:00:00 2001 From: sdesrozis Date: Fri, 22 Jan 2021 08:53:24 +0000 Subject: [PATCH 2/5] autopep8 fix --- ignite/metrics/frequency.py | 7 ++----- ignite/metrics/metric.py | 8 ++++++-- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/ignite/metrics/frequency.py b/ignite/metrics/frequency.py index 0f515ede92e..6d9aacd40bb 100644 --- a/ignite/metrics/frequency.py +++ b/ignite/metrics/frequency.py @@ -3,18 +3,15 @@ import torch import ignite.distributed as idist -from ignite.engine import Engine, Events, CallableEventWithFilter +from ignite.engine import CallableEventWithFilter, Engine, Events from ignite.handlers.timing import Timer from ignite.metrics.metric import Metric, MetricUsage, reinit__is_reduced, sync_all_reduce class FrequencyWise(MetricUsage): - def __init__(self, event: CallableEventWithFilter = Events.ITERATION_COMPLETED) -> None: super(FrequencyWise, self).__init__( - started=Events.EPOCH_STARTED, - completed=event, - iteration_completed=Events.ITERATION_COMPLETED, + started=Events.EPOCH_STARTED, completed=event, iteration_completed=Events.ITERATION_COMPLETED, ) diff --git a/ignite/metrics/metric.py b/ignite/metrics/metric.py index a013d5ef414..034a290ffbf 100644 --- a/ignite/metrics/metric.py +++ b/ignite/metrics/metric.py @@ -31,8 +31,12 @@ class MetricUsage: :meth:`~ignite.metrics.Metric.iteration_completed`. """ - def __init__(self, started: CallableEventWithFilter, completed: CallableEventWithFilter, - iteration_completed: CallableEventWithFilter) -> None: + def __init__( + self, + started: CallableEventWithFilter, + completed: CallableEventWithFilter, + iteration_completed: CallableEventWithFilter, + ) -> None: self.__started = started self.__completed = completed self.__iteration_completed = iteration_completed From 706c6c80e3f70facf4036df277577e420376d515 Mon Sep 17 00:00:00 2001 From: Desroziers Date: Mon, 25 Jan 2021 11:59:47 +0100 Subject: [PATCH 3/5] move _check_usage to static method --- ignite/metrics/metric.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ignite/metrics/metric.py b/ignite/metrics/metric.py index 034a290ffbf..503bc2bb5bf 100644 --- a/ignite/metrics/metric.py +++ b/ignite/metrics/metric.py @@ -320,7 +320,8 @@ def completed(self, engine: Engine, name: str) -> None: engine.state.metrics[name] = result - def _check_usage(self, usage: Union[str, MetricUsage]) -> MetricUsage: + @staticmethod + def _check_usage(usage: Union[str, MetricUsage]) -> MetricUsage: if isinstance(usage, str): if usage == EpochWise.usage_name: usage = EpochWise() From 17426a7e324bbfe9297c32ab8f01aab6fa9103ca Mon Sep 17 00:00:00 2001 From: Desroziers Date: Mon, 25 Jan 2021 12:00:13 +0100 Subject: [PATCH 4/5] introduce RunningAverageWise --- ignite/metrics/running_average.py | 25 +++++++++++++++---------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/ignite/metrics/running_average.py b/ignite/metrics/running_average.py index db36b1a9048..18ecee0a497 100644 --- a/ignite/metrics/running_average.py +++ b/ignite/metrics/running_average.py @@ -4,11 +4,20 @@ import ignite.distributed as idist from ignite.engine import Engine, Events -from ignite.metrics.metric import EpochWise, Metric, MetricUsage, reinit__is_reduced, sync_all_reduce +from ignite.metrics.metric import Metric, MetricUsage, reinit__is_reduced, sync_all_reduce __all__ = ["RunningAverage"] +class RunningAverageWise(MetricUsage): + def __init__(self) -> None: + super(RunningAverageWise, self).__init__( + started=Events.EPOCH_STARTED, + completed=Events.ITERATION_COMPLETED, + iteration_completed=Events.ITERATION_COMPLETED, + ) + + class RunningAverage(Metric): """Compute running average of a metric or the output of process function. @@ -82,10 +91,12 @@ def __init__( self.alpha = alpha self.epoch_bound = epoch_bound super(RunningAverage, self).__init__(output_transform=output_transform, device=device) # type: ignore[arg-type] + self._value = None # type: Optional[Union[float, torch.Tensor]] @reinit__is_reduced def reset(self) -> None: - self._value = None # type: Optional[Union[float, torch.Tensor]] + if self.epoch_bound: + self._value = None @reinit__is_reduced def update(self, output: Sequence) -> None: @@ -100,14 +111,8 @@ def compute(self) -> Union[torch.Tensor, float]: return self._value - def attach(self, engine: Engine, name: str, _usage: Union[str, MetricUsage] = EpochWise()) -> None: - if self.epoch_bound: - # restart average every epoch - engine.add_event_handler(Events.EPOCH_STARTED, self.started) - # compute metric - engine.add_event_handler(Events.ITERATION_COMPLETED, self.iteration_completed) - # apply running average - engine.add_event_handler(Events.ITERATION_COMPLETED, self.completed, name) + def attach(self, engine: Engine, name: str, usage: Union[str, MetricUsage] = RunningAverageWise()) -> None: + super(RunningAverage, self).attach(engine, name, usage) def _get_metric_value(self) -> Union[torch.Tensor, float]: return self.src.compute() From 89d4312eb31983aa1fdad7afa5491b01de61d2d4 Mon Sep 17 00:00:00 2001 From: Desroziers Date: Tue, 2 Feb 2021 09:15:10 +0100 Subject: [PATCH 5/5] fix for checkpoint --- ignite/handlers/checkpoint.py | 28 ++++++++++-------------- tests/ignite/handlers/test_checkpoint.py | 3 --- 2 files changed, 12 insertions(+), 19 deletions(-) diff --git a/ignite/handlers/checkpoint.py b/ignite/handlers/checkpoint.py index e95d5f45ddc..016bf943307 100644 --- a/ignite/handlers/checkpoint.py +++ b/ignite/handlers/checkpoint.py @@ -256,7 +256,7 @@ def score_function(engine): def __init__( self, - to_save: Optional[Mapping], + to_save: Mapping, save_handler: Union[Callable, BaseSaveHandler], filename_prefix: str = "", score_function: Optional[Callable] = None, @@ -268,23 +268,19 @@ def __init__( greater_or_equal: bool = False, ) -> None: - if to_save is not None: # for compatibility with ModelCheckpoint - if not isinstance(to_save, collections.Mapping): - raise TypeError(f"Argument `to_save` should be a dictionary, but given {type(to_save)}") + if not isinstance(to_save, collections.Mapping): + raise TypeError(f"Argument `to_save` should be a dictionary, but given {type(to_save)}") - if len(to_save) < 1: - raise ValueError("No objects to checkpoint.") - - self._check_objects(to_save, "state_dict") + self._check_objects(to_save, "state_dict") - if include_self: - if not isinstance(to_save, collections.MutableMapping): - raise TypeError( - f"If `include_self` is True, then `to_save` must be mutable, but given {type(to_save)}." - ) + if include_self: + if not isinstance(to_save, collections.MutableMapping): + raise TypeError( + f"If `include_self` is True, then `to_save` must be mutable, but given {type(to_save)}." + ) - if "checkpointer" in to_save: - raise ValueError(f"Cannot have key 'checkpointer' if `include_self` is True: {to_save}") + if "checkpointer" in to_save: + raise ValueError(f"Cannot have key 'checkpointer' if `include_self` is True: {to_save}") if not (callable(save_handler) or isinstance(save_handler, BaseSaveHandler)): raise TypeError("Argument `save_handler` should be callable or inherit from BaseSaveHandler") @@ -746,7 +742,7 @@ def __init__( disk_saver = DiskSaver(dirname, atomic=atomic, create_dir=create_dir, require_empty=require_empty, **kwargs) super(ModelCheckpoint, self).__init__( - to_save=None, + to_save={}, save_handler=disk_saver, filename_prefix=filename_prefix, score_function=score_function, diff --git a/tests/ignite/handlers/test_checkpoint.py b/tests/ignite/handlers/test_checkpoint.py index e5cd023ed34..bd36f16b753 100644 --- a/tests/ignite/handlers/test_checkpoint.py +++ b/tests/ignite/handlers/test_checkpoint.py @@ -45,9 +45,6 @@ def test_checkpoint_wrong_input(): with pytest.raises(TypeError, match=r"Argument `to_save` should be a dictionary"): Checkpoint([12], lambda x: x, "prefix") - with pytest.raises(ValueError, match=r"No objects to checkpoint."): - Checkpoint({}, lambda x: x, "prefix") - model = DummyModel() to_save = {"model": model}