From 7bd06e29505c53d7c9c9ce3a67995d2de0eb527c Mon Sep 17 00:00:00 2001 From: Jimmy Yao Date: Tue, 7 Jun 2022 15:03:52 +0000 Subject: [PATCH 1/7] restruct #25350 --- python/ray/air/train/data_parallel_trainer.py | 1 - python/ray/air/train/gbdt_trainer.py | 1 - python/ray/air/trainer.py | 9 ++------- 3 files changed, 2 insertions(+), 9 deletions(-) diff --git a/python/ray/air/train/data_parallel_trainer.py b/python/ray/air/train/data_parallel_trainer.py index 89fe7ed014d3..22021189734f 100644 --- a/python/ray/air/train/data_parallel_trainer.py +++ b/python/ray/air/train/data_parallel_trainer.py @@ -229,7 +229,6 @@ def __init__(self, train_loop_per_worker, my_backend_config: "num_workers", "num_cpus_per_worker", "num_gpus_per_worker", - "resources_per_worker", "additional_resources_per_worker", "use_gpu", ] diff --git a/python/ray/air/train/gbdt_trainer.py b/python/ray/air/train/gbdt_trainer.py index 579098d9d000..ef38fb632902 100644 --- a/python/ray/air/train/gbdt_trainer.py +++ b/python/ray/air/train/gbdt_trainer.py @@ -69,7 +69,6 @@ class GBDTTrainer(Trainer): "num_workers", "num_cpus_per_worker", "num_gpus_per_worker", - "resources_per_worker", "additional_resources_per_worker", "use_gpu", ] diff --git a/python/ray/air/trainer.py b/python/ray/air/trainer.py index c4809da090c0..c8428de35f84 100644 --- a/python/ray/air/trainer.py +++ b/python/ray/air/trainer.py @@ -225,13 +225,8 @@ def _validate_and_get_scaling_config_data_class( ) -> ScalingConfigDataClass: """Return scaling config dataclass after validating updated keys.""" if isinstance(dataclass_or_dict, dict): - ensure_only_allowed_dict_keys_set( - dataclass_or_dict, cls._scaling_config_allowed_keys - ) - scaling_config_dataclass = ScalingConfigDataClass(**dataclass_or_dict) - - return scaling_config_dataclass - + dataclass_or_dict = ScalingConfigDataClass(**dataclass_or_dict) + ensure_only_allowed_dataclass_keys_updated( dataclass=dataclass_or_dict, allowed_keys=cls._scaling_config_allowed_keys, From 1bcd2867991fcfc423c9a2528b64181eaae6760c Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Tue, 7 Jun 2022 15:49:33 +0000 Subject: [PATCH 2/7] Add additional validation, remove unused func --- python/ray/air/_internal/config.py | 31 +++++++++--------------------- python/ray/air/tests/test_api.py | 26 ++++++++++--------------- python/ray/air/trainer.py | 7 ++----- 3 files changed, 21 insertions(+), 43 deletions(-) diff --git a/python/ray/air/_internal/config.py b/python/ray/air/_internal/config.py index 3d09140c627a..63410d3f1b66 100644 --- a/python/ray/air/_internal/config.py +++ b/python/ray/air/_internal/config.py @@ -2,28 +2,6 @@ from typing import Iterable -def ensure_only_allowed_dict_keys_set( - data: dict, - allowed_keys: Iterable[str], -): - """ - Validate dict by raising an exception if any key not included in - ``allowed_keys`` is set. - - Args: - data: Dict to check. - allowed_keys: Iterable of keys that can be contained in dict keys. - """ - allowed_keys_set = set(allowed_keys) - bad_keys = [key for key in data.keys() if key not in allowed_keys_set] - - if bad_keys: - raise ValueError( - f"Key(s) {bad_keys} are not allowed to be set in the current context. " - "Remove them from the dict." - ) - - def ensure_only_allowed_dataclass_keys_updated( dataclass: dataclasses.dataclass, allowed_keys: Iterable[str], @@ -41,6 +19,15 @@ def ensure_only_allowed_dataclass_keys_updated( allowed_keys = set(allowed_keys) + keys_not_in_dict = [key for key in allowed_keys if key not in default_data.__dict__] + if keys_not_in_dict: + raise ValueError( + f"Key(s) {keys_not_in_dict} are not present in " + f"{dataclass.__class__.__name__}. " + "Remove them from `allowed_keys`. " + f"Valid keys: {list(default_data.__dict__.keys())}" + ) + # These keys should not have been updated in the `dataclass` object prohibited_keys = set(default_data.__dict__) - allowed_keys diff --git a/python/ray/air/tests/test_api.py b/python/ray/air/tests/test_api.py index 94c190a7dc26..b91e5ed9b6ba 100644 --- a/python/ray/air/tests/test_api.py +++ b/python/ray/air/tests/test_api.py @@ -5,10 +5,7 @@ from ray.air.config import ScalingConfigDataClass from ray.air.trainer import Trainer from ray.air.preprocessor import Preprocessor -from ray.air._internal.config import ( - ensure_only_allowed_dataclass_keys_updated, - ensure_only_allowed_dict_keys_set, -) +from ray.air._internal.config import ensure_only_allowed_dataclass_keys_updated class DummyTrainer(Trainer): @@ -65,29 +62,26 @@ def test_scaling_config_validate_config_valid_class(): ) -def test_scaling_config_validate_config_valid_dict(): - scaling_config = {"num_workers": 2} - ensure_only_allowed_dict_keys_set(scaling_config, ["num_workers"]) - - def test_scaling_config_validate_config_prohibited_class(): # Check for prohibited keys scaling_config = {"num_workers": 2} - with pytest.raises(ValueError): + with pytest.raises(ValueError) as exc_info: ensure_only_allowed_dataclass_keys_updated( ScalingConfigDataClass(**scaling_config), ["trainer_resources"], ) + assert "trainer_resources" in str(exc_info.value) -def test_scaling_config_validate_config_prohibited_dict(): - # Check for prohibited keys +def test_scaling_config_validate_config_bad_allowed_keys(): + # Check for keys not present in dict scaling_config = {"num_workers": 2} - with pytest.raises(ValueError): - ensure_only_allowed_dict_keys_set( - scaling_config, - ["trainer_resources"], + with pytest.raises(ValueError) as exc_info: + ensure_only_allowed_dataclass_keys_updated( + ScalingConfigDataClass(**scaling_config), + ["BAD_KEY"], ) + assert "BAD_KEY" in str(exc_info.value) def test_datasets(): diff --git a/python/ray/air/trainer.py b/python/ray/air/trainer.py index c8428de35f84..2f1ae291a846 100644 --- a/python/ray/air/trainer.py +++ b/python/ray/air/trainer.py @@ -14,10 +14,7 @@ ) from ray.air.preprocessor import Preprocessor from ray.air.result import Result -from ray.air._internal.config import ( - ensure_only_allowed_dataclass_keys_updated, - ensure_only_allowed_dict_keys_set, -) +from ray.air._internal.config import ensure_only_allowed_dataclass_keys_updated from ray.tune import Trainable from ray.tune.error import TuneError from ray.tune.function_runner import wrap_function @@ -226,7 +223,7 @@ def _validate_and_get_scaling_config_data_class( """Return scaling config dataclass after validating updated keys.""" if isinstance(dataclass_or_dict, dict): dataclass_or_dict = ScalingConfigDataClass(**dataclass_or_dict) - + ensure_only_allowed_dataclass_keys_updated( dataclass=dataclass_or_dict, allowed_keys=cls._scaling_config_allowed_keys, From a6f97f46c01cbd416f44f5118882f54410cd83c6 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Tue, 7 Jun 2022 15:59:03 +0000 Subject: [PATCH 3/7] Fix GBTD trainer --- python/ray/air/train/data_parallel_trainer.py | 5 ++-- python/ray/air/train/gbdt_trainer.py | 29 ++++++++++++------- python/ray/tune/tests/test_tuner.py | 7 ++--- 3 files changed, 23 insertions(+), 18 deletions(-) diff --git a/python/ray/air/train/data_parallel_trainer.py b/python/ray/air/train/data_parallel_trainer.py index 22021189734f..cd5ccb1a944a 100644 --- a/python/ray/air/train/data_parallel_trainer.py +++ b/python/ray/air/train/data_parallel_trainer.py @@ -227,10 +227,9 @@ def __init__(self, train_loop_per_worker, my_backend_config: _scaling_config_allowed_keys = Trainer._scaling_config_allowed_keys + [ "num_workers", - "num_cpus_per_worker", - "num_gpus_per_worker", - "additional_resources_per_worker", + "resources_per_worker", "use_gpu", + "placement_strategy", ] _dataset_config = { diff --git a/python/ray/air/train/gbdt_trainer.py b/python/ray/air/train/gbdt_trainer.py index ef38fb632902..591af4ee92a1 100644 --- a/python/ray/air/train/gbdt_trainer.py +++ b/python/ray/air/train/gbdt_trainer.py @@ -17,16 +17,15 @@ def _convert_scaling_config_to_ray_params( - scaling_config: ScalingConfig, + scaling_config: ScalingConfigDataClass, ray_params_cls: Type["xgboost_ray.RayParams"], default_ray_params: Optional[Dict[str, Any]] = None, ) -> "xgboost_ray.RayParams": default_ray_params = default_ray_params or {} - scaling_config_dataclass = ScalingConfigDataClass(**scaling_config) - resources_per_worker = scaling_config_dataclass.additional_resources_per_worker - num_workers = scaling_config_dataclass.num_workers - cpus_per_worker = scaling_config_dataclass.num_cpus_per_worker - gpus_per_worker = scaling_config_dataclass.num_gpus_per_worker + resources_per_worker = scaling_config.additional_resources_per_worker + num_workers = scaling_config.num_workers + cpus_per_worker = scaling_config.num_cpus_per_worker + gpus_per_worker = scaling_config.num_gpus_per_worker ray_params = ray_params_cls( num_actors=int(num_workers), @@ -67,10 +66,9 @@ class GBDTTrainer(Trainer): _scaling_config_allowed_keys = Trainer._scaling_config_allowed_keys + [ "num_workers", - "num_cpus_per_worker", - "num_gpus_per_worker", - "additional_resources_per_worker", + "resources_per_worker", "use_gpu", + "placement_strategy", ] _dmatrix_cls: type _ray_params_cls: type @@ -142,8 +140,11 @@ def _train(self, **kwargs): @property def _ray_params(self) -> "xgboost_ray.RayParams": + scaling_config_dataclass = self._validate_and_get_scaling_config_data_class( + self.scaling_config + ) return _convert_scaling_config_to_ray_params( - self.scaling_config, self._ray_params_cls, self._default_ray_params + scaling_config_dataclass, self._ray_params_cls, self._default_ray_params ) def preprocess_datasets(self) -> None: @@ -196,6 +197,7 @@ def training_loop(self) -> None: def as_trainable(self) -> Type[Trainable]: trainable_cls = super().as_trainable() + trainer_cls = self.__class__ scaling_config = self.scaling_config ray_params_cls = self._ray_params_cls default_ray_params = self._default_ray_params @@ -213,8 +215,13 @@ def save_checkpoint(self, tmp_checkpoint_dir: str = ""): @classmethod def default_resource_request(cls, config): updated_scaling_config = config.get("scaling_config", scaling_config) + scaling_config_dataclass = ( + trainer_cls._validate_and_get_scaling_config_data_class( + updated_scaling_config + ) + ) return _convert_scaling_config_to_ray_params( - updated_scaling_config, ray_params_cls, default_ray_params + scaling_config_dataclass, ray_params_cls, default_ray_params ).get_tune_resources() return GBDTTrainable diff --git a/python/ray/tune/tests/test_tuner.py b/python/ray/tune/tests/test_tuner.py index d24550f09678..4cda2a87c04f 100644 --- a/python/ray/tune/tests/test_tuner.py +++ b/python/ray/tune/tests/test_tuner.py @@ -25,12 +25,11 @@ class DummyTrainer(Trainer): _scaling_config_allowed_keys = [ + "trainer_resources", "num_workers", - "num_cpus_per_worker", - "num_gpus_per_worker", - "additional_resources_per_worker", "use_gpu", - "trainer_resources", + "resources_per_worker", + "placement_strategy", ] def training_loop(self) -> None: From b3331124cd40dbe7e0dec7f08a7a48a348285553 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Wed, 8 Jun 2022 20:12:17 +0000 Subject: [PATCH 4/7] Fix --- python/ray/air/tests/test_trainer.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/python/ray/air/tests/test_trainer.py b/python/ray/air/tests/test_trainer.py index ea0b8b455535..d21f1357167b 100644 --- a/python/ray/air/tests/test_trainer.py +++ b/python/ray/air/tests/test_trainer.py @@ -29,12 +29,11 @@ def transform(self, ds): class DummyTrainer(Trainer): _scaling_config_allowed_keys = [ + "trainer_resources", "num_workers", - "num_cpus_per_worker", - "num_gpus_per_worker", - "additional_resources_per_worker", "use_gpu", - "trainer_resources", + "resources_per_worker", + "placement_strategy", ] def __init__(self, train_loop, custom_arg=None, **kwargs): From be6437f65624114fdbd027150fe9082940c7c2cf Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Wed, 8 Jun 2022 21:42:06 +0000 Subject: [PATCH 5/7] Fix CI --- python/ray/air/tests/test_api.py | 4 +++- python/ray/air/tests/test_sklearn_trainer.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/python/ray/air/tests/test_api.py b/python/ray/air/tests/test_api.py index b91e5ed9b6ba..db46644a020f 100644 --- a/python/ray/air/tests/test_api.py +++ b/python/ray/air/tests/test_api.py @@ -70,7 +70,8 @@ def test_scaling_config_validate_config_prohibited_class(): ScalingConfigDataClass(**scaling_config), ["trainer_resources"], ) - assert "trainer_resources" in str(exc_info.value) + assert "num_workers" in str(exc_info.value) + assert "to be updated" in str(exc_info.value) def test_scaling_config_validate_config_bad_allowed_keys(): @@ -82,6 +83,7 @@ def test_scaling_config_validate_config_bad_allowed_keys(): ["BAD_KEY"], ) assert "BAD_KEY" in str(exc_info.value) + assert "are not present in" in str(exc_info.value) def test_datasets(): diff --git a/python/ray/air/tests/test_sklearn_trainer.py b/python/ray/air/tests/test_sklearn_trainer.py index b8b8804dc9f0..7e73a3a1d285 100644 --- a/python/ray/air/tests/test_sklearn_trainer.py +++ b/python/ray/air/tests/test_sklearn_trainer.py @@ -172,7 +172,7 @@ def test_validation(ray_start_4_cpus): label_column="target", datasets={TRAIN_DATASET_KEY: train_dataset, "cv": valid_dataset}, ) - with pytest.raises(ValueError, match="are not allowed to be set"): + with pytest.raises(ValueError, match="are not allowed to be updated"): SklearnTrainer( estimator=RandomForestClassifier(), scaling_config={"num_workers": 2}, From 1dfebef6a91b50726bcee2d248cf482f1ad746e8 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Thu, 9 Jun 2022 21:02:13 +0000 Subject: [PATCH 6/7] Add comments --- python/ray/air/_internal/config.py | 4 ++++ python/ray/air/data_batch_type.py | 10 ++++++++++ 2 files changed, 14 insertions(+) create mode 100644 python/ray/air/data_batch_type.py diff --git a/python/ray/air/_internal/config.py b/python/ray/air/_internal/config.py index 63410d3f1b66..ce1df8f77acc 100644 --- a/python/ray/air/_internal/config.py +++ b/python/ray/air/_internal/config.py @@ -10,6 +10,9 @@ def ensure_only_allowed_dataclass_keys_updated( Validate dataclass by raising an exception if any key not included in ``allowed_keys`` differs from the default value. + A ``ValueError`` will also be raised if any of the ``allowed_keys`` + is not present in ``dataclass.__dict__``. + Args: dataclass: Dict or dataclass to check. allowed_keys: dataclass attribute keys that can have a value different than @@ -19,6 +22,7 @@ def ensure_only_allowed_dataclass_keys_updated( allowed_keys = set(allowed_keys) + # TODO: split keys_not_in_dict validation to a separate function. keys_not_in_dict = [key for key in allowed_keys if key not in default_data.__dict__] if keys_not_in_dict: raise ValueError( diff --git a/python/ray/air/data_batch_type.py b/python/ray/air/data_batch_type.py new file mode 100644 index 000000000000..1af1fa60cf30 --- /dev/null +++ b/python/ray/air/data_batch_type.py @@ -0,0 +1,10 @@ +from typing import Dict, Union, TYPE_CHECKING + +if TYPE_CHECKING: + import numpy + import pandas + import pyarrow + +DataBatchType = Union[ + "numpy.ndarray", "pandas.DataFrame", "pyarrow.Table", Dict[str, "numpy.ndarray"] +] \ No newline at end of file From 1c5e8a7e5dd10c49e2ba482cfe2ed0a9da4b0330 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Thu, 9 Jun 2022 23:09:05 +0200 Subject: [PATCH 7/7] Delete data_batch_type.py --- python/ray/air/data_batch_type.py | 10 ---------- 1 file changed, 10 deletions(-) delete mode 100644 python/ray/air/data_batch_type.py diff --git a/python/ray/air/data_batch_type.py b/python/ray/air/data_batch_type.py deleted file mode 100644 index 1af1fa60cf30..000000000000 --- a/python/ray/air/data_batch_type.py +++ /dev/null @@ -1,10 +0,0 @@ -from typing import Dict, Union, TYPE_CHECKING - -if TYPE_CHECKING: - import numpy - import pandas - import pyarrow - -DataBatchType = Union[ - "numpy.ndarray", "pandas.DataFrame", "pyarrow.Table", Dict[str, "numpy.ndarray"] -] \ No newline at end of file