Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AIR] Refactor ScalingConfig key validation #25549

Merged
merged 8 commits into from
Jun 13, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 13 additions & 22 deletions python/ray/air/_internal/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,28 +2,6 @@
from typing import Iterable


def ensure_only_allowed_dict_keys_set(
data: dict,
allowed_keys: Iterable[str],
):
"""
Validate dict by raising an exception if any key not included in
``allowed_keys`` is set.

Args:
data: Dict to check.
allowed_keys: Iterable of keys that can be contained in dict keys.
"""
allowed_keys_set = set(allowed_keys)
bad_keys = [key for key in data.keys() if key not in allowed_keys_set]

if bad_keys:
raise ValueError(
f"Key(s) {bad_keys} are not allowed to be set in the current context. "
"Remove them from the dict."
)


def ensure_only_allowed_dataclass_keys_updated(
dataclass: dataclasses.dataclass,
allowed_keys: Iterable[str],
Expand All @@ -32,6 +10,9 @@ def ensure_only_allowed_dataclass_keys_updated(
Validate dataclass by raising an exception if any key not included in
``allowed_keys`` differs from the default value.

A ``ValueError`` will also be raised if any of the ``allowed_keys``
is not present in ``dataclass.__dict__``.

Args:
dataclass: Dict or dataclass to check.
allowed_keys: dataclass attribute keys that can have a value different than
Expand All @@ -41,6 +22,16 @@ def ensure_only_allowed_dataclass_keys_updated(

allowed_keys = set(allowed_keys)

# TODO: split keys_not_in_dict validation to a separate function.
keys_not_in_dict = [key for key in allowed_keys if key not in default_data.__dict__]
if keys_not_in_dict:
raise ValueError(
f"Key(s) {keys_not_in_dict} are not present in "
f"{dataclass.__class__.__name__}. "
"Remove them from `allowed_keys`. "
f"Valid keys: {list(default_data.__dict__.keys())}"
)
Comment on lines +26 to +33
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Somehow this feels weird to me, as this is meant to be for the developer rather than the end user. I.e. if we check in a Trainer that fails here this isn't really actionable for the user.

Is this more suitable to be added in tests instead?

Copy link
Member

@Yard1 Yard1 Jun 9, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, this will fail in the tests. There is no chance a user will see this error, unless they develop their own trainer (in which case that's actually good IMO. Makes it harder to shoot yourself in the foot).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah sounds reasonable. Can we update method name/docs?

Longer term maybe we should split this logic out so that we have one validation step for the Trainer definition (developer), and one for the instantiation (end user).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@matthewdeng done, ptal


# These keys should not have been updated in the `dataclass` object
prohibited_keys = set(default_data.__dict__) - allowed_keys

Expand Down
28 changes: 12 additions & 16 deletions python/ray/air/tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,7 @@
from ray.air.config import ScalingConfigDataClass
from ray.train import BaseTrainer
from ray.air.preprocessor import Preprocessor
from ray.air._internal.config import (
ensure_only_allowed_dataclass_keys_updated,
ensure_only_allowed_dict_keys_set,
)
from ray.air._internal.config import ensure_only_allowed_dataclass_keys_updated


class DummyTrainer(BaseTrainer):
Expand Down Expand Up @@ -65,29 +62,28 @@ def test_scaling_config_validate_config_valid_class():
)


def test_scaling_config_validate_config_valid_dict():
scaling_config = {"num_workers": 2}
ensure_only_allowed_dict_keys_set(scaling_config, ["num_workers"])


def test_scaling_config_validate_config_prohibited_class():
# Check for prohibited keys
scaling_config = {"num_workers": 2}
with pytest.raises(ValueError):
with pytest.raises(ValueError) as exc_info:
ensure_only_allowed_dataclass_keys_updated(
ScalingConfigDataClass(**scaling_config),
["trainer_resources"],
)
assert "num_workers" in str(exc_info.value)
assert "to be updated" in str(exc_info.value)


def test_scaling_config_validate_config_prohibited_dict():
# Check for prohibited keys
def test_scaling_config_validate_config_bad_allowed_keys():
# Check for keys not present in dict
scaling_config = {"num_workers": 2}
with pytest.raises(ValueError):
ensure_only_allowed_dict_keys_set(
scaling_config,
["trainer_resources"],
with pytest.raises(ValueError) as exc_info:
ensure_only_allowed_dataclass_keys_updated(
ScalingConfigDataClass(**scaling_config),
["BAD_KEY"],
)
assert "BAD_KEY" in str(exc_info.value)
assert "are not present in" in str(exc_info.value)


def test_datasets():
Expand Down
12 changes: 2 additions & 10 deletions python/ray/train/base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,7 @@
ScalingConfigDataClass,
)
from ray.air.result import Result
from ray.air._internal.config import (
ensure_only_allowed_dataclass_keys_updated,
ensure_only_allowed_dict_keys_set,
)
from ray.air._internal.config import ensure_only_allowed_dataclass_keys_updated
from ray.tune import Trainable
from ray.tune.error import TuneError
from ray.tune.function_runner import wrap_function
Expand Down Expand Up @@ -225,12 +222,7 @@ def _validate_and_get_scaling_config_data_class(
) -> ScalingConfigDataClass:
"""Return scaling config dataclass after validating updated keys."""
if isinstance(dataclass_or_dict, dict):
ensure_only_allowed_dict_keys_set(
dataclass_or_dict, cls._scaling_config_allowed_keys
)
scaling_config_dataclass = ScalingConfigDataClass(**dataclass_or_dict)

return scaling_config_dataclass
dataclass_or_dict = ScalingConfigDataClass(**dataclass_or_dict)

ensure_only_allowed_dataclass_keys_updated(
dataclass=dataclass_or_dict,
Expand Down
4 changes: 1 addition & 3 deletions python/ray/train/data_parallel_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,11 +228,9 @@ def __init__(self, train_loop_per_worker, my_backend_config:

_scaling_config_allowed_keys = BaseTrainer._scaling_config_allowed_keys + [
"num_workers",
"num_cpus_per_worker",
"num_gpus_per_worker",
"resources_per_worker",
"additional_resources_per_worker",
"use_gpu",
"placement_strategy",
]

_dataset_config = {
Expand Down
28 changes: 17 additions & 11 deletions python/ray/train/gbdt_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,15 @@


def _convert_scaling_config_to_ray_params(
scaling_config: ScalingConfig,
scaling_config: ScalingConfigDataClass,
ray_params_cls: Type["xgboost_ray.RayParams"],
default_ray_params: Optional[Dict[str, Any]] = None,
) -> "xgboost_ray.RayParams":
default_ray_params = default_ray_params or {}
scaling_config_dataclass = ScalingConfigDataClass(**scaling_config)
resources_per_worker = scaling_config_dataclass.additional_resources_per_worker
num_workers = scaling_config_dataclass.num_workers
cpus_per_worker = scaling_config_dataclass.num_cpus_per_worker
gpus_per_worker = scaling_config_dataclass.num_gpus_per_worker
resources_per_worker = scaling_config.additional_resources_per_worker
num_workers = scaling_config.num_workers
cpus_per_worker = scaling_config.num_cpus_per_worker
gpus_per_worker = scaling_config.num_gpus_per_worker

ray_params = ray_params_cls(
num_actors=int(num_workers),
Expand Down Expand Up @@ -67,11 +66,9 @@ class GBDTTrainer(BaseTrainer):

_scaling_config_allowed_keys = BaseTrainer._scaling_config_allowed_keys + [
"num_workers",
"num_cpus_per_worker",
"num_gpus_per_worker",
"resources_per_worker",
"additional_resources_per_worker",
"use_gpu",
"placement_strategy",
]
_dmatrix_cls: type
_ray_params_cls: type
Expand Down Expand Up @@ -143,8 +140,11 @@ def _train(self, **kwargs):

@property
def _ray_params(self) -> "xgboost_ray.RayParams":
scaling_config_dataclass = self._validate_and_get_scaling_config_data_class(
self.scaling_config
)
return _convert_scaling_config_to_ray_params(
self.scaling_config, self._ray_params_cls, self._default_ray_params
scaling_config_dataclass, self._ray_params_cls, self._default_ray_params
)

def preprocess_datasets(self) -> None:
Expand Down Expand Up @@ -197,6 +197,7 @@ def training_loop(self) -> None:

def as_trainable(self) -> Type[Trainable]:
trainable_cls = super().as_trainable()
trainer_cls = self.__class__
scaling_config = self.scaling_config
ray_params_cls = self._ray_params_cls
default_ray_params = self._default_ray_params
Expand All @@ -214,8 +215,13 @@ def save_checkpoint(self, tmp_checkpoint_dir: str = ""):
@classmethod
def default_resource_request(cls, config):
updated_scaling_config = config.get("scaling_config", scaling_config)
scaling_config_dataclass = (
trainer_cls._validate_and_get_scaling_config_data_class(
updated_scaling_config
)
)
return _convert_scaling_config_to_ray_params(
updated_scaling_config, ray_params_cls, default_ray_params
scaling_config_dataclass, ray_params_cls, default_ray_params
).get_tune_resources()

return GBDTTrainable
7 changes: 3 additions & 4 deletions python/ray/train/tests/test_base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,11 @@ def transform(self, ds):

class DummyTrainer(BaseTrainer):
_scaling_config_allowed_keys = [
"trainer_resources",
"num_workers",
"num_cpus_per_worker",
"num_gpus_per_worker",
"additional_resources_per_worker",
"use_gpu",
"trainer_resources",
"resources_per_worker",
"placement_strategy",
]

def __init__(self, train_loop, custom_arg=None, **kwargs):
Expand Down
2 changes: 1 addition & 1 deletion python/ray/train/tests/test_sklearn_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def test_validation(ray_start_4_cpus):
label_column="target",
datasets={TRAIN_DATASET_KEY: train_dataset, "cv": valid_dataset},
)
with pytest.raises(ValueError, match="are not allowed to be set"):
with pytest.raises(ValueError, match="are not allowed to be updated"):
SklearnTrainer(
estimator=RandomForestClassifier(),
scaling_config={"num_workers": 2},
Expand Down
7 changes: 3 additions & 4 deletions python/ray/tune/tests/test_tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,11 @@

class DummyTrainer(BaseTrainer):
_scaling_config_allowed_keys = [
"trainer_resources",
"num_workers",
"num_cpus_per_worker",
"num_gpus_per_worker",
"additional_resources_per_worker",
"use_gpu",
"trainer_resources",
"resources_per_worker",
"placement_strategy",
]

def training_loop(self) -> None:
Expand Down