Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Tune] Fix PTL tutorial docs #19999

Merged
merged 1 commit into from
Nov 4, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 18 additions & 23 deletions doc/source/tune/_tutorials/tune-pytorch-lightning.rst
Original file line number Diff line number Diff line change
Expand Up @@ -68,15 +68,10 @@ Lastly, we added a new metric, the validation accuracy, to the logs.
And that's it! You can now run ``train_mnist(config)`` to train the classifier, e.g.
like so:

.. code-block:: python

config = {
"layer_1_size": 128,
"layer_2_size": 256,
"lr": 1e-3,
"batch_size": 64
}
train_mnist(config)
.. literalinclude:: /../../python/ray/tune/examples/mnist_pytorch_lightning.py
:language: python
:start-after: __no_tune_train_begin__
:end-before: __no_tune_train_end__

Tuning the model parameters
---------------------------
Expand Down Expand Up @@ -105,13 +100,12 @@ callback for multiple modules.
Ray Tune comes with ready-to-use PyTorch Lightning callbacks. To report metrics
back to Tune after each validation epoch, we will use the ``TuneReportCallback``:

.. code-block:: python

from ray.tune.integration.pytorch_lightning import TuneReportCallback
callback = TuneReportCallback({
"loss": "avg_val_loss",
"mean_accuracy": "avg_val_accuracy"
}, on="validation_end")
.. literalinclude:: /../../python/ray/tune/examples/mnist_pytorch_lightning.py
:language: python
:start-after: __tune_train_begin__
:end-before: __tune_train_end__
:lines: 12-17
:dedent: 12

This callback will take the ``avg_val_loss`` and ``avg_val_accuracy`` values
from the PyTorch Lightning trainer and report them to Tune as the ``loss``
Expand All @@ -135,6 +129,8 @@ TensorBoard, one time for Tune's logs, and another time for PyTorch Lightning's
:language: python
:start-after: __tune_train_begin__
:end-before: __tune_train_end__
:lines: 2-8
:dedent: 4


Configuring the search space
Expand Down Expand Up @@ -283,13 +279,12 @@ another callback to save model checkpoints. Since Tune requires a call to
``tune.report()`` after creating a new checkpoint to register it, we will use
a combined reporting and checkpointing callback:

.. code-block:: python

from ray.tune.integration.pytorch_lightning import TuneReportCheckpointCallback
callback = TuneReportCheckpointCallback(
metrics={"loss": "val_loss", "mean_accuracy": "val_accuracy"},
filename="checkpoint",
on="validation_end")
.. literalinclude:: /../../python/ray/tune/examples/mnist_pytorch_lightning.py
:language: python
:start-after: __tune_train_checkpoint_begin__
:end-before: __tune_train_checkpoint_end__
:lines: 15-21
:dedent: 12

The ``checkpoint`` value is the name of the checkpoint file within the
checkpoint directory.
Expand Down
12 changes: 11 additions & 1 deletion python/ray/tune/examples/mnist_pytorch_lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

# __import_tune_begin__
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.utilities.cloud_io import load as pl_load
from ray import tune
from ray.tune import CLIReporter
from ray.tune.schedulers import ASHAScheduler, PopulationBasedTraining
Expand Down Expand Up @@ -127,6 +126,17 @@ def train_mnist(config):
trainer.fit(model)
# __lightning_end__

# __no_tune_train_begin__
def train_mnist_no_tune():
config = {
"layer_1_size": 128,
"layer_2_size": 256,
"lr": 1e-3,
"batch_size": 64
}
train_mnist(config)
# __no_tune_train_end__


# __tune_train_begin__
def train_mnist_tune(config, num_epochs=10, num_gpus=0, data_dir="~/data"):
Expand Down