Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Tune] [Doc] Tune checkpointing and Tuner restore docfix #29411

Merged
merged 15 commits into from
Oct 27, 2022
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions doc/source/tune/api_docs/checkpointing/class-checkpointing.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
Class API Checkpointing
~~~~~~~~~~~~~~~~~~~~~~~

You can also implement checkpoint/restore using the Trainable Class API:

.. code-block:: python

class MyTrainableClass(Trainable):
justinvyu marked this conversation as resolved.
Show resolved Hide resolved
def save_checkpoint(self, tmp_checkpoint_dir):
checkpoint_path = os.path.join(tmp_checkpoint_dir, "model.pth")
torch.save(self.model.state_dict(), checkpoint_path)
return tmp_checkpoint_dir

def load_checkpoint(self, tmp_checkpoint_dir):
checkpoint_path = os.path.join(tmp_checkpoint_dir, "model.pth")
self.model.load_state_dict(torch.load(checkpoint_path))

tuner = tune.Tuner(
MyTrainableClass,
run_config=air.RunConfig(
checkpoint_config=air.CheckpointConfig(checkpoint_frequency=2)
)
)
results = tuner.fit()
35 changes: 35 additions & 0 deletions doc/source/tune/api_docs/checkpointing/function-checkpointing.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
Function API Checkpointing
~~~~~~~~~~~~~~~~~~~~~~~~~~

Many Tune features rely on checkpointing, including the usage of certain Trial Schedulers and fault tolerance.
You can save and load checkpoint in Ray Tune in the following manner:

.. code-block:: python

import time
from ray import tune
from ray.air import session
justinvyu marked this conversation as resolved.
Show resolved Hide resolved
from ray.air.checkpoint import Checkpoint

def train_func(config):
step = 0
loaded_checkpoint = session.get_checkpoint()
if loaded_checkpoint:
last_step = loaded_checkpoint.to_dict()["step"]
step = last_step + 1

for iter in range(step, 100):
time.sleep(1)

checkpoint = Checkpoint.from_dict({"step": step})
session.report({"message": "Hello world Ray Tune!"}, checkpoint=checkpoint)

tuner = tune.Tuner(train_func)
results = tuner.fit()

.. note:: ``checkpoint_frequency`` and ``checkpoint_at_end`` will not work with Function API checkpointing.

In this example, checkpoints will be saved by training iteration to ``<local_dir>/<exp_name>/trial_name/checkpoint_<step>``.

Tune also may copy or move checkpoints during the course of tuning. For this purpose,
it is important not to depend on absolute paths in the implementation of ``save``.
57 changes: 2 additions & 55 deletions doc/source/tune/api_docs/trainable.rst
Original file line number Diff line number Diff line change
Expand Up @@ -71,44 +71,9 @@ such as ``iterations_since_restore``. See :ref:`tune-autofilled-metrics` for an

print("best config: ", results.get_best_result(metric="score", mode="max").config)


.. _tune-function-checkpointing:

Function API Checkpointing
~~~~~~~~~~~~~~~~~~~~~~~~~~

Many Tune features rely on checkpointing, including the usage of certain Trial Schedulers and fault tolerance.
You can save and load checkpoint in Ray Tune in the following manner:

.. code-block:: python

import time
from ray import tune
from ray.air import session
from ray.air.checkpoint import Checkpoint

def train_func(config):
step = 0
loaded_checkpoint = session.get_checkpoint()
if loaded_checkpoint:
last_step = loaded_checkpoint.to_dict()["step"]
step = last_step + 1

for iter in range(step, 100):
time.sleep(1)

checkpoint = Checkpoint.from_dict({"step": step})
session.report({"message": "Hello world Ray Tune!"}, checkpoint=checkpoint)

tuner = tune.Tuner(train_func)
results = tuner.fit()

.. note:: ``checkpoint_frequency`` and ``checkpoint_at_end`` will not work with Function API checkpointing.

In this example, checkpoints will be saved by training iteration to ``local_dir/exp_name/trial_name/checkpoint_<step>``.

Tune also may copy or move checkpoints during the course of tuning. For this purpose,
it is important not to depend on absolute paths in the implementation of ``save``.
.. include:: checkpointing/function-checkpointing.rst
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why break this out into a separate file?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I reuse this section in the Tune "working with checkpoints" user guide, since that's where I would intuitively look for an example of how to actually checkpoint. I've commented where that is below.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you instead link it, instead of re-rendering it in two places?

