Skip to content

Commit

Permalink
Adding max_iters as an optional arg in Engine run (#1381)
Browse files Browse the repository at this point in the history
* initial draft, adding max_iters as optional args in run

* fixed typo

* minor bug fixes

* resolving failing tests

* fixed out-of-place conditional

* typo fix

* updated docstring for 'run'

* added initial tests

* (WIP) restructured creating a new state with max_iters

* updated tests & docstrings

* initial draft, adding max_iters as optional args in run

* fixed typo

* minor bug fixes

* resolving failing tests

* fixed out-of-place conditional

* typo fix

* updated docstring for 'run'

* added initial tests

* (WIP) restructured creating a new state with max_iters

* updated tests & docstrings

* added test to check _is_done

* updating engine loop condition

* autopep8 fix

* linting issues

* fixed mypy errors

* fixed formatting

* minor fix & add test for larger max_iters

* removed unused typechecking

Co-authored-by: thescripted <[email protected]>
Co-authored-by: vfdev <[email protected]>
  • Loading branch information
3 people authored Nov 9, 2020
1 parent 1296c74 commit 307ac11
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 8 deletions.
48 changes: 40 additions & 8 deletions ignite/engine/engine.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import functools
import logging
import math
import time
import warnings
import weakref
Expand Down Expand Up @@ -510,7 +511,7 @@ def load_state_dict(self, state_dict: Mapping) -> None:
If `engine.state_dict_user_keys` contains keys, they should be also present in the state dictionary.
Iteration and epoch values are 0-based: the first iteration or epoch is zero.
This method does not remove any custom attributs added by user.
This method does not remove any custom attributes added by user.
Args:
state_dict (Mapping): a dict with parameters
Expand Down Expand Up @@ -557,7 +558,13 @@ def load_state_dict(self, state_dict: Mapping) -> None:

@staticmethod
def _is_done(state: State) -> bool:
return state.iteration == state.epoch_length * state.max_epochs # type: ignore[operator]
is_done_iters = state.max_iters is not None and state.iteration >= state.max_iters
is_done_count = (
state.epoch_length is not None
and state.iteration >= state.epoch_length * state.max_epochs # type: ignore[operator]
)
is_done_epochs = state.max_epochs is not None and state.epoch >= state.max_epochs
return is_done_iters or is_done_count or is_done_epochs

def set_data(self, data: Union[Iterable, DataLoader]) -> None:
"""Method to set data. After calling the method the next batch passed to `processing_function` is
Expand Down Expand Up @@ -595,13 +602,19 @@ def switch_dataloader():
self.state.dataloader = data
self._dataloader_iter = iter(self.state.dataloader)

def run(self, data: Iterable, max_epochs: Optional[int] = None, epoch_length: Optional[int] = None,) -> State:
def run(
self,
data: Iterable,
max_epochs: Optional[int] = None,
max_iters: Optional[int] = None,
epoch_length: Optional[int] = None,
) -> State:
"""Runs the `process_function` over the passed data.
Engine has a state and the following logic is applied in this function:
- At the first call, new state is defined by `max_epochs`, `epoch_length` if provided. A timer for
total and per-epoch time is initialized when Events.STARTED is handled.
- At the first call, new state is defined by `max_epochs`, `max_iters`, `epoch_length`, if provided.
A timer for total and per-epoch time is initialized when Events.STARTED is handled.
- If state is already defined such that there are iterations to run until `max_epochs` and no input arguments
provided, state is kept and used in the function.
- If state is defined and engine is "done" (no iterations to run until `max_epochs`), a new state is defined.
Expand All @@ -617,6 +630,8 @@ def run(self, data: Iterable, max_epochs: Optional[int] = None, epoch_length: Op
`len(data)`. If `data` is an iterator and `epoch_length` is not set, then it will be automatically
determined as the iteration on which data iterator raises `StopIteration`.
This argument should not change if run is resuming from a state.
max_iters (int, optional): Number of iterations to run for.
`max_iters` and `max_epochs` are mutually exclusive; only one of the two arguments should be provided.
Returns:
State: output state.
Expand Down Expand Up @@ -670,16 +685,27 @@ def switch_batch(engine):

if self.state.max_epochs is None or self._is_done(self.state):
# Create new state
if max_epochs is None:
max_epochs = 1
if epoch_length is None:
epoch_length = self._get_data_length(data)
if epoch_length is not None and epoch_length < 1:
raise ValueError("Input data has zero size. Please provide non-empty data")

if max_iters is None:
if max_epochs is None:
max_epochs = 1
else:
if max_epochs is not None:
raise ValueError(
"Arguments max_iters and max_epochs are mutually exclusive."
"Please provide only max_epochs or max_iters."
)
if epoch_length is not None:
max_epochs = math.ceil(max_iters / epoch_length)

self.state.iteration = 0
self.state.epoch = 0
self.state.max_epochs = max_epochs
self.state.max_iters = max_iters
self.state.epoch_length = epoch_length
self.logger.info("Engine run starting with max_epochs={}.".format(max_epochs))
else:
Expand Down Expand Up @@ -726,7 +752,7 @@ def _internal_run(self) -> State:
try:
start_time = time.time()
self._fire_event(Events.STARTED)
while self.state.epoch < self.state.max_epochs and not self.should_terminate: # type: ignore[operator]
while not self._is_done(self.state) and not self.should_terminate:
self.state.epoch += 1
self._fire_event(Events.EPOCH_STARTED)

Expand Down Expand Up @@ -800,6 +826,8 @@ def _run_once_on_dataset(self) -> float:
if self.state.epoch_length is None:
# Define epoch length and stop the epoch
self.state.epoch_length = iter_counter
if self.state.max_iters is not None:
self.state.max_epochs = math.ceil(self.state.max_iters / self.state.epoch_length)
break

# Should exit while loop if we can not iterate
Expand Down Expand Up @@ -839,6 +867,10 @@ def _run_once_on_dataset(self) -> float:
if self.state.epoch_length is not None and iter_counter == self.state.epoch_length:
break

if self.state.max_iters is not None and self.state.iteration == self.state.max_iters:
self.should_terminate = True
break

except Exception as e:
self.logger.error("Current run is terminating due to exception: %s.", str(e))
self._handle_exception(e)
Expand Down
2 changes: 2 additions & 0 deletions ignite/engine/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,7 @@ class State:
state.dataloader # data passed to engine
state.epoch_length # optional length of an epoch
state.max_epochs # number of epochs to run
state.max_iter # number of iterations to run
state.batch # batch passed to `process_function`
state.output # output of `process_function` after a single iteration
state.metrics # dictionary with defined metrics if any
Expand All @@ -368,6 +369,7 @@ def __init__(self, **kwargs: Any) -> None:
self.epoch = 0
self.epoch_length = None # type: Optional[int]
self.max_epochs = None # type: Optional[int]
self.max_iters = None # type: Optional[int]
self.output = None # type: Optional[int]
self.batch = None # type: Optional[int]
self.metrics = {} # type: Dict[str, Any]
Expand Down
46 changes: 46 additions & 0 deletions tests/ignite/engine/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -891,3 +891,49 @@ def switch_dataloader():
trainer.set_data(data2)

trainer.run(data1, max_epochs=10)


def test_run_with_max_iters():
max_iters = 8
engine = Engine(lambda e, b: 1)
engine.run([0] * 20, max_iters=max_iters)
assert engine.state.iteration == max_iters
assert engine.state.max_iters == max_iters


def test_run_with_max_iters_greater_than_epoch_length():
max_iters = 73
engine = Engine(lambda e, b: 1)
engine.run([0] * 20, max_iters=max_iters)
assert engine.state.iteration == max_iters


def test_run_with_invalid_max_iters_and_max_epoch():
max_iters = 12
max_epochs = 2
engine = Engine(lambda e, b: 1)
with pytest.raises(
ValueError,
match=r"Arguments max_iters and max_epochs are mutually exclusive."
"Please provide only max_epochs or max_iters.",
):
engine.run([0] * 20, max_iters=max_iters, max_epochs=max_epochs)


def test_epoch_events_fired():
max_iters = 32
engine = Engine(lambda e, b: 1)

@engine.on(Events.EPOCH_COMPLETED)
def fired_event(engine):
assert engine.state.iteration % engine.state.epoch_length == 0

engine.run([0] * 10, max_iters=max_iters)


def test_is_done_with_max_iters():
state = State(iteration=100, epoch=1, max_epochs=3, epoch_length=100, max_iters=250)
assert not Engine._is_done(state)

state = State(iteration=250, epoch=1, max_epochs=3, epoch_length=100, max_iters=250)
assert Engine._is_done(state)

0 comments on commit 307ac11

Please sign in to comment.