Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[tune] Introduce tune.with_resources() to specify function trainable resources #26830

Merged
merged 6 commits into from
Jul 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion python/ray/tune/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
from ray.tune.search import create_searcher
from ray.tune.schedulers import create_scheduler
from ray.tune.execution.placement_groups import PlacementGroupFactory
from ray.tune.trainable.util import with_parameters
from ray.tune.trainable.util import with_parameters, with_resources

from ray._private.usage import usage_lib

Expand All @@ -57,6 +57,7 @@
"run",
"run_experiments",
"with_parameters",
"with_resources",
"Stopper",
"Experiment",
"function",
Expand Down
85 changes: 83 additions & 2 deletions python/ray/tune/tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

import gym
import numpy as np
import pytest
import ray
from ray import tune
from ray.air._internal.remote_storage import _ensure_directory
Expand Down Expand Up @@ -1409,6 +1410,88 @@ def step(self):
assert sys.getsizeof(dumped) < 100 * 1024


@pytest.fixture
def ray_start_2_cpus_2_gpus():
address_info = ray.init(num_cpus=2, num_gpus=2)
yield address_info
# The code after the yield will run as teardown code.
ray.shutdown()


@pytest.mark.parametrize("num_gpus", [1, 2])
def test_with_resources_dict(ray_start_2_cpus_2_gpus, num_gpus):
def train_fn(config):
return len(ray.get_gpu_ids())

[trial] = tune.run(
tune.with_resources(train_fn, resources={"gpu": num_gpus})
).trials

assert trial.last_result["_metric"] == num_gpus


@pytest.mark.parametrize("num_gpus", [1, 2])
def test_with_resources_pgf(ray_start_2_cpus_2_gpus, num_gpus):
def train_fn(config):
return len(ray.get_gpu_ids())

[trial] = tune.run(
tune.with_resources(
train_fn, resources=PlacementGroupFactory([{"GPU": num_gpus}])
)
).trials

assert trial.last_result["_metric"] == num_gpus


@pytest.mark.parametrize("num_gpus", [1, 2])
def test_with_resources_fn(ray_start_2_cpus_2_gpus, num_gpus):
def train_fn(config):
return len(ray.get_gpu_ids())

[trial] = tune.run(
tune.with_resources(
train_fn,
resources=lambda config: PlacementGroupFactory(
[{"GPU": config["use_gpus"]}]
),
),
config={"use_gpus": num_gpus},
).trials

assert trial.last_result["_metric"] == num_gpus


@pytest.mark.parametrize("num_gpus", [1, 2])
def test_with_resources_class_fn(ray_start_2_cpus_2_gpus, num_gpus):
class MyTrainable(tune.Trainable):
def step(self):
return {"_metric": len(ray.get_gpu_ids()), "done": True}

def save_checkpoint(self, checkpoint_dir: str):
pass

def load_checkpoint(self, checkpoint):
pass

@classmethod
def default_resource_request(cls, config):
# This will be overwritten by tune.with_trainables()
return PlacementGroupFactory([{"CPU": 2, "GPU": 0}])

[trial] = tune.run(
tune.with_resources(
MyTrainable,
resources=lambda config: PlacementGroupFactory(
[{"GPU": config["use_gpus"]}]
),
),
config={"use_gpus": num_gpus},
).trials

assert trial.last_result["_metric"] == num_gpus


