Skip to content

Commit

Permalink
[Tune] [Doc] Tune checkpointing and Tuner restore docfix (ray-project…
Browse files Browse the repository at this point in the history
…#29411)

Signed-off-by: Weichen Xu <[email protected]>
  • Loading branch information
justinvyu authored and WeichenXu123 committed Dec 19, 2022
1 parent 68bec1b commit 1d16a62
Show file tree
Hide file tree
Showing 6 changed files with 306 additions and 182 deletions.
170 changes: 60 additions & 110 deletions doc/source/tune/api_docs/trainable.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,109 +7,74 @@
Training (tune.Trainable, session.report)
==========================================

Training can be done with either a **Function API** (``session.report``) or **Class API** (``tune.Trainable``).
Training can be done with either a **Function API** (:ref:`session.report <tune-function-docstring>`) or **Class API** (:ref:`tune.Trainable <tune-trainable-docstring>`).

For the sake of example, let's maximize this objective function:

.. code-block:: python
def objective(x, a, b):
return a * (x ** 0.5) + b
.. literalinclude:: /tune/doc_code/trainable.py
:language: python
:start-after: __example_objective_start__
:end-before: __example_objective_end__

.. _tune-function-api:

Function API
------------

With the Function API, you can report intermediate metrics by simply calling ``session.report`` within the provided function.
The Function API allows you to define a custom training function that Tune will run in parallel Ray actor processes,
one for each Tune trial.

With the Function API, you can report intermediate metrics by simply calling ``session.report`` within the function.

.. code-block:: python
.. literalinclude:: /tune/doc_code/trainable.py
:language: python
:start-after: __function_api_report_intermediate_metrics_start__
:end-before: __function_api_report_intermediate_metrics_end__

from ray import tune
from ray.air import session
.. tip:: Do not use ``session.report`` within a ``Trainable`` class.

def trainable(config):
# config (dict): A dict of hyperparameters.
In the previous example, we reported on every step, but this metric reporting frequency
is configurable. For example, we could also report only a single time at the end with the final score:

for x in range(20):
intermediate_score = objective(x, config["a"], config["b"])
.. literalinclude:: /tune/doc_code/trainable.py
:language: python
:start-after: __function_api_report_final_metrics_start__
:end-before: __function_api_report_final_metrics_end__

session.report({"score": intermediate_score}) # This sends the score to Tune.
It's also possible to return a final set of metrics to Tune by returning them from your function:

tuner = tune.Tuner(
trainable,
param_space={"a": 2, "b": 4}
)
results = tuner.fit()
print("best config: ", results.get_best_result(metric="score", mode="max").config)
.. tip:: Do not use ``session.report`` within a ``Trainable`` class.

Tune will run this function on a separate thread in a Ray actor process.
.. literalinclude:: /tune/doc_code/trainable.py
:language: python
:start-after: __function_api_return_final_metrics_start__
:end-before: __function_api_return_final_metrics_end__

You'll notice that Ray Tune will output extra values in addition to the user reported metrics,
such as ``iterations_since_restore``. See :ref:`tune-autofilled-metrics` for an explanation/glossary of these values.

.. code-block:: python
def trainable(config):
# config (dict): A dict of hyperparameters.
final_score = 0
for x in range(20):
final_score = objective(x, config["a"], config["b"])
return {"score": final_score} # This sends the score to Tune.
tuner = tune.Tuner(
trainable,
param_space={"a": 2, "b": 4}
)
results = tuner.fit()
print("best config: ", results.get_best_result(metric="score", mode="max").config)
.. _tune-function-checkpointing:

Function API Checkpointing
~~~~~~~~~~~~~~~~~~~~~~~~~~

Many Tune features rely on checkpointing, including the usage of certain Trial Schedulers and fault tolerance.
You can save and load checkpoint in Ray Tune in the following manner:

.. code-block:: python
import time
from ray import tune
from ray.air import session
from ray.air.checkpoint import Checkpoint
You can save and load checkpoints in Ray Tune in the following manner:

def train_func(config):
step = 0
loaded_checkpoint = session.get_checkpoint()
if loaded_checkpoint:
last_step = loaded_checkpoint.to_dict()["step"]
step = last_step + 1
for iter in range(step, 100):
time.sleep(1)
checkpoint = Checkpoint.from_dict({"step": step})
session.report({"message": "Hello world Ray Tune!"}, checkpoint=checkpoint)
tuner = tune.Tuner(train_func)
results = tuner.fit()
.. literalinclude:: /tune/doc_code/trainable.py
:language: python
:start-after: __function_api_checkpointing_start__
:end-before: __function_api_checkpointing_end__

.. note:: ``checkpoint_frequency`` and ``checkpoint_at_end`` will not work with Function API checkpointing.

In this example, checkpoints will be saved by training iteration to ``local_dir/exp_name/trial_name/checkpoint_<step>``.
In this example, checkpoints will be saved by training iteration to ``<local_dir>/<exp_name>/trial_name/checkpoint_<step>``.

Tune also may copy or move checkpoints during the course of tuning. For this purpose,
it is important not to depend on absolute paths in the implementation of ``save``.

See :ref:`here for more information on creating checkpoints <air-checkpoint-ref>`.
If using framework-specific trainers from Ray AIR, see :ref:`here <air-trainer-ref>` for
references to framework-specific checkpoints such as `TensorflowCheckpoint`.

.. _tune-class-api:

Trainable Class API
Expand All @@ -119,32 +84,10 @@ Trainable Class API

The Trainable **class API** will require users to subclass ``ray.tune.Trainable``. Here's a naive example of this API:

.. code-block:: python
from ray import tune
class Trainable(tune.Trainable):
def setup(self, config):
# config (dict): A dict of hyperparameters
self.x = 0
self.a = config["a"]
self.b = config["b"]
def step(self): # This is called iteratively.
score = objective(self.x, self.a, self.b)
self.x += 1
return {"score": score}
tuner = tune.Tuner(
Trainable,
tune_config=air.RunConfig(stop={"training_iteration": 20}),
param_space={
"a": 2,
"b": 4
})
results = tuner.fit()
print('best config: ', results.get_best_result(metric="score", mode="max").config)
.. literalinclude:: /tune/doc_code/trainable.py
:language: python
:start-after: __class_api_example_start__
:end-before: __class_api_example_end__

As a subclass of ``tune.Trainable``, Tune will create a ``Trainable`` object on a
separate process (using the :ref:`Ray Actor API <actor-guide>`).
Expand All @@ -169,20 +112,10 @@ Class API Checkpointing

You can also implement checkpoint/restore using the Trainable Class API:

.. code-block:: python
class MyTrainableClass(Trainable):
def save_checkpoint(self, tmp_checkpoint_dir):
checkpoint_path = os.path.join(tmp_checkpoint_dir, "model.pth")
torch.save(self.model.state_dict(), checkpoint_path)
return tmp_checkpoint_dir
def load_checkpoint(self, tmp_checkpoint_dir):
checkpoint_path = os.path.join(tmp_checkpoint_dir, "model.pth")
self.model.load_state_dict(torch.load(checkpoint_path))
tuner = tune.Tuner(MyTrainableClass, run_config=air.RunConfig(checkpoint_config=air.CheckpointConfig(checkpoint_frequency=2)))
results = tuner.fit()
.. literalinclude:: /tune/doc_code/trainable.py
:language: python
:start-after: __class_api_checkpointing_start__
:end-before: __class_api_checkpointing_end__

You can checkpoint with three different mechanisms: manually, periodically, and at termination.

Expand Down Expand Up @@ -278,6 +211,22 @@ It is up to the user to correctly update the hyperparameters of your trainable.
return True
Comparing the Function API and Class API
----------------------------------------

Here are a few key concepts and what they look like for the Function and Class API's.

======================= =============================================== ==============================================
Concept Function API Class API
======================= =============================================== ==============================================
Training Iteration Increments on each `session.report` call Increments on each `Trainable.step` call
Report metrics `session.report(metrics)` Return metrics from `Trainable.step`
Saving a checkpoint `session.report(..., checkpoint=checkpoint)` `Trainable.save_checkpoint`
Loading a checkpoint `session.get_checkpoint()` `Trainable.load_checkpoint`
Accessing config Passed as an argument `def train_func(config):` Passed through `Trainable.setup`
======================= =============================================== ==============================================


Advanced Resource Allocation
----------------------------

Expand Down Expand Up @@ -330,10 +279,11 @@ session (Function API)
.. autofunction:: ray.air.session.get_trial_dir
:noindex:

.. _tune-trainable-docstring:

tune.Trainable (Class API)
--------------------------


.. autoclass:: ray.tune.Trainable
:member-order: groupwise
:private-members:
Expand Down
150 changes: 150 additions & 0 deletions doc/source/tune/doc_code/trainable.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
# flake8: noqa

# __class_api_checkpointing_start__
import os
import torch
from torch import nn

from ray import air, tune


class MyTrainableClass(tune.Trainable):
def setup(self, config):
self.model = nn.Sequential(
nn.Linear(config.get("input_size", 32), 32), nn.ReLU(), nn.Linear(32, 10)
)

def step(self):
return {}

def save_checkpoint(self, tmp_checkpoint_dir):
checkpoint_path = os.path.join(tmp_checkpoint_dir, "model.pth")
torch.save(self.model.state_dict(), checkpoint_path)
return tmp_checkpoint_dir

def load_checkpoint(self, tmp_checkpoint_dir):
checkpoint_path = os.path.join(tmp_checkpoint_dir, "model.pth")
self.model.load_state_dict(torch.load(checkpoint_path))


tuner = tune.Tuner(
MyTrainableClass,
param_space={"input_size": 64},
run_config=air.RunConfig(
stop={"training_iteration": 2},
checkpoint_config=air.CheckpointConfig(checkpoint_frequency=2),
),
)
tuner.fit()
# __class_api_checkpointing_end__

# __function_api_checkpointing_start__
from ray import tune
from ray.air import session
from ray.air.checkpoint import Checkpoint


def train_func(config):
epochs = config.get("epochs", 2)
start = 0
loaded_checkpoint = session.get_checkpoint()
if loaded_checkpoint:
last_step = loaded_checkpoint.to_dict()["step"]
start = last_step + 1

for step in range(start, epochs):
# Model training here
# ...

# Report metrics and save a checkpoint
metrics = {"metric": "my_metric"}
checkpoint = Checkpoint.from_dict({"step": step})
session.report(metrics, checkpoint=checkpoint)


tuner = tune.Tuner(train_func)
results = tuner.fit()
# __function_api_checkpointing_end__

# fmt: off
# __example_objective_start__
def objective(x, a, b):
return a * (x ** 0.5) + b
# __example_objective_end__
# fmt: on

# __function_api_report_intermediate_metrics_start__
from ray import tune
from ray.air import session


def trainable(config: dict):
intermediate_score = 0
for x in range(20):
intermediate_score = objective(x, config["a"], config["b"])
session.report({"score": intermediate_score}) # This sends the score to Tune.


tuner = tune.Tuner(trainable, param_space={"a": 2, "b": 4})
results = tuner.fit()
# __function_api_report_intermediate_metrics_end__

# __function_api_report_final_metrics_start__
from ray import tune
from ray.air import session


def trainable(config: dict):
final_score = 0
for x in range(20):
final_score = objective(x, config["a"], config["b"])

session.report({"score": final_score}) # This sends the score to Tune.


tuner = tune.Tuner(trainable, param_space={"a": 2, "b": 4})
results = tuner.fit()
# __function_api_report_final_metrics_end__

# fmt: off
# __function_api_return_final_metrics_start__
def trainable(config: dict):
final_score = 0
for x in range(20):
final_score = objective(x, config["a"], config["b"])

return {"score": final_score} # This sends the score to Tune.
# __function_api_return_final_metrics_end__
# fmt: on

# __class_api_example_start__
from ray import air, tune


class Trainable(tune.Trainable):
def setup(self, config: dict):
# config (dict): A dict of hyperparameters
self.x = 0
self.a = config["a"]
self.b = config["b"]

def step(self): # This is called iteratively.
score = objective(self.x, self.a, self.b)
self.x += 1
return {"score": score}


tuner = tune.Tuner(
Trainable,
run_config=air.RunConfig(
# Train for 20 steps
stop={"training_iteration": 20},
checkpoint_config=air.CheckpointConfig(
# We haven't implemented checkpointing yet. See below!
checkpoint_at_end=False
),
),
param_space={"a": 2, "b": 4},
)
results = tuner.fit()
# __class_api_example_end__
Loading

0 comments on commit 1d16a62

Please sign in to comment.