-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adjust the TuneCallback class for Pytorch-lightning v1.6+ (removing o… #27770
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,13 +5,193 @@ | |
from ray import tune | ||
|
||
import os | ||
from importlib_metadata import version | ||
from packaging.version import parse as v_parse | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
ray_pl_use_master = v_parse(version("pytorch_lightning")) >= v_parse("1.6") | ||
|
||
|
||
class TuneCallback(Callback): | ||
"""Base class for Tune's PyTorch Lightning callbacks.""" | ||
|
||
_allowed = [ | ||
"fit_start", | ||
"fit_end", | ||
"sanity_check_start", | ||
"sanity_check_end", | ||
"train_epoch_start", | ||
"train_epoch_end", | ||
"validation_epoch_start", | ||
"validation_epoch_end", | ||
"test_epoch_start", | ||
"test_epoch_end", | ||
"train_batch_start", | ||
"train_batch_end", | ||
"validation_batch_start", | ||
"validation_batch_end", | ||
"test_batch_start", | ||
"test_batch_end", | ||
"train_start", | ||
"train_end", | ||
"validation_start", | ||
"validation_end", | ||
"test_start", | ||
"test_end", | ||
"exception", | ||
] | ||
|
||
def __init__(self, on: Union[str, List[str]] = "validation_end"): | ||
if not isinstance(on, list): | ||
on = [on] | ||
if any(w not in self._allowed for w in on): | ||
raise ValueError( | ||
"Invalid trigger time selected: {}. Must be one of {}".format( | ||
on, self._allowed | ||
) | ||
) | ||
self._on = on | ||
|
||
def _handle(self, trainer: Trainer, pl_module: Optional[LightningModule]): | ||
raise NotImplementedError | ||
|
||
def on_init_start(self, trainer: Trainer): | ||
if "init_start" in self._on: | ||
self._handle(trainer, None) | ||
|
||
def on_init_end(self, trainer: Trainer): | ||
if "init_end" in self._on: | ||
self._handle(trainer, None) | ||
|
||
def on_fit_start( | ||
self, trainer: Trainer, pl_module: Optional[LightningModule] = None | ||
): | ||
if "fit_start" in self._on: | ||
self._handle(trainer, None) | ||
|
||
def on_fit_end(self, trainer: Trainer, pl_module: Optional[LightningModule] = None): | ||
if "fit_end" in self._on: | ||
self._handle(trainer, None) | ||
|
||
def on_sanity_check_start(self, trainer: Trainer, pl_module: LightningModule): | ||
if "sanity_check_start" in self._on: | ||
self._handle(trainer, pl_module) | ||
|
||
def on_sanity_check_end(self, trainer: Trainer, pl_module: LightningModule): | ||
if "sanity_check_end" in self._on: | ||
self._handle(trainer, pl_module) | ||
|
||
def on_train_epoch_start(self, trainer: Trainer, pl_module: LightningModule): | ||
if "train_epoch_start" in self._on: | ||
self._handle(trainer, pl_module) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @richardliaw should we ensure backwards compatible support for PTL < 1.6? If so, we can do some versioning checks of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So I looked into implementing this. My proposed solution would be to have an additional class, something like Additionally in the tests, we'd need to test both, right? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah we should ensure backwards compat support. Having a backwards compat class seems good to me, thanks! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @richardliaw Awesome, I've added those changes in the newest commit |
||
|
||
def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule): | ||
if "train_epoch_end" in self._on: | ||
self._handle(trainer, pl_module) | ||
|
||
def on_validation_epoch_start(self, trainer: Trainer, pl_module: LightningModule): | ||
if "validation_epoch_start" in self._on: | ||
self._handle(trainer, pl_module) | ||
|
||
def on_validation_epoch_end(self, trainer: Trainer, pl_module: LightningModule): | ||
if "validation_epoch_end" in self._on: | ||
self._handle(trainer, pl_module) | ||
|
||
def on_test_epoch_start(self, trainer: Trainer, pl_module: LightningModule): | ||
if "test_epoch_start" in self._on: | ||
self._handle(trainer, pl_module) | ||
|
||
def on_test_epoch_end(self, trainer: Trainer, pl_module: LightningModule): | ||
if "test_epoch_end" in self._on: | ||
self._handle(trainer, pl_module) | ||
|
||
def on_train_batch_start(self, trainer: Trainer, pl_module: LightningModule): | ||
if "train_batch_start" in self._on: | ||
self._handle(trainer, pl_module) | ||
|
||
def on_train_batch_end(self, trainer: Trainer, pl_module: LightningModule): | ||
if "train_batch_end" in self._on: | ||
self._handle(trainer, pl_module) | ||
|
||
def on_validation_batch_start( | ||
self, | ||
trainer: Trainer, | ||
pl_module: LightningModule, | ||
batch, | ||
batch_idx, | ||
dataloader_idx, | ||
): | ||
if "validation_batch_start" in self._on: | ||
self._handle(trainer, pl_module) | ||
|
||
def on_validation_batch_end( | ||
self, | ||
trainer: Trainer, | ||
pl_module: LightningModule, | ||
outputs, | ||
batch, | ||
batch_idx, | ||
dataloader_idx, | ||
): | ||
if "validation_batch_end" in self._on: | ||
self._handle(trainer, pl_module) | ||
|
||
def on_test_batch_start( | ||
self, | ||
trainer: Trainer, | ||
pl_module: LightningModule, | ||
batch, | ||
batch_idx, | ||
dataloader_idx, | ||
): | ||
if "test_batch_start" in self._on: | ||
self._handle(trainer, pl_module) | ||
|
||
def on_test_batch_end( | ||
self, | ||
trainer: Trainer, | ||
pl_module: LightningModule, | ||
outputs, | ||
batch, | ||
batch_idx, | ||
dataloader_idx, | ||
): | ||
if "test_batch_end" in self._on: | ||
self._handle(trainer, pl_module) | ||
|
||
def on_train_start(self, trainer: Trainer, pl_module: LightningModule): | ||
if "train_start" in self._on: | ||
self._handle(trainer, pl_module) | ||
|
||
def on_train_end(self, trainer: Trainer, pl_module: LightningModule): | ||
if "train_end" in self._on: | ||
self._handle(trainer, pl_module) | ||
|
||
def on_validation_start(self, trainer: Trainer, pl_module: LightningModule): | ||
if "validation_start" in self._on: | ||
self._handle(trainer, pl_module) | ||
|
||
def on_validation_end(self, trainer: Trainer, pl_module: LightningModule): | ||
if "validation_end" in self._on: | ||
self._handle(trainer, pl_module) | ||
|
||
def on_test_start(self, trainer: Trainer, pl_module: LightningModule): | ||
if "test_start" in self._on: | ||
self._handle(trainer, pl_module) | ||
|
||
def on_test_end(self, trainer: Trainer, pl_module: LightningModule): | ||
if "test_end" in self._on: | ||
self._handle(trainer, pl_module) | ||
|
||
def on_exception(self, trainer: Trainer, pl_module: LightningModule): | ||
if "exception" in self._on: | ||
self._handle(trainer, pl_module) | ||
|
||
|
||
class TuneCallbackBackwardCompat(Callback): | ||
"""Base class for Tune's PyTorch Lightning callbacks.""" | ||
|
||
_allowed = [ | ||
"init_start", | ||
"init_end", | ||
|
@@ -167,7 +347,13 @@ def on_keyboard_interrupt(self, trainer: Trainer, pl_module: LightningModule): | |
self._handle(trainer, pl_module) | ||
|
||
|
||
class TuneReportCallback(TuneCallback): | ||
if ray_pl_use_master: # pytorch-lightning >= 1.6 | ||
BaseTuneCallback = TuneCallback | ||
else: | ||
BaseTuneCallback = TuneCallbackBackwardCompat | ||
|
||
|
||
class TuneReportCallback(BaseTuneCallback): | ||
"""PyTorch Lightning to Ray Tune reporting callback | ||
|
||
Reports metrics to Ray Tune. | ||
|
@@ -180,7 +366,7 @@ class TuneReportCallback(TuneCallback): | |
value will be the metric key reported to PyTorch Lightning. | ||
on: When to trigger checkpoint creations. Must be one of | ||
the PyTorch Lightning event hooks (less the ``on_``), e.g. | ||
"batch_start", or "train_end". Defaults to "validation_end". | ||
"train_batch_start", or "train_end". Defaults to "validation_end". | ||
|
||
Example: | ||
|
||
|
@@ -239,7 +425,7 @@ def _handle(self, trainer: Trainer, pl_module: LightningModule): | |
tune.report(**report_dict) | ||
|
||
|
||
class _TuneCheckpointCallback(TuneCallback): | ||
class _TuneCheckpointCallback(BaseTuneCallback): | ||
"""PyTorch Lightning checkpoint callback | ||
|
||
Saves checkpoints after each validation step. | ||
|
@@ -253,7 +439,7 @@ class _TuneCheckpointCallback(TuneCallback): | |
directory. Defaults to "checkpoint". | ||
on: When to trigger checkpoint creations. Must be one of | ||
the PyTorch Lightning event hooks (less the ``on_``), e.g. | ||
"batch_start", or "train_end". Defaults to "validation_end". | ||
"train_batch_start", or "train_end". Defaults to "validation_end". | ||
|
||
|
||
""" | ||
|
@@ -272,7 +458,7 @@ def _handle(self, trainer: Trainer, pl_module: LightningModule): | |
trainer.save_checkpoint(os.path.join(checkpoint_dir, self._filename)) | ||
|
||
|
||
class TuneReportCheckpointCallback(TuneCallback): | ||
class TuneReportCheckpointCallback(BaseTuneCallback): | ||
"""PyTorch Lightning report and checkpoint callback | ||
|
||
Saves checkpoints after each validation step. Also reports metrics to Tune, | ||
|
@@ -288,7 +474,7 @@ class TuneReportCheckpointCallback(TuneCallback): | |
directory. Defaults to "checkpoint". | ||
on: When to trigger checkpoint creations. Must be one of | ||
the PyTorch Lightning event hooks (less the ``on_``), e.g. | ||
"batch_start", or "train_end". Defaults to "validation_end". | ||
"train_batch_start", or "train_end". Defaults to "validation_end". | ||
|
||
|
||
Example: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,13 +8,18 @@ | |
|
||
from torch.utils.data import DataLoader, Dataset | ||
|
||
from importlib_metadata import version | ||
from packaging.version import parse as v_parse | ||
|
||
from ray import tune | ||
from ray.tune.integration.pytorch_lightning import ( | ||
TuneReportCallback, | ||
TuneReportCheckpointCallback, | ||
_TuneCheckpointCallback, | ||
) | ||
|
||
ray_pl_use_master = v_parse(version("pytorch_lightning")) >= v_parse("1.6") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Looks like we may need the same changes in https://github.com/ray-project/ray/blob/master/python/ray/tests/ray_lightning/simple_tune.py |
||
|
||
|
||
class _MockDataset(Dataset): | ||
def __init__(self, values): | ||
|
@@ -104,14 +109,19 @@ def testCheckpointCallback(self): | |
|
||
def train(config): | ||
module = _MockModule(10.0, 20.0) | ||
trainer = pl.Trainer( | ||
max_epochs=1, | ||
callbacks=[ | ||
if ray_pl_use_master: | ||
callbacks = [ | ||
_TuneCheckpointCallback( | ||
"trainer.ckpt", on=["train_batch_end", "train_end"] | ||
) | ||
] | ||
else: | ||
callbacks = [ | ||
_TuneCheckpointCallback( | ||
"trainer.ckpt", on=["batch_end", "train_end"] | ||
) | ||
], | ||
) | ||
] | ||
trainer = pl.Trainer(max_epochs=1, callbacks=callbacks) | ||
trainer.fit(module) | ||
|
||
analysis = tune.run( | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Would rename this to something more descriptive since master may change in the future (e.g.
use_ptl_1_6_api
)