class SerializabilityTest(unittest.TestCase):
@classmethod
def setUpClass(cls):
Expand Down Expand Up @@ -1821,6 +1904,4 @@ def __init__(


if __name__ == "__main__":
import pytest

sys.exit(pytest.main(["-v", __file__]))
14 changes: 13 additions & 1 deletion python/ray/tune/trainable/function_trainable.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,14 @@
import warnings
from functools import partial
from numbers import Number
from typing import Any, Callable, Dict, Optional, Type
from typing import Any, Callable, Dict, Optional, Type, Union

from ray.tune.resources import Resources
from six.moves import queue

from ray.air.checkpoint import Checkpoint
from ray.tune import TuneError
from ray.tune.execution.placement_groups import PlacementGroupFactory
from ray.tune.trainable import session
from ray.tune.result import (
DEFAULT_METRIC,
Expand Down Expand Up @@ -641,6 +643,8 @@ def wrap_function(
DeprecationWarning,
)

resources = getattr(train_func, "_resources", None)

class ImplicitFunc(*inherit_from):
_name = name or (
train_func.__name__ if hasattr(train_func, "__name__") else "func"
Expand Down Expand Up @@ -685,4 +689,12 @@ def handle_output(output):
reporter(**{RESULT_DUPLICATE: True})
return output

@classmethod
def default_resource_request(
cls, config: Dict[str, Any]
) -> Optional[Union[Resources, PlacementGroupFactory]]:
if not isinstance(resources, PlacementGroupFactory) and callable(resources):
return resources(config)
return resources

return ImplicitFunc
98 changes: 95 additions & 3 deletions python/ray/tune/trainable/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,24 @@
import logging
import os
import shutil
from typing import Any, Dict, Optional, Union
from typing import Any, Callable, Dict, Optional, Type, Union, TYPE_CHECKING

import pandas as pd

import ray
import ray.cloudpickle as pickle
from ray.tune.execution.placement_groups import (
PlacementGroupFactory,
resource_dict_to_pg_factory,
)
from ray.tune.registry import _ParameterRegistry
from ray.tune.resources import Resources
from ray.tune.utils import detect_checkpoint_function
from ray.util import placement_group
from ray.util.annotations import DeveloperAPI
from ray.util.annotations import DeveloperAPI, PublicAPI

if TYPE_CHECKING:
from ray.tune.trainable import Trainable

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -226,7 +234,8 @@ def get_remote_worker_options(
return options, pg


def with_parameters(trainable, **kwargs):
@PublicAPI(stability="beta")
def with_parameters(trainable: Union[Type["Trainable"], Callable], **kwargs):
"""Wrapper for trainables to pass arbitrary large data objects.

This wrapper function will store all passed parameters in the Ray
Expand Down Expand Up @@ -367,3 +376,86 @@ def _inner(config):
inner.__mixins__ = trainable.__mixins__

return inner


@PublicAPI(stability="beta")
def with_resources(
trainable: Union[Type["Trainable"], Callable],
resources: Union[
Dict[str, float], PlacementGroupFactory, Callable[[dict], PlacementGroupFactory]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am a little hesitant about adding PlacementGroupFactory actually.

We need a better API. How about just plain resources dict for now? Until someone yells at us?
it covers most simple cases.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We use it for resources_per_trial so it would be a bit odd to leave it out imo.
We could add ScalingConfig here... :-D

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should keep the placement group factory for now - it's a tune concept, it's advanced, but it's the same as in resources_per_trial. I'm happy to deprecate this once we have a good alternative, but then we can deprecate it in all places. Does that sound good?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok sg!

],
):
"""Wrapper for trainables to specify resource requests.

This wrapper allows specification of resource requirements for a specific
trainable. It will override potential existing resource requests (use
with caution!).

The main use case is to request resources for function trainables when used
with the Tuner() API.

Class trainables should usually just implement the ``default_resource_request()``
method.

Args:
trainable: Trainable to wrap.
resources: Resource dict, placement group factory, or callable that takes
in a config dict and returns a placement group factory.

Example:

.. code-block:: python

from ray import tune
from ray.tune.tuner import Tuner

def train(config):
return len(ray.get_gpu_ids()) # Returns 2

tuner = Tuner(
tune.with_resources(train, resources={"gpu": 2}),
# ...
)
krfricke marked this conversation as resolved.
Show resolved Hide resolved
results = tuner.fit()

"""
from ray.tune.trainable import Trainable

if not callable(trainable) or (
inspect.isclass(trainable) and not issubclass(trainable, Trainable)
):
raise ValueError(
f"`tune.with_parameters() only works with function trainables "
f"or classes that inherit from `tune.Trainable()`. Got type: "
f"{type(trainable)}."
)

if isinstance(resources, PlacementGroupFactory):
pgf = resources
elif isinstance(resources, dict):
pgf = resource_dict_to_pg_factory(resources)
elif callable(resources):
pgf = resources
else:
raise ValueError(
f"Invalid resource type for `with_resources()`: {type(resources)}"
)

if not inspect.isclass(trainable):
# Just set an attribute. This will be resolved later in `wrap_function()`.
trainable._resources = pgf
else:

class ResourceTrainable(trainable):
@classmethod
def default_resource_request(
cls, config: Dict[str, Any]
) -> Optional[Union[Resources, PlacementGroupFactory]]:
if not isinstance(pgf, PlacementGroupFactory) and callable(pgf):
return pgf(config)
return pgf

ResourceTrainable.__name__ = trainable.__name__
trainable = ResourceTrainable

return trainable