Skip to content

Commit

Permalink
Revert "Issue #1247 (#1252)"
Browse files Browse the repository at this point in the history
This reverts commit b829473.
  • Loading branch information
vfdev-5 committed Jan 17, 2022
1 parent a711700 commit 5ab2eb5
Show file tree
Hide file tree
Showing 5 changed files with 304 additions and 1 deletion.
5 changes: 5 additions & 0 deletions ignite/contrib/handlers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
from ignite.contrib.handlers.clearml_logger import ClearMLLogger
<<<<<<< HEAD
=======
from ignite.contrib.handlers.custom_events import CustomPeriodicEvent
from ignite.contrib.handlers.lr_finder import FastaiLRFinder
>>>>>>> Revert "Issue #1247 (#1252)"
from ignite.contrib.handlers.mlflow_logger import MLflowLogger
from ignite.contrib.handlers.neptune_logger import NeptuneLogger
from ignite.contrib.handlers.polyaxon_logger import PolyaxonLogger
Expand Down
125 changes: 125 additions & 0 deletions ignite/contrib/handlers/custom_events.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
import warnings

from ignite.engine import EventEnum, Events, State


class CustomPeriodicEvent:
"""DEPRECATED. Use filtered events instead.
Handler to define a custom periodic events as a number of elapsed iterations/epochs
for an engine.
When custom periodic event is created and attached to an engine, the following events are fired:
1) K iterations is specified:
- `Events.ITERATIONS_<K>_STARTED`
- `Events.ITERATIONS_<K>_COMPLETED`
1) K epochs is specified:
- `Events.EPOCHS_<K>_STARTED`
- `Events.EPOCHS_<K>_COMPLETED`
Examples:
.. code-block:: python
from ignite.engine import Engine, Events
from ignite.contrib.handlers import CustomPeriodicEvent
# Let's define an event every 1000 iterations
cpe1 = CustomPeriodicEvent(n_iterations=1000)
cpe1.attach(trainer)
# Let's define an event every 10 epochs
cpe2 = CustomPeriodicEvent(n_epochs=10)
cpe2.attach(trainer)
@trainer.on(cpe1.Events.ITERATIONS_1000_COMPLETED)
def on_every_1000_iterations(engine):
# run a computation after 1000 iterations
# ...
print(engine.state.iterations_1000)
@trainer.on(cpe2.Events.EPOCHS_10_STARTED)
def on_every_10_epochs(engine):
# run a computation every 10 epochs
# ...
print(engine.state.epochs_10)
Args:
n_iterations (int, optional): number iterations of the custom periodic event
n_epochs (int, optional): number iterations of the custom periodic event. Argument is optional, but only one,
either n_iterations or n_epochs should defined.
"""

def __init__(self, n_iterations=None, n_epochs=None):

warnings.warn(
"CustomPeriodicEvent is deprecated since 0.4.0 and will be removed in 0.5.0. Use filtered events instead.",
DeprecationWarning,
)

if n_iterations is not None:
if not isinstance(n_iterations, int):
raise TypeError("Argument n_iterations should be an integer")
if n_iterations < 1:
raise ValueError("Argument n_iterations should be positive")

if n_epochs is not None:
if not isinstance(n_epochs, int):
raise TypeError("Argument n_epochs should be an integer")
if n_epochs < 1:
raise ValueError("Argument n_epochs should be positive")

if (n_iterations is None and n_epochs is None) or (n_iterations and n_epochs):
raise ValueError("Either n_iterations or n_epochs should be defined")

if n_iterations:
prefix = "iterations"
self.state_attr = "iteration"
self.period = n_iterations

if n_epochs:
prefix = "epochs"
self.state_attr = "epoch"
self.period = n_epochs

self.custom_state_attr = "{}_{}".format(prefix, self.period)
event_name = "{}_{}".format(prefix.upper(), self.period)
setattr(
self,
"Events",
EventEnum("Events", " ".join(["{}_STARTED".format(event_name), "{}_COMPLETED".format(event_name)])),
)

# Update State.event_to_attr
for e in self.Events:
State.event_to_attr[e] = self.custom_state_attr

# Create aliases
self._periodic_event_started = getattr(self.Events, "{}_STARTED".format(event_name))
self._periodic_event_completed = getattr(self.Events, "{}_COMPLETED".format(event_name))

def _on_started(self, engine):
setattr(engine.state, self.custom_state_attr, 0)

def _on_periodic_event_started(self, engine):
if getattr(engine.state, self.state_attr) % self.period == 1:
setattr(engine.state, self.custom_state_attr, getattr(engine.state, self.custom_state_attr) + 1)
engine.fire_event(self._periodic_event_started)

