Skip to content

Commit

Permalink
[tune] Introduce tune.with_resources() to specify function trainable …
Browse files Browse the repository at this point in the history
…resources (ray-project#26830)

We don't have a way to specify resource requirements with the Tuner() API. This PR introduces tune.with_resources() to attach a resource request to class and function trainables. In class trainables, it will override potential existing default resource requests.

Signed-off-by: Kai Fricke <[email protected]>
Signed-off-by: Rohan138 <[email protected]>
  • Loading branch information
krfricke authored and Rohan138 committed Jul 28, 2022
1 parent 51e8429 commit 54239c5
Show file tree
Hide file tree
Showing 4 changed files with 193 additions and 7 deletions.
3 changes: 2 additions & 1 deletion python/ray/tune/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
from ray.tune.search import create_searcher
from ray.tune.schedulers import create_scheduler
from ray.tune.execution.placement_groups import PlacementGroupFactory
from ray.tune.trainable.util import with_parameters
from ray.tune.trainable.util import with_parameters, with_resources

from ray._private.usage import usage_lib

Expand All @@ -55,6 +55,7 @@
"run",
"run_experiments",
"with_parameters",
"with_resources",
"Stopper",
"Experiment",
"sample_from",
Expand Down
85 changes: 83 additions & 2 deletions python/ray/tune/tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

import gym
import numpy as np
import pytest
import ray
from ray import tune
from ray.air._internal.remote_storage import _ensure_directory
Expand Down Expand Up @@ -1409,6 +1410,88 @@ def step(self):
assert sys.getsizeof(dumped) < 100 * 1024


@pytest.fixture
def ray_start_2_cpus_2_gpus():
address_info = ray.init(num_cpus=2, num_gpus=2)
yield address_info
# The code after the yield will run as teardown code.
ray.shutdown()


@pytest.mark.parametrize("num_gpus", [1, 2])
def test_with_resources_dict(ray_start_2_cpus_2_gpus, num_gpus):
def train_fn(config):
return len(ray.get_gpu_ids())

[trial] = tune.run(
tune.with_resources(train_fn, resources={"gpu": num_gpus})
).trials

assert trial.last_result["_metric"] == num_gpus


@pytest.mark.parametrize("num_gpus", [1, 2])
def test_with_resources_pgf(ray_start_2_cpus_2_gpus, num_gpus):
def train_fn(config):
return len(ray.get_gpu_ids())

[trial] = tune.run(
tune.with_resources(
train_fn, resources=PlacementGroupFactory([{"GPU": num_gpus}])
)
).trials

assert trial.last_result["_metric"] == num_gpus


@pytest.mark.parametrize("num_gpus", [1, 2])
def test_with_resources_fn(ray_start_2_cpus_2_gpus, num_gpus):
def train_fn(config):
return len(ray.get_gpu_ids())

[trial] = tune.run(
tune.with_resources(
train_fn,
resources=lambda config: PlacementGroupFactory(
[{"GPU": config["use_gpus"]}]
),
),
config={"use_gpus": num_gpus},
).trials

assert trial.last_result["_metric"] == num_gpus


@pytest.mark.parametrize("num_gpus", [1, 2])
def test_with_resources_class_fn(ray_start_2_cpus_2_gpus, num_gpus):
class MyTrainable(tune.Trainable):
def step(self):
return {"_metric": len(ray.get_gpu_ids()), "done": True}

def save_checkpoint(self, checkpoint_dir: str):
pass

def load_checkpoint(self, checkpoint):
pass

@classmethod
def default_resource_request(cls, config):
# This will be overwritten by tune.with_trainables()
return PlacementGroupFactory([{"CPU": 2, "GPU": 0}])

[trial] = tune.run(
tune.with_resources(
MyTrainable,
resources=lambda config: PlacementGroupFactory(
[{"GPU": config["use_gpus"]}]
),
),
config={"use_gpus": num_gpus},
).trials

assert trial.last_result["_metric"] == num_gpus


class SerializabilityTest(unittest.TestCase):
@classmethod
def setUpClass(cls):
Expand Down Expand Up @@ -1821,6 +1904,4 @@ def __init__(


if __name__ == "__main__":
import pytest

sys.exit(pytest.main(["-v", __file__]))
14 changes: 13 additions & 1 deletion python/ray/tune/trainable/function_trainable.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,14 @@
import warnings
from functools import partial
from numbers import Number
from typing import Any, Callable, Dict, Optional, Type
from typing import Any, Callable, Dict, Optional, Type, Union

from ray.tune.resources import Resources
from six.moves import queue

from ray.air.checkpoint import Checkpoint
from ray.tune import TuneError
from ray.tune.execution.placement_groups import PlacementGroupFactory
from ray.tune.trainable import session
from ray.tune.result import (
DEFAULT_METRIC,
Expand Down Expand Up @@ -641,6 +643,8 @@ def wrap_function(
DeprecationWarning,
)

resources = getattr(train_func, "_resources", None)

class ImplicitFunc(*inherit_from):
_name = name or (
train_func.__name__ if hasattr(train_func, "__name__") else "func"
Expand Down Expand Up @@ -685,4 +689,12 @@ def handle_output(output):
reporter(**{RESULT_DUPLICATE: True})
return output

@classmethod
def default_resource_request(
cls, config: Dict[str, Any]
) -> Optional[Union[Resources, PlacementGroupFactory]]:
if not isinstance(resources, PlacementGroupFactory) and callable(resources):
return resources(config)
return resources

return ImplicitFunc
98 changes: 95 additions & 3 deletions python/ray/tune/trainable/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,24 @@
import logging
import os
import shutil
from typing import Any, Dict, Optional, Union
from typing import Any, Callable, Dict, Optional, Type, Union, TYPE_CHECKING

import pandas as pd

import ray
import ray.cloudpickle as pickle
from ray.tune.execution.placement_groups import (
PlacementGroupFactory,
resource_dict_to_pg_factory,
)
from ray.tune.registry import _ParameterRegistry
from ray.tune.resources import Resources
from ray.tune.utils import detect_checkpoint_function
from ray.util import placement_group
from ray.util.annotations import DeveloperAPI
from ray.util.annotations import DeveloperAPI, PublicAPI

if TYPE_CHECKING:
from ray.tune.trainable import Trainable

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -226,7 +234,8 @@ def get_remote_worker_options(
return options, pg


def with_parameters(trainable, **kwargs):
@PublicAPI(stability="beta")
def with_parameters(trainable: Union[Type["Trainable"], Callable], **kwargs):
"""Wrapper for trainables to pass arbitrary large data objects.
This wrapper function will store all passed parameters in the Ray
Expand Down Expand Up @@ -367,3 +376,86 @@ def _inner(config):
inner.__mixins__ = trainable.__mixins__

return inner


@PublicAPI(stability="beta")
def with_resources(
trainable: Union[Type["Trainable"], Callable],
resources: Union[
Dict[str, float], PlacementGroupFactory, Callable[[dict], PlacementGroupFactory]
],
):
"""Wrapper for trainables to specify resource requests.
This wrapper allows specification of resource requirements for a specific
trainable. It will override potential existing resource requests (use
with caution!).
The main use case is to request resources for function trainables when used
with the Tuner() API.
Class trainables should usually just implement the ``default_resource_request()``
method.
Args:
trainable: Trainable to wrap.
resources: Resource dict, placement group factory, or callable that takes
in a config dict and returns a placement group factory.
Example:
.. code-block:: python
from ray import tune
from ray.tune.tuner import Tuner
def train(config):
return len(ray.get_gpu_ids()) # Returns 2
tuner = Tuner(
tune.with_resources(train, resources={"gpu": 2}),
# ...
)
results = tuner.fit()
"""
from ray.tune.trainable import Trainable

if not callable(trainable) or (
inspect.isclass(trainable) and not issubclass(trainable, Trainable)
):
raise ValueError(
f"`tune.with_parameters() only works with function trainables "
f"or classes that inherit from `tune.Trainable()`. Got type: "
f"{type(trainable)}."
)

if isinstance(resources, PlacementGroupFactory):
pgf = resources
elif isinstance(resources, dict):
pgf = resource_dict_to_pg_factory(resources)
elif callable(resources):
pgf = resources
else:
raise ValueError(
f"Invalid resource type for `with_resources()`: {type(resources)}"
)

if not inspect.isclass(trainable):
# Just set an attribute. This will be resolved later in `wrap_function()`.
trainable._resources = pgf
else:

class ResourceTrainable(trainable):
@classmethod
def default_resource_request(
cls, config: Dict[str, Any]
) -> Optional[Union[Resources, PlacementGroupFactory]]:
if not isinstance(pgf, PlacementGroupFactory) and callable(pgf):
return pgf(config)
return pgf

ResourceTrainable.__name__ = trainable.__name__
trainable = ResourceTrainable

return trainable

0 comments on commit 54239c5

Please sign in to comment.