Skip to content

Commit

Permalink
[tune] ResourceChangingScheduler dynamic resource allocation during…
Browse files Browse the repository at this point in the history
… tuning (#16787)
  • Loading branch information
Yard1 authored Jul 14, 2021
1 parent cfc5806 commit 6e780eb
Show file tree
Hide file tree
Showing 15 changed files with 704 additions and 14 deletions.
7 changes: 6 additions & 1 deletion doc/source/tune/_tutorials/tune-xgboost.rst
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ from ``1.0``. Even in this simple example, most runs result
in a good accuracy of over ``0.90``.

Maybe you have noticed the ``config`` parameter we pass to the XGBoost algorithm. This
is a ``dict`` in which you can specify parameters for the XGBoost algorithm. In this
is a :class:`dict` in which you can specify parameters for the XGBoost algorithm. In this
simple example, the only parameters we passed are the ``objective`` and ``eval_metric`` parameters.
The value ``binary:logistic`` tells XGBoost that we aim to train a logistic regression model for
a binary classification task. You can find an overview over all valid objectives
Expand Down Expand Up @@ -441,6 +441,11 @@ The output of our run could look like this:
As you can see, most trials have been stopped only after a few iterations. Only the
two most promising trials were run for the full 10 iterations.

You can also ensure that all available resources are being used as the scheduler
terminates trials, freeing them up. This can be done through the
``ResourceChangingScheduler``. An example of this can be found here:
:doc:`/tune/examples/xgboost_dynamic_resources_example`.

Using fractional GPUs
---------------------
You can often accelerate your training by using GPUs in addition to CPUs. However,
Expand Down
15 changes: 15 additions & 0 deletions doc/source/tune/api_docs/schedulers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ Tune includes distributed implementations of early stopping algorithms such as `

When using schedulers, you may face compatibility issues, as shown in the below compatibility matrix. Certain schedulers cannot be used with Search Algorithms, and certain schedulers are require :ref:`checkpointing to be implemented <tune-checkpoint>`.

Schedulers can dynamically change trial resource requirements during tuning. This is currently implemented in ``ResourceChangingScheduler``, which can wrap around any other scheduler.

.. list-table:: TrialScheduler Feature Compatibility Matrix
:header-rows: 1

Expand Down Expand Up @@ -230,6 +232,19 @@ An example of this in use can be found here: :doc:`/tune/examples/bohb_example`.

.. autoclass:: ray.tune.schedulers.HyperBandForBOHB

ResourceChangingScheduler
-------------------------

This class is a utility scheduler, allowing for trial resource requirements to be changed during tuning. It wraps around another scheduler and uses its decisions. Currently, only supports the Trainable (class) API for tuning. Your Trainable must implement ``Trainable.update_resources``, which will let your model know about the new resources assigned.

An example of this in use can be found here: :doc:`/tune/examples/xgboost_dynamic_resources_example`.

.. autoclass:: ray.tune.schedulers.ResourceChangingScheduler

evenly_distribute_cpus_gpus
~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: ray.tune.schedulers.resource_changing_scheduler.evenly_distribute_cpus_gpus

FIFOScheduler
-------------
Expand Down
1 change: 1 addition & 0 deletions doc/source/tune/examples/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ XGBoost, LightGBM

- :ref:`XGBoost tutorial <tune-xgboost>`: A guide to tuning XGBoost parameters with Tune.
- :doc:`/tune/examples/xgboost_example`: Trains a basic XGBoost model with Tune with the function-based API and an XGBoost callback.
- :doc:`/tune/examples/xgboost_dynamic_resources_example`: Trains a basic XGBoost model with Tune with the class-based API and a ResourceChangingScheduler, ensuring all resources are being used at all time.
- :doc:`/tune/examples/lightgbm_example`: Trains a basic LightGBM model with Tune with the function-based API and a LightGBM callback.

RLlib
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
:orphan:

xgboost_dynamic_resources_example
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~


.. literalinclude:: /../../python/ray/tune/examples/xgboost_dynamic_resources_example.py
8 changes: 8 additions & 0 deletions python/ray/tune/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -825,6 +825,14 @@ py_test(
tags = ["exclusive", "example"]
)

py_test(
name = "xgboost_dynamic_resources_example",
size = "medium",
srcs = ["examples/xgboost_dynamic_resources_example.py"],
deps = [":tune_lib"],
tags = ["exclusive", "example"]
)

py_test(
name = "zoopt_example",
size = "small",
Expand Down
6 changes: 6 additions & 0 deletions python/ray/tune/examples/README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,12 @@ XGBoost Example
- `xgboost_example <https://github.com/ray-project/ray/blob/master/python/ray/tune/examples/xgboost_example.py>`__: Trains a basic XGBoost model with Tune with the function-based API and a XGBoost callback.


XGBoost with Dynamic Resources Example
--------------------------------------

- `xgboost_dynamic_resources_example <https://github.com/ray-project/ray/blob/master/python/ray/tune/examples/xgboost_dynamic_resources_example.py>`__: Trains a basic XGBoost model with Tune with the class-based API and a ResourceChangingScheduler, ensuring all resources are being used at all time.


LightGBM Example
----------------

Expand Down
232 changes: 232 additions & 0 deletions python/ray/tune/examples/xgboost_dynamic_resources_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,232 @@
from typing import Union, Dict, Any
import sklearn.datasets
import sklearn.metrics
import os
from sklearn.model_selection import train_test_split
import xgboost as xgb
from xgboost.core import Booster
import pickle

import ray
from ray import tune
from ray.tune.schedulers import ResourceChangingScheduler, ASHAScheduler
from ray.tune import Trainable
from ray.tune.resources import Resources
from ray.tune.utils.placement_groups import PlacementGroupFactory
from ray.tune.suggest.basic_variant import BasicVariantGenerator
from ray.tune.trial import Trial
from ray.tune import trial_runner


# Dynamic resource allocation is currently only possible
# with Trainable (class) API
class BreastCancerTrainable(Trainable):
def setup(self, config):
self.config = config
self.nthread = config.pop("nthread", 1)
self.model: xgb.Booster = None
# Load dataset
data, labels = sklearn.datasets.load_breast_cancer(return_X_y=True)
# Split into train and test set
train_x, test_x, train_y, test_y = train_test_split(
data, labels, test_size=0.25)
# Build input matrices for XGBoost
self.train_set = xgb.DMatrix(train_x, label=train_y)
self.test_set = xgb.DMatrix(test_x, label=test_y)

def step(self):
results = {}
config = self.config.copy()
config["nthread"] = int(self.nthread)
self.model = xgb.train(
config,
self.train_set,
evals=[(self.test_set, "eval")],
verbose_eval=False,
xgb_model=self.model,
evals_result=results,
num_boost_round=1)
print(config, results)
return {
"eval-logloss": results["eval"]["logloss"][-1],
"nthread": self.nthread
}

def save_checkpoint(self, checkpoint_dir):
path = os.path.join(checkpoint_dir, "checkpoint")
with open(path, "wb") as outputFile:
pickle.dump((self.config, self.nthread, self.model.save_raw()),
outputFile)
return path

def load_checkpoint(self, checkpoint_path):
with open(checkpoint_path, "rb") as inputFile:
self.config, self.nthread, raw_model = pickle.load(inputFile)
self.model = Booster()
self.model.load_model(bytearray(raw_model))
data, labels = sklearn.datasets.load_breast_cancer(return_X_y=True)
# Split into train and test set
train_x, test_x, train_y, test_y = train_test_split(
data, labels, test_size=0.25)
# Build input matrices for XGBoost
self.train_set = xgb.DMatrix(train_x, label=train_y)
self.test_set = xgb.DMatrix(test_x, label=test_y)

def update_resources(
self, new_resources: Union[PlacementGroupFactory, Resources]):
if isinstance(new_resources, PlacementGroupFactory):
self.nthread = new_resources.head_cpus
else:
self.nthread = new_resources.cpu


def get_best_model_checkpoint(analysis):
best_bst = xgb.Booster()
with open(analysis.best_checkpoint, "rb") as inputFile:
_, _, raw_model = pickle.load(inputFile)
best_bst.load_model(bytearray(raw_model))
accuracy = 1. - analysis.best_result["eval-logloss"]
print(f"Best model parameters: {analysis.best_config}")
print(f"Best model total accuracy: {accuracy:.4f}")
return best_bst


def tune_xgboost():
search_space = {
# You can mix constants with search space objects.
"objective": "binary:logistic",
"eval_metric": ["logloss", "error"],
"max_depth": 9,
"learning_rate": 1,
"min_child_weight": tune.grid_search([2, 3]),
"subsample": tune.grid_search([0.8, 0.9]),
"colsample_bynode": tune.grid_search([0.8, 0.9]),
"random_state": 1,
"num_parallel_tree": 2000,
}
# This will enable aggressive early stopping of bad trials.
base_scheduler = ASHAScheduler(
max_t=16, # 16 training iterations
grace_period=1,
reduction_factor=2)

def example_resources_allocation_function(
trial_runner: "trial_runner.TrialRunner", trial: Trial,
result: Dict[str, Any],
base_trial_resource: Union[PlacementGroupFactory, Resources]
) -> Union[None, PlacementGroupFactory, Resources]:
"""This is a basic example of a resource allocating function.
The function naively balances available CPUs over live trials.
This function returns a new ``PlacementGroupFactory`` with updated
resource requirements, or None. If the returned
``PlacementGroupFactory`` is equal by value to the one the
trial has currently, the scheduler will skip the update process
internally (same with None).
See :func:`evenly_distribute_cpus_gpus` for a more complex,
robust approach.
Args:
trial_runner (TrialRunner): Trial runner for this Tune run.
Can be used to obtain information about other trials.
trial (Trial): The trial to allocate new resources to.
result (Dict[str, Any]): The latest results of trial.
base_trial_resource (Union[PlacementGroupFactory, Resources]):
Base trial resources as defined in
``tune.run(resources_per_trial)``
"""

# Don't bother if this is just the first iteration
if result["training_iteration"] < 1:
return None

# default values if resources_per_trial is unspecified
if base_trial_resource is None:
base_trial_resource = PlacementGroupFactory([{"CPU": 1, "GPU": 0}])

# Assume that the number of CPUs cannot go below what was
# specified in tune.run
min_cpu = base_trial_resource.required_resources.get("CPU", 0)

# Get the number of CPUs available in total (not just free)
total_available_cpus = (
trial_runner.trial_executor._avail_resources.cpu)

# Divide the free CPUs among all live trials
cpu_to_use = max(
min_cpu,
total_available_cpus // len(trial_runner.get_live_trials()))

# Assign new CPUs to the trial in a PlacementGroupFactory
return PlacementGroupFactory([{"CPU": cpu_to_use}])

# You can either define your own resources_allocation_function, or
# use the default one - evenly_distribute_cpus_gpus

# from ray.tune.schedulers.resource_changing_scheduler import \
# evenly_distribute_cpus_gpus

scheduler = ResourceChangingScheduler(
base_scheduler=base_scheduler,
resources_allocation_function=example_resources_allocation_function
# resources_allocation_function=evenly_distribute_cpus_gpus # default
)

search = BasicVariantGenerator()

analysis = tune.run(
BreastCancerTrainable,
metric="eval-logloss",
mode="min",
resources_per_trial=PlacementGroupFactory([{
"CPU": 1,
"GPU": 0
}]),
config=search_space,
search_alg=search,
num_samples=1,
checkpoint_at_end=True,
scheduler=scheduler)

assert analysis.results_df["training_iteration"].max() == 16
assert analysis.results_df["nthread"].max() > 1

return analysis


if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
"--server-address",
type=str,
default=None,
required=False,
help="The address of server to connect to if using "
"Ray Client.")
args, _ = parser.parse_known_args()

if args.server_address:
ray.util.connect(args.server_address)
else:
ray.init(num_cpus=8)

analysis = tune_xgboost()

# Load the best model checkpoint.
if args.server_address:
# If connecting to a remote server with Ray Client, checkpoint loading
# should be wrapped in a task so it will execute on the server.
# We have to make sure it gets executed on the same node that
# ``tune.run`` is called on.
from ray.tune.utils import force_on_current_node
remote_fn = force_on_current_node(
ray.remote(get_best_model_checkpoint))
best_bst = ray.get(remote_fn.remote(analysis))
else:
best_bst = get_best_model_checkpoint(analysis)

# You could now do further predictions with
# best_bst.predict(...)
20 changes: 20 additions & 0 deletions python/ray/tune/ray_trial_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,7 @@ def _start_trial(self, trial, checkpoint=None, runner=None,
return False
trial.set_runner(runner)
self.restore(trial, checkpoint)
self._notify_trainable_of_new_resources_if_needed(trial)
self.set_status(trial, Trial.RUNNING)

if trial in self._staged_trials:
Expand All @@ -466,6 +467,22 @@ def _start_trial(self, trial, checkpoint=None, runner=None,
self._train(trial)
return True

def _notify_trainable_of_new_resources_if_needed(self, trial: Trial):
if trial.has_new_resources:
trainable = trial.runner
trial.has_new_resources = False
with self._change_working_directory(trial):
with warn_if_slow("update_resources"):
try:
ray.get(
trainable.update_resources.remote(
trial.placement_group_factory if trial.
uses_placement_groups else trial.resources),
timeout=DEFAULT_GET_TIMEOUT)
except GetTimeoutError:
logger.exception(
"Trial %s: updating resources timed out.", trial)

def _stop_trial(self,
trial,
error=False,
Expand Down Expand Up @@ -949,6 +966,9 @@ def on_step_end(self, trial_runner):

self._pg_manager.cleanup()

def force_reconcilation_on_next_step_end(self):
self.last_pg_recon = -float("inf")

def save(self, trial, storage=Checkpoint.PERSISTENT, result=None):
"""Saves the trial's state to a checkpoint asynchronously.
Expand Down
Loading

0 comments on commit 6e780eb

Please sign in to comment.