def _on_periodic_event_completed(self, engine):
if getattr(engine.state, self.state_attr) % self.period == 0:
engine.fire_event(self._periodic_event_completed)

def attach(self, engine):
engine.register_events(*self.Events)

engine.add_event_handler(Events.STARTED, self._on_started)
engine.add_event_handler(
getattr(Events, "{}_STARTED".format(self.state_attr.upper())), self._on_periodic_event_started
)
engine.add_event_handler(
getattr(Events, "{}_COMPLETED".format(self.state_attr.upper())), self._on_periodic_event_completed
)
29 changes: 29 additions & 0 deletions tests/ignite/contrib/handlers/test_base_logger.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from unittest.mock import MagicMock, call

import math
import pytest
import torch

from ignite.contrib.handlers import CustomPeriodicEvent
from ignite.contrib.handlers.base_logger import BaseLogger, BaseOptimizerParamsHandler, BaseOutputHandler
from ignite.engine import Engine, Events, EventsList, State
from tests.ignite.contrib.handlers import MockFP16DeepSpeedZeroOptimizer
Expand Down Expand Up @@ -242,6 +244,33 @@ def update_fn(engine, batch):
mock_log_handler.assert_called_with(trainer, logger, event)
assert mock_log_handler.call_count == n_calls

with pytest.warns(DeprecationWarning, match="CustomPeriodicEvent is deprecated"):
n_iterations = 10
cpe1 = CustomPeriodicEvent(n_iterations=n_iterations)
n = len(data) * n_epochs / n_iterations
nf = math.floor(n)
ns = nf + 1 if nf < n else nf
_test(cpe1.Events.ITERATIONS_10_STARTED, ns, cpe1)
_test(cpe1.Events.ITERATIONS_10_COMPLETED, nf, cpe1)

with pytest.warns(DeprecationWarning, match="CustomPeriodicEvent is deprecated"):
n_iterations = 15
cpe2 = CustomPeriodicEvent(n_iterations=n_iterations)
n = len(data) * n_epochs / n_iterations
nf = math.floor(n)
ns = nf + 1 if nf < n else nf
_test(cpe2.Events.ITERATIONS_15_STARTED, ns, cpe2)
_test(cpe2.Events.ITERATIONS_15_COMPLETED, nf, cpe2)

with pytest.warns(DeprecationWarning, match="CustomPeriodicEvent is deprecated"):
n_custom_epochs = 2
cpe3 = CustomPeriodicEvent(n_epochs=n_custom_epochs)
n = n_epochs / n_custom_epochs
nf = math.floor(n)
ns = nf + 1 if nf < n else nf
_test(cpe3.Events.EPOCHS_2_STARTED, ns, cpe3)
_test(cpe3.Events.EPOCHS_2_COMPLETED, nf, cpe3)


def test_as_context_manager():

Expand Down
133 changes: 133 additions & 0 deletions tests/ignite/contrib/handlers/test_custom_events.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
import math

import pytest

from ignite.contrib.handlers.custom_events import CustomPeriodicEvent
from ignite.engine import Engine


def test_bad_input():

with pytest.warns(DeprecationWarning, match=r"CustomPeriodicEvent is deprecated"):
with pytest.raises(TypeError, match="Argument n_iterations should be an integer"):
CustomPeriodicEvent(n_iterations="a")
with pytest.raises(ValueError, match="Argument n_iterations should be positive"):
CustomPeriodicEvent(n_iterations=0)
with pytest.raises(TypeError, match="Argument n_iterations should be an integer"):
CustomPeriodicEvent(n_iterations=10.0)
with pytest.raises(TypeError, match="Argument n_epochs should be an integer"):
CustomPeriodicEvent(n_epochs="a")
with pytest.raises(ValueError, match="Argument n_epochs should be positive"):
CustomPeriodicEvent(n_epochs=0)
with pytest.raises(TypeError, match="Argument n_epochs should be an integer"):
CustomPeriodicEvent(n_epochs=10.0)
with pytest.raises(ValueError, match="Either n_iterations or n_epochs should be defined"):
CustomPeriodicEvent()
with pytest.raises(ValueError, match="Either n_iterations or n_epochs should be defined"):
CustomPeriodicEvent(n_iterations=1, n_epochs=2)


def test_new_events():
def update(*args, **kwargs):
pass

with pytest.warns(DeprecationWarning, match="CustomPeriodicEvent is deprecated"):
engine = Engine(update)
cpe = CustomPeriodicEvent(n_iterations=5)
cpe.attach(engine)

assert hasattr(cpe, "Events")
assert hasattr(cpe.Events, "ITERATIONS_5_STARTED")
assert hasattr(cpe.Events, "ITERATIONS_5_COMPLETED")

assert engine._allowed_events[-2] == getattr(cpe.Events, "ITERATIONS_5_STARTED")
assert engine._allowed_events[-1] == getattr(cpe.Events, "ITERATIONS_5_COMPLETED")

with pytest.warns(DeprecationWarning, match="CustomPeriodicEvent is deprecated"):
cpe = CustomPeriodicEvent(n_epochs=5)
cpe.attach(engine)

assert hasattr(cpe, "Events")
assert hasattr(cpe.Events, "EPOCHS_5_STARTED")
assert hasattr(cpe.Events, "EPOCHS_5_COMPLETED")

assert engine._allowed_events[-2] == getattr(cpe.Events, "EPOCHS_5_STARTED")
assert engine._allowed_events[-1] == getattr(cpe.Events, "EPOCHS_5_COMPLETED")


def test_integration_iterations():
def _test(n_iterations, max_epochs, n_iters_per_epoch):
def update(*args, **kwargs):
pass

engine = Engine(update)
with pytest.warns(DeprecationWarning, match="CustomPeriodicEvent is deprecated"):
cpe = CustomPeriodicEvent(n_iterations=n_iterations)
cpe.attach(engine)
data = list(range(n_iters_per_epoch))

custom_period = [0]
n_calls_iter_started = [0]
n_calls_iter_completed = [0]

event_started = getattr(cpe.Events, "ITERATIONS_{}_STARTED".format(n_iterations))

@engine.on(event_started)
def on_my_event_started(engine):
assert (engine.state.iteration - 1) % n_iterations == 0
custom_period[0] += 1
custom_iter = getattr(engine.state, "iterations_{}".format(n_iterations))
assert custom_iter == custom_period[0]
n_calls_iter_started[0] += 1

event_completed = getattr(cpe.Events, "ITERATIONS_{}_COMPLETED".format(n_iterations))

@engine.on(event_completed)
def on_my_event_ended(engine):
assert engine.state.iteration % n_iterations == 0
custom_iter = getattr(engine.state, "iterations_{}".format(n_iterations))
assert custom_iter == custom_period[0]
n_calls_iter_completed[0] += 1

engine.run(data, max_epochs=max_epochs)

n = len(data) * max_epochs / n_iterations
nf = math.floor(n)
assert custom_period[0] == n_calls_iter_started[0]
assert n_calls_iter_started[0] == nf + 1 if nf < n else nf
assert n_calls_iter_completed[0] == nf

_test(3, 5, 16)
_test(4, 5, 16)
_test(5, 5, 16)
_test(300, 50, 1000)


def test_integration_epochs():
def update(*args, **kwargs):
pass

engine = Engine(update)

n_epochs = 3
with pytest.warns(DeprecationWarning, match="CustomPeriodicEvent is deprecated"):
cpe = CustomPeriodicEvent(n_epochs=n_epochs)
cpe.attach(engine)
data = list(range(16))

custom_period = [1]

@engine.on(cpe.Events.EPOCHS_3_STARTED)
def on_my_epoch_started(engine):
assert (engine.state.epoch - 1) % n_epochs == 0
assert engine.state.epochs_3 == custom_period[0]

@engine.on(cpe.Events.EPOCHS_3_COMPLETED)
def on_my_epoch_ended(engine):
assert engine.state.epoch % n_epochs == 0
assert engine.state.epochs_3 == custom_period[0]
custom_period[0] += 1

engine.run(data, max_epochs=10)

assert custom_period[0] == 4
13 changes: 12 additions & 1 deletion tests/ignite/contrib/handlers/test_tqdm_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import pytest
import torch

from ignite.contrib.handlers import ProgressBar
from ignite.contrib.handlers import CustomPeriodicEvent, ProgressBar
from ignite.engine import Engine, Events
from ignite.handlers import TerminateOnNan
from ignite.metrics import RunningAverage
Expand Down Expand Up @@ -475,6 +475,17 @@ def test_pbar_wrong_events_order():
pbar.attach(engine, event_name=Events.ITERATION_STARTED, closing_event_name=Events.EPOCH_COMPLETED(every=10))


def test_pbar_on_custom_events(capsys):

engine = Engine(update_fn)
pbar = ProgressBar()
with pytest.warns(DeprecationWarning, match="CustomPeriodicEvent is deprecated"):
cpe = CustomPeriodicEvent(n_iterations=15)

with pytest.raises(ValueError, match=r"not in allowed events for this engine"):
pbar.attach(engine, event_name=cpe.Events.ITERATIONS_15_COMPLETED, closing_event_name=Events.EPOCH_COMPLETED)


def test_pbar_with_nan_input():
def update(engine, batch):
x = batch
Expand Down

0 comments on commit 5ab2eb5

Please sign in to comment.