Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Tune] Dynamically identify PyTorch Lightning Callback hooks #30045

Merged
merged 6 commits into from
Nov 8, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
202 changes: 54 additions & 148 deletions python/ray/tune/integration/pytorch_lightning.py
Original file line number Diff line number Diff line change
@@ -1,172 +1,77 @@
import inspect
import logging
from typing import Dict, List, Optional, Union
from typing import Dict, List, Optional, Type, Union

from pytorch_lightning import Callback, Trainer, LightningModule
from ray import tune
from ray.util import PublicAPI

import os

logger = logging.getLogger(__name__)

# Get all Pytorch Lightning Callback hooks based on whatever PTL version is being used.
_allowed_hooks = {
name
for name, fn in inspect.getmembers(Callback, predicate=inspect.isfunction)
if name.startswith("on_")
}

class TuneCallback(Callback):
"""Base class for Tune's PyTorch Lightning callbacks."""

_allowed = [
"init_start",
"init_end",
"fit_start",
"fit_end",
"sanity_check_start",
"sanity_check_end",
"epoch_start",
"epoch_end",
"batch_start",
"validation_batch_start",
"validation_batch_end",
"test_batch_start",
"test_batch_end",
"batch_end",
"train_start",
"train_end",
"validation_start",
"validation_end",
"test_start",
"test_end",
"keyboard_interrupt",
]

def __init__(self, on: Union[str, List[str]] = "validation_end"):
if not isinstance(on, list):
on = [on]
if any(w not in self._allowed for w in on):
raise ValueError(
"Invalid trigger time selected: {}. Must be one of {}".format(
on, self._allowed
)
)
self._on = on

def _handle(self, trainer: Trainer, pl_module: Optional[LightningModule]):
raise NotImplementedError
def _override_ptl_hooks(callback_cls: Type["TuneCallback"]) -> Type["TuneCallback"]:
"""Overrides all allowed PTL Callback hooks with our custom handle logic."""

def on_init_start(self, trainer: Trainer):
if "init_start" in self._on:
self._handle(trainer, None)
def generate_overridden_hook(fn_name):
def overridden_hook(
self,
trainer: Trainer,
*args,
pl_module: Optional[LightningModule] = None,
**kwargs,
):
if fn_name in self._on:
self._handle(trainer=trainer, pl_module=pl_module)

def on_init_end(self, trainer: Trainer):
if "init_end" in self._on:
self._handle(trainer, None)
return overridden_hook

def on_fit_start(
self, trainer: Trainer, pl_module: Optional[LightningModule] = None
):
if "fit_start" in self._on:
self._handle(trainer, None)
# Set the overridden hook to all the allowed hooks in TuneCallback.
for fn_name in _allowed_hooks:
setattr(callback_cls, fn_name, generate_overridden_hook(fn_name))

def on_fit_end(self, trainer: Trainer, pl_module: Optional[LightningModule] = None):
if "fit_end" in self._on:
self._handle(trainer, None)
return callback_cls

def on_sanity_check_start(self, trainer: Trainer, pl_module: LightningModule):
if "sanity_check_start" in self._on:
self._handle(trainer, pl_module)

def on_sanity_check_end(self, trainer: Trainer, pl_module: LightningModule):
if "sanity_check_end" in self._on:
self._handle(trainer, pl_module)

def on_epoch_start(self, trainer: Trainer, pl_module: LightningModule):
if "epoch_start" in self._on:
self._handle(trainer, pl_module)

def on_epoch_end(self, trainer: Trainer, pl_module: LightningModule):
if "epoch_end" in self._on:
self._handle(trainer, pl_module)

def on_batch_start(self, trainer: Trainer, pl_module: LightningModule):
if "batch_start" in self._on:
self._handle(trainer, pl_module)

def on_validation_batch_start(
self,
trainer: Trainer,
pl_module: LightningModule,
batch,
batch_idx,
dataloader_idx,
):
if "validation_batch_start" in self._on:
self._handle(trainer, pl_module)

def on_validation_batch_end(
self,
trainer: Trainer,
pl_module: LightningModule,
outputs,
batch,
batch_idx,
dataloader_idx,
):
if "validation_batch_end" in self._on:
self._handle(trainer, pl_module)

def on_test_batch_start(
self,
trainer: Trainer,
pl_module: LightningModule,
batch,
batch_idx,
dataloader_idx,
):
if "test_batch_start" in self._on:
self._handle(trainer, pl_module)

