Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Tune] Rename overwrite_trainable argument in Tuner restore to trainable #32059

Merged
2 changes: 1 addition & 1 deletion python/ray/train/tests/test_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ def train_func(config):
resume_errored=True,
)
# Should warn about the RunConfig being ignored
assert "RunConfig" in str(warn_record[0].message)
assert any("RunConfig" in str(record.message) for record in warn_record)
assert "The trainable will be overwritten" in caplog.text

results = tuner.fit()
Expand Down
7 changes: 3 additions & 4 deletions python/ray/tune/impl/tuner_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,17 +251,16 @@ def _validate_overwrite_trainable(
"# Reconstruct the trainable with the same parameters\n"
"trainable_with_params = tune.with_parameters(trainable, ...)\n"
"tuner = tune.Tuner.restore(\n"
" ..., overwrite_trainable=trainable_with_params\n"
" ..., trainable=trainable_with_params\n"
")\n\nSee https://docs.ray.io/en/master/tune/api_docs/trainable.html"
"#tune-with-parameters for more details."
)
if not overwrite_trainable:
return

error_message = (
"Usage of `overwrite_trainable` is limited to re-specifying the "
"same trainable that was passed to `Tuner`, in the case "
"that the trainable is not serializable (e.g. it holds object references)."
"Invalid trainable input. To avoid errors, pass in the same trainable "
"that was used to initialize the Tuner."
)

if type(original_trainable) != type(overwrite_trainable):
Expand Down
7 changes: 4 additions & 3 deletions python/ray/tune/trainable/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,13 +325,13 @@ def step(self):

1. ``tune.with_parameters`` stores parameters in the object store and
attaches object references to the trainable, but the objects they point to
may not exist anymore upon restore.
may not exist anymore upon restoring in a new Ray cluster.

2. The attached objects could be arbitrarily large, so Tune does not save the
object data along with the trainable.

To restore, Tune allows the trainable to be re-specified in
:meth:`Tuner.restore(overwrite_trainable=...) <ray.tune.tuner.Tuner.restore>`.
:meth:`Tuner.restore(path, trainable=...) <ray.tune.tuner.Tuner.restore>`.
Continuing from the previous examples, here's an example of restoration:

.. code-block:: python
Expand All @@ -342,7 +342,8 @@ def step(self):

tuner = Tuner.restore(
"/path/to/experiment/",
overwrite_trainable=tune.with_parameters(MyTrainable, data=data)
trainable=tune.with_parameters(MyTrainable, data=data),
# ...
)

"""
Expand Down
38 changes: 29 additions & 9 deletions python/ray/tune/tuner.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Any, Callable, Dict, Optional, Type, Union, TYPE_CHECKING
import warnings

import ray

Expand Down Expand Up @@ -159,9 +160,13 @@ def __init__(
def restore(
cls,
path: str,
trainable: Optional[
Union[str, Callable, Type[Trainable], "BaseTrainer"]
] = None,
resume_unfinished: bool = True,
resume_errored: bool = False,
restart_errored: bool = False,
# Deprecated
overwrite_trainable: Optional[
Union[str, Callable, Type[Trainable], "BaseTrainer"]
] = None,
Expand Down Expand Up @@ -191,24 +196,39 @@ def restore(
console output of previous run.
Note: depending on whether ray client mode is used or not,
this path may or may not exist on your local machine.
trainable: The trainable to use upon resuming the experiment.
This should be the same trainable that was used to initialize
the original Tuner.
NOTE: Starting in 2.5, this will be a required parameter.
resume_unfinished: If True, will continue to run unfinished trials.
resume_errored: If True, will re-schedule errored trials and try to
restore from their latest checkpoints.
restart_errored: If True, will re-schedule errored trials but force
restarting them from scratch (no checkpoint will be loaded).
overwrite_trainable: A newly specified trainable that will overwrite
the trainable that was originally saved by Tune. This should
only be used to resume an experiment where the original trainable
is not fully serializable (e.g. when the trainable has object
references attached to it via ``tune.with_parameters``, the objects
they point to may not exist if restoring from a new Ray cluster).
NOTE: This API is experimental and should be used with caution.
overwrite_trainable: Deprecated. Use the `trainable` argument instead.
"""
# TODO(xwjiang): Add some comments to clarify the config behavior across
# retored runs.
# For example, is callbacks supposed to be automatically applied
# when a Tuner is restored and fit again?

if overwrite_trainable:
if not trainable:
trainable = overwrite_trainable
warning_message = (
"`overwrite_trainable` has been renamed to `trainable`. "
"The old argument will be removed starting from version 2.5."
)
warnings.warn(warning_message, DeprecationWarning)

if not trainable:
warning_message = (
"Passing in the experiment's `trainable` will be a required argument "
"to `Tuner.restore` starting from version 2.5. "
"Please specify the trainable to avoid this warning."
)
warnings.warn(warning_message)

resume_config = _ResumeConfig(
resume_unfinished=resume_unfinished,
resume_errored=resume_errored,
Expand All @@ -219,7 +239,7 @@ def restore(
tuner_internal = TunerInternal(
restore_path=path,
resume_config=resume_config,
trainable=overwrite_trainable,
trainable=trainable,
)
return Tuner(_tuner_internal=tuner_internal)
else:
Expand All @@ -228,7 +248,7 @@ def restore(
).remote(
restore_path=path,
resume_config=resume_config,
trainable=overwrite_trainable,
trainable=trainable,
)
return Tuner(_tuner_internal=tuner_internal)

Expand Down