Skip to content

Commit

Permalink
[Tune] Fix CheckpointConfig validation for function trainables (#31255)
Browse files Browse the repository at this point in the history
This fixes an issue where a ValueError wasn't being properly raised when passing in a function trainable and setting `checkpoint_at_end=True` or `checkpoint_frequency > 0`. Previously, the error was only raised for function trainables of the form `def train_func(config, checkpoint_dir):`, which is the old checkpoint dir function API.

Signed-off-by: Justin Yu <[email protected]>
  • Loading branch information
justinvyu authored and AmeerHajAli committed Jan 12, 2023
1 parent 5def089 commit 0d34108
Show file tree
Hide file tree
Showing 7 changed files with 121 additions and 38 deletions.
27 changes: 12 additions & 15 deletions python/ray/tune/experiment/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import datetime
from functools import partial
import grpc
import inspect
import logging
import os
from pathlib import Path
Expand All @@ -24,11 +23,11 @@

from ray.air import CheckpointConfig
from ray.tune.error import TuneError
from ray.tune.registry import register_trainable
from ray.tune.registry import register_trainable, is_function_trainable
from ray.tune.result import DEFAULT_RESULTS_DIR
from ray.tune.stopper import CombinedStopper, FunctionStopper, Stopper, TimeoutStopper
from ray.tune.syncer import SyncConfig
from ray.tune.utils import date_str, _detect_checkpoint_function
from ray.tune.utils import date_str

from ray.util.annotations import DeveloperAPI

Expand Down Expand Up @@ -169,23 +168,21 @@ def __init__(
else:
checkpoint_config = checkpoint_config or CheckpointConfig()

if (
callable(run)
and not inspect.isclass(run)
and _detect_checkpoint_function(run)
):
if is_function_trainable(run):
if checkpoint_config.checkpoint_at_end:
raise ValueError(
"'checkpoint_at_end' cannot be used with a "
"checkpointable function. You can specify "
"and register checkpoints within "
"your trainable function."
"'checkpoint_at_end' cannot be used with a function trainable. "
"You should include one last call to "
"`ray.air.session.report(metrics=..., checkpoint=...)` at the end "
"of your training loop to get this behavior."
)
if checkpoint_config.checkpoint_frequency:
raise ValueError(
"'checkpoint_freq' cannot be used with a "
"checkpointable function. You can specify checkpoints "
"within your trainable function."
"'checkpoint_frequency' cannot be set for a function trainable. "
"You will need to report a checkpoint every "
"`checkpoint_frequency` iterations within your training loop using "
"`ray.air.session.report(metrics=..., checkpoint=...)` "
"to get this behavior."
)
try:
self._run_identifier = Experiment.register_if_needed(run)
Expand Down
27 changes: 16 additions & 11 deletions python/ray/tune/impl/tuner_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,36 +460,41 @@ def _get_tune_run_arguments(self, trainable: TrainableType) -> Dict[str, Any]:
# If we specifically know this trainable doesn't support the
# argument, raise an error
raise ValueError(
f"You passed `checkpoint_freq={checkpoint_freq}` to your "
f"CheckpointConfig, but this trainer does not support "
f"this argument. If the trainer takes in a training loop, "
f"you will need to trigger checkpointing yourself using "
f"`ray.air.session.report(metrics=..., checkpoint=...)`."
f"You passed `checkpoint_frequency={checkpoint_freq}` to your "
"CheckpointConfig, but this trainer does not support "
"this argument. If you passed in an AIR trainer that takes in a "
"custom training loop, you will need to "
"report a checkpoint every `checkpoint_frequency` iterations "
"within your training loop using "
"`ray.air.session.report(metrics=..., checkpoint=...)` "
"to get this behavior."
)
elif handle_checkpoint_freq is True:
# If we specifically support it, it's handled in the training loop,
# so we disable tune's bookkeeping.
checkpoint_freq = 0
# Otherwise, this is a non-trainer trainable and we just keep the
# Otherwise, the trainable is not an AIR trainer and we just keep the
# user-supplied value.

# Function trainables will raise a runtime error later if set > 0
if checkpoint_at_end is not None:
# Again, function trainables usually don't handle this argument.
handle_cp_at_end = getattr(trainable, "_handles_checkpoint_at_end", None)
if handle_cp_at_end is False:
# If we specifically know we don't support it, raise an error.
raise ValueError(
f"You passed `checkpoint_at_end={checkpoint_at_end}` to your "
f"CheckpointConfig, but this trainer does not support "
f"this argument. If the trainer takes in a training loop, "
f"you will need to trigger checkpointing yourself using "
f"`ray.air.session.report(metrics=..., checkpoint=...)`. "
"CheckpointConfig, but this trainer does not support "
"this argument. If you passed in an AIR trainer that takes in a "
"custom training loop, you should include one last call to "
"`ray.air.session.report(metrics=..., checkpoint=...)` "
"at the end of your training loop to get this behavior."
)
elif handle_cp_at_end is True:
# If we specifically support it, it's handled in the training loop,
# so we disable tune's internal bookkeeping.
checkpoint_at_end = False
# If this is a user-defined trainable, just keep the value
# Function trainables will raise a runtime error later if set to True
else:
# Set default to False for function trainables and True for everything else
if is_function_trainable(trainable):
Expand Down
20 changes: 15 additions & 5 deletions python/ray/tune/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,14 +59,24 @@ def validate_trainable(trainable_name):

@DeveloperAPI
def is_function_trainable(trainable: Union[str, Callable, Type]) -> bool:
"""Check if a given trainable is a function trainable."""
"""Check if a given trainable is a function trainable.
Either the trainable has been wrapped as a FunctionTrainable class already,
or it's still a FunctionType/partial/callable."""
from ray.tune.trainable import FunctionTrainable

if isinstance(trainable, str):
trainable = get_trainable_cls(trainable)

return not isinstance(trainable, type) and (
isinstance(trainable, FunctionType)
or isinstance(trainable, partial)
or callable(trainable)
is_wrapped_func = isinstance(trainable, type) and issubclass(
trainable, FunctionTrainable
)
return is_wrapped_func or (
not isinstance(trainable, type)
and (
isinstance(trainable, FunctionType)
or isinstance(trainable, partial)
or callable(trainable)
)
)


Expand Down
27 changes: 24 additions & 3 deletions python/ray/tune/tests/test_experiment.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import unittest
import threading
import unittest

import ray
from ray.rllib import _register_all
from ray.air import CheckpointConfig
from ray.tune import register_trainable, SyncConfig
from ray.tune.experiment import Experiment, _convert_to_experiment_list
from ray.tune.error import TuneError
Expand All @@ -21,7 +21,6 @@ def test_remote_checkpoint_dir_with_query_string():
class ExperimentTest(unittest.TestCase):
def tearDown(self):
ray.shutdown()
_register_all() # re-register the evicted objects

def setUp(self):
def train(config, reporter):
Expand Down Expand Up @@ -63,6 +62,28 @@ def testConvertExperimentJSON(self):
def testConvertExperimentIncorrect(self):
self.assertRaises(TuneError, lambda: _convert_to_experiment_list("hi"))

def testFuncTrainableCheckpointConfigValidation(self):
"""Raise an error when trying to specify checkpoint_at_end/checkpoint_frequency
with a function trainable."""
with self.assertRaises(ValueError):
Experiment(
name="foo",
run="f1", # Will point to a wrapped function trainable
checkpoint_config=CheckpointConfig(checkpoint_at_end=True),
)
with self.assertRaises(ValueError):
Experiment(
name="foo",
run="f1",
checkpoint_config=CheckpointConfig(checkpoint_frequency=1),
)
with self.assertRaises(ValueError):
Experiment(
name="foo",
run=lambda config: 1,
checkpoint_config=CheckpointConfig(checkpoint_at_end=True),
)


class ValidateUtilTest(unittest.TestCase):
def testDiagnoseSerialization(self):
Expand Down
1 change: 0 additions & 1 deletion python/ray/tune/tests/test_experiment_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ def nan_test_exp(self):
name="testing_nan",
local_dir=self.test_dir,
stop={"training_iteration": 1},
checkpoint_freq=1,
num_samples=self.num_samples,
config={
"width": tune.sample_from(lambda spec: 10 + int(90 * random.random())),
Expand Down
53 changes: 51 additions & 2 deletions python/ray/tune/tests/test_tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
)
from ray.data import Dataset, Datasource, ReadTask, from_pandas, read_datasource
from ray.data.block import BlockMetadata
from ray.train.data_parallel_trainer import DataParallelTrainer
from ray.train.torch import TorchTrainer
from ray.train.trainer import BaseTrainer
from ray.train.xgboost import XGBoostTrainer
Expand Down Expand Up @@ -351,16 +352,64 @@ def catch_kwargs(**kwargs):
assert assertion(caught_kwargs)


def test_tuner_fn_trainable_checkpoint_at_end_true(shutdown_only):
def test_tuner_fn_trainable_invalid_checkpoint_config(shutdown_only):
tuner = Tuner(
lambda config, checkpoint_dir: 1,
lambda config: 1,
run_config=ray.air.RunConfig(
checkpoint_config=ray.air.CheckpointConfig(checkpoint_at_end=True)
),
)
with pytest.raises(ValueError):
tuner.fit()

tuner = Tuner(
lambda config: 1,
run_config=ray.air.RunConfig(
checkpoint_config=ray.air.CheckpointConfig(checkpoint_frequency=1)
),
)
with pytest.raises(ValueError):
tuner.fit()


def test_tuner_trainer_checkpoint_config(shutdown_only):
custom_training_loop_trainer = DataParallelTrainer(
train_loop_per_worker=lambda config: 1
)
tuner = Tuner(
custom_training_loop_trainer,
run_config=ray.air.RunConfig(
checkpoint_config=ray.air.CheckpointConfig(checkpoint_at_end=True)
),
)
with pytest.raises(ValueError):
tuner.fit()

tuner = Tuner(
custom_training_loop_trainer,
run_config=ray.air.RunConfig(
checkpoint_config=ray.air.CheckpointConfig(checkpoint_frequency=1)
),
)
with pytest.raises(ValueError):
tuner.fit()

handles_checkpoints_trainer = XGBoostTrainer(
label_column="target",
params={},
datasets={"train": ray.data.from_items(list(range(5)))},
)
tuner = Tuner(
handles_checkpoints_trainer,
run_config=ray.air.RunConfig(
checkpoint_config=ray.air.CheckpointConfig(
checkpoint_at_end=True, checkpoint_frequency=1
)
),
)._local_tuner
# Check that validation passes for a Trainer that does handle checkpointing
tuner._get_tune_run_arguments(tuner.converted_trainable)


def test_tuner_fn_trainable_checkpoint_at_end_false(shutdown_only):
tuner = Tuner(
Expand Down
4 changes: 3 additions & 1 deletion python/ray/tune/tests/test_tuner_restore.py
Original file line number Diff line number Diff line change
Expand Up @@ -626,7 +626,9 @@ def create_trainable_with_params():
local_dir=str(tmp_path),
stop={"training_iteration": 3},
failure_config=FailureConfig(max_failures=0),
checkpoint_config=CheckpointConfig(checkpoint_frequency=1),
checkpoint_config=CheckpointConfig(
checkpoint_frequency=0 if use_function_trainable else 1
),
),
param_space={"fail_marker": fail_marker},
)
Expand Down

0 comments on commit 0d34108

Please sign in to comment.