def on_test_batch_end(
self,
trainer: Trainer,
pl_module: LightningModule,
outputs,
batch,
batch_idx,
dataloader_idx,
):
if "test_batch_end" in self._on:
self._handle(trainer, pl_module)

def on_batch_end(self, trainer: Trainer, pl_module: LightningModule):
if "batch_end" in self._on:
self._handle(trainer, pl_module)

def on_train_start(self, trainer: Trainer, pl_module: LightningModule):
if "train_start" in self._on:
self._handle(trainer, pl_module)
@_override_ptl_hooks
class TuneCallback(Callback):
"""Base class for Tune's PyTorch Lightning callbacks.

def on_train_end(self, trainer: Trainer, pl_module: LightningModule):
if "train_end" in self._on:
self._handle(trainer, pl_module)
Args:
When to trigger checkpoint creations. Must be one of
the PyTorch Lightning event hooks (less the ``on_``), e.g.
"train_batch_start", or "train_end". Defaults to "validation_end"
"""

def on_validation_start(self, trainer: Trainer, pl_module: LightningModule):
if "validation_start" in self._on:
self._handle(trainer, pl_module)
def __init__(self, on: Union[str, List[str]] = "validation_end"):
if not isinstance(on, list):
on = [on]

def on_validation_end(self, trainer: Trainer, pl_module: LightningModule):
if "validation_end" in self._on:
self._handle(trainer, pl_module)
for hook in on:
if f"on_{hook}" not in _allowed_hooks:
raise ValueError(
f"Invalid hook selected: {hook}. Must be one of "
f"{_allowed_hooks}"
)

def on_test_start(self, trainer: Trainer, pl_module: LightningModule):
if "test_start" in self._on:
self._handle(trainer, pl_module)
# Add back the "on_" prefix for internal consistency.
on = [f"on_{hook}" for hook in on]

def on_test_end(self, trainer: Trainer, pl_module: LightningModule):
if "test_end" in self._on:
self._handle(trainer, pl_module)
self._on = on

def on_keyboard_interrupt(self, trainer: Trainer, pl_module: LightningModule):
if "keyboard_interrupt" in self._on:
self._handle(trainer, pl_module)
def _handle(self, trainer: Trainer, pl_module: Optional[LightningModule]):
raise NotImplementedError


@PublicAPI
class TuneReportCallback(TuneCallback):
"""PyTorch Lightning to Ray Tune reporting callback

Expand All @@ -180,7 +85,7 @@ class TuneReportCallback(TuneCallback):
value will be the metric key reported to PyTorch Lightning.
on: When to trigger checkpoint creations. Must be one of
the PyTorch Lightning event hooks (less the ``on_``), e.g.
"batch_start", or "train_end". Defaults to "validation_end".
"train_batch_start", or "train_end". Defaults to "validation_end".

Example:

Expand All @@ -205,7 +110,7 @@ def __init__(
metrics: Optional[Union[str, List[str], Dict[str, str]]] = None,
on: Union[str, List[str]] = "validation_end",
):
super(TuneReportCallback, self).__init__(on)
super(TuneReportCallback, self).__init__(on=on)
if isinstance(metrics, str):
metrics = [metrics]
self._metrics = metrics
Expand Down Expand Up @@ -253,7 +158,7 @@ class _TuneCheckpointCallback(TuneCallback):
directory. Defaults to "checkpoint".
on: When to trigger checkpoint creations. Must be one of
the PyTorch Lightning event hooks (less the ``on_``), e.g.
"batch_start", or "train_end". Defaults to "validation_end".
"train_batch_start", or "train_end". Defaults to "validation_end".


"""
Expand All @@ -272,6 +177,7 @@ def _handle(self, trainer: Trainer, pl_module: LightningModule):
trainer.save_checkpoint(os.path.join(checkpoint_dir, self._filename))


@PublicAPI
class TuneReportCheckpointCallback(TuneCallback):
"""PyTorch Lightning report and checkpoint callback

Expand All @@ -288,7 +194,7 @@ class TuneReportCheckpointCallback(TuneCallback):
directory. Defaults to "checkpoint".
on: When to trigger checkpoint creations. Must be one of
the PyTorch Lightning event hooks (less the ``on_``), e.g.
"batch_start", or "train_end". Defaults to "validation_end".
"train_batch_start", or "train_end". Defaults to "validation_end".


Example:
Expand Down