Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AIR] Add Scaling Config validation #23889

Merged
merged 20 commits into from
Apr 19, 2022
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions python/ray/ml/config.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,22 @@
from dataclasses import dataclass
from typing import Dict, Any, Optional, List, Mapping, Callable, Union, TYPE_CHECKING
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
List,
Mapping,
Optional,
Union,
krfricke marked this conversation as resolved.
Show resolved Hide resolved
)

from ray.tune.syncer import SyncConfig
from ray.util import PublicAPI

if TYPE_CHECKING:
from ray.tune.trainable import PlacementGroupFactory
from ray.tune.callback import Callback
from ray.tune.stopper import Stopper
from ray.tune.trainable import PlacementGroupFactory

ScalingConfig = Dict[str, Any]

Expand Down
37 changes: 37 additions & 0 deletions python/ray/ml/tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,13 @@

import ray
from ray.ml import Checkpoint
from ray.ml.config import ScalingConfigDataClass
from ray.ml.trainer import Trainer
from ray.ml.preprocessor import Preprocessor
from ray.ml.utils.config import (
ensure_only_allowed_dataclass_keys_updated,
ensure_only_allowed_dict_keys_set,
)


class DummyTrainer(Trainer):
Expand Down Expand Up @@ -53,6 +58,38 @@ def test_scaling_config():
DummyTrainer(scaling_config=None)


def test_scaling_config_validate_config_valid_class():
scaling_config = {"num_workers": 2}
ensure_only_allowed_dataclass_keys_updated(
ScalingConfigDataClass(**scaling_config), ["num_workers"]
)


def test_scaling_config_validate_config_valid_dict():
scaling_config = {"num_workers": 2}
ensure_only_allowed_dict_keys_set(scaling_config, ["num_workers"])


def test_scaling_config_validate_config_prohibited_class():
# Check for prohibited keys
scaling_config = {"num_workers": 2}
with pytest.raises(ValueError):
ensure_only_allowed_dataclass_keys_updated(
ScalingConfigDataClass(**scaling_config),
["trainer_resources"],
)


def test_scaling_config_validate_config_prohibited_dict():
# Check for prohibited keys
scaling_config = {"num_workers": 2}
with pytest.raises(ValueError):
ensure_only_allowed_dict_keys_set(
scaling_config,
["trainer_resources"],
)


def test_datasets():
with pytest.raises(ValueError):
DummyTrainer(datasets="invalid")
Expand Down
33 changes: 30 additions & 3 deletions python/ray/ml/trainer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import abc
import inspect
import logging
from typing import Dict, Union, Callable, Optional, TYPE_CHECKING, Type
from typing import Dict, Union, Callable, Optional, TYPE_CHECKING, Type, Any

import ray

Expand All @@ -10,6 +10,10 @@
from ray.ml.result import Result
from ray.ml.config import RunConfig, ScalingConfig, ScalingConfigDataClass
from ray.ml.constants import TRAIN_DATASET_KEY
from ray.ml.utils.config import (
ensure_only_allowed_dataclass_keys_updated,
ensure_only_allowed_dict_keys_set,
)
from ray.tune import Trainable
from ray.tune.error import TuneError
from ray.tune.function_runner import wrap_function
Expand Down Expand Up @@ -133,6 +137,8 @@ def training_loop(self):
resume_from_checkpoint: A checkpoint to resume training from.
"""

_scaling_config_allowed_keys = []
Yard1 marked this conversation as resolved.
Show resolved Hide resolved

def __init__(
self,
*,
Expand Down Expand Up @@ -210,6 +216,25 @@ def _validate_attributes(self):
f"with value `{self.resume_from_checkpoint}`."
)

@classmethod
def _validate_and_get_scaling_config_data_class(
cls, dataclass_or_dict: Union[ScalingConfigDataClass, Dict[str, Any]]
) -> ScalingConfigDataClass:
"""Return scaling config dataclass after validating updated keys."""
if isinstance(dataclass_or_dict, dict):
ensure_only_allowed_dict_keys_set(
dataclass_or_dict, cls._scaling_config_allowed_keys
)
scaling_config_dataclass = ScalingConfigDataClass(**dataclass_or_dict)

return scaling_config_dataclass

ensure_only_allowed_dataclass_keys_updated(
dataclass=dataclass_or_dict,
allowed_keys=cls._scaling_config_allowed_keys,
)
return dataclass_or_dict

def setup(self) -> None:
"""Called during fit() to perform initial setup on the Trainer.

Expand Down Expand Up @@ -359,8 +384,10 @@ def _trainable_func(self, config, reporter, checkpoint_dir):
@classmethod
def default_resource_request(cls, config):
updated_scaling_config = config.get("scaling_config", scaling_config)
scaling_config_dataclass = ScalingConfigDataClass(
**updated_scaling_config
scaling_config_dataclass = (
trainer_cls._validate_and_get_scaling_config_data_class(
updated_scaling_config
)
)
return scaling_config_dataclass.as_placement_group_factory()

Expand Down
56 changes: 56 additions & 0 deletions python/ray/ml/utils/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import dataclasses
from typing import Iterable


def ensure_only_allowed_dict_keys_set(
data: dict,
allowed_keys: Iterable[str],
):
"""
Validate dict by raising an exception if any key not included in
``allowed_keys`` is set.

Args:
data: Dict to check.
allowed_keys: Iterable of keys that can be contained in dict keys.
"""
allowed_keys_set = set(allowed_keys)
bad_keys = [key for key in data.keys() if key not in allowed_keys_set]

if bad_keys:
raise ValueError(
f"Key(s) {bad_keys} are not allowed to be set in the current context. "
"Remove them from the dict."
)


def ensure_only_allowed_dataclass_keys_updated(
dataclass: dataclasses.dataclass,
allowed_keys: Iterable[str],
):
"""
Validate dataclass by raising an exception if any key not included in
``allowed_keys`` differs from the default value.

Args:
dataclass: Dict or dataclass to check.
allowed_keys: dataclass attribute keys that can have a value different than
the default one.
"""
default_data = dataclass.__class__()

allowed_keys = set(allowed_keys)

# These keys should not have been updated in the `dataclass` object
prohibited_keys = set(default_data.__dict__) - allowed_keys

bad_keys = [
key
for key in prohibited_keys
if dataclass.__dict__[key] != default_data.__dict__[key]
]
if bad_keys:
raise ValueError(
f"Key(s) {bad_keys} are not allowed to be updated in the current context. "
"Remove them from the dataclass."
)