Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adjust the TuneCallback class for Pytorch-lightning v1.6+ (removing o… #27770

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
198 changes: 192 additions & 6 deletions python/ray/tune/integration/pytorch_lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,193 @@
from ray import tune

import os
from importlib_metadata import version
from packaging.version import parse as v_parse

logger = logging.getLogger(__name__)

ray_pl_use_master = v_parse(version("pytorch_lightning")) >= v_parse("1.6")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Would rename this to something more descriptive since master may change in the future (e.g. use_ptl_1_6_api)



class TuneCallback(Callback):
"""Base class for Tune's PyTorch Lightning callbacks."""

_allowed = [
"fit_start",
"fit_end",
"sanity_check_start",
"sanity_check_end",
"train_epoch_start",
"train_epoch_end",
"validation_epoch_start",
"validation_epoch_end",
"test_epoch_start",
"test_epoch_end",
"train_batch_start",
"train_batch_end",
"validation_batch_start",
"validation_batch_end",
"test_batch_start",
"test_batch_end",
"train_start",
"train_end",
"validation_start",
"validation_end",
"test_start",
"test_end",
"exception",
]

def __init__(self, on: Union[str, List[str]] = "validation_end"):
if not isinstance(on, list):
on = [on]
if any(w not in self._allowed for w in on):
raise ValueError(
"Invalid trigger time selected: {}. Must be one of {}".format(
on, self._allowed
)
)
self._on = on

def _handle(self, trainer: Trainer, pl_module: Optional[LightningModule]):
raise NotImplementedError

def on_init_start(self, trainer: Trainer):
if "init_start" in self._on:
self._handle(trainer, None)

def on_init_end(self, trainer: Trainer):
if "init_end" in self._on:
self._handle(trainer, None)

def on_fit_start(
self, trainer: Trainer, pl_module: Optional[LightningModule] = None
):
if "fit_start" in self._on:
self._handle(trainer, None)

def on_fit_end(self, trainer: Trainer, pl_module: Optional[LightningModule] = None):
if "fit_end" in self._on:
self._handle(trainer, None)

def on_sanity_check_start(self, trainer: Trainer, pl_module: LightningModule):
if "sanity_check_start" in self._on:
self._handle(trainer, pl_module)

def on_sanity_check_end(self, trainer: Trainer, pl_module: LightningModule):
if "sanity_check_end" in self._on:
self._handle(trainer, pl_module)

def on_train_epoch_start(self, trainer: Trainer, pl_module: LightningModule):
if "train_epoch_start" in self._on:
self._handle(trainer, pl_module)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@richardliaw should we ensure backwards compatible support for PTL < 1.6?

If so, we can do some versioning checks of pytorch_lightning similar to https://github.com/ray-project/ray/pull/27395/files

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I looked into implementing this. My proposed solution would be to have an additional class, something like TuneCallbackBackwardCompat that is the old class, and then we inherit from either TuneCallback or the backward compatible version in the remaining classes TuneReportCallback, _TuneCheckpointCallback, TuneReportCheckpointCallback.

Additionally in the tests, we'd need to test both, right?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah we should ensure backwards compat support. Having a backwards compat class seems good to me, thanks!

Copy link
Contributor Author

@davzaman davzaman Aug 19, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@richardliaw Awesome, I've added those changes in the newest commit


def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule):
if "train_epoch_end" in self._on:
self._handle(trainer, pl_module)

def on_validation_epoch_start(self, trainer: Trainer, pl_module: LightningModule):
if "validation_epoch_start" in self._on:
self._handle(trainer, pl_module)

def on_validation_epoch_end(self, trainer: Trainer, pl_module: LightningModule):
if "validation_epoch_end" in self._on:
self._handle(trainer, pl_module)

def on_test_epoch_start(self, trainer: Trainer, pl_module: LightningModule):
if "test_epoch_start" in self._on:
self._handle(trainer, pl_module)

def on_test_epoch_end(self, trainer: Trainer, pl_module: LightningModule):
if "test_epoch_end" in self._on:
self._handle(trainer, pl_module)

def on_train_batch_start(self, trainer: Trainer, pl_module: LightningModule):
if "train_batch_start" in self._on:
self._handle(trainer, pl_module)

def on_train_batch_end(self, trainer: Trainer, pl_module: LightningModule):
if "train_batch_end" in self._on:
self._handle(trainer, pl_module)

def on_validation_batch_start(
self,
trainer: Trainer,
pl_module: LightningModule,
batch,
batch_idx,
dataloader_idx,
):
if "validation_batch_start" in self._on:
self._handle(trainer, pl_module)

def on_validation_batch_end(
self,
trainer: Trainer,
pl_module: LightningModule,
outputs,
batch,
batch_idx,
dataloader_idx,
):
if "validation_batch_end" in self._on:
self._handle(trainer, pl_module)

def on_test_batch_start(
self,
trainer: Trainer,
pl_module: LightningModule,
batch,
batch_idx,
dataloader_idx,
):
if "test_batch_start" in self._on:
self._handle(trainer, pl_module)