Otherwise for example, the search results are going to get cluttered.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So concretely, don't break it out into a separate file, put it in the "working with checkpoints" part, and use a relative reference here linking to that "working with checkpoints" section


.. _tune-class-api:

Expand Down Expand Up @@ -164,25 +129,7 @@ See :ref:`tune-autofilled-metrics` for an explanation/glossary of these values.

.. _tune-trainable-save-restore:

Class API Checkpointing
~~~~~~~~~~~~~~~~~~~~~~~

You can also implement checkpoint/restore using the Trainable Class API:

.. code-block:: python

class MyTrainableClass(Trainable):
def save_checkpoint(self, tmp_checkpoint_dir):
checkpoint_path = os.path.join(tmp_checkpoint_dir, "model.pth")
torch.save(self.model.state_dict(), checkpoint_path)
return tmp_checkpoint_dir

def load_checkpoint(self, tmp_checkpoint_dir):
checkpoint_path = os.path.join(tmp_checkpoint_dir, "model.pth")
self.model.load_state_dict(torch.load(checkpoint_path))

tuner = tune.Tuner(MyTrainableClass, run_config=air.RunConfig(checkpoint_config=air.CheckpointConfig(checkpoint_frequency=2)))
results = tuner.fit()
.. include:: checkpointing/class-checkpointing.rst

You can checkpoint with three different mechanisms: manually, periodically, and at termination.

Expand Down
113 changes: 67 additions & 46 deletions doc/source/tune/tutorials/tune-checkpoints.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,30 @@ Commonly, this includes the model and optimizer states. This is useful mostly fo
the meantime. This only makes sense if the trials can then continue training from the latest state.
- The checkpoint can be later used for other downstream tasks like batch inference.

Everything that is reported by ``session.report()`` is a trial-level checkpoint.
See :ref:`here for more information on saving checkpoints <air-checkpoint-ref>`.
Everything that is saved by ``session.report()`` (if using the Function API) or
``Trainable.save_checkpoint`` (if using the Class API) is a **trial-level checkpoint.**
See below for examples of saving and loading trial-level checkpoints.


How do I save and load trial checkpoints?
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here are the examples again within the Tune user guide.

-----------------------------------------

.. include:: ../api_docs/checkpointing/function-checkpointing.rst

.. include:: ../api_docs/checkpointing/class-checkpointing.rst

See :ref:`here for more information on creating checkpoints <air-checkpoint-ref>`.
If using framework-specific trainers from Ray AIR, see :ref:`here <air-trainer-ref>` for
references to framework-specific checkpoints such as `TensorflowCheckpoint`.

.. _tune-checkpoint-syncing:

Checkpointing and synchronization
---------------------------------

This topic is mostly relevant to Trial checkpoint.
.. note::

This topic is relevant to trial checkpoints.

Tune stores checkpoints on the node where the trials are executed. If you are training on more than one node,
this means that some trial checkpoints may be on the head node and others are not.
Expand Down Expand Up @@ -108,7 +123,9 @@ This will automatically store both the experiment state and the trial checkpoint
name="experiment_name",
sync_config=tune.SyncConfig(
upload_dir="s3://bucket-name/sub-path/"
)))
)
)
)
tuner.fit()

We don't have to provide a ``syncer`` here as it will be automatically detected. However, you can provide
Expand All @@ -126,7 +143,8 @@ a string if you want to use a custom command:
sync_config=tune.SyncConfig(
upload_dir="s3://bucket-name/sub-path/",
syncer="aws s3 sync {source} {target}", # Custom sync command
)),
)
)
)
tuner.fit()

Expand Down Expand Up @@ -191,7 +209,8 @@ Alternatively, a function can be provided with the following signature:
syncer=custom_sync_func,
sync_period=60 # Synchronize more often
)
))
)
)
results = tuner.fit()

When syncing results back to the driver, the source would be a path similar to
Expand Down Expand Up @@ -230,11 +249,13 @@ Your ``my_trainable`` is either a:

2. **Custom training function**

* All this means is that your function needs to take care of saving and loading from checkpoint.
For saving, this is done through ``session.report()`` API, which can take in a ``Checkpoint`` object.
For loading, your function can access existing checkpoint through ``Session.get_checkpoint()`` API.
See :doc:`this example </tune/examples/includes/custom_func_checkpointing>`,
it's quite simple to do.
All this means is that your function needs to take care of saving and loading from checkpoint.

