Skip to content

Commit

Permalink
Merge pull request #9 from benoitmartin88/rc
Browse files Browse the repository at this point in the history
Hotfix sigint (#8)
  • Loading branch information
benoitmartin88 authored Oct 10, 2019
2 parents 2a4d0a3 + 765fa05 commit 7d73acd
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 3 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
.DS_Store

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
Expand Down
7 changes: 6 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,12 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
# Unreleased


# # [0.2.0] - 2019-09-20
# [0.2.1] - 2019-10-10
## Bugfix
- Fix ignored SIGINT


# [0.2.0] - 2019-09-20
## New
- Add `ModuleTrainer.evaluate` method
- Add CsvWriter to evaluate method
Expand Down
2 changes: 1 addition & 1 deletion pytorchtrainer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@

__version__ = '0.2.0'
__version__ = '0.2.1'


from .trainer import create_default_trainer, ModuleTrainer, State
Expand Down
4 changes: 3 additions & 1 deletion pytorchtrainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def __init__(self, model: nn.Module, optimizer: optim.Optimizer, train_function,
signal.signal(signal.SIGINT, self.__graceful_exit)

self.state = State()
self.stop_condition = None
self.model = model
self.optimizer = optimizer
self.train_function = train_function
Expand Down Expand Up @@ -112,10 +113,11 @@ def train(self, train_dataloader: torch.utils.data.DataLoader, max_epochs=100, s
if not callable(stop_condition):
raise TypeError("Argument stop_condition should be a function.")

self.stop_condition = stop_condition
self.model.train() # set the module to training mode

train_start = time()
while self.state.current_epoch < max_epochs and not stop_condition(self.state):
while self.state.current_epoch < max_epochs and not self.stop_condition(self.state):
self.model.zero_grad()

for self.state.current_iteration, batch in enumerate(train_dataloader):
Expand Down
28 changes: 28 additions & 0 deletions test/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,3 +188,31 @@ def test_add_progressbar_metric_errors(self):
self.assertRaises(TypeError, trainer.add_progressbar_metric, None, [])
self.assertRaises(TypeError, trainer.add_progressbar_metric, "", [None])
self.assertRaises(RuntimeError, trainer.add_progressbar_metric, "", [CsvWriter()])

def test_sigint(self):
max_epoch = 100

def _run(q):
print("_run from sub-process")
trainer = create_default_trainer(self.model, self.optimizer, self.criterion)
trainer.train(self.train_loader, max_epochs=max_epoch)
q.put(trainer.state.current_epoch)

import multiprocessing as mp
import os
import signal
import time

q = mp.Queue()
p = mp.Process(target=_run, args=(q,))
p.start()

# wait 1 second before sending sigint
time.sleep(1)
os.kill(p.pid, signal.SIGINT)

last_epoch = q.get()

p.join()

self.assertLess(last_epoch, max_epoch)

0 comments on commit 7d73acd

Please sign in to comment.