def on_test_batch_end(
self,
trainer: Trainer,
pl_module: LightningModule,
outputs,
batch,
batch_idx,
dataloader_idx,
):
if "test_batch_end" in self._on:
self._handle(trainer, pl_module)

def on_train_start(self, trainer: Trainer, pl_module: LightningModule):
if "train_start" in self._on:
self._handle(trainer, pl_module)

def on_train_end(self, trainer: Trainer, pl_module: LightningModule):
if "train_end" in self._on:
self._handle(trainer, pl_module)

def on_validation_start(self, trainer: Trainer, pl_module: LightningModule):
if "validation_start" in self._on:
self._handle(trainer, pl_module)

def on_validation_end(self, trainer: Trainer, pl_module: LightningModule):
if "validation_end" in self._on:
self._handle(trainer, pl_module)

def on_test_start(self, trainer: Trainer, pl_module: LightningModule):
if "test_start" in self._on:
self._handle(trainer, pl_module)

def on_test_end(self, trainer: Trainer, pl_module: LightningModule):
if "test_end" in self._on:
self._handle(trainer, pl_module)

def on_exception(self, trainer: Trainer, pl_module: LightningModule):
if "exception" in self._on:
self._handle(trainer, pl_module)


class TuneCallbackBackwardCompat(Callback):
"""Base class for Tune's PyTorch Lightning callbacks."""

_allowed = [
"init_start",
"init_end",
Expand Down Expand Up @@ -167,7 +347,13 @@ def on_keyboard_interrupt(self, trainer: Trainer, pl_module: LightningModule):
self._handle(trainer, pl_module)


class TuneReportCallback(TuneCallback):
if ray_pl_use_master: # pytorch-lightning >= 1.6
BaseTuneCallback = TuneCallback
else:
BaseTuneCallback = TuneCallbackBackwardCompat


class TuneReportCallback(BaseTuneCallback):
"""PyTorch Lightning to Ray Tune reporting callback

Reports metrics to Ray Tune.
Expand All @@ -180,7 +366,7 @@ class TuneReportCallback(TuneCallback):
value will be the metric key reported to PyTorch Lightning.
on: When to trigger checkpoint creations. Must be one of
the PyTorch Lightning event hooks (less the ``on_``), e.g.
"batch_start", or "train_end". Defaults to "validation_end".
"train_batch_start", or "train_end". Defaults to "validation_end".

Example:

Expand Down Expand Up @@ -239,7 +425,7 @@ def _handle(self, trainer: Trainer, pl_module: LightningModule):
tune.report(**report_dict)


class _TuneCheckpointCallback(TuneCallback):
class _TuneCheckpointCallback(BaseTuneCallback):
"""PyTorch Lightning checkpoint callback

Saves checkpoints after each validation step.
Expand All @@ -253,7 +439,7 @@ class _TuneCheckpointCallback(TuneCallback):
directory. Defaults to "checkpoint".
on: When to trigger checkpoint creations. Must be one of
the PyTorch Lightning event hooks (less the ``on_``), e.g.
"batch_start", or "train_end". Defaults to "validation_end".
"train_batch_start", or "train_end". Defaults to "validation_end".


"""
Expand All @@ -272,7 +458,7 @@ def _handle(self, trainer: Trainer, pl_module: LightningModule):
trainer.save_checkpoint(os.path.join(checkpoint_dir, self._filename))


class TuneReportCheckpointCallback(TuneCallback):
class TuneReportCheckpointCallback(BaseTuneCallback):
"""PyTorch Lightning report and checkpoint callback

Saves checkpoints after each validation step. Also reports metrics to Tune,
Expand All @@ -288,7 +474,7 @@ class TuneReportCheckpointCallback(TuneCallback):
directory. Defaults to "checkpoint".
on: When to trigger checkpoint creations. Must be one of
the PyTorch Lightning event hooks (less the ``on_``), e.g.
"batch_start", or "train_end". Defaults to "validation_end".
"train_batch_start", or "train_end". Defaults to "validation_end".


Example:
Expand Down
20 changes: 15 additions & 5 deletions python/ray/tune/tests/test_integration_pytorch_lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,18 @@

from torch.utils.data import DataLoader, Dataset

from importlib_metadata import version
from packaging.version import parse as v_parse

from ray import tune
from ray.tune.integration.pytorch_lightning import (
TuneReportCallback,
TuneReportCheckpointCallback,
_TuneCheckpointCallback,
)

ray_pl_use_master = v_parse(version("pytorch_lightning")) >= v_parse("1.6")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.



class _MockDataset(Dataset):
def __init__(self, values):
Expand Down Expand Up @@ -104,14 +109,19 @@ def testCheckpointCallback(self):

def train(config):
module = _MockModule(10.0, 20.0)
trainer = pl.Trainer(
max_epochs=1,
callbacks=[
if ray_pl_use_master:
callbacks = [
_TuneCheckpointCallback(
"trainer.ckpt", on=["train_batch_end", "train_end"]
)
]
else:
callbacks = [
_TuneCheckpointCallback(
"trainer.ckpt", on=["batch_end", "train_end"]
)
],
)
]
trainer = pl.Trainer(max_epochs=1, callbacks=callbacks)
trainer.fit(module)

analysis = tune.run(
Expand Down