* For saving, this is done through :meth:`session.report() <ray.air.session.report>` API, which can take in a ``Checkpoint`` object.

* For loading, your function can access an existing checkpoint through the :meth:`session.get_checkpoint() <ray.air.session.get_checkpoint>` API.

* See :doc:`this example </tune/examples/includes/custom_func_checkpointing>` for reference.

Let's assume for this example you're running this script from your laptop, and connecting to your remote Ray cluster
via ``ray.init()``, making your script on your laptop the "driver".
Expand All @@ -247,31 +268,29 @@ via ``ray.init()``, making your script on your laptop the "driver".

ray.init(address="<cluster-IP>:<port>") # set `address=None` to train on laptop

# configure how checkpoints are sync'd to the scheduler/sampler
# we recommend cloud storage checkpointing as it survives the cluster when
# instances are terminated, and has better performance
# Configure how checkpoints are sync'd to the scheduler/sampler
# We recommend cloud storage checkpointing as it survives the cluster when
# instances are terminated and has better performance
sync_config = tune.SyncConfig(
upload_dir="s3://my-checkpoints-bucket/path/", # requires AWS credentials
)

# this starts the run!
# This starts the run!
tuner = tune.Tuner(
my_trainable,
run_config=air.RunConfig(
# name of your experiment
# if this experiment exists, we will resume from the last run
# as specified by
# Name of your experiment
name="my-tune-exp",
# a directory where results are stored before being
# Directory where each node's results are stored before being
# sync'd to head node/cloud storage
local_dir="/tmp/mypath",
# see above! we will sync our checkpoints to S3 directory
# See above! we will sync our checkpoints to S3 directory
sync_config=sync_config,
checkpoint_config=air.CheckpointConfig(
# we'll keep the best five checkpoints at all times
# We'll keep the best five checkpoints at all times
# checkpoints (by AUC score, reported by the trainable, descending)
checkpoint_score_attr="max-auc",
keep_checkpoints_num=5,
checkpoint_score_attribute="max-auc",
num_to_keep=5,
),
),
)
Expand All @@ -281,13 +300,24 @@ In this example, checkpoints will be saved:

* **Locally**: not saved! Nothing will be sync'd to the driver (your laptop) automatically (because cloud syncing is enabled)
* **S3**: ``s3://my-checkpoints-bucket/path/my-tune-exp/<trial_name>/checkpoint_<step>``
* **On head node**: ``~/ray-results/my-tune-exp/<trial_name>/checkpoint_<step>`` (but only for trials done on that node)
* **On workers nodes**: ``~/ray-results/my-tune-exp/<trial_name>/checkpoint_<step>`` (but only for trials done on that node)
* **On head node**: ``/tmp/mypath/my-tune-exp/<trial_name>/checkpoint_<step>`` (but only for trials done on that node)
* **On workers nodes**: ``/tmp/mypath/my-tune-exp/<trial_name>/checkpoint_<step>`` (but only for trials done on that node)

If this run stopped for any reason (finished, errored, user CTRL+C), you can restart it any time using experiment checkpoints saved in the cloud:

.. code-block:: python

tuner = Tuner.restore(
"s3://my-checkpoints-bucket/path/my-tune-exp",
justinvyu marked this conversation as resolved.
Show resolved Hide resolved
resume_errored=True
)
tuner.fit()


If your run stopped for any reason (finished, errored, user CTRL+C), you can restart it any time by
``tuner=Tuner.restore(experiment_checkpoint_dir).fit()``.
There are a few options for restoring an experiment:
"resume_unfinished", "resume_errored" and "restart_errored". See ``Tuner.restore()`` for more details.
``resume_unfinished``, ``resume_errored`` and ``restart_errored``.
Please see the documentation of
:meth:`Tuner.restore() <ray.tune.tuner.Tuner.restore>` for more details.

.. _rsync-checkpointing:

Expand All @@ -298,7 +328,7 @@ Local or rsync checkpointing can be a good option if:

1. You want to tune on a single laptop Ray cluster
2. You aren't using Ray on Kubernetes (rsync doesn't work with Ray on Kubernetes)
3. You don't want to use S3
3. You don't want to cloud storage (i.e. S3)
justinvyu marked this conversation as resolved.
Show resolved Hide resolved

Let's take a look at an example:

Expand All @@ -310,32 +340,23 @@ Let's take a look at an example:

ray.init(address="<cluster-IP>:<port>") # set `address=None` to train on laptop

