diff --git a/python/ray/ml/config.py b/python/ray/ml/config.py index 0f863fc91408..7a593d1f5550 100644 --- a/python/ray/ml/config.py +++ b/python/ray/ml/config.py @@ -1,13 +1,22 @@ from dataclasses import dataclass -from typing import Dict, Any, Optional, List, Mapping, Callable, Union, TYPE_CHECKING +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + List, + Mapping, + Optional, + Union, +) from ray.tune.syncer import SyncConfig from ray.util import PublicAPI if TYPE_CHECKING: - from ray.tune.trainable import PlacementGroupFactory from ray.tune.callback import Callback from ray.tune.stopper import Stopper + from ray.tune.trainable import PlacementGroupFactory ScalingConfig = Dict[str, Any] diff --git a/python/ray/ml/tests/test_api.py b/python/ray/ml/tests/test_api.py index 9ce9f26ca312..4e6773791b2b 100644 --- a/python/ray/ml/tests/test_api.py +++ b/python/ray/ml/tests/test_api.py @@ -2,8 +2,13 @@ import ray from ray.ml import Checkpoint +from ray.ml.config import ScalingConfigDataClass from ray.ml.trainer import Trainer from ray.ml.preprocessor import Preprocessor +from ray.ml.utils.config import ( + ensure_only_allowed_dataclass_keys_updated, + ensure_only_allowed_dict_keys_set, +) class DummyTrainer(Trainer): @@ -53,6 +58,38 @@ def test_scaling_config(): DummyTrainer(scaling_config=None) +def test_scaling_config_validate_config_valid_class(): + scaling_config = {"num_workers": 2} + ensure_only_allowed_dataclass_keys_updated( + ScalingConfigDataClass(**scaling_config), ["num_workers"] + ) + + +def test_scaling_config_validate_config_valid_dict(): + scaling_config = {"num_workers": 2} + ensure_only_allowed_dict_keys_set(scaling_config, ["num_workers"]) + + +def test_scaling_config_validate_config_prohibited_class(): + # Check for prohibited keys + scaling_config = {"num_workers": 2} + with pytest.raises(ValueError): + ensure_only_allowed_dataclass_keys_updated( + ScalingConfigDataClass(**scaling_config), + ["trainer_resources"], + ) + + +def test_scaling_config_validate_config_prohibited_dict(): + # Check for prohibited keys + scaling_config = {"num_workers": 2} + with pytest.raises(ValueError): + ensure_only_allowed_dict_keys_set( + scaling_config, + ["trainer_resources"], + ) + + def test_datasets(): with pytest.raises(ValueError): DummyTrainer(datasets="invalid") diff --git a/python/ray/ml/tests/test_trainer.py b/python/ray/ml/tests/test_trainer.py index 7bbc4dd0d323..3b8897d8f0a1 100644 --- a/python/ray/ml/tests/test_trainer.py +++ b/python/ray/ml/tests/test_trainer.py @@ -27,6 +27,15 @@ def transform(self, ds): class DummyTrainer(Trainer): + _scaling_config_allowed_keys = [ + "num_workers", + "num_cpus_per_worker", + "num_gpus_per_worker", + "additional_resources_per_worker", + "use_gpu", + "trainer_resources", + ] + def __init__(self, train_loop, custom_arg=None, **kwargs): self.custom_arg = custom_arg self.train_loop = train_loop diff --git a/python/ray/ml/train/data_parallel_trainer.py b/python/ray/ml/train/data_parallel_trainer.py index 21fcb54552a5..a92215a8126a 100644 --- a/python/ray/ml/train/data_parallel_trainer.py +++ b/python/ray/ml/train/data_parallel_trainer.py @@ -7,7 +7,7 @@ from ray import tune from ray.ml.constants import TRAIN_DATASET_KEY, PREPROCESSOR_KEY from ray.ml.trainer import Trainer -from ray.ml.config import ScalingConfig, RunConfig, ScalingConfigDataClass +from ray.ml.config import ScalingConfig, RunConfig from ray.ml.trainer import GenDataset from ray.ml.preprocessor import Preprocessor from ray.ml.checkpoint import Checkpoint @@ -181,6 +181,14 @@ def __init__(self, train_loop_per_worker, my_backend_config: resume_from_checkpoint: A checkpoint to resume training from. """ + _scaling_config_allowed_keys = [ + "num_workers", + "num_cpus_per_worker", + "num_gpus_per_worker", + "additional_resources_per_worker", + "use_gpu", + ] + def __init__( self, *, @@ -250,7 +258,9 @@ def _validate_train_loop_per_worker( ) def training_loop(self) -> None: - scaling_config_dataclass = ScalingConfigDataClass(**self.scaling_config) + scaling_config_dataclass = self._validate_and_get_scaling_config_data_class( + self.scaling_config + ) train_loop_per_worker = construct_train_func( self.train_loop_per_worker, diff --git a/python/ray/ml/train/gbdt_trainer.py b/python/ray/ml/train/gbdt_trainer.py index 8637bb8f994d..d4edd902f946 100644 --- a/python/ray/ml/train/gbdt_trainer.py +++ b/python/ray/ml/train/gbdt_trainer.py @@ -65,6 +65,13 @@ class GBDTTrainer(Trainer): **train_kwargs: Additional kwargs passed to framework ``train()`` function. """ + _scaling_config_allowed_keys = [ + "num_workers", + "num_cpus_per_worker", + "num_gpus_per_worker", + "additional_resources_per_worker", + "use_gpu", + ] _dmatrix_cls: type _ray_params_cls: type _tune_callback_cls: type diff --git a/python/ray/ml/trainer.py b/python/ray/ml/trainer.py index ccfbf618774c..81bf00c80d5e 100644 --- a/python/ray/ml/trainer.py +++ b/python/ray/ml/trainer.py @@ -1,19 +1,22 @@ import abc import inspect import logging -from typing import Dict, Union, Callable, Optional, TYPE_CHECKING, Type +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Type, Union import ray - -from ray.ml.preprocessor import Preprocessor +from ray.util import PublicAPI from ray.ml.checkpoint import Checkpoint -from ray.ml.result import Result from ray.ml.config import RunConfig, ScalingConfig, ScalingConfigDataClass from ray.ml.constants import TRAIN_DATASET_KEY +from ray.ml.preprocessor import Preprocessor +from ray.ml.result import Result +from ray.ml.utils.config import ( + ensure_only_allowed_dataclass_keys_updated, + ensure_only_allowed_dict_keys_set, +) from ray.tune import Trainable from ray.tune.error import TuneError from ray.tune.function_runner import wrap_function -from ray.util import PublicAPI from ray.util.annotations import DeveloperAPI from ray.util.ml_utils.dict import merge_dicts @@ -133,6 +136,8 @@ def training_loop(self): resume_from_checkpoint: A checkpoint to resume training from. """ + _scaling_config_allowed_keys: List[str] = [] + def __init__( self, *, @@ -210,6 +215,25 @@ def _validate_attributes(self): f"with value `{self.resume_from_checkpoint}`." ) + @classmethod + def _validate_and_get_scaling_config_data_class( + cls, dataclass_or_dict: Union[ScalingConfigDataClass, Dict[str, Any]] + ) -> ScalingConfigDataClass: + """Return scaling config dataclass after validating updated keys.""" + if isinstance(dataclass_or_dict, dict): + ensure_only_allowed_dict_keys_set( + dataclass_or_dict, cls._scaling_config_allowed_keys + ) + scaling_config_dataclass = ScalingConfigDataClass(**dataclass_or_dict) + + return scaling_config_dataclass + + ensure_only_allowed_dataclass_keys_updated( + dataclass=dataclass_or_dict, + allowed_keys=cls._scaling_config_allowed_keys, + ) + return dataclass_or_dict + def setup(self) -> None: """Called during fit() to perform initial setup on the Trainer. @@ -359,8 +383,10 @@ def _trainable_func(self, config, reporter, checkpoint_dir): @classmethod def default_resource_request(cls, config): updated_scaling_config = config.get("scaling_config", scaling_config) - scaling_config_dataclass = ScalingConfigDataClass( - **updated_scaling_config + scaling_config_dataclass = ( + trainer_cls._validate_and_get_scaling_config_data_class( + updated_scaling_config + ) ) return scaling_config_dataclass.as_placement_group_factory() diff --git a/python/ray/ml/utils/config.py b/python/ray/ml/utils/config.py new file mode 100644 index 000000000000..3d09140c627a --- /dev/null +++ b/python/ray/ml/utils/config.py @@ -0,0 +1,56 @@ +import dataclasses +from typing import Iterable + + +def ensure_only_allowed_dict_keys_set( + data: dict, + allowed_keys: Iterable[str], +): + """ + Validate dict by raising an exception if any key not included in + ``allowed_keys`` is set. + + Args: + data: Dict to check. + allowed_keys: Iterable of keys that can be contained in dict keys. + """ + allowed_keys_set = set(allowed_keys) + bad_keys = [key for key in data.keys() if key not in allowed_keys_set] + + if bad_keys: + raise ValueError( + f"Key(s) {bad_keys} are not allowed to be set in the current context. " + "Remove them from the dict." + ) + + +def ensure_only_allowed_dataclass_keys_updated( + dataclass: dataclasses.dataclass, + allowed_keys: Iterable[str], +): + """ + Validate dataclass by raising an exception if any key not included in + ``allowed_keys`` differs from the default value. + + Args: + dataclass: Dict or dataclass to check. + allowed_keys: dataclass attribute keys that can have a value different than + the default one. + """ + default_data = dataclass.__class__() + + allowed_keys = set(allowed_keys) + + # These keys should not have been updated in the `dataclass` object + prohibited_keys = set(default_data.__dict__) - allowed_keys + + bad_keys = [ + key + for key in prohibited_keys + if dataclass.__dict__[key] != default_data.__dict__[key] + ] + if bad_keys: + raise ValueError( + f"Key(s) {bad_keys} are not allowed to be updated in the current context. " + "Remove them from the dataclass." + ) diff --git a/python/ray/tune/tests/test_tuner.py b/python/ray/tune/tests/test_tuner.py index 287d38f62cba..b952c62742bf 100644 --- a/python/ray/tune/tests/test_tuner.py +++ b/python/ray/tune/tests/test_tuner.py @@ -175,6 +175,15 @@ def on_step_end(self, iteration, trials, **kwargs): def test_tuner_trainer_fail(self): class DummyTrainer(Trainer): + _scaling_config_allowed_keys = [ + "num_workers", + "num_cpus_per_worker", + "num_gpus_per_worker", + "additional_resources_per_worker", + "use_gpu", + "trainer_resources", + ] + def training_loop(self) -> None: raise RuntimeError("There is an error in trainer!")