Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Tune] Dynamically identify PyTorch Lightning Callback hooks #30045

Merged
merged 6 commits into from
Nov 8, 2022
Merged
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
208 changes: 57 additions & 151 deletions python/ray/tune/integration/pytorch_lightning.py
Original file line number Diff line number Diff line change
@@ -1,173 +1,78 @@
import inspect
import logging
from typing import Dict, List, Optional, Union
from typing import Dict, List, Optional, Type, Union

from pytorch_lightning import Callback, Trainer, LightningModule
from ray import tune
from ray.util import PublicAPI

import os

logger = logging.getLogger(__name__)

# Get all Pytorch Lightning Callback hooks based on whatever PTL version is being used.
_allowed_hooks = {
name
for name, fn in inspect.getmembers(Callback, predicate=inspect.isfunction)
if name.startswith("on_")
}

class TuneCallback(Callback):
"""Base class for Tune's PyTorch Lightning callbacks."""

_allowed = [
"init_start",
"init_end",
"fit_start",
"fit_end",
"sanity_check_start",
"sanity_check_end",
"epoch_start",
"epoch_end",
"batch_start",
"validation_batch_start",
"validation_batch_end",
"test_batch_start",
"test_batch_end",
"batch_end",
"train_start",
"train_end",
"validation_start",
"validation_end",
"test_start",
"test_end",
"keyboard_interrupt",
]

def __init__(self, on: Union[str, List[str]] = "validation_end"):
if not isinstance(on, list):
on = [on]
if any(w not in self._allowed for w in on):
raise ValueError(
"Invalid trigger time selected: {}. Must be one of {}".format(
on, self._allowed
)
)
self._on = on

def _handle(self, trainer: Trainer, pl_module: Optional[LightningModule]):
raise NotImplementedError

def on_init_start(self, trainer: Trainer):
if "init_start" in self._on:
self._handle(trainer, None)
def _override_ptl_hooks(callback_cls: Type["_TuneCallback"]) -> Type["_TuneCallback"]:
"""Overrides all allowed PTL Callback hooks with our custom handle logic."""

def on_init_end(self, trainer: Trainer):
if "init_end" in self._on:
self._handle(trainer, None)
def generate_overridden_hook(fn_name):
def overridden_hook(
self,
trainer: Trainer,
*args,
pl_module: Optional[LightningModule] = None,
**kwargs,
):
if fn_name in self._on:
self._handle(trainer=trainer, pl_module=pl_module)

def on_fit_start(
self, trainer: Trainer, pl_module: Optional[LightningModule] = None
):
if "fit_start" in self._on:
self._handle(trainer, None)

def on_fit_end(self, trainer: Trainer, pl_module: Optional[LightningModule] = None):
if "fit_end" in self._on:
self._handle(trainer, None)

def on_sanity_check_start(self, trainer: Trainer, pl_module: LightningModule):
if "sanity_check_start" in self._on:
self._handle(trainer, pl_module)
return overridden_hook

def on_sanity_check_end(self, trainer: Trainer, pl_module: LightningModule):
if "sanity_check_end" in self._on:
self._handle(trainer, pl_module)
# Set the overridden hook to all the allowed hooks in _TuneCallback.
for fn_name in _allowed_hooks:
setattr(callback_cls, fn_name, generate_overridden_hook(fn_name))

def on_epoch_start(self, trainer: Trainer, pl_module: LightningModule):
if "epoch_start" in self._on:
self._handle(trainer, pl_module)
return callback_cls

def on_epoch_end(self, trainer: Trainer, pl_module: LightningModule):
if "epoch_end" in self._on:
self._handle(trainer, pl_module)

def on_batch_start(self, trainer: Trainer, pl_module: LightningModule):
if "batch_start" in self._on:
self._handle(trainer, pl_module)

def on_validation_batch_start(
self,
trainer: Trainer,
pl_module: LightningModule,
batch,
batch_idx,
dataloader_idx,
):
if "validation_batch_start" in self._on:
self._handle(trainer, pl_module)

