From 9649a1472e289b4100f97659a9e0e1a5f1c04532 Mon Sep 17 00:00:00 2001 From: Davina Zaman Date: Wed, 10 Aug 2022 16:17:20 -0700 Subject: [PATCH 1/2] Adjust the TuneCallback class for Pytorch-lightning v1.6+ (removing or modifying deprecated functions), reorganized the order to match the rest of the functions, along with the integration tests. Signed-off-by: Davina Zaman --- .../ray/tune/integration/pytorch_lightning.py | 62 ++++++++++++------- .../test_integration_pytorch_lightning.py | 2 +- 2 files changed, 41 insertions(+), 23 deletions(-) diff --git a/python/ray/tune/integration/pytorch_lightning.py b/python/ray/tune/integration/pytorch_lightning.py index 82e5cd3e427b..c164b5bfc171 100644 --- a/python/ray/tune/integration/pytorch_lightning.py +++ b/python/ray/tune/integration/pytorch_lightning.py @@ -13,27 +13,29 @@ class TuneCallback(Callback): """Base class for Tune's PyTorch Lightning callbacks.""" _allowed = [ - "init_start", - "init_end", "fit_start", "fit_end", "sanity_check_start", "sanity_check_end", - "epoch_start", - "epoch_end", - "batch_start", + "train_epoch_start", + "train_epoch_end", + "validation_epoch_start", + "validation_epoch_end", + "test_epoch_start", + "test_epoch_end", + "train_batch_start", + "train_batch_end", "validation_batch_start", "validation_batch_end", "test_batch_start", "test_batch_end", - "batch_end", "train_start", "train_end", "validation_start", "validation_end", "test_start", "test_end", - "keyboard_interrupt", + "exception", ] def __init__(self, on: Union[str, List[str]] = "validation_end"): @@ -76,16 +78,36 @@ def on_sanity_check_end(self, trainer: Trainer, pl_module: LightningModule): if "sanity_check_end" in self._on: self._handle(trainer, pl_module) - def on_epoch_start(self, trainer: Trainer, pl_module: LightningModule): - if "epoch_start" in self._on: + def on_train_epoch_start(self, trainer: Trainer, pl_module: LightningModule): + if "train_epoch_start" in self._on: self._handle(trainer, pl_module) - def on_epoch_end(self, trainer: Trainer, pl_module: LightningModule): - if "epoch_end" in self._on: + def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule): + if "train_epoch_end" in self._on: self._handle(trainer, pl_module) - def on_batch_start(self, trainer: Trainer, pl_module: LightningModule): - if "batch_start" in self._on: + def on_validation_epoch_start(self, trainer: Trainer, pl_module: LightningModule): + if "validation_epoch_start" in self._on: + self._handle(trainer, pl_module) + + def on_validation_epoch_end(self, trainer: Trainer, pl_module: LightningModule): + if "validation_epoch_end" in self._on: + self._handle(trainer, pl_module) + + def on_test_epoch_start(self, trainer: Trainer, pl_module: LightningModule): + if "test_epoch_start" in self._on: + self._handle(trainer, pl_module) + + def on_test_epoch_end(self, trainer: Trainer, pl_module: LightningModule): + if "test_epoch_end" in self._on: + self._handle(trainer, pl_module) + + def on_train_batch_start(self, trainer: Trainer, pl_module: LightningModule): + if "train_batch_start" in self._on: + self._handle(trainer, pl_module) + + def on_train_batch_end(self, trainer: Trainer, pl_module: LightningModule): + if "train_batch_end" in self._on: self._handle(trainer, pl_module) def on_validation_batch_start( @@ -134,10 +156,6 @@ def on_test_batch_end( if "test_batch_end" in self._on: self._handle(trainer, pl_module) - def on_batch_end(self, trainer: Trainer, pl_module: LightningModule): - if "batch_end" in self._on: - self._handle(trainer, pl_module) - def on_train_start(self, trainer: Trainer, pl_module: LightningModule): if "train_start" in self._on: self._handle(trainer, pl_module) @@ -162,8 +180,8 @@ def on_test_end(self, trainer: Trainer, pl_module: LightningModule): if "test_end" in self._on: self._handle(trainer, pl_module) - def on_keyboard_interrupt(self, trainer: Trainer, pl_module: LightningModule): - if "keyboard_interrupt" in self._on: + def on_exception(self, trainer: Trainer, pl_module: LightningModule): + if "exception" in self._on: self._handle(trainer, pl_module) @@ -180,7 +198,7 @@ class TuneReportCallback(TuneCallback): value will be the metric key reported to PyTorch Lightning. on: When to trigger checkpoint creations. Must be one of the PyTorch Lightning event hooks (less the ``on_``), e.g. - "batch_start", or "train_end". Defaults to "validation_end". + "train_batch_start", or "train_end". Defaults to "validation_end". Example: @@ -253,7 +271,7 @@ class _TuneCheckpointCallback(TuneCallback): directory. Defaults to "checkpoint". on: When to trigger checkpoint creations. Must be one of the PyTorch Lightning event hooks (less the ``on_``), e.g. - "batch_start", or "train_end". Defaults to "validation_end". + "train_batch_start", or "train_end". Defaults to "validation_end". """ @@ -288,7 +306,7 @@ class TuneReportCheckpointCallback(TuneCallback): directory. Defaults to "checkpoint". on: When to trigger checkpoint creations. Must be one of the PyTorch Lightning event hooks (less the ``on_``), e.g. - "batch_start", or "train_end". Defaults to "validation_end". + "train_batch_start", or "train_end". Defaults to "validation_end". Example: diff --git a/python/ray/tune/tests/test_integration_pytorch_lightning.py b/python/ray/tune/tests/test_integration_pytorch_lightning.py index a9da351d087f..af78e5dc92a6 100644 --- a/python/ray/tune/tests/test_integration_pytorch_lightning.py +++ b/python/ray/tune/tests/test_integration_pytorch_lightning.py @@ -108,7 +108,7 @@ def train(config): max_epochs=1, callbacks=[ _TuneCheckpointCallback( - "trainer.ckpt", on=["batch_end", "train_end"] + "trainer.ckpt", on=["train_batch_end", "train_end"] ) ], ) From fd655bc97e0a6aa6e05511756f3e68667f974b88 Mon Sep 17 00:00:00 2001 From: Davina Zaman Date: Fri, 19 Aug 2022 15:31:59 -0700 Subject: [PATCH 2/2] add backwards compatible pytorch-lightning tunecallback class and adjust tests Signed-off-by: Davina Zaman --- .../ray/tune/integration/pytorch_lightning.py | 174 +++++++++++++++++- .../test_integration_pytorch_lightning.py | 20 +- 2 files changed, 186 insertions(+), 8 deletions(-) diff --git a/python/ray/tune/integration/pytorch_lightning.py b/python/ray/tune/integration/pytorch_lightning.py index c164b5bfc171..8583eccd194f 100644 --- a/python/ray/tune/integration/pytorch_lightning.py +++ b/python/ray/tune/integration/pytorch_lightning.py @@ -5,9 +5,13 @@ from ray import tune import os +from importlib_metadata import version +from packaging.version import parse as v_parse logger = logging.getLogger(__name__) +ray_pl_use_master = v_parse(version("pytorch_lightning")) >= v_parse("1.6") + class TuneCallback(Callback): """Base class for Tune's PyTorch Lightning callbacks.""" @@ -185,7 +189,171 @@ def on_exception(self, trainer: Trainer, pl_module: LightningModule): self._handle(trainer, pl_module) -class TuneReportCallback(TuneCallback): +class TuneCallbackBackwardCompat(Callback): + """Base class for Tune's PyTorch Lightning callbacks.""" + + _allowed = [ + "init_start", + "init_end", + "fit_start", + "fit_end", + "sanity_check_start", + "sanity_check_end", + "epoch_start", + "epoch_end", + "batch_start", + "validation_batch_start", + "validation_batch_end", + "test_batch_start", + "test_batch_end", + "batch_end", + "train_start", + "train_end", + "validation_start", + "validation_end", + "test_start", + "test_end", + "keyboard_interrupt", + ] + + def __init__(self, on: Union[str, List[str]] = "validation_end"): + if not isinstance(on, list): + on = [on] + if any(w not in self._allowed for w in on): + raise ValueError( + "Invalid trigger time selected: {}. Must be one of {}".format( + on, self._allowed + ) + ) + self._on = on + + def _handle(self, trainer: Trainer, pl_module: Optional[LightningModule]): + raise NotImplementedError + + def on_init_start(self, trainer: Trainer): + if "init_start" in self._on: + self._handle(trainer, None) + + def on_init_end(self, trainer: Trainer): + if "init_end" in self._on: + self._handle(trainer, None) + + def on_fit_start( + self, trainer: Trainer, pl_module: Optional[LightningModule] = None + ): + if "fit_start" in self._on: + self._handle(trainer, None) + + def on_fit_end(self, trainer: Trainer, pl_module: Optional[LightningModule] = None): + if "fit_end" in self._on: + self._handle(trainer, None) + + def on_sanity_check_start(self, trainer: Trainer, pl_module: LightningModule): + if "sanity_check_start" in self._on: + self._handle(trainer, pl_module) + + def on_sanity_check_end(self, trainer: Trainer, pl_module: LightningModule): + if "sanity_check_end" in self._on: + self._handle(trainer, pl_module) + + def on_epoch_start(self, trainer: Trainer, pl_module: LightningModule): + if "epoch_start" in self._on: + self._handle(trainer, pl_module) + + def on_epoch_end(self, trainer: Trainer, pl_module: LightningModule): + if "epoch_end" in self._on: + self._handle(trainer, pl_module) + + def on_batch_start(self, trainer: Trainer, pl_module: LightningModule): + if "batch_start" in self._on: + self._handle(trainer, pl_module) + + def on_validation_batch_start( + self, + trainer: Trainer, + pl_module: LightningModule, + batch, + batch_idx, + dataloader_idx, + ): + if "validation_batch_start" in self._on: + self._handle(trainer, pl_module) + + def on_validation_batch_end( + self, + trainer: Trainer, + pl_module: LightningModule, + outputs, + batch, + batch_idx, + dataloader_idx, + ): + if "validation_batch_end" in self._on: + self._handle(trainer, pl_module) + + def on_test_batch_start( + self, + trainer: Trainer, + pl_module: LightningModule, + batch, + batch_idx, + dataloader_idx, + ): + if "test_batch_start" in self._on: + self._handle(trainer, pl_module) + + def on_test_batch_end( + self, + trainer: Trainer, + pl_module: LightningModule, + outputs, + batch, + batch_idx, + dataloader_idx, + ): + if "test_batch_end" in self._on: + self._handle(trainer, pl_module) + + def on_batch_end(self, trainer: Trainer, pl_module: LightningModule): + if "batch_end" in self._on: + self._handle(trainer, pl_module) + + def on_train_start(self, trainer: Trainer, pl_module: LightningModule): + if "train_start" in self._on: + self._handle(trainer, pl_module) + + def on_train_end(self, trainer: Trainer, pl_module: LightningModule): + if "train_end" in self._on: + self._handle(trainer, pl_module) + + def on_validation_start(self, trainer: Trainer, pl_module: LightningModule): + if "validation_start" in self._on: + self._handle(trainer, pl_module) + + def on_validation_end(self, trainer: Trainer, pl_module: LightningModule): + if "validation_end" in self._on: + self._handle(trainer, pl_module) + + def on_test_start(self, trainer: Trainer, pl_module: LightningModule): + if "test_start" in self._on: + self._handle(trainer, pl_module) + + def on_test_end(self, trainer: Trainer, pl_module: LightningModule): + if "test_end" in self._on: + self._handle(trainer, pl_module) + + def on_keyboard_interrupt(self, trainer: Trainer, pl_module: LightningModule): + if "keyboard_interrupt" in self._on: + self._handle(trainer, pl_module) + + +if ray_pl_use_master: # pytorch-lightning >= 1.6 + BaseTuneCallback = TuneCallback +else: + BaseTuneCallback = TuneCallbackBackwardCompat + + +class TuneReportCallback(BaseTuneCallback): """PyTorch Lightning to Ray Tune reporting callback Reports metrics to Ray Tune. @@ -257,7 +425,7 @@ def _handle(self, trainer: Trainer, pl_module: LightningModule): tune.report(**report_dict) -class _TuneCheckpointCallback(TuneCallback): +class _TuneCheckpointCallback(BaseTuneCallback): """PyTorch Lightning checkpoint callback Saves checkpoints after each validation step. @@ -290,7 +458,7 @@ def _handle(self, trainer: Trainer, pl_module: LightningModule): trainer.save_checkpoint(os.path.join(checkpoint_dir, self._filename)) -class TuneReportCheckpointCallback(TuneCallback): +class TuneReportCheckpointCallback(BaseTuneCallback): """PyTorch Lightning report and checkpoint callback Saves checkpoints after each validation step. Also reports metrics to Tune, diff --git a/python/ray/tune/tests/test_integration_pytorch_lightning.py b/python/ray/tune/tests/test_integration_pytorch_lightning.py index af78e5dc92a6..e5d78a6a0dfe 100644 --- a/python/ray/tune/tests/test_integration_pytorch_lightning.py +++ b/python/ray/tune/tests/test_integration_pytorch_lightning.py @@ -8,6 +8,9 @@ from torch.utils.data import DataLoader, Dataset +from importlib_metadata import version +from packaging.version import parse as v_parse + from ray import tune from ray.tune.integration.pytorch_lightning import ( TuneReportCallback, @@ -15,6 +18,8 @@ _TuneCheckpointCallback, ) +ray_pl_use_master = v_parse(version("pytorch_lightning")) >= v_parse("1.6") + class _MockDataset(Dataset): def __init__(self, values): @@ -104,14 +109,19 @@ def testCheckpointCallback(self): def train(config): module = _MockModule(10.0, 20.0) - trainer = pl.Trainer( - max_epochs=1, - callbacks=[ + if ray_pl_use_master: + callbacks = [ _TuneCheckpointCallback( "trainer.ckpt", on=["train_batch_end", "train_end"] ) - ], - ) + ] + else: + callbacks = [ + _TuneCheckpointCallback( + "trainer.ckpt", on=["batch_end", "train_end"] + ) + ] + trainer = pl.Trainer(max_epochs=1, callbacks=callbacks) trainer.fit(module) analysis = tune.run(