From 2de056f1a3011e3283fdeb6048c44033eba3c37c Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Mon, 31 May 2021 11:18:56 +0000 Subject: [PATCH] Revert "Adding max_iters as an optional arg in Engine run (#1381)" This reverts commit 307ac11a87cd98b8deeb128e3dea992e4178800b. --- ignite/engine/engine.py | 35 ++++++----------------------------- ignite/engine/events.py | 2 -- 2 files changed, 6 insertions(+), 31 deletions(-) diff --git a/ignite/engine/engine.py b/ignite/engine/engine.py index c9b958e2dd17..6b2d36997abb 100644 --- a/ignite/engine/engine.py +++ b/ignite/engine/engine.py @@ -1,6 +1,5 @@ import functools import logging -import math import time import warnings import weakref @@ -680,7 +679,7 @@ def load_state_dict(self, state_dict: Mapping) -> None: `seed`. If `engine.state_dict_user_keys` contains keys, they should be also present in the state dictionary. Iteration and epoch values are 0-based: the first iteration or epoch is zero. - This method does not remove any custom attributes added by user. + This method does not remove any custom attributs added by user. Args: state_dict: a dict with parameters @@ -725,14 +724,7 @@ def load_state_dict(self, state_dict: Mapping) -> None: @staticmethod def _is_done(state: State) -> bool: - is_done_iters = state.max_iters is not None and state.iteration >= state.max_iters - is_done_count = ( - state.epoch_length is not None - and state.max_epochs is not None - and state.iteration >= state.epoch_length * state.max_epochs - ) - is_done_epochs = state.max_epochs is not None and state.epoch >= state.max_epochs - return is_done_iters or is_done_count or is_done_epochs + return state.iteration == state.epoch_length * state.max_epochs # type: ignore[operator] def set_data(self, data: Union[Iterable, DataLoader]) -> None: """Method to set data. After calling the method the next batch passed to `processing_function` is @@ -774,7 +766,6 @@ def run( self, data: Optional[Iterable] = None, max_epochs: Optional[int] = None, - max_iters: Optional[int] = None, epoch_length: Optional[int] = None, seed: Optional[int] = None, ) -> State: @@ -782,7 +773,7 @@ def run( Engine has a state and the following logic is applied in this function: - - At the first call, new state is defined by `max_epochs`, `max_iters`, `epoch_length`, `seed`, if provided. + - At the first call, new state is defined by `max_epochs`, `epoch_length`, `seed`, if provided. A timer for total and per-epoch time is initialized when Events.STARTED is handled. - If state is already defined such that there are iterations to run until `max_epochs` and no input arguments provided, state is kept and used in the function. @@ -800,8 +791,6 @@ def run( `len(data)`. If `data` is an iterator and `epoch_length` is not set, then it will be automatically determined as the iteration on which data iterator raises `StopIteration`. This argument should not change if run is resuming from a state. - max_iters: Number of iterations to run for. - `max_iters` and `max_epochs` are mutually exclusive; only one of the two arguments should be provided. seed: Deprecated argument. Please, use `torch.manual_seed` or :meth:`~ignite.utils.manual_seed`. Returns: @@ -860,6 +849,8 @@ def switch_batch(engine): if self.state.max_epochs is None or (self._is_done(self.state) and self._internal_run_generator is None): # Create new state + if max_epochs is None: + max_epochs = 1 if epoch_length is None: if data is None: raise ValueError("epoch_length should be provided if data is None") @@ -868,22 +859,9 @@ def switch_batch(engine): if epoch_length is not None and epoch_length < 1: raise ValueError("Input data has zero size. Please provide non-empty data") - if max_iters is None: - if max_epochs is None: - max_epochs = 1 - else: - if max_epochs is not None: - raise ValueError( - "Arguments max_iters and max_epochs are mutually exclusive." - "Please provide only max_epochs or max_iters." - ) - if epoch_length is not None: - max_epochs = math.ceil(max_iters / epoch_length) - self.state.iteration = 0 self.state.epoch = 0 self.state.max_epochs = max_epochs - self.state.max_iters = max_iters self.state.epoch_length = epoch_length # Reset generator if previously used self._internal_run_generator = None @@ -978,6 +956,7 @@ def _internal_run_as_gen(self) -> Generator: self.state.times[Events.EPOCH_COMPLETED.name] = epoch_time_taken handlers_start_time = time.time() + self._fire_event(Events.EPOCH_COMPLETED) epoch_time_taken += time.time() - handlers_start_time # update time wrt handlers @@ -1056,8 +1035,6 @@ def _run_once_on_dataset_as_gen(self) -> Generator[State, None, float]: if self.state.epoch_length is None: # Define epoch length and stop the epoch self.state.epoch_length = iter_counter - if self.state.max_iters is not None: - self.state.max_epochs = math.ceil(self.state.max_iters / self.state.epoch_length) break # Should exit while loop if we can not iterate diff --git a/ignite/engine/events.py b/ignite/engine/events.py index 5d46b38a0dfb..b539c73b4c40 100644 --- a/ignite/engine/events.py +++ b/ignite/engine/events.py @@ -451,7 +451,6 @@ class State: state.dataloader # data passed to engine state.epoch_length # optional length of an epoch state.max_epochs # number of epochs to run - state.max_iters # number of iterations to run state.batch # batch passed to `process_function` state.output # output of `process_function` after a single iteration state.metrics # dictionary with defined metrics if any @@ -478,7 +477,6 @@ def __init__(self, **kwargs: Any) -> None: self.epoch = 0 self.epoch_length: Optional[int] = None self.max_epochs: Optional[int] = None - self.max_iters: Optional[int] = None self.output: Optional[int] = None self.batch: Optional[int] = None self.metrics: Dict[str, Any] = {}