Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Tune] [Doc] Tune checkpointing and Tuner restore docfix #29411

Merged
merged 15 commits into from
Oct 27, 2022
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
Class API Checkpointing
~~~~~~~~~~~~~~~~~~~~~~~

You can also implement checkpoint/restore using the Trainable Class API:

.. literalinclude:: /tune/doc_code/trainable.py
:language: python
:start-after: __class_api_checkpointing_start__
:end-before: __class_api_checkpointing_end__
17 changes: 17 additions & 0 deletions doc/source/tune/api_docs/checkpointing/function-checkpointing.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
Function API Checkpointing
~~~~~~~~~~~~~~~~~~~~~~~~~~

Many Tune features rely on checkpointing, including the usage of certain Trial Schedulers and fault tolerance.
You can save and load checkpoints in Ray Tune in the following manner:

.. literalinclude:: /tune/doc_code/trainable.py
:language: python
:start-after: __function_api_checkpointing_start__
:end-before: __function_api_checkpointing_end__

.. note:: ``checkpoint_frequency`` and ``checkpoint_at_end`` will not work with Function API checkpointing.

In this example, checkpoints will be saved by training iteration to ``<local_dir>/<exp_name>/trial_name/checkpoint_<step>``.

Tune also may copy or move checkpoints during the course of tuning. For this purpose,
it is important not to depend on absolute paths in the implementation of ``save``.
148 changes: 23 additions & 125 deletions doc/source/tune/api_docs/trainable.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@
Training (tune.Trainable, session.report)
==========================================

Training can be done with either a **Class API** (``tune.Trainable``) or **function API** (``session.report``).
Training can be done with either a **Class API** (:ref:`tune.Trainable <tune-trainable-docstring>`) or **function API** (:ref:`session.report <tune-function-docstring>`).

For the sake of example, let's maximize this objective function:

.. code-block:: python

def objective(x, a, b):
return a * (x ** 0.5) + b
.. literalinclude:: /tune/doc_code/trainable.py
:language: python
:start-after: __example_objective_start__
:end-before: __example_objective_end__

.. _tune-function-api:

Expand All @@ -23,27 +23,10 @@ Function API

With the Function API, you can report intermediate metrics by simply calling ``session.report`` within the provided function.


.. code-block:: python

from ray import tune
from ray.air import session

def trainable(config):
# config (dict): A dict of hyperparameters.

for x in range(20):
intermediate_score = objective(x, config["a"], config["b"])

session.report({"score": intermediate_score}) # This sends the score to Tune.

tuner = tune.Tuner(
trainable,
param_space={"a": 2, "b": 4}
)
results = tuner.fit()

print("best config: ", results.get_best_result(metric="score", mode="max").config)
.. literalinclude:: /tune/doc_code/trainable.py
:language: python
:start-after: __function_api_report_intermediate_metrics_start__
:end-before: __function_api_report_intermediate_metrics_end__

.. tip:: Do not use ``session.report`` within a ``Trainable`` class.

Expand All @@ -52,63 +35,17 @@ Tune will run this function on a separate thread in a Ray actor process.
You'll notice that Ray Tune will output extra values in addition to the user reported metrics,
such as ``iterations_since_restore``. See :ref:`tune-autofilled-metrics` for an explanation/glossary of these values.

.. code-block:: python

def trainable(config):
# config (dict): A dict of hyperparameters.

final_score = 0
for x in range(20):
final_score = objective(x, config["a"], config["b"])

return {"score": final_score} # This sends the score to Tune.

tuner = tune.Tuner(
trainable,
param_space={"a": 2, "b": 4}
)
results = tuner.fit()

print("best config: ", results.get_best_result(metric="score", mode="max").config)
In the previous example, we reported on every step, but this metric reporting frequency
is configurable. For example, we could also report only a single time at the end with the final score:

.. literalinclude:: /tune/doc_code/trainable.py
:language: python
:start-after: __function_api_report_final_metrics_start__
:end-before: __function_api_report_final_metrics_end__

