Skip to content

Commit

Permalink
Revert "Adding max_iters as an optional arg in Engine run (#1381)"
Browse files Browse the repository at this point in the history
This reverts commit 307ac11.
  • Loading branch information
vfdev-5 committed Feb 17, 2023
1 parent 1626989 commit 2de056f
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 31 deletions.
35 changes: 6 additions & 29 deletions ignite/engine/engine.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import functools
import logging
import math
import time
import warnings
import weakref
Expand Down Expand Up @@ -680,7 +679,7 @@ def load_state_dict(self, state_dict: Mapping) -> None:
`seed`. If `engine.state_dict_user_keys` contains keys, they should be also present in the state dictionary.
Iteration and epoch values are 0-based: the first iteration or epoch is zero.
This method does not remove any custom attributes added by user.
This method does not remove any custom attributs added by user.
Args:
state_dict: a dict with parameters
Expand Down Expand Up @@ -725,14 +724,7 @@ def load_state_dict(self, state_dict: Mapping) -> None:

@staticmethod
def _is_done(state: State) -> bool:
is_done_iters = state.max_iters is not None and state.iteration >= state.max_iters
is_done_count = (
state.epoch_length is not None
and state.max_epochs is not None
and state.iteration >= state.epoch_length * state.max_epochs
)
is_done_epochs = state.max_epochs is not None and state.epoch >= state.max_epochs
return is_done_iters or is_done_count or is_done_epochs
return state.iteration == state.epoch_length * state.max_epochs # type: ignore[operator]

def set_data(self, data: Union[Iterable, DataLoader]) -> None:
"""Method to set data. After calling the method the next batch passed to `processing_function` is
Expand Down Expand Up @@ -774,15 +766,14 @@ def run(
self,
data: Optional[Iterable] = None,
max_epochs: Optional[int] = None,
max_iters: Optional[int] = None,
epoch_length: Optional[int] = None,
seed: Optional[int] = None,
) -> State:
"""Runs the ``process_function`` over the passed data.
Engine has a state and the following logic is applied in this function:
- At the first call, new state is defined by `max_epochs`, `max_iters`, `epoch_length`, `seed`, if provided.
- At the first call, new state is defined by `max_epochs`, `epoch_length`, `seed`, if provided.
A timer for total and per-epoch time is initialized when Events.STARTED is handled.
- If state is already defined such that there are iterations to run until `max_epochs` and no input arguments
provided, state is kept and used in the function.
Expand All @@ -800,8 +791,6 @@ def run(
`len(data)`. If `data` is an iterator and `epoch_length` is not set, then it will be automatically
determined as the iteration on which data iterator raises `StopIteration`.
This argument should not change if run is resuming from a state.
max_iters: Number of iterations to run for.
`max_iters` and `max_epochs` are mutually exclusive; only one of the two arguments should be provided.
seed: Deprecated argument. Please, use `torch.manual_seed` or :meth:`~ignite.utils.manual_seed`.
Returns:
Expand Down Expand Up @@ -860,6 +849,8 @@ def switch_batch(engine):

if self.state.max_epochs is None or (self._is_done(self.state) and self._internal_run_generator is None):
# Create new state
if max_epochs is None:
max_epochs = 1
if epoch_length is None:
if data is None:
raise ValueError("epoch_length should be provided if data is None")
Expand All @@ -868,22 +859,9 @@ def switch_batch(engine):
if epoch_length is not None and epoch_length < 1:
raise ValueError("Input data has zero size. Please provide non-empty data")

if max_iters is None:
if max_epochs is None:
max_epochs = 1
else:
if max_epochs is not None:
raise ValueError(
"Arguments max_iters and max_epochs are mutually exclusive."
"Please provide only max_epochs or max_iters."
)
if epoch_length is not None:
max_epochs = math.ceil(max_iters / epoch_length)

self.state.iteration = 0
self.state.epoch = 0
self.state.max_epochs = max_epochs
self.state.max_iters = max_iters
self.state.epoch_length = epoch_length
# Reset generator if previously used
self._internal_run_generator = None
Expand Down Expand Up @@ -978,6 +956,7 @@ def _internal_run_as_gen(self) -> Generator:
self.state.times[Events.EPOCH_COMPLETED.name] = epoch_time_taken

handlers_start_time = time.time()

self._fire_event(Events.EPOCH_COMPLETED)
epoch_time_taken += time.time() - handlers_start_time
# update time wrt handlers
Expand Down Expand Up @@ -1056,8 +1035,6 @@ def _run_once_on_dataset_as_gen(self) -> Generator[State, None, float]:
if self.state.epoch_length is None:
# Define epoch length and stop the epoch
self.state.epoch_length = iter_counter
if self.state.max_iters is not None:
self.state.max_epochs = math.ceil(self.state.max_iters / self.state.epoch_length)
break

# Should exit while loop if we can not iterate
Expand Down
2 changes: 0 additions & 2 deletions ignite/engine/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,7 +451,6 @@ class State:
state.dataloader # data passed to engine
state.epoch_length # optional length of an epoch
state.max_epochs # number of epochs to run
state.max_iters # number of iterations to run
state.batch # batch passed to `process_function`
state.output # output of `process_function` after a single iteration
state.metrics # dictionary with defined metrics if any
Expand All @@ -478,7 +477,6 @@ def __init__(self, **kwargs: Any) -> None:
self.epoch = 0
self.epoch_length: Optional[int] = None
self.max_epochs: Optional[int] = None
self.max_iters: Optional[int] = None
self.output: Optional[int] = None
self.batch: Optional[int] = None
self.metrics: Dict[str, Any] = {}
Expand Down

0 comments on commit 2de056f

Please sign in to comment.