diff --git a/ignite/engine/engine.py b/ignite/engine/engine.py index 59b7d9747ee..4b343465a7b 100644 --- a/ignite/engine/engine.py +++ b/ignite/engine/engine.py @@ -1,5 +1,6 @@ import functools import logging +import math import time import warnings import weakref @@ -510,7 +511,7 @@ def load_state_dict(self, state_dict: Mapping) -> None: If `engine.state_dict_user_keys` contains keys, they should be also present in the state dictionary. Iteration and epoch values are 0-based: the first iteration or epoch is zero. - This method does not remove any custom attributs added by user. + This method does not remove any custom attributes added by user. Args: state_dict (Mapping): a dict with parameters @@ -557,7 +558,13 @@ def load_state_dict(self, state_dict: Mapping) -> None: @staticmethod def _is_done(state: State) -> bool: - return state.iteration == state.epoch_length * state.max_epochs # type: ignore[operator] + is_done_iters = state.max_iters is not None and state.iteration >= state.max_iters + is_done_count = ( + state.epoch_length is not None + and state.iteration >= state.epoch_length * state.max_epochs # type: ignore[operator] + ) + is_done_epochs = state.max_epochs is not None and state.epoch >= state.max_epochs + return is_done_iters or is_done_count or is_done_epochs def set_data(self, data: Union[Iterable, DataLoader]) -> None: """Method to set data. After calling the method the next batch passed to `processing_function` is @@ -595,13 +602,19 @@ def switch_dataloader(): self.state.dataloader = data self._dataloader_iter = iter(self.state.dataloader) - def run(self, data: Iterable, max_epochs: Optional[int] = None, epoch_length: Optional[int] = None,) -> State: + def run( + self, + data: Iterable, + max_epochs: Optional[int] = None, + max_iters: Optional[int] = None, + epoch_length: Optional[int] = None, + ) -> State: """Runs the `process_function` over the passed data. Engine has a state and the following logic is applied in this function: - - At the first call, new state is defined by `max_epochs`, `epoch_length` if provided. A timer for - total and per-epoch time is initialized when Events.STARTED is handled. + - At the first call, new state is defined by `max_epochs`, `max_iters`, `epoch_length`, if provided. + A timer for total and per-epoch time is initialized when Events.STARTED is handled. - If state is already defined such that there are iterations to run until `max_epochs` and no input arguments provided, state is kept and used in the function. - If state is defined and engine is "done" (no iterations to run until `max_epochs`), a new state is defined. @@ -617,6 +630,8 @@ def run(self, data: Iterable, max_epochs: Optional[int] = None, epoch_length: Op `len(data)`. If `data` is an iterator and `epoch_length` is not set, then it will be automatically determined as the iteration on which data iterator raises `StopIteration`. This argument should not change if run is resuming from a state. + max_iters (int, optional): Number of iterations to run for. + `max_iters` and `max_epochs` are mutually exclusive; only one of the two arguments should be provided. Returns: State: output state. @@ -670,16 +685,27 @@ def switch_batch(engine): if self.state.max_epochs is None or self._is_done(self.state): # Create new state - if max_epochs is None: - max_epochs = 1 if epoch_length is None: epoch_length = self._get_data_length(data) if epoch_length is not None and epoch_length < 1: raise ValueError("Input data has zero size. Please provide non-empty data") + if max_iters is None: + if max_epochs is None: + max_epochs = 1 + else: + if max_epochs is not None: + raise ValueError( + "Arguments max_iters and max_epochs are mutually exclusive." + "Please provide only max_epochs or max_iters." + ) + if epoch_length is not None: + max_epochs = math.ceil(max_iters / epoch_length) + self.state.iteration = 0 self.state.epoch = 0 self.state.max_epochs = max_epochs + self.state.max_iters = max_iters self.state.epoch_length = epoch_length self.logger.info("Engine run starting with max_epochs={}.".format(max_epochs)) else: @@ -726,7 +752,7 @@ def _internal_run(self) -> State: try: start_time = time.time() self._fire_event(Events.STARTED) - while self.state.epoch < self.state.max_epochs and not self.should_terminate: # type: ignore[operator] + while not self._is_done(self.state) and not self.should_terminate: self.state.epoch += 1 self._fire_event(Events.EPOCH_STARTED) @@ -800,6 +826,8 @@ def _run_once_on_dataset(self) -> float: if self.state.epoch_length is None: # Define epoch length and stop the epoch self.state.epoch_length = iter_counter + if self.state.max_iters is not None: + self.state.max_epochs = math.ceil(self.state.max_iters / self.state.epoch_length) break # Should exit while loop if we can not iterate @@ -839,6 +867,10 @@ def _run_once_on_dataset(self) -> float: if self.state.epoch_length is not None and iter_counter == self.state.epoch_length: break + if self.state.max_iters is not None and self.state.iteration == self.state.max_iters: + self.should_terminate = True + break + except Exception as e: self.logger.error("Current run is terminating due to exception: %s.", str(e)) self._handle_exception(e) diff --git a/ignite/engine/events.py b/ignite/engine/events.py index abb0e8e6404..cd8f14da7d6 100644 --- a/ignite/engine/events.py +++ b/ignite/engine/events.py @@ -344,6 +344,7 @@ class State: state.dataloader # data passed to engine state.epoch_length # optional length of an epoch state.max_epochs # number of epochs to run + state.max_iter # number of iterations to run state.batch # batch passed to `process_function` state.output # output of `process_function` after a single iteration state.metrics # dictionary with defined metrics if any @@ -368,6 +369,7 @@ def __init__(self, **kwargs: Any) -> None: self.epoch = 0 self.epoch_length = None # type: Optional[int] self.max_epochs = None # type: Optional[int] + self.max_iters = None # type: Optional[int] self.output = None # type: Optional[int] self.batch = None # type: Optional[int] self.metrics = {} # type: Dict[str, Any] diff --git a/tests/ignite/engine/test_engine.py b/tests/ignite/engine/test_engine.py index 43701b4b1f4..d00755a91c8 100644 --- a/tests/ignite/engine/test_engine.py +++ b/tests/ignite/engine/test_engine.py @@ -891,3 +891,49 @@ def switch_dataloader(): trainer.set_data(data2) trainer.run(data1, max_epochs=10) + + +def test_run_with_max_iters(): + max_iters = 8 + engine = Engine(lambda e, b: 1) + engine.run([0] * 20, max_iters=max_iters) + assert engine.state.iteration == max_iters + assert engine.state.max_iters == max_iters + + +def test_run_with_max_iters_greater_than_epoch_length(): + max_iters = 73 + engine = Engine(lambda e, b: 1) + engine.run([0] * 20, max_iters=max_iters) + assert engine.state.iteration == max_iters + + +def test_run_with_invalid_max_iters_and_max_epoch(): + max_iters = 12 + max_epochs = 2 + engine = Engine(lambda e, b: 1) + with pytest.raises( + ValueError, + match=r"Arguments max_iters and max_epochs are mutually exclusive." + "Please provide only max_epochs or max_iters.", + ): + engine.run([0] * 20, max_iters=max_iters, max_epochs=max_epochs) + + +def test_epoch_events_fired(): + max_iters = 32 + engine = Engine(lambda e, b: 1) + + @engine.on(Events.EPOCH_COMPLETED) + def fired_event(engine): + assert engine.state.iteration % engine.state.epoch_length == 0 + + engine.run([0] * 10, max_iters=max_iters) + + +def test_is_done_with_max_iters(): + state = State(iteration=100, epoch=1, max_epochs=3, epoch_length=100, max_iters=250) + assert not Engine._is_done(state) + + state = State(iteration=250, epoch=1, max_epochs=3, epoch_length=100, max_iters=250) + assert Engine._is_done(state)