.. _tune-function-checkpointing:

Function API Checkpointing
~~~~~~~~~~~~~~~~~~~~~~~~~~

Many Tune features rely on checkpointing, including the usage of certain Trial Schedulers and fault tolerance.
You can save and load checkpoint in Ray Tune in the following manner:

.. code-block:: python

import time
from ray import tune
from ray.air import session
from ray.air.checkpoint import Checkpoint

def train_func(config):
step = 0
loaded_checkpoint = session.get_checkpoint()
if loaded_checkpoint:
last_step = loaded_checkpoint.to_dict()["step"]
step = last_step + 1

for iter in range(step, 100):
time.sleep(1)

checkpoint = Checkpoint.from_dict({"step": step})
session.report({"message": "Hello world Ray Tune!"}, checkpoint=checkpoint)

tuner = tune.Tuner(train_func)
results = tuner.fit()

.. note:: ``checkpoint_frequency`` and ``checkpoint_at_end`` will not work with Function API checkpointing.

In this example, checkpoints will be saved by training iteration to ``local_dir/exp_name/trial_name/checkpoint_<step>``.

Tune also may copy or move checkpoints during the course of tuning. For this purpose,
it is important not to depend on absolute paths in the implementation of ``save``.
.. include:: checkpointing/function-checkpointing.rst
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why break this out into a separate file?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I reuse this section in the Tune "working with checkpoints" user guide, since that's where I would intuitively look for an example of how to actually checkpoint. I've commented where that is below.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you instead link it, instead of re-rendering it in two places?

Otherwise for example, the search results are going to get cluttered.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So concretely, don't break it out into a separate file, put it in the "working with checkpoints" part, and use a relative reference here linking to that "working with checkpoints" section


.. _tune-class-api:

Expand All @@ -119,32 +56,10 @@ Trainable Class API

The Trainable **class API** will require users to subclass ``ray.tune.Trainable``. Here's a naive example of this API:

.. code-block:: python

from ray import tune

class Trainable(tune.Trainable):
def setup(self, config):
# config (dict): A dict of hyperparameters
self.x = 0
self.a = config["a"]
self.b = config["b"]

def step(self): # This is called iteratively.
score = objective(self.x, self.a, self.b)
self.x += 1
return {"score": score}

tuner = tune.Tuner(
Trainable,
tune_config=air.RunConfig(stop={"training_iteration": 20}),
param_space={
"a": 2,
"b": 4
})
results = tuner.fit()

print('best config: ', results.get_best_result(metric="score", mode="max").config)
.. literalinclude:: /tune/doc_code/trainable.py
:language: python
:start-after: __class_api_example_start__
:end-before: __class_api_example_end__

As a subclass of ``tune.Trainable``, Tune will create a ``Trainable`` object on a
separate process (using the :ref:`Ray Actor API <actor-guide>`).
Expand All @@ -164,25 +79,7 @@ See :ref:`tune-autofilled-metrics` for an explanation/glossary of these values.

.. _tune-trainable-save-restore:

Class API Checkpointing
~~~~~~~~~~~~~~~~~~~~~~~

You can also implement checkpoint/restore using the Trainable Class API:

.. code-block:: python

class MyTrainableClass(Trainable):
def save_checkpoint(self, tmp_checkpoint_dir):
checkpoint_path = os.path.join(tmp_checkpoint_dir, "model.pth")
torch.save(self.model.state_dict(), checkpoint_path)
return tmp_checkpoint_dir

def load_checkpoint(self, tmp_checkpoint_dir):
checkpoint_path = os.path.join(tmp_checkpoint_dir, "model.pth")
self.model.load_state_dict(torch.load(checkpoint_path))

tuner = tune.Tuner(MyTrainableClass, run_config=air.RunConfig(checkpoint_config=air.CheckpointConfig(checkpoint_frequency=2)))
results = tuner.fit()
.. include:: checkpointing/class-checkpointing.rst

You can checkpoint with three different mechanisms: manually, periodically, and at termination.

