Skip to content

Commit

Permalink
Removed state.restart method (#1385)
Browse files Browse the repository at this point in the history
  • Loading branch information
vfdev-5 authored Oct 14, 2020
1 parent 97ee8f5 commit 260017d
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 9 deletions.
7 changes: 3 additions & 4 deletions ignite/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -639,16 +639,15 @@ def run(
def switch_batch(engine):
engine.state.batch = preprocess_batch(engine.state.batch)
Restart the training from the beginning. User can reset `max_epochs = None` or either call
`trainer.state.restart()`:
Restart the training from the beginning. User can reset `max_epochs = None`:
.. code-block:: python
# ...
trainer.run(train_loader, max_epochs=5)
# Reset model weights etc. and restart the training
trainer.state.restart() # equivalent to trainer.state.max_epochs = None
trainer.state.max_epochs = None
trainer.run(train_loader, max_epochs=2)
"""
Expand All @@ -667,7 +666,7 @@ def switch_batch(engine):
if max_epochs < self.state.epoch:
raise ValueError(
"Argument max_epochs should be larger than the start epoch "
"defined in the state: {} vs {}. Please, call state.restart() "
"defined in the state: {} vs {}. Please, set engine.state.max_epochs = None "
"before calling engine.run() in order to restart the training from the beginning.".format(
max_epochs, self.state.epoch
)
Expand Down
3 changes: 0 additions & 3 deletions ignite/engine/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,9 +396,6 @@ def get_event_attrib_value(self, event_name: Union[CallableEventWithFilter, Enum
raise RuntimeError("Unknown event name '{}'".format(event_name))
return getattr(self, State.event_to_attr[event_name])

def restart(self) -> None:
self.max_epochs = None

def __repr__(self) -> str:
s = "State:\n"
for attr, value in self.__dict__.items():
Expand Down
4 changes: 2 additions & 2 deletions tests/ignite/engine/test_engine_state_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,9 +276,9 @@ def test_restart_training():
with pytest.raises(
ValueError,
match=r"Argument max_epochs should be larger than the start epoch defined in the state: 2 vs 5. "
r"Please, call state.restart\(\) "
r"Please, .+ "
r"before calling engine.run\(\) in order to restart the training from the beginning.",
):
state = engine.run(data, max_epochs=2)
state.restart()
state.max_epochs = None
engine.run(data, max_epochs=2)

0 comments on commit 260017d

Please sign in to comment.