diff --git a/python/ray/tune/__init__.py b/python/ray/tune/__init__.py index f9613dce787d..14997b6489d7 100644 --- a/python/ray/tune/__init__.py +++ b/python/ray/tune/__init__.py @@ -39,7 +39,7 @@ from ray.tune.search import create_searcher from ray.tune.schedulers import create_scheduler from ray.tune.execution.placement_groups import PlacementGroupFactory -from ray.tune.trainable.util import with_parameters +from ray.tune.trainable.util import with_parameters, with_resources from ray._private.usage import usage_lib @@ -57,6 +57,7 @@ "run", "run_experiments", "with_parameters", + "with_resources", "Stopper", "Experiment", "function", diff --git a/python/ray/tune/tests/test_api.py b/python/ray/tune/tests/test_api.py index 2f7dd4c7dd9e..7c7001ade46c 100644 --- a/python/ray/tune/tests/test_api.py +++ b/python/ray/tune/tests/test_api.py @@ -12,6 +12,7 @@ import gym import numpy as np +import pytest import ray from ray import tune from ray.air._internal.remote_storage import _ensure_directory @@ -1409,6 +1410,88 @@ def step(self): assert sys.getsizeof(dumped) < 100 * 1024 +@pytest.fixture +def ray_start_2_cpus_2_gpus(): + address_info = ray.init(num_cpus=2, num_gpus=2) + yield address_info + # The code after the yield will run as teardown code. + ray.shutdown() + + +@pytest.mark.parametrize("num_gpus", [1, 2]) +def test_with_resources_dict(ray_start_2_cpus_2_gpus, num_gpus): + def train_fn(config): + return len(ray.get_gpu_ids()) + + [trial] = tune.run( + tune.with_resources(train_fn, resources={"gpu": num_gpus}) + ).trials + + assert trial.last_result["_metric"] == num_gpus + + +@pytest.mark.parametrize("num_gpus", [1, 2]) +def test_with_resources_pgf(ray_start_2_cpus_2_gpus, num_gpus): + def train_fn(config): + return len(ray.get_gpu_ids()) + + [trial] = tune.run( + tune.with_resources( + train_fn, resources=PlacementGroupFactory([{"GPU": num_gpus}]) + ) + ).trials + + assert trial.last_result["_metric"] == num_gpus + + +@pytest.mark.parametrize("num_gpus", [1, 2]) +def test_with_resources_fn(ray_start_2_cpus_2_gpus, num_gpus): + def train_fn(config): + return len(ray.get_gpu_ids()) + + [trial] = tune.run( + tune.with_resources( + train_fn, + resources=lambda config: PlacementGroupFactory( + [{"GPU": config["use_gpus"]}] + ), + ), + config={"use_gpus": num_gpus}, + ).trials + + assert trial.last_result["_metric"] == num_gpus + + +@pytest.mark.parametrize("num_gpus", [1, 2]) +def test_with_resources_class_fn(ray_start_2_cpus_2_gpus, num_gpus): + class MyTrainable(tune.Trainable): + def step(self): + return {"_metric": len(ray.get_gpu_ids()), "done": True} + + def save_checkpoint(self, checkpoint_dir: str): + pass + + def load_checkpoint(self, checkpoint): + pass + + @classmethod + def default_resource_request(cls, config): + # This will be overwritten by tune.with_trainables() + return PlacementGroupFactory([{"CPU": 2, "GPU": 0}]) + + [trial] = tune.run( + tune.with_resources( + MyTrainable, + resources=lambda config: PlacementGroupFactory( + [{"GPU": config["use_gpus"]}] + ), + ), + config={"use_gpus": num_gpus}, + ).trials + + assert trial.last_result["_metric"] == num_gpus + + class SerializabilityTest(unittest.TestCase): @classmethod def setUpClass(cls): @@ -1821,6 +1904,4 @@ def __init__( if __name__ == "__main__": - import pytest - sys.exit(pytest.main(["-v", __file__])) diff --git a/python/ray/tune/trainable/function_trainable.py b/python/ray/tune/trainable/function_trainable.py index b52a0b0ecfa4..dc6d1d4c7270 100644 --- a/python/ray/tune/trainable/function_trainable.py +++ b/python/ray/tune/trainable/function_trainable.py @@ -9,12 +9,14 @@ import warnings from functools import partial from numbers import Number -from typing import Any, Callable, Dict, Optional, Type +from typing import Any, Callable, Dict, Optional, Type, Union +from ray.tune.resources import Resources from six.moves import queue from ray.air.checkpoint import Checkpoint from ray.tune import TuneError +from ray.tune.execution.placement_groups import PlacementGroupFactory from ray.tune.trainable import session from ray.tune.result import ( DEFAULT_METRIC, @@ -641,6 +643,8 @@ def wrap_function( DeprecationWarning, ) + resources = getattr(train_func, "_resources", None) + class ImplicitFunc(*inherit_from): _name = name or ( train_func.__name__ if hasattr(train_func, "__name__") else "func" @@ -685,4 +689,12 @@ def handle_output(output): reporter(**{RESULT_DUPLICATE: True}) return output + @classmethod + def default_resource_request( + cls, config: Dict[str, Any] + ) -> Optional[Union[Resources, PlacementGroupFactory]]: + if not isinstance(resources, PlacementGroupFactory) and callable(resources): + return resources(config) + return resources + return ImplicitFunc diff --git a/python/ray/tune/trainable/util.py b/python/ray/tune/trainable/util.py index 0cf24774ec2d..c78bfb10820a 100644 --- a/python/ray/tune/trainable/util.py +++ b/python/ray/tune/trainable/util.py @@ -3,16 +3,24 @@ import logging import os import shutil -from typing import Any, Dict, Optional, Union +from typing import Any, Callable, Dict, Optional, Type, Union, TYPE_CHECKING import pandas as pd import ray import ray.cloudpickle as pickle +from ray.tune.execution.placement_groups import ( + PlacementGroupFactory, + resource_dict_to_pg_factory, +) from ray.tune.registry import _ParameterRegistry +from ray.tune.resources import Resources from ray.tune.utils import detect_checkpoint_function from ray.util import placement_group -from ray.util.annotations import DeveloperAPI +from ray.util.annotations import DeveloperAPI, PublicAPI + +if TYPE_CHECKING: + from ray.tune.trainable import Trainable logger = logging.getLogger(__name__) @@ -226,7 +234,8 @@ def get_remote_worker_options( return options, pg -def with_parameters(trainable, **kwargs): +@PublicAPI(stability="beta") +def with_parameters(trainable: Union[Type["Trainable"], Callable], **kwargs): """Wrapper for trainables to pass arbitrary large data objects. This wrapper function will store all passed parameters in the Ray @@ -367,3 +376,86 @@ def _inner(config): inner.__mixins__ = trainable.__mixins__ return inner + + +@PublicAPI(stability="beta") +def with_resources( + trainable: Union[Type["Trainable"], Callable], + resources: Union[ + Dict[str, float], PlacementGroupFactory, Callable[[dict], PlacementGroupFactory] + ], +): + """Wrapper for trainables to specify resource requests. + + This wrapper allows specification of resource requirements for a specific + trainable. It will override potential existing resource requests (use + with caution!). + + The main use case is to request resources for function trainables when used + with the Tuner() API. + + Class trainables should usually just implement the ``default_resource_request()`` + method. + + Args: + trainable: Trainable to wrap. + resources: Resource dict, placement group factory, or callable that takes + in a config dict and returns a placement group factory. + + Example: + + .. code-block:: python + + from ray import tune + from ray.tune.tuner import Tuner + + def train(config): + return len(ray.get_gpu_ids()) # Returns 2 + + tuner = Tuner( + tune.with_resources(train, resources={"gpu": 2}), + # ... + ) + results = tuner.fit() + + """ + from ray.tune.trainable import Trainable + + if not callable(trainable) or ( + inspect.isclass(trainable) and not issubclass(trainable, Trainable) + ): + raise ValueError( + f"`tune.with_parameters() only works with function trainables " + f"or classes that inherit from `tune.Trainable()`. Got type: " + f"{type(trainable)}." + ) + + if isinstance(resources, PlacementGroupFactory): + pgf = resources + elif isinstance(resources, dict): + pgf = resource_dict_to_pg_factory(resources) + elif callable(resources): + pgf = resources + else: + raise ValueError( + f"Invalid resource type for `with_resources()`: {type(resources)}" + ) + + if not inspect.isclass(trainable): + # Just set an attribute. This will be resolved later in `wrap_function()`. + trainable._resources = pgf + else: + + class ResourceTrainable(trainable): + @classmethod + def default_resource_request( + cls, config: Dict[str, Any] + ) -> Optional[Union[Resources, PlacementGroupFactory]]: + if not isinstance(pgf, PlacementGroupFactory) and callable(pgf): + return pgf(config) + return pgf + + ResourceTrainable.__name__ = trainable.__name__ + trainable = ResourceTrainable + + return trainable