Expand Down Expand Up @@ -327,10 +224,11 @@ session (Function API)
.. autofunction:: ray.air.session.get_trial_resources
:noindex:

.. _tune-trainable-docstring:

tune.Trainable (Class API)
--------------------------


.. autoclass:: ray.tune.Trainable
:member-order: groupwise
:private-members:
Expand Down
138 changes: 138 additions & 0 deletions doc/source/tune/doc_code/trainable.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
# __class_api_checkpointing_start__
import os
import torch
from torch import nn

from ray import air, tune


class MyTrainableClass(tune.Trainable):
def setup(self, config):
self.model = nn.Sequential(
nn.Linear(config.get("input_size", 32), 32), nn.ReLU(), nn.Linear(32, 10)
)

def step(self):
return {}

def save_checkpoint(self, tmp_checkpoint_dir):
checkpoint_path = os.path.join(tmp_checkpoint_dir, "model.pth")
torch.save(self.model.state_dict(), checkpoint_path)
return tmp_checkpoint_dir

def load_checkpoint(self, tmp_checkpoint_dir):
checkpoint_path = os.path.join(tmp_checkpoint_dir, "model.pth")
self.model.load_state_dict(torch.load(checkpoint_path))


tuner = tune.Tuner(
MyTrainableClass,
param_space={"input_size": 64},
run_config=air.RunConfig(
stop={"training_iteration": 2},
checkpoint_config=air.CheckpointConfig(checkpoint_frequency=2),
),
)
tuner.fit()
# __class_api_checkpointing_end__

# __function_api_checkpointing_start__
from ray import tune
from ray.air import session
from ray.air.checkpoint import Checkpoint


def train_func(config):
epochs = config.get("epochs", 2)
start = 0
loaded_checkpoint = session.get_checkpoint()
if loaded_checkpoint:
last_step = loaded_checkpoint.to_dict()["step"]
start = last_step + 1

for step in range(start, epochs):
# Model training here
# ...

# Report metrics and save a checkpoint
metrics = {"metric": "my_metric"}
checkpoint = Checkpoint.from_dict({"step": step})
session.report(metrics, checkpoint=checkpoint)


tuner = tune.Tuner(train_func)
results = tuner.fit()
# __function_api_checkpointing_end__


# __example_objective_start__
def objective(x, a, b):
return a * (x ** 0.5) + b


# __example_objective_end__

# __function_api_report_intermediate_metrics_start__
from ray import tune
from ray.air import session


def trainable(config: dict):
intermediate_score = 0
for x in range(20):
intermediate_score = objective(x, config["a"], config["b"])
session.report({"score": intermediate_score}) # This sends the score to Tune.


tuner = tune.Tuner(trainable, param_space={"a": 2, "b": 4})
results = tuner.fit()
# __function_api_report_intermediate_metrics_end__

# __function_api_report_final_metrics_start__
from ray import tune
from ray.air import session


def trainable(config: dict):
final_score = 0
for x in range(20):
final_score = objective(x, config["a"], config["b"])

session.report({"score": final_score}) # This sends the score to Tune.


tuner = tune.Tuner(trainable, param_space={"a": 2, "b": 4})
results = tuner.fit()
# __function_api_report_final_metrics_end__

# __class_api_example_start__
from ray import air, tune


class Trainable(tune.Trainable):
def setup(self, config: dict):
# config (dict): A dict of hyperparameters
self.x = 0
self.a = config["a"]
self.b = config["b"]

def step(self): # This is called iteratively.
score = objective(self.x, self.a, self.b)
self.x += 1
return {"score": score}


tuner = tune.Tuner(
Trainable,
run_config=air.RunConfig(
# Train for 20 steps
stop={"training_iteration": 20},
checkpoint_config=air.CheckpointConfig(
# We haven't implemented checkpointing yet. See below!
checkpoint_at_end=False
),
),
param_space={"a": 2, "b": 4},
)
results = tuner.fit()
# __class_api_example_end__
Loading