-
-
Notifications
You must be signed in to change notification settings - Fork 617
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
This reverts commit b829473.
- Loading branch information
Showing
5 changed files
with
304 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
import warnings | ||
|
||
from ignite.engine import EventEnum, Events, State | ||
|
||
|
||
class CustomPeriodicEvent: | ||
"""DEPRECATED. Use filtered events instead. | ||
Handler to define a custom periodic events as a number of elapsed iterations/epochs | ||
for an engine. | ||
When custom periodic event is created and attached to an engine, the following events are fired: | ||
1) K iterations is specified: | ||
- `Events.ITERATIONS_<K>_STARTED` | ||
- `Events.ITERATIONS_<K>_COMPLETED` | ||
1) K epochs is specified: | ||
- `Events.EPOCHS_<K>_STARTED` | ||
- `Events.EPOCHS_<K>_COMPLETED` | ||
Examples: | ||
.. code-block:: python | ||
from ignite.engine import Engine, Events | ||
from ignite.contrib.handlers import CustomPeriodicEvent | ||
# Let's define an event every 1000 iterations | ||
cpe1 = CustomPeriodicEvent(n_iterations=1000) | ||
cpe1.attach(trainer) | ||
# Let's define an event every 10 epochs | ||
cpe2 = CustomPeriodicEvent(n_epochs=10) | ||
cpe2.attach(trainer) | ||
@trainer.on(cpe1.Events.ITERATIONS_1000_COMPLETED) | ||
def on_every_1000_iterations(engine): | ||
# run a computation after 1000 iterations | ||
# ... | ||
print(engine.state.iterations_1000) | ||
@trainer.on(cpe2.Events.EPOCHS_10_STARTED) | ||
def on_every_10_epochs(engine): | ||
# run a computation every 10 epochs | ||
# ... | ||
print(engine.state.epochs_10) | ||
Args: | ||
n_iterations (int, optional): number iterations of the custom periodic event | ||
n_epochs (int, optional): number iterations of the custom periodic event. Argument is optional, but only one, | ||
either n_iterations or n_epochs should defined. | ||
""" | ||
|
||
def __init__(self, n_iterations=None, n_epochs=None): | ||
|
||
warnings.warn( | ||
"CustomPeriodicEvent is deprecated since 0.4.0 and will be removed in 0.5.0. Use filtered events instead.", | ||
DeprecationWarning, | ||
) | ||
|
||
if n_iterations is not None: | ||
if not isinstance(n_iterations, int): | ||
raise TypeError("Argument n_iterations should be an integer") | ||
if n_iterations < 1: | ||
raise ValueError("Argument n_iterations should be positive") | ||
|
||
if n_epochs is not None: | ||
if not isinstance(n_epochs, int): | ||
raise TypeError("Argument n_epochs should be an integer") | ||
if n_epochs < 1: | ||
raise ValueError("Argument n_epochs should be positive") | ||
|
||
if (n_iterations is None and n_epochs is None) or (n_iterations and n_epochs): | ||
raise ValueError("Either n_iterations or n_epochs should be defined") | ||
|
||
if n_iterations: | ||
prefix = "iterations" | ||
self.state_attr = "iteration" | ||
self.period = n_iterations | ||
|
||
if n_epochs: | ||
prefix = "epochs" | ||
self.state_attr = "epoch" | ||
self.period = n_epochs | ||
|
||
self.custom_state_attr = "{}_{}".format(prefix, self.period) | ||
event_name = "{}_{}".format(prefix.upper(), self.period) | ||
setattr( | ||
self, | ||
"Events", | ||
EventEnum("Events", " ".join(["{}_STARTED".format(event_name), "{}_COMPLETED".format(event_name)])), | ||
) | ||
|
||
# Update State.event_to_attr | ||
for e in self.Events: | ||
State.event_to_attr[e] = self.custom_state_attr | ||
|
||
# Create aliases | ||
self._periodic_event_started = getattr(self.Events, "{}_STARTED".format(event_name)) | ||
self._periodic_event_completed = getattr(self.Events, "{}_COMPLETED".format(event_name)) | ||
|
||
def _on_started(self, engine): | ||
setattr(engine.state, self.custom_state_attr, 0) | ||
|
||
def _on_periodic_event_started(self, engine): | ||
if getattr(engine.state, self.state_attr) % self.period == 1: | ||
setattr(engine.state, self.custom_state_attr, getattr(engine.state, self.custom_state_attr) + 1) | ||
engine.fire_event(self._periodic_event_started) | ||
|
||
def _on_periodic_event_completed(self, engine): | ||
if getattr(engine.state, self.state_attr) % self.period == 0: | ||
engine.fire_event(self._periodic_event_completed) | ||
|
||
def attach(self, engine): | ||
engine.register_events(*self.Events) | ||
|
||
engine.add_event_handler(Events.STARTED, self._on_started) | ||
engine.add_event_handler( | ||
getattr(Events, "{}_STARTED".format(self.state_attr.upper())), self._on_periodic_event_started | ||
) | ||
engine.add_event_handler( | ||
getattr(Events, "{}_COMPLETED".format(self.state_attr.upper())), self._on_periodic_event_completed | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,133 @@ | ||
import math | ||
|
||
import pytest | ||
|
||
from ignite.contrib.handlers.custom_events import CustomPeriodicEvent | ||
from ignite.engine import Engine | ||
|
||
|
||
def test_bad_input(): | ||
|
||
with pytest.warns(DeprecationWarning, match=r"CustomPeriodicEvent is deprecated"): | ||
with pytest.raises(TypeError, match="Argument n_iterations should be an integer"): | ||
CustomPeriodicEvent(n_iterations="a") | ||
with pytest.raises(ValueError, match="Argument n_iterations should be positive"): | ||
CustomPeriodicEvent(n_iterations=0) | ||
with pytest.raises(TypeError, match="Argument n_iterations should be an integer"): | ||
CustomPeriodicEvent(n_iterations=10.0) | ||
with pytest.raises(TypeError, match="Argument n_epochs should be an integer"): | ||
CustomPeriodicEvent(n_epochs="a") | ||
with pytest.raises(ValueError, match="Argument n_epochs should be positive"): | ||
CustomPeriodicEvent(n_epochs=0) | ||
with pytest.raises(TypeError, match="Argument n_epochs should be an integer"): | ||
CustomPeriodicEvent(n_epochs=10.0) | ||
with pytest.raises(ValueError, match="Either n_iterations or n_epochs should be defined"): | ||
CustomPeriodicEvent() | ||
with pytest.raises(ValueError, match="Either n_iterations or n_epochs should be defined"): | ||
CustomPeriodicEvent(n_iterations=1, n_epochs=2) | ||
|
||
|
||
def test_new_events(): | ||
def update(*args, **kwargs): | ||
pass | ||
|
||
with pytest.warns(DeprecationWarning, match="CustomPeriodicEvent is deprecated"): | ||
engine = Engine(update) | ||
cpe = CustomPeriodicEvent(n_iterations=5) | ||
cpe.attach(engine) | ||
|
||
assert hasattr(cpe, "Events") | ||
assert hasattr(cpe.Events, "ITERATIONS_5_STARTED") | ||
assert hasattr(cpe.Events, "ITERATIONS_5_COMPLETED") | ||
|
||
assert engine._allowed_events[-2] == getattr(cpe.Events, "ITERATIONS_5_STARTED") | ||
assert engine._allowed_events[-1] == getattr(cpe.Events, "ITERATIONS_5_COMPLETED") | ||
|
||
with pytest.warns(DeprecationWarning, match="CustomPeriodicEvent is deprecated"): | ||
cpe = CustomPeriodicEvent(n_epochs=5) | ||
cpe.attach(engine) | ||
|
||
assert hasattr(cpe, "Events") | ||
assert hasattr(cpe.Events, "EPOCHS_5_STARTED") | ||
assert hasattr(cpe.Events, "EPOCHS_5_COMPLETED") | ||
|
||
assert engine._allowed_events[-2] == getattr(cpe.Events, "EPOCHS_5_STARTED") | ||
assert engine._allowed_events[-1] == getattr(cpe.Events, "EPOCHS_5_COMPLETED") | ||
|
||
|
||
def test_integration_iterations(): | ||
def _test(n_iterations, max_epochs, n_iters_per_epoch): | ||
def update(*args, **kwargs): | ||
pass | ||
|
||
engine = Engine(update) | ||
with pytest.warns(DeprecationWarning, match="CustomPeriodicEvent is deprecated"): | ||
cpe = CustomPeriodicEvent(n_iterations=n_iterations) | ||
cpe.attach(engine) | ||
data = list(range(n_iters_per_epoch)) | ||
|
||
custom_period = [0] | ||
n_calls_iter_started = [0] | ||
n_calls_iter_completed = [0] | ||
|
||
event_started = getattr(cpe.Events, "ITERATIONS_{}_STARTED".format(n_iterations)) | ||
|
||
@engine.on(event_started) | ||
def on_my_event_started(engine): | ||
assert (engine.state.iteration - 1) % n_iterations == 0 | ||
custom_period[0] += 1 | ||
custom_iter = getattr(engine.state, "iterations_{}".format(n_iterations)) | ||
assert custom_iter == custom_period[0] | ||
n_calls_iter_started[0] += 1 | ||
|
||
event_completed = getattr(cpe.Events, "ITERATIONS_{}_COMPLETED".format(n_iterations)) | ||
|
||
@engine.on(event_completed) | ||
def on_my_event_ended(engine): | ||
assert engine.state.iteration % n_iterations == 0 | ||
custom_iter = getattr(engine.state, "iterations_{}".format(n_iterations)) | ||
assert custom_iter == custom_period[0] | ||
n_calls_iter_completed[0] += 1 | ||
|
||
engine.run(data, max_epochs=max_epochs) | ||
|
||
n = len(data) * max_epochs / n_iterations | ||
nf = math.floor(n) | ||
assert custom_period[0] == n_calls_iter_started[0] | ||
assert n_calls_iter_started[0] == nf + 1 if nf < n else nf | ||
assert n_calls_iter_completed[0] == nf | ||
|
||
_test(3, 5, 16) | ||
_test(4, 5, 16) | ||
_test(5, 5, 16) | ||
_test(300, 50, 1000) | ||
|
||
|
||
def test_integration_epochs(): | ||
def update(*args, **kwargs): | ||
pass | ||
|
||
engine = Engine(update) | ||
|
||
n_epochs = 3 | ||
with pytest.warns(DeprecationWarning, match="CustomPeriodicEvent is deprecated"): | ||
cpe = CustomPeriodicEvent(n_epochs=n_epochs) | ||
cpe.attach(engine) | ||
data = list(range(16)) | ||
|
||
custom_period = [1] | ||
|
||
@engine.on(cpe.Events.EPOCHS_3_STARTED) | ||
def on_my_epoch_started(engine): | ||
assert (engine.state.epoch - 1) % n_epochs == 0 | ||
assert engine.state.epochs_3 == custom_period[0] | ||
|
||
@engine.on(cpe.Events.EPOCHS_3_COMPLETED) | ||
def on_my_epoch_ended(engine): | ||
assert engine.state.epoch % n_epochs == 0 | ||
assert engine.state.epochs_3 == custom_period[0] | ||
custom_period[0] += 1 | ||
|
||
engine.run(data, max_epochs=10) | ||
|
||
assert custom_period[0] == 4 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters