Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added debug mode for Engine #2851

Open
wants to merge 47 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
8e73cfd
Update engine.py to resolve #1992
puhuk Feb 6, 2023
b840aa2
autopep8 fix
puhuk Feb 6, 2023
a1ee3ad
Update engine.py
puhuk Feb 6, 2023
4c05843
Merge branch '1992' of https://github.com/puhuk/ignite into 1992
puhuk Feb 7, 2023
32850f3
Correct file and test code
puhuk Feb 7, 2023
7cea859
autopep8 fix
puhuk Feb 7, 2023
36ecbb9
Update test_engine.py
puhuk Feb 13, 2023
b2e5399
Merge branch '1992' of https://github.com/puhuk/ignite into 1992
puhuk Feb 13, 2023
4448d52
Delete test_debug.py
puhuk Feb 13, 2023
2a630db
Update
puhuk Feb 13, 2023
f99fd74
Update engine.py
puhuk Feb 13, 2023
a36842f
Update engine.py
puhuk Feb 13, 2023
ab7d751
update
puhuk Feb 16, 2023
e7754e3
autopep8 fix
puhuk Feb 16, 2023
d685df5
Update engine.py
puhuk Feb 17, 2023
01aba8d
Merge branch '1992' of https://github.com/puhuk/ignite into 1992
puhuk Feb 17, 2023
44d39aa
Update test_engine.py
puhuk Feb 17, 2023
46017f0
Update test_engine.py
puhuk Feb 17, 2023
e942a3c
Update debug mode
puhuk Feb 22, 2023
e77bb0f
Update test_engine.py
puhuk Feb 22, 2023
d1be836
Update
puhuk Feb 22, 2023
b4e2b60
Update events.py
puhuk Feb 22, 2023
eed1689
autopep8 fix
puhuk Feb 22, 2023
e735d07
Update events.py
puhuk Feb 22, 2023
a2155b4
Merge branch '1992' of https://github.com/puhuk/ignite into 1992
puhuk Feb 22, 2023
88a5482
Update events.py
puhuk Feb 22, 2023
7bbedb9
Update events.py
puhuk Feb 22, 2023
fe40ba0
Update engine.py
puhuk Feb 22, 2023
c5f8e07
Update
puhuk Feb 24, 2023
0720841
autopep8 fix
puhuk Feb 24, 2023
fbaf3a4
Update engine.py
puhuk Feb 24, 2023
49f08ed
Merge branch '1992' of https://github.com/puhuk/ignite into 1992
puhuk Feb 24, 2023
a9bc500
Update engine.py
puhuk Feb 24, 2023
0cd306f
Update engine.py
puhuk Feb 24, 2023
3e195f2
Update
puhuk Feb 24, 2023
c704986
autopep8 fix
puhuk Feb 24, 2023
6a5c589
Update engine.py
puhuk Feb 24, 2023
5598a14
Merge branch '1992' of https://github.com/puhuk/ignite into 1992
puhuk Feb 24, 2023
49364d6
Update
puhuk Feb 24, 2023
05668c1
Update
puhuk Feb 24, 2023
c93ba56
Update engine.py
puhuk Feb 24, 2023
eca4f9f
Update engine.py
puhuk Feb 24, 2023
c8bcaa0
autopep8 fix
puhuk Feb 24, 2023
984b3fe
Update engine.py
puhuk Feb 24, 2023
aa902b4
Merge branch '1992' of https://github.com/puhuk/ignite into 1992
puhuk Feb 24, 2023
0f9ee21
Update engine.py
puhuk Feb 24, 2023
90f549f
Update engine.py
puhuk Feb 24, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions ignite/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,23 @@ def compute_mean_std(engine, batch):
_state_dict_all_req_keys = ("epoch_length", "max_epochs")
_state_dict_one_of_opt_keys = ("iteration", "epoch")

class debug_mode(EventEnum):
DEBUG_NONE = 0
DEBUG_EVENTS = 1
DEBUG_OUTPUT = 2
DEBUG_GRADS = 4

def __iter__(self) -> Iterator:
return iter(self.name)

def __int__(self) -> str:
return self.value

DEBUG_NONE = debug_mode.DEBUG_NONE
DEBUG_EVENTS = debug_mode.DEBUG_EVENTS
DEBUG_OUTPUT = debug_mode.DEBUG_OUTPUT
DEBUG_GRADS = debug_mode.DEBUG_GRADS

# Flag to disable engine._internal_run as generator feature for BC
interrupt_resume_enabled = True

Expand All @@ -143,6 +160,8 @@ def __init__(self, process_function: Callable[["Engine", Any], Any]):
self._dataloader_iter: Optional[Iterator[Any]] = None
self._init_iter: Optional[int] = None

self.debug_level = 0

self.register_events(*Events)

if self._process_function is None:
Expand Down Expand Up @@ -425,6 +444,28 @@ def _fire_event(self, event_name: Any, *event_args: Any, **event_kwargs: Any) ->
first, others = ((args[0],), args[1:]) if (args and args[0] == self) else ((), args)
func(*first, *(event_args + others), **kwargs)

def debug(self, level: debug_mode = DEBUG_NONE, config: Union[Dict, Any] = None) -> None:
if isinstance(level, int):
raise ValueError(
f"Unknown event name '{level}'. Level should be combinations of Engine.DEBUG_NONE, "
f"Engine.DEBUG_EVENTS, Engine.DEBUG_OUTPUT, Engine.DEBUG_GRADS"
)
self.lr = config["optimizer"].param_groups[0]["lr"]
self.layer = config["layer"]

log = ""
for item in level:
if item == Engine.DEBUG_NONE:
log += ""
elif item == Engine.DEBUG_EVENTS:
log += f"{self.state.epoch} | {self.state.iteration}, Firing handlers for event {self.last_event_name} "
elif item == Engine.DEBUG_OUTPUT:
log += f"Loss : {self.state.output}, LR : {self.lr} "
elif item == Engine.DEBUG_GRADS:
log += f"Gradients : {self.layer.weight.grad} "

self.logger.debug(log)

def fire_event(self, event_name: Any) -> None:
"""Execute all the handlers associated with given event.

Expand Down
1 change: 1 addition & 0 deletions ignite/engine/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,7 @@ def __init__(self, **kwargs: Any) -> None:
self.batch: Optional[int] = None
self.metrics: Dict[str, Any] = {}
self.dataloader: Optional[Union[DataLoader, Iterable[Any]]] = None
self.debug_config: Dict[str, Any] = {}
self.seed: Optional[int] = None
self.times: Dict[str, Optional[float]] = {
Events.EPOCH_COMPLETED.name: None,
Expand Down
73 changes: 73 additions & 0 deletions tests/ignite/engine/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1388,3 +1388,76 @@ def check_iter_epoch():
state = engine.run(data, max_epochs=max_epochs)
assert state.iteration == max_epochs * len(data) and state.epoch == max_epochs
assert num_calls_check_iter_epoch == 1


def test_engine_debug():
import torch.nn.functional as F
from torch import nn
from torch.optim import SGD
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision.transforms import Compose, ToTensor

from ignite.engine import create_supervised_trainer

class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
self.conv2_drop = nn.Dropout2d()
self.fc1 = nn.Linear(320, 50)
self.fc2 = nn.Linear(50, 10)

def forward(self, x):
x = F.relu(F.max_pool2d(self.conv1(x), 2))
x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
x = x.view(-1, 320)
x = F.relu(self.fc1(x))
x = F.dropout(x, training=self.training)
x = self.fc2(x)
return F.log_softmax(x, dim=-1)

def _test():
train_loader = DataLoader(
MNIST(download=True, root=".", transform=Compose([ToTensor()]), train=True),
batch_size=64,
shuffle=True,
)

model = Net()
device = "cpu"
log_interval = 10
epochs = 10

if torch.cuda.is_available():
device = "cuda"

model.to(device) # Move model before creating optimizer
optimizer = SGD(model.parameters(), lr=0.01, momentum=0.5)
criterion = nn.NLLLoss()
trainer = create_supervised_trainer(model, optimizer, criterion, device=device)
debug_config = {}
debug_config["optimizer"] = optimizer
debug_config["layer"] = model.fc2

def log_training_debug_events(engine):
trainer.debug(level=Engine.DEBUG_EVENTS, config=debug_config)

def log_training_debug_outputs(engine):
trainer.debug(level=Engine.DEBUG_OUTPUT, config=debug_config)

def log_training_debug_grads(engine):
trainer.debug(level=Engine.DEBUG_GRADS, config=debug_config)

def log_training_debug_int(engine):
with pytest.raises(
ValueError,
match=r"Unknown event name '2'. Level should be combinations of Engine.DEBUG_NONE, "
r"Engine.DEBUG_EVENTS, Engine.DEBUG_OUTPUT, Engine.DEBUG_GRADS",
):
trainer.debug(level=2, config=debug_config)

trainer.run(train_loader, max_epochs=epochs)

_test()