From bdc8ea71801359a815d28f60d7b62afb19c998ff Mon Sep 17 00:00:00 2001 From: Justin Yu Date: Tue, 27 Sep 2022 14:33:11 -0700 Subject: [PATCH 01/11] Clarify Tuner() vs. Tuner.restore behavior and provide more API ref links Signed-off-by: Justin Yu --- .../tune/tutorials/tune-checkpoints.rst | 82 ++++++++++--------- .../tune/tutorials/tune-distributed.rst | 4 +- doc/source/tune/tutorials/tune-stopping.rst | 25 +++--- 3 files changed, 58 insertions(+), 53 deletions(-) diff --git a/doc/source/tune/tutorials/tune-checkpoints.rst b/doc/source/tune/tutorials/tune-checkpoints.rst index 305da90637fe..faa9ac1401b0 100644 --- a/doc/source/tune/tutorials/tune-checkpoints.rst +++ b/doc/source/tune/tutorials/tune-checkpoints.rst @@ -108,7 +108,9 @@ This will automatically store both the experiment state and the trial checkpoint name="experiment_name", sync_config=tune.SyncConfig( upload_dir="s3://bucket-name/sub-path/" - ))) + ) + ) + ) tuner.fit() We don't have to provide a ``syncer`` here as it will be automatically detected. However, you can provide @@ -126,7 +128,8 @@ a string if you want to use a custom command: sync_config=tune.SyncConfig( upload_dir="s3://bucket-name/sub-path/", syncer="aws s3 sync {source} {target}", # Custom sync command - )), + ) + ) ) tuner.fit() @@ -191,7 +194,8 @@ Alternatively, a function can be provided with the following signature: syncer=custom_sync_func, sync_period=60 # Synchronize more often ) - )) + ) + ) results = tuner.fit() When syncing results back to the driver, the source would be a path similar to @@ -230,11 +234,13 @@ Your ``my_trainable`` is either a: 2. **Custom training function** - * All this means is that your function needs to take care of saving and loading from checkpoint. - For saving, this is done through ``session.report()`` API, which can take in a ``Checkpoint`` object. - For loading, your function can access existing checkpoint through ``Session.get_checkpoint()`` API. - See :doc:`this example `, - it's quite simple to do. + All this means is that your function needs to take care of saving and loading from checkpoint. + + * For saving, this is done through :meth:`session.report() ` API, which can take in a ``Checkpoint`` object. + + * For loading, your function can access an existing checkpoint through the :meth:`session.get_checkpoint() ` API. + + * See :doc:`this example ` for reference. Let's assume for this example you're running this script from your laptop, and connecting to your remote Ray cluster via ``ray.init()``, making your script on your laptop the "driver". @@ -247,28 +253,26 @@ via ``ray.init()``, making your script on your laptop the "driver". ray.init(address=":") # set `address=None` to train on laptop - # configure how checkpoints are sync'd to the scheduler/sampler - # we recommend cloud storage checkpointing as it survives the cluster when - # instances are terminated, and has better performance + # Configure how checkpoints are sync'd to the scheduler/sampler + # We recommend cloud storage checkpointing as it survives the cluster when + # instances are terminated and has better performance sync_config = tune.SyncConfig( upload_dir="s3://my-checkpoints-bucket/path/", # requires AWS credentials ) - # this starts the run! + # This starts the run! tuner = tune.Tuner( my_trainable, run_config=air.RunConfig( - # name of your experiment - # if this experiment exists, we will resume from the last run - # as specified by + # Name of your experiment name="my-tune-exp", - # a directory where results are stored before being + # Directory where each node's results are stored before being # sync'd to head node/cloud storage local_dir="/tmp/mypath", - # see above! we will sync our checkpoints to S3 directory + # See above! we will sync our checkpoints to S3 directory sync_config=sync_config, checkpoint_config=air.CheckpointConfig( - # we'll keep the best five checkpoints at all times + # We'll keep the best five checkpoints at all times # checkpoints (by AUC score, reported by the trainable, descending) checkpoint_score_attr="max-auc", keep_checkpoints_num=5, @@ -281,13 +285,24 @@ In this example, checkpoints will be saved: * **Locally**: not saved! Nothing will be sync'd to the driver (your laptop) automatically (because cloud syncing is enabled) * **S3**: ``s3://my-checkpoints-bucket/path/my-tune-exp//checkpoint_`` -* **On head node**: ``~/ray-results/my-tune-exp//checkpoint_`` (but only for trials done on that node) -* **On workers nodes**: ``~/ray-results/my-tune-exp//checkpoint_`` (but only for trials done on that node) +* **On head node**: ``/tmp/mypath/my-tune-exp//checkpoint_`` (but only for trials done on that node) +* **On workers nodes**: ``/tmp/mypath/my-tune-exp//checkpoint_`` (but only for trials done on that node) + +If this run stopped for any reason (finished, errored, user CTRL+C), you can restart it any time using experiment checkpoints saved in the cloud: + +.. code-block:: python + + tuner = Tuner.restore( + "s3://my-checkpoints-bucket/path/my-tune-exp", + resume_errored=True + ) + tuner.fit() + -If your run stopped for any reason (finished, errored, user CTRL+C), you can restart it any time by -``tuner=Tuner.restore(experiment_checkpoint_dir).fit()``. There are a few options for restoring an experiment: -"resume_unfinished", "resume_errored" and "restart_errored". See ``Tuner.restore()`` for more details. +``resume_unfinished``, ``resume_errored`` and ``restart_errored``. +Please see the documentation of +:meth:`Tuner.restore() ` for more details. .. _rsync-checkpointing: @@ -298,7 +313,7 @@ Local or rsync checkpointing can be a good option if: 1. You want to tune on a single laptop Ray cluster 2. You aren't using Ray on Kubernetes (rsync doesn't work with Ray on Kubernetes) -3. You don't want to use S3 +3. You don't want to cloud storage (i.e. S3) Let's take a look at an example: @@ -310,29 +325,20 @@ Let's take a look at an example: ray.init(address=":") # set `address=None` to train on laptop - # configure how checkpoints are sync'd to the scheduler/sampler - sync_config = tune.syncConfig() # the default mode is to use use rsync + # Configure how checkpoints are sync'd to the scheduler/sampler + sync_config = tune.SyncConfig() # the default mode is to use use rsync - # this starts the run! + # This starts the run! tuner = tune.Tuner( my_trainable, - run_config=air.RunConfig( - # name of your experiment - # If the experiment with the same name is already run, - # Tuner willl resume from the last run specified by sync_config(if one exists). - # Otherwise, will start a new run. name="my-tune-exp", - # a directory where results are stored before being - # sync'd to head node/cloud storage local_dir="/tmp/mypath", - # sync our checkpoints via rsync - # you don't have to pass an empty sync config - but we + # Sync our checkpoints via rsync + # You don't have to pass an empty sync config - but we # do it here for clarity and comparison sync_config=sync_config, checkpoint_config=air.CheckpointConfig( - # we'll keep the best five checkpoints at all times - # checkpoints (by AUC score, reported by the trainable, descending) checkpoint_score_attr="max-auc", keep_checkpoints_num=5, ) diff --git a/doc/source/tune/tutorials/tune-distributed.rst b/doc/source/tune/tutorials/tune-distributed.rst index 7d1412f563cd..716f63de42da 100644 --- a/doc/source/tune/tutorials/tune-distributed.rst +++ b/doc/source/tune/tutorials/tune-distributed.rst @@ -225,9 +225,7 @@ If the trial/actor is placed on a different node, Tune will automatically push t Recovering From Failures ~~~~~~~~~~~~~~~~~~~~~~~~ -Tune automatically persists the progress of your entire experiment (a ``Tuner.fit()`` session), so if an experiment crashes or is otherwise cancelled, it can be resumed through ``Tuner.restore()``. -There are a few options for restoring an experiment: -"resume_unfinished", "resume_errored" and "restart_errored". See ``Tuner.restore()`` for more details. +Tune automatically persists the progress of your entire experiment (a ``Tuner.fit()`` session), so if an experiment crashes or is otherwise cancelled, it can be resumed through :meth:`Tuner.restore() `. .. _tune-distributed-common: diff --git a/doc/source/tune/tutorials/tune-stopping.rst b/doc/source/tune/tutorials/tune-stopping.rst index e23e3801501e..8e1ee1cdd4cd 100644 --- a/doc/source/tune/tutorials/tune-stopping.rst +++ b/doc/source/tune/tutorials/tune-stopping.rst @@ -20,15 +20,16 @@ If you've stopped a run and and want to resume from where you left off, you can then call ``Tuner.restore()`` like this: .. code-block:: python - :emphasize-lines: 4 tuner = Tuner.restore( path="~/ray_results/my_experiment" ) tuner.fit() -There are a few options for resuming an experiment: -"resume_unfinished", "resume_errored" and "restart_errored". See ``Tuner.restore()`` for more details. +There are a few options for restoring an experiment: +``resume_unfinished``, ``resume_errored`` and ``restart_errored``. +Please see the documentation of +:meth:`Tuner.restore() ` for more details. ``path`` here is determined by the ``air.RunConfig.name`` you supplied to your ``Tuner()``. If you didn't supply name to ``Tuner``, it is likely that your ``path`` looks something like: @@ -48,18 +49,18 @@ of your original tuning run: Number of trials: 1/1 (1 RUNNING) What's happening under the hood? --------------------------------- +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +:ref:`Here `, we describe the two types of Tune checkpoints: +experiment-level and trial-level checkpoints. + +Upon resuming an interrupted/errored Tune run: -:ref:`Here ` we talked about two types of Tune checkpoints. -Both checkpoints come into play when resuming a Tune run. +#. Tune first looks at the experiment-level checkpoint to find the list of trials at the time of the interruption. -When resuming an interrupted/errored Tune run, Tune first looks at the experiment-level checkpoint -to find the list of trials at the time of the interruption. Ray Tune then locates the trial-level -checkpoint of each trial. +#. Tune then locates and restores from the trial-level checkpoint of each trial. -Depending on the specified resume option -("resume_unfinished", "resume_errored", "restart_errored"), Ray Tune then decides whether to -restore a given non-finished trial from its latest available checkpoint or start from scratch. +#. Depending on the specified resume option (``resume_unfinished``, ``resume_errored``, ``restart_errored``), Tune decides whether to restore a given unfinished trial from its latest available checkpoint or to start from scratch. .. _tune-stopping-ref: From fc4dab7c809365aa09aa9c19fad99e67d8885834 Mon Sep 17 00:00:00 2001 From: Justin Yu Date: Mon, 17 Oct 2022 09:29:48 -0700 Subject: [PATCH 02/11] Fix checkpoint config arg names Signed-off-by: Justin Yu --- doc/source/tune/tutorials/tune-checkpoints.rst | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/doc/source/tune/tutorials/tune-checkpoints.rst b/doc/source/tune/tutorials/tune-checkpoints.rst index faa9ac1401b0..21c983a254ba 100644 --- a/doc/source/tune/tutorials/tune-checkpoints.rst +++ b/doc/source/tune/tutorials/tune-checkpoints.rst @@ -274,8 +274,8 @@ via ``ray.init()``, making your script on your laptop the "driver". checkpoint_config=air.CheckpointConfig( # We'll keep the best five checkpoints at all times # checkpoints (by AUC score, reported by the trainable, descending) - checkpoint_score_attr="max-auc", - keep_checkpoints_num=5, + checkpoint_score_attribute="max-auc", + num_to_keep=5, ), ), ) @@ -339,9 +339,9 @@ Let's take a look at an example: # do it here for clarity and comparison sync_config=sync_config, checkpoint_config=air.CheckpointConfig( - checkpoint_score_attr="max-auc", - keep_checkpoints_num=5, - ) + checkpoint_score_attribute="max-auc", + num_to_keep=5, + ), ) ) From 9ea4298f6925b7c2ec807a6fdfde8c0f5585e350 Mon Sep 17 00:00:00 2001 From: Justin Yu Date: Mon, 17 Oct 2022 13:19:17 -0700 Subject: [PATCH 03/11] Fix Trainable class docstrings Signed-off-by: Justin Yu --- doc/source/tune/api_docs/trainable.rst | 7 +++++- python/ray/tune/trainable/trainable.py | 34 +++++++++++++++++--------- 2 files changed, 29 insertions(+), 12 deletions(-) diff --git a/doc/source/tune/api_docs/trainable.rst b/doc/source/tune/api_docs/trainable.rst index ef3b04549085..4a11be8d63f0 100644 --- a/doc/source/tune/api_docs/trainable.rst +++ b/doc/source/tune/api_docs/trainable.rst @@ -181,7 +181,12 @@ You can also implement checkpoint/restore using the Trainable Class API: checkpoint_path = os.path.join(tmp_checkpoint_dir, "model.pth") self.model.load_state_dict(torch.load(checkpoint_path)) - tuner = tune.Tuner(MyTrainableClass, run_config=air.RunConfig(checkpoint_config=air.CheckpointConfig(checkpoint_frequency=2))) + tuner = tune.Tuner( + MyTrainableClass, + run_config=air.RunConfig( + checkpoint_config=air.CheckpointConfig(checkpoint_frequency=2) + ) + ) results = tuner.fit() You can checkpoint with three different mechanisms: manually, periodically, and at termination. diff --git a/python/ray/tune/trainable/trainable.py b/python/ray/tune/trainable/trainable.py index e95a2abcc68f..ed82fdf62091 100644 --- a/python/ray/tune/trainable/trainable.py +++ b/python/ray/tune/trainable/trainable.py @@ -1115,9 +1115,10 @@ def save_checkpoint(self, checkpoint_dir: str) -> Optional[Union[str, Dict]]: Returns: A dict or string. If string, the return value is expected to be - prefixed by `tmp_checkpoint_dir`. If dict, the return value will - be automatically serialized by Tune and - passed to ``Trainable.load_checkpoint()``. + prefixed by `checkpoint_dir`. If dict, the return value will + be automatically serialized by Tune. In both cases, the return value + is exactly what will be passed to ``Trainable.load_checkpoint()`` + upon restore. Example: >>> trainable, trainable1, trainable2 = ... # doctest: +SKIP @@ -1146,23 +1147,34 @@ def load_checkpoint(self, checkpoint: Union[Dict, str]): The directory structure under the checkpoint_dir provided to ``Trainable.save_checkpoint`` is preserved. - See the example below. + See the examples below. Example: >>> from ray.tune.trainable import Trainable >>> class Example(Trainable): ... def save_checkpoint(self, checkpoint_path): - ... print(checkpoint_path) - ... return os.path.join(checkpoint_path, "my/check/point") - ... def load_checkpoint(self, checkpoint): - ... print(checkpoint) + ... my_checkpoint_path = os.path.join(checkpoint_path, "my/path") + ... return my_checkpoint_path + ... def load_checkpoint(self, my_checkpoint_path): + ... print(my_checkpoint_path) >>> trainer = Example() >>> # This is used when PAUSED. >>> obj = trainer.save_to_object() # doctest: +SKIP - /tmpc8k_c_6hsave_to_object/checkpoint_0/my/check/point + /tmpc8k_c_6hsave_to_object/checkpoint_0/my/path >>> # Note the different prefix. >>> trainer.restore_from_object(obj) # doctest: +SKIP - /tmpb87b5axfrestore_from_object/checkpoint_0/my/check/point + /tmpb87b5axfrestore_from_object/checkpoint_0/my/path + + If `Trainable.save_checkpoint` returned a dict, then Tune will directly pass + the dict data as the argument to this method. + + Example: + >>> from ray.tune.trainable import Trainable + >>> class Example(Trainable): + ... def save_checkpoint(self, checkpoint_path): + ... return {"my_data": 1} + ... def load_checkpoint(self, checkpoint_dict): + ... print(checkpoint_dict["my_data"]) .. versionadded:: 0.8.7 @@ -1171,7 +1183,7 @@ def load_checkpoint(self, checkpoint: Union[Dict, str]): returned by `save_checkpoint`. If a string, then it is a checkpoint path that may have a different prefix than that returned by `save_checkpoint`. The directory structure - underneath the `checkpoint_dir` `save_checkpoint` is preserved. + underneath the `checkpoint_dir` from `save_checkpoint` is preserved. """ raise NotImplementedError From 7157941cb99452db6e8323165890cea632f8b423 Mon Sep 17 00:00:00 2001 From: Justin Yu Date: Mon, 17 Oct 2022 13:53:28 -0700 Subject: [PATCH 04/11] Add checkpointing examples in Tune checkpointing guide Signed-off-by: Justin Yu --- .../checkpointing/class-checkpointing.rst | 24 +++++++ .../checkpointing/function-checkpointing.rst | 35 +++++++++++ doc/source/tune/api_docs/trainable.rst | 62 +------------------ .../tune/tutorials/tune-checkpoints.rst | 21 ++++++- 4 files changed, 79 insertions(+), 63 deletions(-) create mode 100644 doc/source/tune/api_docs/checkpointing/class-checkpointing.rst create mode 100644 doc/source/tune/api_docs/checkpointing/function-checkpointing.rst diff --git a/doc/source/tune/api_docs/checkpointing/class-checkpointing.rst b/doc/source/tune/api_docs/checkpointing/class-checkpointing.rst new file mode 100644 index 000000000000..4cf4c90046b7 --- /dev/null +++ b/doc/source/tune/api_docs/checkpointing/class-checkpointing.rst @@ -0,0 +1,24 @@ +Class API Checkpointing +~~~~~~~~~~~~~~~~~~~~~~~ + +You can also implement checkpoint/restore using the Trainable Class API: + +.. code-block:: python + + class MyTrainableClass(Trainable): + def save_checkpoint(self, tmp_checkpoint_dir): + checkpoint_path = os.path.join(tmp_checkpoint_dir, "model.pth") + torch.save(self.model.state_dict(), checkpoint_path) + return tmp_checkpoint_dir + + def load_checkpoint(self, tmp_checkpoint_dir): + checkpoint_path = os.path.join(tmp_checkpoint_dir, "model.pth") + self.model.load_state_dict(torch.load(checkpoint_path)) + + tuner = tune.Tuner( + MyTrainableClass, + run_config=air.RunConfig( + checkpoint_config=air.CheckpointConfig(checkpoint_frequency=2) + ) + ) + results = tuner.fit() diff --git a/doc/source/tune/api_docs/checkpointing/function-checkpointing.rst b/doc/source/tune/api_docs/checkpointing/function-checkpointing.rst new file mode 100644 index 000000000000..1ef0212624af --- /dev/null +++ b/doc/source/tune/api_docs/checkpointing/function-checkpointing.rst @@ -0,0 +1,35 @@ +Function API Checkpointing +~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Many Tune features rely on checkpointing, including the usage of certain Trial Schedulers and fault tolerance. +You can save and load checkpoint in Ray Tune in the following manner: + +.. code-block:: python + + import time + from ray import tune + from ray.air import session + from ray.air.checkpoint import Checkpoint + + def train_func(config): + step = 0 + loaded_checkpoint = session.get_checkpoint() + if loaded_checkpoint: + last_step = loaded_checkpoint.to_dict()["step"] + step = last_step + 1 + + for iter in range(step, 100): + time.sleep(1) + + checkpoint = Checkpoint.from_dict({"step": step}) + session.report({"message": "Hello world Ray Tune!"}, checkpoint=checkpoint) + + tuner = tune.Tuner(train_func) + results = tuner.fit() + +.. note:: ``checkpoint_frequency`` and ``checkpoint_at_end`` will not work with Function API checkpointing. + +In this example, checkpoints will be saved by training iteration to ``//trial_name/checkpoint_``. + +Tune also may copy or move checkpoints during the course of tuning. For this purpose, +it is important not to depend on absolute paths in the implementation of ``save``. diff --git a/doc/source/tune/api_docs/trainable.rst b/doc/source/tune/api_docs/trainable.rst index 4a11be8d63f0..711bd9e07240 100644 --- a/doc/source/tune/api_docs/trainable.rst +++ b/doc/source/tune/api_docs/trainable.rst @@ -71,44 +71,9 @@ such as ``iterations_since_restore``. See :ref:`tune-autofilled-metrics` for an print("best config: ", results.get_best_result(metric="score", mode="max").config) - .. _tune-function-checkpointing: -Function API Checkpointing -~~~~~~~~~~~~~~~~~~~~~~~~~~ - -Many Tune features rely on checkpointing, including the usage of certain Trial Schedulers and fault tolerance. -You can save and load checkpoint in Ray Tune in the following manner: - -.. code-block:: python - - import time - from ray import tune - from ray.air import session - from ray.air.checkpoint import Checkpoint - - def train_func(config): - step = 0 - loaded_checkpoint = session.get_checkpoint() - if loaded_checkpoint: - last_step = loaded_checkpoint.to_dict()["step"] - step = last_step + 1 - - for iter in range(step, 100): - time.sleep(1) - - checkpoint = Checkpoint.from_dict({"step": step}) - session.report({"message": "Hello world Ray Tune!"}, checkpoint=checkpoint) - - tuner = tune.Tuner(train_func) - results = tuner.fit() - -.. note:: ``checkpoint_frequency`` and ``checkpoint_at_end`` will not work with Function API checkpointing. - -In this example, checkpoints will be saved by training iteration to ``local_dir/exp_name/trial_name/checkpoint_``. - -Tune also may copy or move checkpoints during the course of tuning. For this purpose, -it is important not to depend on absolute paths in the implementation of ``save``. +.. include:: checkpointing/function-checkpointing.rst .. _tune-class-api: @@ -164,30 +129,7 @@ See :ref:`tune-autofilled-metrics` for an explanation/glossary of these values. .. _tune-trainable-save-restore: -Class API Checkpointing -~~~~~~~~~~~~~~~~~~~~~~~ - -You can also implement checkpoint/restore using the Trainable Class API: - -.. code-block:: python - - class MyTrainableClass(Trainable): - def save_checkpoint(self, tmp_checkpoint_dir): - checkpoint_path = os.path.join(tmp_checkpoint_dir, "model.pth") - torch.save(self.model.state_dict(), checkpoint_path) - return tmp_checkpoint_dir - - def load_checkpoint(self, tmp_checkpoint_dir): - checkpoint_path = os.path.join(tmp_checkpoint_dir, "model.pth") - self.model.load_state_dict(torch.load(checkpoint_path)) - - tuner = tune.Tuner( - MyTrainableClass, - run_config=air.RunConfig( - checkpoint_config=air.CheckpointConfig(checkpoint_frequency=2) - ) - ) - results = tuner.fit() +.. include:: checkpointing/class-checkpointing.rst You can checkpoint with three different mechanisms: manually, periodically, and at termination. diff --git a/doc/source/tune/tutorials/tune-checkpoints.rst b/doc/source/tune/tutorials/tune-checkpoints.rst index 21c983a254ba..28157408c76a 100644 --- a/doc/source/tune/tutorials/tune-checkpoints.rst +++ b/doc/source/tune/tutorials/tune-checkpoints.rst @@ -31,15 +31,30 @@ Commonly, this includes the model and optimizer states. This is useful mostly fo the meantime. This only makes sense if the trials can then continue training from the latest state. - The checkpoint can be later used for other downstream tasks like batch inference. -Everything that is reported by ``session.report()`` is a trial-level checkpoint. -See :ref:`here for more information on saving checkpoints `. +Everything that is saved by ``session.report()`` (if using the Function API) or +``Trainable.save_checkpoint`` (if using the Class API) is a **trial-level checkpoint.** +See below for examples of saving and loading trial-level checkpoints. + + +How do I save and load trial checkpoints? +----------------------------------------- + +.. include:: ../api_docs/checkpointing/function-checkpointing.rst + +.. include:: ../api_docs/checkpointing/class-checkpointing.rst + +See :ref:`here for more information on creating checkpoints `. +If using framework-specific trainers from Ray AIR, see :ref:`here ` for +references to framework-specific checkpoints such as `TensorflowCheckpoint`. .. _tune-checkpoint-syncing: Checkpointing and synchronization --------------------------------- -This topic is mostly relevant to Trial checkpoint. +.. note:: + + This topic is relevant to trial checkpoints. Tune stores checkpoints on the node where the trials are executed. If you are training on more than one node, this means that some trial checkpoints may be on the head node and others are not. From 2056287906993a1a5fad4faa4cc8266418dd5fe3 Mon Sep 17 00:00:00 2001 From: Justin Yu Date: Tue, 18 Oct 2022 11:52:58 -0700 Subject: [PATCH 05/11] Move code to a runnable file in doc_code, address comments Signed-off-by: Justin Yu --- .../checkpointing/class-checkpointing.rst | 23 +-- .../checkpointing/function-checkpointing.rst | 28 +--- doc/source/tune/api_docs/trainable.rst | 91 +++--------- doc/source/tune/doc_code/trainable.py | 138 ++++++++++++++++++ .../tune/tutorials/tune-checkpoints.rst | 5 +- python/ray/tune/trainable/trainable.py | 1 + 6 files changed, 172 insertions(+), 114 deletions(-) create mode 100644 doc/source/tune/doc_code/trainable.py diff --git a/doc/source/tune/api_docs/checkpointing/class-checkpointing.rst b/doc/source/tune/api_docs/checkpointing/class-checkpointing.rst index 4cf4c90046b7..a5397f713fd2 100644 --- a/doc/source/tune/api_docs/checkpointing/class-checkpointing.rst +++ b/doc/source/tune/api_docs/checkpointing/class-checkpointing.rst @@ -3,22 +3,7 @@ Class API Checkpointing You can also implement checkpoint/restore using the Trainable Class API: -.. code-block:: python - - class MyTrainableClass(Trainable): - def save_checkpoint(self, tmp_checkpoint_dir): - checkpoint_path = os.path.join(tmp_checkpoint_dir, "model.pth") - torch.save(self.model.state_dict(), checkpoint_path) - return tmp_checkpoint_dir - - def load_checkpoint(self, tmp_checkpoint_dir): - checkpoint_path = os.path.join(tmp_checkpoint_dir, "model.pth") - self.model.load_state_dict(torch.load(checkpoint_path)) - - tuner = tune.Tuner( - MyTrainableClass, - run_config=air.RunConfig( - checkpoint_config=air.CheckpointConfig(checkpoint_frequency=2) - ) - ) - results = tuner.fit() +.. literalinclude:: /tune/doc_code/trainable.py + :language: python + :start-after: __class_api_checkpointing_start__ + :end-before: __class_api_checkpointing_end__ diff --git a/doc/source/tune/api_docs/checkpointing/function-checkpointing.rst b/doc/source/tune/api_docs/checkpointing/function-checkpointing.rst index 1ef0212624af..35ad74f24ac8 100644 --- a/doc/source/tune/api_docs/checkpointing/function-checkpointing.rst +++ b/doc/source/tune/api_docs/checkpointing/function-checkpointing.rst @@ -2,30 +2,12 @@ Function API Checkpointing ~~~~~~~~~~~~~~~~~~~~~~~~~~ Many Tune features rely on checkpointing, including the usage of certain Trial Schedulers and fault tolerance. -You can save and load checkpoint in Ray Tune in the following manner: +You can save and load checkpoints in Ray Tune in the following manner: -.. code-block:: python - - import time - from ray import tune - from ray.air import session - from ray.air.checkpoint import Checkpoint - - def train_func(config): - step = 0 - loaded_checkpoint = session.get_checkpoint() - if loaded_checkpoint: - last_step = loaded_checkpoint.to_dict()["step"] - step = last_step + 1 - - for iter in range(step, 100): - time.sleep(1) - - checkpoint = Checkpoint.from_dict({"step": step}) - session.report({"message": "Hello world Ray Tune!"}, checkpoint=checkpoint) - - tuner = tune.Tuner(train_func) - results = tuner.fit() +.. literalinclude:: /tune/doc_code/trainable.py + :language: python + :start-after: __function_api_checkpointing_start__ + :end-before: __function_api_checkpointing_end__ .. note:: ``checkpoint_frequency`` and ``checkpoint_at_end`` will not work with Function API checkpointing. diff --git a/doc/source/tune/api_docs/trainable.rst b/doc/source/tune/api_docs/trainable.rst index 711bd9e07240..f5ccecf69482 100644 --- a/doc/source/tune/api_docs/trainable.rst +++ b/doc/source/tune/api_docs/trainable.rst @@ -7,14 +7,14 @@ Training (tune.Trainable, session.report) ========================================== -Training can be done with either a **Class API** (``tune.Trainable``) or **function API** (``session.report``). +Training can be done with either a **Class API** (:ref:`tune.Trainable `) or **function API** (:ref:`session.report `). For the sake of example, let's maximize this objective function: -.. code-block:: python - - def objective(x, a, b): - return a * (x ** 0.5) + b +.. literalinclude:: /tune/doc_code/trainable.py + :language: python + :start-after: __example_objective_start__ + :end-before: __example_objective_end__ .. _tune-function-api: @@ -23,27 +23,10 @@ Function API With the Function API, you can report intermediate metrics by simply calling ``session.report`` within the provided function. - -.. code-block:: python - - from ray import tune - from ray.air import session - - def trainable(config): - # config (dict): A dict of hyperparameters. - - for x in range(20): - intermediate_score = objective(x, config["a"], config["b"]) - - session.report({"score": intermediate_score}) # This sends the score to Tune. - - tuner = tune.Tuner( - trainable, - param_space={"a": 2, "b": 4} - ) - results = tuner.fit() - - print("best config: ", results.get_best_result(metric="score", mode="max").config) +.. literalinclude:: /tune/doc_code/trainable.py + :language: python + :start-after: __function_api_report_intermediate_metrics_start__ + :end-before: __function_api_report_intermediate_metrics_end__ .. tip:: Do not use ``session.report`` within a ``Trainable`` class. @@ -52,24 +35,13 @@ Tune will run this function on a separate thread in a Ray actor process. You'll notice that Ray Tune will output extra values in addition to the user reported metrics, such as ``iterations_since_restore``. See :ref:`tune-autofilled-metrics` for an explanation/glossary of these values. -.. code-block:: python - - def trainable(config): - # config (dict): A dict of hyperparameters. - - final_score = 0 - for x in range(20): - final_score = objective(x, config["a"], config["b"]) +In the previous example, we reported on every step, but this metric reporting frequency +is configurable. For example, we could also report only a single time at the end with the final score: - return {"score": final_score} # This sends the score to Tune. - - tuner = tune.Tuner( - trainable, - param_space={"a": 2, "b": 4} - ) - results = tuner.fit() - - print("best config: ", results.get_best_result(metric="score", mode="max").config) +.. literalinclude:: /tune/doc_code/trainable.py + :language: python + :start-after: __function_api_report_final_metrics_start__ + :end-before: __function_api_report_final_metrics_end__ .. _tune-function-checkpointing: @@ -84,32 +56,10 @@ Trainable Class API The Trainable **class API** will require users to subclass ``ray.tune.Trainable``. Here's a naive example of this API: -.. code-block:: python - - from ray import tune - - class Trainable(tune.Trainable): - def setup(self, config): - # config (dict): A dict of hyperparameters - self.x = 0 - self.a = config["a"] - self.b = config["b"] - - def step(self): # This is called iteratively. - score = objective(self.x, self.a, self.b) - self.x += 1 - return {"score": score} - - tuner = tune.Tuner( - Trainable, - tune_config=air.RunConfig(stop={"training_iteration": 20}), - param_space={ - "a": 2, - "b": 4 - }) - results = tuner.fit() - - print('best config: ', results.get_best_result(metric="score", mode="max").config) +.. literalinclude:: /tune/doc_code/trainable.py + :language: python + :start-after: __class_api_example_start__ + :end-before: __class_api_example_end__ As a subclass of ``tune.Trainable``, Tune will create a ``Trainable`` object on a separate process (using the :ref:`Ray Actor API `). @@ -274,10 +224,11 @@ session (Function API) .. autofunction:: ray.air.session.get_trial_resources :noindex: +.. _tune-trainable-docstring: + tune.Trainable (Class API) -------------------------- - .. autoclass:: ray.tune.Trainable :member-order: groupwise :private-members: diff --git a/doc/source/tune/doc_code/trainable.py b/doc/source/tune/doc_code/trainable.py new file mode 100644 index 000000000000..d76adb0cc6d9 --- /dev/null +++ b/doc/source/tune/doc_code/trainable.py @@ -0,0 +1,138 @@ +# __class_api_checkpointing_start__ +import os +import torch +from torch import nn + +from ray import air, tune + + +class MyTrainableClass(tune.Trainable): + def setup(self, config): + self.model = nn.Sequential( + nn.Linear(config.get("input_size", 32), 32), nn.ReLU(), nn.Linear(32, 10) + ) + + def step(self): + return {} + + def save_checkpoint(self, tmp_checkpoint_dir): + checkpoint_path = os.path.join(tmp_checkpoint_dir, "model.pth") + torch.save(self.model.state_dict(), checkpoint_path) + return tmp_checkpoint_dir + + def load_checkpoint(self, tmp_checkpoint_dir): + checkpoint_path = os.path.join(tmp_checkpoint_dir, "model.pth") + self.model.load_state_dict(torch.load(checkpoint_path)) + + +tuner = tune.Tuner( + MyTrainableClass, + param_space={"input_size": 64}, + run_config=air.RunConfig( + stop={"training_iteration": 2}, + checkpoint_config=air.CheckpointConfig(checkpoint_frequency=2), + ), +) +tuner.fit() +# __class_api_checkpointing_end__ + +# __function_api_checkpointing_start__ +from ray import tune +from ray.air import session +from ray.air.checkpoint import Checkpoint + + +def train_func(config): + epochs = config.get("epochs", 2) + start = 0 + loaded_checkpoint = session.get_checkpoint() + if loaded_checkpoint: + last_step = loaded_checkpoint.to_dict()["step"] + start = last_step + 1 + + for step in range(start, epochs): + # Model training here + # ... + + # Report metrics and save a checkpoint + metrics = {"metric": "my_metric"} + checkpoint = Checkpoint.from_dict({"step": step}) + session.report(metrics, checkpoint=checkpoint) + + +tuner = tune.Tuner(train_func) +results = tuner.fit() +# __function_api_checkpointing_end__ + + +# __example_objective_start__ +def objective(x, a, b): + return a * (x ** 0.5) + b + + +# __example_objective_end__ + +# __function_api_report_intermediate_metrics_start__ +from ray import tune +from ray.air import session + + +def trainable(config: dict): + intermediate_score = 0 + for x in range(20): + intermediate_score = objective(x, config["a"], config["b"]) + session.report({"score": intermediate_score}) # This sends the score to Tune. + + +tuner = tune.Tuner(trainable, param_space={"a": 2, "b": 4}) +results = tuner.fit() +# __function_api_report_intermediate_metrics_end__ + +# __function_api_report_final_metrics_start__ +from ray import tune +from ray.air import session + + +def trainable(config: dict): + final_score = 0 + for x in range(20): + final_score = objective(x, config["a"], config["b"]) + + session.report({"score": final_score}) # This sends the score to Tune. + + +tuner = tune.Tuner(trainable, param_space={"a": 2, "b": 4}) +results = tuner.fit() +# __function_api_report_final_metrics_end__ + +# __class_api_example_start__ +from ray import air, tune + + +class Trainable(tune.Trainable): + def setup(self, config: dict): + # config (dict): A dict of hyperparameters + self.x = 0 + self.a = config["a"] + self.b = config["b"] + + def step(self): # This is called iteratively. + score = objective(self.x, self.a, self.b) + self.x += 1 + return {"score": score} + + +tuner = tune.Tuner( + Trainable, + run_config=air.RunConfig( + # Train for 20 steps + stop={"training_iteration": 20}, + checkpoint_config=air.CheckpointConfig( + # We haven't implemented checkpointing yet. See below! + checkpoint_at_end=False + ), + ), + param_space={"a": 2, "b": 4}, +) +results = tuner.fit() +# __class_api_example_end__ diff --git a/doc/source/tune/tutorials/tune-checkpoints.rst b/doc/source/tune/tutorials/tune-checkpoints.rst index 28157408c76a..04f4b75dd503 100644 --- a/doc/source/tune/tutorials/tune-checkpoints.rst +++ b/doc/source/tune/tutorials/tune-checkpoints.rst @@ -307,7 +307,8 @@ If this run stopped for any reason (finished, errored, user CTRL+C), you can res .. code-block:: python - tuner = Tuner.restore( + from ray import tune + tuner = tune.Tuner.restore( "s3://my-checkpoints-bucket/path/my-tune-exp", resume_errored=True ) @@ -328,7 +329,7 @@ Local or rsync checkpointing can be a good option if: 1. You want to tune on a single laptop Ray cluster 2. You aren't using Ray on Kubernetes (rsync doesn't work with Ray on Kubernetes) -3. You don't want to cloud storage (i.e. S3) +3. You don't want to cloud storage (e.g. S3) Let's take a look at an example: diff --git a/python/ray/tune/trainable/trainable.py b/python/ray/tune/trainable/trainable.py index ed82fdf62091..58f33fa87881 100644 --- a/python/ray/tune/trainable/trainable.py +++ b/python/ray/tune/trainable/trainable.py @@ -1150,6 +1150,7 @@ def load_checkpoint(self, checkpoint: Union[Dict, str]): See the examples below. Example: + >>> import os >>> from ray.tune.trainable import Trainable >>> class Example(Trainable): ... def save_checkpoint(self, checkpoint_path): From 3e569236c3404141dfae04211f6b1c098746e692 Mon Sep 17 00:00:00 2001 From: Justin Yu Date: Fri, 21 Oct 2022 13:42:18 -0700 Subject: [PATCH 06/11] Exclude doc code from linting to allow imports in each code block Signed-off-by: Justin Yu --- doc/source/tune/doc_code/trainable.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/doc/source/tune/doc_code/trainable.py b/doc/source/tune/doc_code/trainable.py index d76adb0cc6d9..22a6c21812d4 100644 --- a/doc/source/tune/doc_code/trainable.py +++ b/doc/source/tune/doc_code/trainable.py @@ -1,3 +1,5 @@ +# flake8: noqa + # __class_api_checkpointing_start__ import os import torch From 26ac2853a81758878a19ac5133d1b972310e5ea5 Mon Sep 17 00:00:00 2001 From: Justin Yu Date: Fri, 21 Oct 2022 14:01:40 -0700 Subject: [PATCH 07/11] Remove trailing space added by autoformatting Signed-off-by: Justin Yu --- doc/source/tune/doc_code/trainable.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/doc/source/tune/doc_code/trainable.py b/doc/source/tune/doc_code/trainable.py index 22a6c21812d4..b4405df1b3e7 100644 --- a/doc/source/tune/doc_code/trainable.py +++ b/doc/source/tune/doc_code/trainable.py @@ -66,13 +66,12 @@ def train_func(config): results = tuner.fit() # __function_api_checkpointing_end__ - +# fmt: off # __example_objective_start__ def objective(x, a, b): return a * (x ** 0.5) + b - - # __example_objective_end__ +# fmt: on # __function_api_report_intermediate_metrics_start__ from ray import tune From 6927410f5c4cba811f71b47e5dc50abe53346fcd Mon Sep 17 00:00:00 2001 From: Justin Yu Date: Wed, 26 Oct 2022 16:06:17 -0700 Subject: [PATCH 08/11] Reference Trainable checkpointing examples instead of duplicating Signed-off-by: Justin Yu --- .../checkpointing/class-checkpointing.rst | 9 ------ .../checkpointing/function-checkpointing.rst | 17 ---------- doc/source/tune/api_docs/trainable.rst | 32 +++++++++++++++++-- .../tune/tutorials/tune-checkpoints.rst | 16 ++-------- 4 files changed, 33 insertions(+), 41 deletions(-) delete mode 100644 doc/source/tune/api_docs/checkpointing/class-checkpointing.rst delete mode 100644 doc/source/tune/api_docs/checkpointing/function-checkpointing.rst diff --git a/doc/source/tune/api_docs/checkpointing/class-checkpointing.rst b/doc/source/tune/api_docs/checkpointing/class-checkpointing.rst deleted file mode 100644 index a5397f713fd2..000000000000 --- a/doc/source/tune/api_docs/checkpointing/class-checkpointing.rst +++ /dev/null @@ -1,9 +0,0 @@ -Class API Checkpointing -~~~~~~~~~~~~~~~~~~~~~~~ - -You can also implement checkpoint/restore using the Trainable Class API: - -.. literalinclude:: /tune/doc_code/trainable.py - :language: python - :start-after: __class_api_checkpointing_start__ - :end-before: __class_api_checkpointing_end__ diff --git a/doc/source/tune/api_docs/checkpointing/function-checkpointing.rst b/doc/source/tune/api_docs/checkpointing/function-checkpointing.rst deleted file mode 100644 index 35ad74f24ac8..000000000000 --- a/doc/source/tune/api_docs/checkpointing/function-checkpointing.rst +++ /dev/null @@ -1,17 +0,0 @@ -Function API Checkpointing -~~~~~~~~~~~~~~~~~~~~~~~~~~ - -Many Tune features rely on checkpointing, including the usage of certain Trial Schedulers and fault tolerance. -You can save and load checkpoints in Ray Tune in the following manner: - -.. literalinclude:: /tune/doc_code/trainable.py - :language: python - :start-after: __function_api_checkpointing_start__ - :end-before: __function_api_checkpointing_end__ - -.. note:: ``checkpoint_frequency`` and ``checkpoint_at_end`` will not work with Function API checkpointing. - -In this example, checkpoints will be saved by training iteration to ``//trial_name/checkpoint_``. - -Tune also may copy or move checkpoints during the course of tuning. For this purpose, -it is important not to depend on absolute paths in the implementation of ``save``. diff --git a/doc/source/tune/api_docs/trainable.rst b/doc/source/tune/api_docs/trainable.rst index d66cb38dba50..d9548a4be8f2 100644 --- a/doc/source/tune/api_docs/trainable.rst +++ b/doc/source/tune/api_docs/trainable.rst @@ -45,7 +45,27 @@ is configurable. For example, we could also report only a single time at the end .. _tune-function-checkpointing: -.. include:: checkpointing/function-checkpointing.rst +Function API Checkpointing +~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Many Tune features rely on checkpointing, including the usage of certain Trial Schedulers and fault tolerance. +You can save and load checkpoints in Ray Tune in the following manner: + +.. literalinclude:: /tune/doc_code/trainable.py + :language: python + :start-after: __function_api_checkpointing_start__ + :end-before: __function_api_checkpointing_end__ + +.. note:: ``checkpoint_frequency`` and ``checkpoint_at_end`` will not work with Function API checkpointing. + +In this example, checkpoints will be saved by training iteration to ``//trial_name/checkpoint_``. + +Tune also may copy or move checkpoints during the course of tuning. For this purpose, +it is important not to depend on absolute paths in the implementation of ``save``. + +See :ref:`here for more information on creating checkpoints `. +If using framework-specific trainers from Ray AIR, see :ref:`here ` for +references to framework-specific checkpoints such as `TensorflowCheckpoint`. .. _tune-class-api: @@ -79,7 +99,15 @@ See :ref:`tune-autofilled-metrics` for an explanation/glossary of these values. .. _tune-trainable-save-restore: -.. include:: checkpointing/class-checkpointing.rst +Class API Checkpointing +~~~~~~~~~~~~~~~~~~~~~~~ + +You can also implement checkpoint/restore using the Trainable Class API: + +.. literalinclude:: /tune/doc_code/trainable.py + :language: python + :start-after: __class_api_checkpointing_start__ + :end-before: __class_api_checkpointing_end__ You can checkpoint with three different mechanisms: manually, periodically, and at termination. diff --git a/doc/source/tune/tutorials/tune-checkpoints.rst b/doc/source/tune/tutorials/tune-checkpoints.rst index 04f4b75dd503..7b6a75760742 100644 --- a/doc/source/tune/tutorials/tune-checkpoints.rst +++ b/doc/source/tune/tutorials/tune-checkpoints.rst @@ -33,19 +33,9 @@ Commonly, this includes the model and optimizer states. This is useful mostly fo Everything that is saved by ``session.report()`` (if using the Function API) or ``Trainable.save_checkpoint`` (if using the Class API) is a **trial-level checkpoint.** -See below for examples of saving and loading trial-level checkpoints. - - -How do I save and load trial checkpoints? ------------------------------------------ - -.. include:: ../api_docs/checkpointing/function-checkpointing.rst - -.. include:: ../api_docs/checkpointing/class-checkpointing.rst - -See :ref:`here for more information on creating checkpoints `. -If using framework-specific trainers from Ray AIR, see :ref:`here ` for -references to framework-specific checkpoints such as `TensorflowCheckpoint`. +See :ref:`checkpointing with the Function API ` and +:ref:`checkpointing with the Class API ` +for examples of saving and loading trial-level checkpoints. .. _tune-checkpoint-syncing: From 444b8e652b4703fb07f7320ed72203e451fd2688 Mon Sep 17 00:00:00 2001 From: Justin Yu Date: Thu, 27 Oct 2022 09:53:25 -0700 Subject: [PATCH 09/11] Add back final metric return docs for function api Signed-off-by: Justin Yu --- doc/source/tune/api_docs/trainable.rst | 20 ++++++++++++++------ doc/source/tune/doc_code/trainable.py | 9 +++++++++ 2 files changed, 23 insertions(+), 6 deletions(-) diff --git a/doc/source/tune/api_docs/trainable.rst b/doc/source/tune/api_docs/trainable.rst index d9548a4be8f2..f5122bcc163c 100644 --- a/doc/source/tune/api_docs/trainable.rst +++ b/doc/source/tune/api_docs/trainable.rst @@ -21,7 +21,10 @@ For the sake of example, let's maximize this objective function: Function API ------------ -With the Function API, you can report intermediate metrics by simply calling ``session.report`` within the provided function. +The Function API allows you to define a custom training function that Tune will run in parallel Ray actor processes, +one for each Tune trial. + +With the Function API, you can report intermediate metrics by simply calling ``session.report`` within the function. .. literalinclude:: /tune/doc_code/trainable.py :language: python @@ -30,11 +33,6 @@ With the Function API, you can report intermediate metrics by simply calling ``s .. tip:: Do not use ``session.report`` within a ``Trainable`` class. -Tune will run this function on a separate thread in a Ray actor process. - -You'll notice that Ray Tune will output extra values in addition to the user reported metrics, -such as ``iterations_since_restore``. See :ref:`tune-autofilled-metrics` for an explanation/glossary of these values. - In the previous example, we reported on every step, but this metric reporting frequency is configurable. For example, we could also report only a single time at the end with the final score: @@ -43,6 +41,16 @@ is configurable. For example, we could also report only a single time at the end :start-after: __function_api_report_final_metrics_start__ :end-before: __function_api_report_final_metrics_end__ +It's also possible to return a final set of metrics to Tune by returning them from your function: + +.. literalinclude:: /tune/doc_code/trainable.py + :language: python + :start-after: __function_api_return_final_metrics_start__ + :end-before: __function_api_return_final_metrics_end__ + +You'll notice that Ray Tune will output extra values in addition to the user reported metrics, +such as ``iterations_since_restore``. See :ref:`tune-autofilled-metrics` for an explanation/glossary of these values. + .. _tune-function-checkpointing: Function API Checkpointing diff --git a/doc/source/tune/doc_code/trainable.py b/doc/source/tune/doc_code/trainable.py index b4405df1b3e7..d10d2f83a571 100644 --- a/doc/source/tune/doc_code/trainable.py +++ b/doc/source/tune/doc_code/trainable.py @@ -106,6 +106,15 @@ def trainable(config: dict): results = tuner.fit() # __function_api_report_final_metrics_end__ +# __function_api_return_final_metrics_start__ +def trainable(config: dict): + final_score = 0 + for x in range(20): + final_score = objective(x, config["a"], config["b"]) + + return {"score": final_score} # This sends the score to Tune. +# __function_api_return_final_metrics_end__ + # __class_api_example_start__ from ray import air, tune From e41f2bf3929c7aaf0151629ddfe9d5e768b20f5a Mon Sep 17 00:00:00 2001 From: Justin Yu Date: Thu, 27 Oct 2022 09:58:48 -0700 Subject: [PATCH 10/11] Add table comparison of Function/Class APIs Signed-off-by: Justin Yu --- doc/source/tune/api_docs/trainable.rst | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/doc/source/tune/api_docs/trainable.rst b/doc/source/tune/api_docs/trainable.rst index f5122bcc163c..a532553d82fe 100644 --- a/doc/source/tune/api_docs/trainable.rst +++ b/doc/source/tune/api_docs/trainable.rst @@ -211,6 +211,22 @@ It is up to the user to correctly update the hyperparameters of your trainable. return True +Comparing the Function API and Class API +---------------------------------------- + +Here are a few key concepts and what they look like for the Function and Class API's. + +======================= =============================================== ============================================== +Concept Function API Class API +======================= =============================================== ============================================== +Training Iteration Increments on each `session.report` call Increments on each `Trainable.step` call +Report metrics `session.report(metrics)` Return metrics from `Trainable.step` +Saving a checkpoint `session.report(..., checkpoint=checkpoint)` `Trainable.save_checkpoint` +Loading a checkpoint `session.get_checkpoint()` `Trainable.load_checkpoint` +Accessing config Passed as an argument `def train_func(config):` Passed through `Trainable.setup` +======================= =============================================== ============================================== + + Advanced Resource Allocation ---------------------------- From 7f6876f868f77a3bfd69ae14def0c9d8c1ba6769 Mon Sep 17 00:00:00 2001 From: Justin Yu Date: Thu, 27 Oct 2022 10:01:42 -0700 Subject: [PATCH 11/11] Disable formatting for single function to avoid extra lines Signed-off-by: Justin Yu --- doc/source/tune/doc_code/trainable.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/doc/source/tune/doc_code/trainable.py b/doc/source/tune/doc_code/trainable.py index d10d2f83a571..856d33c91663 100644 --- a/doc/source/tune/doc_code/trainable.py +++ b/doc/source/tune/doc_code/trainable.py @@ -106,6 +106,7 @@ def trainable(config: dict): results = tuner.fit() # __function_api_report_final_metrics_end__ +# fmt: off # __function_api_return_final_metrics_start__ def trainable(config: dict): final_score = 0 @@ -114,6 +115,7 @@ def trainable(config: dict): return {"score": final_score} # This sends the score to Tune. # __function_api_return_final_metrics_end__ +# fmt: on # __class_api_example_start__ from ray import air, tune