Skip to content

Commit

Permalink
Merge branch 'master' into device-index-warning
Browse files Browse the repository at this point in the history
  • Loading branch information
sdesrozis committed Oct 2, 2020
2 parents 90a4d02 + a6664de commit 2f0428b
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 1 deletion.
17 changes: 16 additions & 1 deletion ignite/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -639,6 +639,18 @@ def run(
def switch_batch(engine):
engine.state.batch = preprocess_batch(engine.state.batch)
Restart the training from the beginning. User can reset `max_epochs = None` or either call
`trainer.state.restart()`:
.. code-block:: python
# ...
trainer.run(train_loader, max_epochs=5)
# Reset model weights etc. and restart the training
trainer.state.restart() # equivalent to trainer.state.max_epochs = None
trainer.run(train_loader, max_epochs=2)
"""
if seed is not None:
warnings.warn(
Expand All @@ -655,7 +667,10 @@ def switch_batch(engine):
if max_epochs < self.state.epoch:
raise ValueError(
"Argument max_epochs should be larger than the start epoch "
"defined in the state: {} vs {}".format(max_epochs, self.state.epoch)
"defined in the state: {} vs {}. Please, call state.restart() "
"before calling engine.run() in order to restart the training from the beginning.".format(
max_epochs, self.state.epoch
)
)
self.state.max_epochs = max_epochs
if epoch_length is not None:
Expand Down
3 changes: 3 additions & 0 deletions ignite/engine/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,9 @@ def get_event_attrib_value(self, event_name: Union[CallableEventWithFilter, Enum
raise RuntimeError("Unknown event name '{}'".format(event_name))
return getattr(self, State.event_to_attr[event_name])

def restart(self) -> None:
self.max_epochs = None

def __repr__(self) -> str:
s = "State:\n"
for attr, value in self.__dict__.items():
Expand Down
15 changes: 15 additions & 0 deletions tests/ignite/engine/test_engine_state_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,3 +267,18 @@ def check_custom_attr():

_test()
_test(with_load_state_dict=True)


def test_restart_training():
data = range(10)
engine = Engine(lambda e, b: 1)
state = engine.run(data, max_epochs=5)
with pytest.raises(
ValueError,
match=r"Argument max_epochs should be larger than the start epoch defined in the state: 2 vs 5. "
r"Please, call state.restart\(\) "
r"before calling engine.run\(\) in order to restart the training from the beginning.",
):
state = engine.run(data, max_epochs=2)
state.restart()
engine.run(data, max_epochs=2)

0 comments on commit 2f0428b

Please sign in to comment.