# configure how checkpoints are sync'd to the scheduler/sampler
sync_config = tune.syncConfig() # the default mode is to use use rsync
# Configure how checkpoints are sync'd to the scheduler/sampler
sync_config = tune.SyncConfig() # the default mode is to use use rsync

# this starts the run!
# This starts the run!
tuner = tune.Tuner(
my_trainable,

run_config=air.RunConfig(
# name of your experiment
# If the experiment with the same name is already run,
# Tuner willl resume from the last run specified by sync_config(if one exists).
# Otherwise, will start a new run.
name="my-tune-exp",
# a directory where results are stored before being
# sync'd to head node/cloud storage
local_dir="/tmp/mypath",
# sync our checkpoints via rsync
# you don't have to pass an empty sync config - but we
# Sync our checkpoints via rsync
# You don't have to pass an empty sync config - but we
# do it here for clarity and comparison
sync_config=sync_config,
checkpoint_config=air.CheckpointConfig(
# we'll keep the best five checkpoints at all times
# checkpoints (by AUC score, reported by the trainable, descending)
checkpoint_score_attr="max-auc",
keep_checkpoints_num=5,
)
checkpoint_score_attribute="max-auc",
num_to_keep=5,
),
)
)

Expand Down
4 changes: 1 addition & 3 deletions doc/source/tune/tutorials/tune-distributed.rst
Original file line number Diff line number Diff line change
Expand Up @@ -225,9 +225,7 @@ If the trial/actor is placed on a different node, Tune will automatically push t
Recovering From Failures
~~~~~~~~~~~~~~~~~~~~~~~~

Tune automatically persists the progress of your entire experiment (a ``Tuner.fit()`` session), so if an experiment crashes or is otherwise cancelled, it can be resumed through ``Tuner.restore()``.
There are a few options for restoring an experiment:
"resume_unfinished", "resume_errored" and "restart_errored". See ``Tuner.restore()`` for more details.
Tune automatically persists the progress of your entire experiment (a ``Tuner.fit()`` session), so if an experiment crashes or is otherwise cancelled, it can be resumed through :meth:`Tuner.restore() <ray.tune.tuner.Tuner.restore>`.

.. _tune-distributed-common:

Expand Down
25 changes: 13 additions & 12 deletions doc/source/tune/tutorials/tune-stopping.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,16 @@ If you've stopped a run and and want to resume from where you left off,
you can then call ``Tuner.restore()`` like this:

.. code-block:: python
:emphasize-lines: 4

tuner = Tuner.restore(
path="~/ray_results/my_experiment"
)
tuner.fit()

There are a few options for resuming an experiment:
"resume_unfinished", "resume_errored" and "restart_errored". See ``Tuner.restore()`` for more details.
There are a few options for restoring an experiment:
``resume_unfinished``, ``resume_errored`` and ``restart_errored``.
Please see the documentation of
:meth:`Tuner.restore() <ray.tune.tuner.Tuner.restore>` for more details.

``path`` here is determined by the ``air.RunConfig.name`` you supplied to your ``Tuner()``.
If you didn't supply name to ``Tuner``, it is likely that your ``path`` looks something like:
Expand All @@ -48,18 +49,18 @@ of your original tuning run:
Number of trials: 1/1 (1 RUNNING)

What's happening under the hood?
--------------------------------
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

:ref:`Here <tune-two-types-of-ckpt>`, we describe the two types of Tune checkpoints:
experiment-level and trial-level checkpoints.

Upon resuming an interrupted/errored Tune run:

:ref:`Here <tune-two-types-of-ckpt>` we talked about two types of Tune checkpoints.
Both checkpoints come into play when resuming a Tune run.
#. Tune first looks at the experiment-level checkpoint to find the list of trials at the time of the interruption.

When resuming an interrupted/errored Tune run, Tune first looks at the experiment-level checkpoint
to find the list of trials at the time of the interruption. Ray Tune then locates the trial-level
checkpoint of each trial.
#. Tune then locates and restores from the trial-level checkpoint of each trial.

Depending on the specified resume option
("resume_unfinished", "resume_errored", "restart_errored"), Ray Tune then decides whether to
restore a given non-finished trial from its latest available checkpoint or start from scratch.
#. Depending on the specified resume option (``resume_unfinished``, ``resume_errored``, ``restart_errored``), Tune decides whether to restore a given unfinished trial from its latest available checkpoint or to start from scratch.

.. _tune-stopping-ref:

Expand Down
Loading