Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding max_iters as an optional arg in Engine run #1381

Merged
merged 34 commits into from
Nov 9, 2020
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
f661556
initial draft, adding max_iters as optional args in run
thescripted Oct 11, 2020
0919761
fixed typo
thescripted Oct 11, 2020
0389374
minor bug fixes
thescripted Oct 11, 2020
8abbc63
resolving failing tests
thescripted Oct 11, 2020
fdd01a1
fixed out-of-place conditional
thescripted Oct 13, 2020
0db9761
typo fix
thescripted Oct 13, 2020
4bf0a1d
updated docstring for 'run'
thescripted Oct 18, 2020
8ba97cb
added initial tests
thescripted Oct 18, 2020
b52ec66
Merge branch 'master' of https://github.com/pytorch/ignite into featu…
thescripted Oct 22, 2020
da92e4b
(WIP) restructured creating a new state with max_iters
thescripted Oct 23, 2020
b40a893
updated tests & docstrings
thescripted Oct 23, 2020
4bba8cb
initial draft, adding max_iters as optional args in run
thescripted Oct 11, 2020
0d6a4fb
fixed typo
thescripted Oct 11, 2020
f353b05
minor bug fixes
thescripted Oct 11, 2020
548b334
resolving failing tests
thescripted Oct 11, 2020
2cb2393
fixed out-of-place conditional
thescripted Oct 13, 2020
4152796
typo fix
thescripted Oct 13, 2020
afb58fb
updated docstring for 'run'
thescripted Oct 18, 2020
c9eab3f
added initial tests
thescripted Oct 18, 2020
9f4fc06
(WIP) restructured creating a new state with max_iters
thescripted Oct 23, 2020
f29267b
updated tests & docstrings
thescripted Oct 23, 2020
4896079
added test to check _is_done
thescripted Oct 26, 2020
ea96a65
updating engine loop condition
thescripted Oct 26, 2020
d49cd50
Merge branch 'feature.max_iters' of https://github.com/thescripted/ig…
thescripted Oct 26, 2020
87fc6dd
autopep8 fix
thescripted Oct 26, 2020
c17ab6c
linting issues
thescripted Oct 26, 2020
7ea52e8
Merge branch 'feature.max_iters' of https://github.com/thescripted/ig…
thescripted Oct 26, 2020
2a4efa9
Merge branch 'master' into feature.max_iters
vfdev-5 Oct 31, 2020
a2796cb
Merge branch 'master' into feature.max_iters
thescripted Nov 3, 2020
1d1416b
fixed mypy errors
thescripted Nov 3, 2020
e019c0f
fixed formatting
thescripted Nov 3, 2020
7444964
minor fix & add test for larger max_iters
thescripted Nov 9, 2020
25c4aa1
Merge branch 'master' into feature.max_iters
thescripted Nov 9, 2020
c179b9d
removed unused typechecking
thescripted Nov 9, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 33 additions & 6 deletions ignite/engine/engine.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import functools
import logging
import math
import time
import warnings
import weakref
Expand Down Expand Up @@ -508,7 +509,7 @@ def load_state_dict(self, state_dict: Mapping) -> None:
`seed`. If `engine.state_dict_user_keys` contains keys, they should be also present in the state dictionary.
Iteration and epoch values are 0-based: the first iteration or epoch is zero.

This method does not remove any custom attributs added by user.
This method does not remove any custom attributes added by user.