def on_validation_batch_end(
self,
trainer: Trainer,
pl_module: LightningModule,
outputs,
batch,
batch_idx,
dataloader_idx,
):
if "validation_batch_end" in self._on:
self._handle(trainer, pl_module)
@_override_ptl_hooks
class _TuneCallback(Callback):
"""Base class for Tune's PyTorch Lightning callbacks.

def on_test_batch_start(
self,
trainer: Trainer,
pl_module: LightningModule,
batch,
batch_idx,
dataloader_idx,
):
if "test_batch_start" in self._on:
self._handle(trainer, pl_module)

def on_test_batch_end(
self,
trainer: Trainer,
pl_module: LightningModule,
outputs,
batch,
batch_idx,
dataloader_idx,
):
if "test_batch_end" in self._on:
self._handle(trainer, pl_module)

def on_batch_end(self, trainer: Trainer, pl_module: LightningModule):
if "batch_end" in self._on:
self._handle(trainer, pl_module)

def on_train_start(self, trainer: Trainer, pl_module: LightningModule):
if "train_start" in self._on:
self._handle(trainer, pl_module)

def on_train_end(self, trainer: Trainer, pl_module: LightningModule):
if "train_end" in self._on:
self._handle(trainer, pl_module)
Args:
When to trigger checkpoint creations. Must be one of
the PyTorch Lightning event hooks (less the ``on_``), e.g.
"train_batch_start", or "train_end". Defaults to "validation_end"
"""

def on_validation_start(self, trainer: Trainer, pl_module: LightningModule):
if "validation_start" in self._on:
self._handle(trainer, pl_module)
def __init__(self, on: Union[str, List[str]] = "validation_end"):
if not isinstance(on, list):
on = [on]

def on_validation_end(self, trainer: Trainer, pl_module: LightningModule):
if "validation_end" in self._on:
self._handle(trainer, pl_module)
for hook in on:
if f"on_{hook}" not in _allowed_hooks:
raise ValueError(
f"Invalid hook selected: {hook}. Must be one of "
f"{_allowed_hooks}"
)

def on_test_start(self, trainer: Trainer, pl_module: LightningModule):
if "test_start" in self._on:
self._handle(trainer, pl_module)
# Add back the "on_" prefix for internal consistency.
on = [f"on_{hook}" for hook in on]

def on_test_end(self, trainer: Trainer, pl_module: LightningModule):
if "test_end" in self._on:
self._handle(trainer, pl_module)
self._on = on

def on_keyboard_interrupt(self, trainer: Trainer, pl_module: LightningModule):
if "keyboard_interrupt" in self._on:
self._handle(trainer, pl_module)
def _handle(self, trainer: Trainer, pl_module: Optional[LightningModule]):
raise NotImplementedError


class TuneReportCallback(TuneCallback):
@PublicAPI
class TuneReportCallback(_TuneCallback):
"""PyTorch Lightning to Ray Tune reporting callback

Reports metrics to Ray Tune.
Expand All @@ -180,7 +85,7 @@ class TuneReportCallback(TuneCallback):
value will be the metric key reported to PyTorch Lightning.
on: When to trigger checkpoint creations. Must be one of
the PyTorch Lightning event hooks (less the ``on_``), e.g.
"batch_start", or "train_end". Defaults to "validation_end".
"train_batch_start", or "train_end". Defaults to "validation_end".

Example:

Expand All @@ -205,7 +110,7 @@ def __init__(
metrics: Optional[Union[str, List[str], Dict[str, str]]] = None,
on: Union[str, List[str]] = "validation_end",
):
super(TuneReportCallback, self).__init__(on)
super(TuneReportCallback, self).__init__(on=on)
if isinstance(metrics, str):
metrics = [metrics]
self._metrics = metrics
Expand Down Expand Up @@ -239,7 +144,7 @@ def _handle(self, trainer: Trainer, pl_module: LightningModule):
tune.report(**report_dict)


class _TuneCheckpointCallback(TuneCallback):
class _TuneCheckpointCallback(_TuneCallback):
"""PyTorch Lightning checkpoint callback

Saves checkpoints after each validation step.
Expand All @@ -253,7 +158,7 @@ class _TuneCheckpointCallback(TuneCallback):
directory. Defaults to "checkpoint".
on: When to trigger checkpoint creations. Must be one of
the PyTorch Lightning event hooks (less the ``on_``), e.g.
"batch_start", or "train_end". Defaults to "validation_end".
"train_batch_start", or "train_end". Defaults to "validation_end".


"""
Expand All @@ -272,7 +177,8 @@ def _handle(self, trainer: Trainer, pl_module: LightningModule):
trainer.save_checkpoint(os.path.join(checkpoint_dir, self._filename))


class TuneReportCheckpointCallback(TuneCallback):
@PublicAPI
class TuneReportCheckpointCallback(_TuneCallback):
"""PyTorch Lightning report and checkpoint callback

Saves checkpoints after each validation step. Also reports metrics to Tune,
Expand All @@ -288,7 +194,7 @@ class TuneReportCheckpointCallback(TuneCallback):
directory. Defaults to "checkpoint".
on: When to trigger checkpoint creations. Must be one of
the PyTorch Lightning event hooks (less the ``on_``), e.g.
"batch_start", or "train_end". Defaults to "validation_end".
"train_batch_start", or "train_end". Defaults to "validation_end".


Example:
Expand Down