Args:
state_dict (Mapping): a dict with parameters
Expand Down Expand Up @@ -597,15 +598,16 @@ def run(
self,
data: Iterable,
max_epochs: Optional[int] = None,
max_iters: Optional[int] = None,
vfdev-5 marked this conversation as resolved.
Show resolved Hide resolved
epoch_length: Optional[int] = None,
seed: Optional[int] = None,
) -> State:
"""Runs the `process_function` over the passed data.

Engine has a state and the following logic is applied in this function:

- At the first call, new state is defined by `max_epochs`, `epoch_length`, `seed` if provided. A timer for
total and per-epoch time is initialized when Events.STARTED is handled.
- At the first call, new state is defined by `max_epochs`, `max_iters`, `epoch_length`, `seed` if provided.
thescripted marked this conversation as resolved.
Show resolved Hide resolved
A timer for total and per-epoch time is initialized when Events.STARTED is handled.
- If state is already defined such that there are iterations to run until `max_epochs` and no input arguments
provided, state is kept and used in the function.
- If state is defined and engine is "done" (no iterations to run until `max_epochs`), a new state is defined.
Expand All @@ -621,6 +623,9 @@ def run(
`len(data)`. If `data` is an iterator and `epoch_length` is not set, then it will be automatically
determined as the iteration on which data iterator raises `StopIteration`.
This argument should not change if run is resuming from a state.
max_iters (int, optional): Max Number of iterations to run for. By default, it will be calculated as
max_epochs * epoch_length. If `data` is an iterator and `max_iters` is not set, then it will be
automatically determined along with `epoch_length` on which data iterator raises `StopIteration`.
thescripted marked this conversation as resolved.
Show resolved Hide resolved
seed (int, optional): Deprecated argument. Please, use `torch.manual_seed` or
:meth:`~ignite.utils.manual_seed`.

Expand Down Expand Up @@ -673,6 +678,7 @@ def switch_batch(engine):
)
)
self.state.max_epochs = max_epochs
self.state.max_iters = max_epochs * self.state.epoch_length
thescripted marked this conversation as resolved.
Show resolved Hide resolved
if epoch_length is not None:
if epoch_length != self.state.epoch_length:
raise ValueError(
Expand All @@ -683,16 +689,28 @@ def switch_batch(engine):

if self.state.max_epochs is None or self._is_done(self.state):
# Create new state
if max_epochs is None:
max_epochs = 1
if epoch_length is None:
epoch_length = self._get_data_length(data)
if epoch_length is not None and epoch_length < 1:
raise ValueError("Input data has zero size. Please provide non-empty data")

if epoch_length is not None:
thescripted marked this conversation as resolved.
Show resolved Hide resolved
if max_iters is None:
if max_epochs is None:
max_epochs = 1
max_iters = max_epochs * epoch_length
else:
if max_epochs is None:
max_epochs = math.ceil(max_iters / epoch_length)

else:
if max_epochs is None:
max_epochs = 1

self.state.iteration = 0
self.state.epoch = 0
self.state.max_epochs = max_epochs
self.state.max_iters = max_iters
self.state.epoch_length = epoch_length
self.logger.info("Engine run starting with max_epochs={}.".format(max_epochs))
else:
Expand Down Expand Up @@ -748,8 +766,11 @@ def _internal_run(self) -> State:
handlers_start_time = time.time()
if self.should_terminate:
self._fire_event(Events.TERMINATE)
else:
elif self.state.iteration % self.state.epoch_length == 0:
thescripted marked this conversation as resolved.
Show resolved Hide resolved
self._fire_event(Events.EPOCH_COMPLETED)
elif self.state.max_iters == self.state.iteration:
self._fire_event(Events.TERMINATE)

time_taken += time.time() - handlers_start_time
# update time wrt handlers
self.state.times[Events.EPOCH_COMPLETED.name] = time_taken
Expand Down Expand Up @@ -800,6 +821,8 @@ def _run_once_on_dataset(self) -> float:
if self.state.epoch_length is None:
# Define epoch length and stop the epoch
self.state.epoch_length = iter_counter
if self.state.max_iters is None:
thescripted marked this conversation as resolved.
Show resolved Hide resolved
self.state.max_iters = self.state.epoch_length * self.state.max_epochs
break

# Should exit while loop if we can not iterate
Expand Down Expand Up @@ -838,6 +861,10 @@ def _run_once_on_dataset(self) -> float:
if self.state.epoch_length is not None and iter_counter == self.state.epoch_length:
break

if self.state.max_iters is not None and self.state.iteration == self.state.max_iters:
self.should_terminate = True
break

except Exception as e:
self.logger.error("Current run is terminating due to exception: %s.", str(e))
self._handle_exception(e)
Expand Down
2 changes: 2 additions & 0 deletions ignite/engine/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,7 @@ class State:
state.dataloader # data passed to engine
state.epoch_length # optional length of an epoch
state.max_epochs # number of epochs to run
state.max_iter # number of iterations to run
state.batch # batch passed to `process_function`
state.output # output of `process_function` after a single iteration
state.metrics # dictionary with defined metrics if any
Expand All @@ -374,6 +375,7 @@ def __init__(self, **kwargs):
self.epoch = 0
self.epoch_length = None
self.max_epochs = None
self.max_iters = None
self.output = None
self.batch = None
self.metrics = {}
Expand Down
28 changes: 28 additions & 0 deletions tests/ignite/engine/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -894,3 +894,31 @@ def switch_dataloader():
trainer.set_data(data2)

trainer.run(data1, max_epochs=10)


def test_run_with_max_iters():
max_iters = 8
thescripted marked this conversation as resolved.
Show resolved Hide resolved
engine = Engine(lambda e, b: 1)
engine.run([0] * 20, max_iters=max_iters)
assert engine.state.iteration == max_iters
assert engine.state.max_iters == max_iters


def test_run_with_max_iters_and_max_epoch():
max_iters = 12
max_epochs = 2
engine = Engine(lambda e, b: 1)
engine.run([0] * 20, max_iters=max_iters, max_epochs=max_epochs)
assert engine.state.iteration == max_iters
assert engine.state.max_epochs == max_epochs


def test_epoch_events_fired():
max_iters = 32
engine = Engine(lambda e, b: 1)

@engine.on(Events.EPOCH_COMPLETED)
def fired_event(engine):
assert engine.state.iteration % engine.state.epoch_length == 0

engine.run([0] * 10, max_iters=max_iters)