diff --git a/ci/jenkins_tests/run_tune_tests.sh b/ci/jenkins_tests/run_tune_tests.sh index 86b3e6a500d6..1b1cd5a8a8b9 100755 --- a/ci/jenkins_tests/run_tune_tests.sh +++ b/ci/jenkins_tests/run_tune_tests.sh @@ -109,3 +109,8 @@ $SUPPRESS_OUTPUT docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $SUPPRESS_OUTPUT docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ python /ray/python/ray/tune/examples/skopt_example.py \ --smoke-test + +# uncomment once statsmodels is updated. +# $SUPPRESS_OUTPUT docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ +# python /ray/python/ray/tune/examples/bohb_example.py \ +# --smoke-test diff --git a/doc/source/tune-schedulers.rst b/doc/source/tune-schedulers.rst index 12b802d06e8d..62938c09628f 100644 --- a/doc/source/tune-schedulers.rst +++ b/doc/source/tune-schedulers.rst @@ -131,13 +131,28 @@ On the other hand, holding ``R`` constant at ``R = 300`` and varying ``eta`` als The implementation takes the same configuration as the example given in the paper and exposes ``max_t``, which is not a parameter in the paper. -2. The example in the `post `_ to calculate ``n_0`` is actually a little different than the algorithm given in the paper. In this implementation, we implement ``n_0`` according to the paper (which is `n` in the below example): +2. The example in the `post `_ to calculate ``n_0`` is actually a little different than the algorithm given in the paper. In this implementation, we implement ``n_0`` according to the paper (which is `n` in the below example): .. image:: images/hyperband_allocation.png 3. There are also implementation specific details like how trials are placed into brackets which are not covered in the paper. This implementation places trials within brackets according to smaller bracket first - meaning that with low number of trials, there will be less early stopping. +HyperBand (BOHB) +---------------- + +.. tip:: This implementation is still experimental. Please report issues on https://github.com/ray-project/ray/issues/. Thanks! + +This class is a variant of HyperBand that enables the BOHB Algorithm. This implementation is true to the original HyperBand implementation and does not implement pipelining nor straggler mitigation. + +This is to be used in conjunction with the Tune BOHB search algorithm. See `TuneBOHB `_ for package requirements, examples, and details. + +An example of this in use can be found in `bohb_example.py `_. + +.. autoclass:: ray.tune.schedulers.HyperBandForBOHB + :noindex: + + Median Stopping Rule -------------------- diff --git a/doc/source/tune-searchalg.rst b/doc/source/tune-searchalg.rst index 2dae8eaf4abe..fc452630761d 100644 --- a/doc/source/tune-searchalg.rst +++ b/doc/source/tune-searchalg.rst @@ -18,6 +18,7 @@ Currently, Tune offers the following search algorithms (and library integrations - `Nevergrad `__ - `Scikit-Optimize `__ - `Ax `__ +- `BOHB `__ Variant Generation (Grid Search/Random Search) @@ -181,6 +182,53 @@ An example of this can be found in `ax_example.py `__ to perform sequential model-based hyperparameter optimization in conjunction with HyperBand. Note that this class does not extend ``ray.tune.suggest.BasicVariantGenerator``, so you will not be able to use Tune's default variant generation/search space declaration when using BOHB. + +Importantly, BOHB is intended to be paired with a specific scheduler class: `HyperBandForBOHB `__. + +This algorithm requires using the `ConfigSpace search space specification `_. In order to use this search algorithm, you will need to install ``HpBandSter`` and ``ConfigSpace``: + +.. code-block:: bash + + $ pip install hpbandster ConfigSpace + + +You can use ``TuneBOHB`` in conjunction with ``HyperBandForBOHB`` as follows: + +.. code-block:: python + + # BOHB uses ConfigSpace for their hyperparameter search space + import ConfigSpace as CS + + config_space = CS.ConfigurationSpace() + config_space.add_hyperparameter( + CS.UniformFloatHyperparameter("height", lower=10, upper=100)) + config_space.add_hyperparameter( + CS.UniformFloatHyperparameter("width", lower=0, upper=100)) + + experiment_metrics = dict(metric="episode_reward_mean", mode="min") + bohb_hyperband = HyperBandForBOHB( + time_attr="training_iteration", max_t=100, **experiment_metrics) + bohb_search = TuneBOHB( + config_space, max_concurrent=4, **experiment_metrics) + + tune.run(MyTrainableClass, + name="bohb_test", + scheduler=bohb_hyperband, + search_alg=bohb_search, + num_samples=5) + +Take a look at `an example here `_. See the `BOHB paper `_ for more details. + +.. autoclass:: ray.tune.suggest.bohb.TuneBOHB + :show-inheritance: + :noindex: + Contributing a New Algorithm ---------------------------- diff --git a/docker/examples/Dockerfile b/docker/examples/Dockerfile index baa126cc962f..9cbc3ae78b09 100644 --- a/docker/examples/Dockerfile +++ b/docker/examples/Dockerfile @@ -11,8 +11,6 @@ RUN pip install gym[atari] opencv-python-headless tensorflow lz4 keras pytest-ti RUN pip install -U h5py # Mutes FutureWarnings RUN pip install --upgrade bayesian-optimization RUN pip install --upgrade git+git://github.com/hyperopt/hyperopt.git -RUN pip install --upgrade sigopt -RUN pip install --upgrade nevergrad -RUN pip install --upgrade scikit-optimize +RUN pip install --upgrade sigopt nevergrad scikit-optimize hpbandster ConfigSpace RUN pip install -U pytest-remotedata>=0.3.1 RUN conda install pytorch-cpu torchvision-cpu -c pytorch diff --git a/docker/tune_test/Dockerfile b/docker/tune_test/Dockerfile index 75ae4e8d9025..f30e156a1cd2 100644 --- a/docker/tune_test/Dockerfile +++ b/docker/tune_test/Dockerfile @@ -13,9 +13,7 @@ RUN conda remove -y --force wrapt RUN pip install gym[atari]==0.10.11 opencv-python-headless tensorflow lz4 keras pytest-timeout smart_open RUN pip install --upgrade bayesian-optimization RUN pip install --upgrade git+git://github.com/hyperopt/hyperopt.git -RUN pip install --upgrade sigopt -RUN pip install --upgrade nevergrad -RUN pip install --upgrade scikit-optimize +RUN pip install --upgrade sigopt nevergrad scikit-optimize hpbandster ConfigSpace RUN pip install -U pytest-remotedata>=0.3.1 RUN conda install pytorch-cpu torchvision-cpu -c pytorch diff --git a/python/ray/tune/examples/bohb_example.py b/python/ray/tune/examples/bohb_example.py new file mode 100644 index 000000000000..11cac9790a0f --- /dev/null +++ b/python/ray/tune/examples/bohb_example.py @@ -0,0 +1,82 @@ +#!/usr/bin/env python + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import json +import os + +import numpy as np + +import ray +from ray.tune import Trainable, run +from ray.tune.schedulers.hb_bohb import HyperBandForBOHB +from ray.tune.suggest.bohb import TuneBOHB + +parser = argparse.ArgumentParser() +parser.add_argument( + "--smoke-test", action="store_true", help="Finish quickly for testing") +parser.add_argument( + "--ray-redis-address", + help="Address of Ray cluster for seamless distributed execution.") +args, _ = parser.parse_known_args() + + +class MyTrainableClass(Trainable): + """Example agent whose learning curve is a random sigmoid. + + The dummy hyperparameters "width" and "height" determine the slope and + maximum reward value reached. + """ + + def _setup(self, config): + self.timestep = 0 + + def _train(self): + self.timestep += 1 + v = np.tanh(float(self.timestep) / self.config.get("width", 1)) + v *= self.config.get("height", 1) + + # Here we use `episode_reward_mean`, but you can also report other + # objectives such as loss or accuracy. + return {"episode_reward_mean": v} + + def _save(self, checkpoint_dir): + path = os.path.join(checkpoint_dir, "checkpoint") + with open(path, "w") as f: + f.write(json.dumps({"timestep": self.timestep})) + return path + + def _restore(self, checkpoint_path): + with open(checkpoint_path) as f: + self.timestep = json.loads(f.read())["timestep"] + + +if __name__ == "__main__": + import ConfigSpace as CS + ray.init(redis_address=args.ray_redis_address) + + # BOHB uses ConfigSpace for their hyperparameter search space + config_space = CS.ConfigurationSpace() + config_space.add_hyperparameter( + CS.UniformFloatHyperparameter("height", lower=10, upper=100)) + config_space.add_hyperparameter( + CS.UniformFloatHyperparameter("width", lower=0, upper=100)) + + experiment_metrics = dict(metric="episode_reward_mean", mode="min") + bohb_hyperband = HyperBandForBOHB( + time_attr="training_iteration", + max_t=100, + reduction_factor=4, + **experiment_metrics) + bohb_search = TuneBOHB( + config_space, max_concurrent=4, **experiment_metrics) + + run(MyTrainableClass, + name="bohb_test", + scheduler=bohb_hyperband, + search_alg=bohb_search, + num_samples=10, + stop={"training_iteration": 10 if args.smoke_test else 100}) diff --git a/python/ray/tune/schedulers/__init__.py b/python/ray/tune/schedulers/__init__.py index 34655372f40a..3731724c92ea 100644 --- a/python/ray/tune/schedulers/__init__.py +++ b/python/ray/tune/schedulers/__init__.py @@ -4,6 +4,7 @@ from ray.tune.schedulers.trial_scheduler import TrialScheduler, FIFOScheduler from ray.tune.schedulers.hyperband import HyperBandScheduler +from ray.tune.schedulers.hb_bohb import HyperBandForBOHB from ray.tune.schedulers.async_hyperband import (AsyncHyperBandScheduler, ASHAScheduler) from ray.tune.schedulers.median_stopping_rule import MedianStoppingRule @@ -12,5 +13,5 @@ __all__ = [ "TrialScheduler", "HyperBandScheduler", "AsyncHyperBandScheduler", "ASHAScheduler", "MedianStoppingRule", "FIFOScheduler", - "PopulationBasedTraining" + "PopulationBasedTraining", "HyperBandForBOHB" ] diff --git a/python/ray/tune/schedulers/hb_bohb.py b/python/ray/tune/schedulers/hb_bohb.py new file mode 100644 index 000000000000..428ffa6f7b08 --- /dev/null +++ b/python/ray/tune/schedulers/hb_bohb.py @@ -0,0 +1,128 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import logging + +from ray.tune.schedulers.trial_scheduler import TrialScheduler +from ray.tune.schedulers.hyperband import HyperBandScheduler, Bracket +from ray.tune.trial import Trial + +logger = logging.getLogger(__name__) + + +class HyperBandForBOHB(HyperBandScheduler): + """Extends HyperBand early stopping algorithm for BOHB. + + This implementation removes the ``HyperBandScheduler`` pipelining. This + class introduces key changes: + + 1. Trials are now placed so that the bracket with the largest size is + filled first. + + 2. Trials will be paused even if the bracket is not filled. This allows + BOHB to insert new trials into the training. + + See ray.tune.schedulers.HyperBandScheduler for parameter docstring. + """ + + def on_trial_add(self, trial_runner, trial): + """Adds new trial. + + On a new trial add, if current bracket is not filled, add to current + bracket. Else, if current band is not filled, create new bracket, add + to current bracket. Else, create new iteration, create new bracket, + add to bracket. + """ + + cur_bracket = self._state["bracket"] + cur_band = self._hyperbands[self._state["band_idx"]] + if cur_bracket is None or cur_bracket.filled(): + retry = True + while retry: + # if current iteration is filled, create new iteration + if self._cur_band_filled(): + cur_band = [] + self._hyperbands.append(cur_band) + self._state["band_idx"] += 1 + + # MAIN CHANGE HERE - largest bracket first! + # cur_band will always be less than s_max_1 or else filled + s = self._s_max_1 - len(cur_band) - 1 + assert s >= 0, "Current band is filled!" + if self._get_r0(s) == 0: + logger.debug("BOHB: Bracket too small - Retrying...") + cur_bracket = None + else: + retry = False + cur_bracket = Bracket(self._time_attr, self._get_n0(s), + self._get_r0(s), self._max_t_attr, + self._eta, s) + cur_band.append(cur_bracket) + self._state["bracket"] = cur_bracket + + self._state["bracket"].add_trial(trial) + self._trial_info[trial] = cur_bracket, self._state["band_idx"] + + def on_trial_result(self, trial_runner, trial, result): + """If bracket is finished, all trials will be stopped. + + If a given trial finishes and bracket iteration is not done, + the trial will be paused and resources will be given up. + + This scheduler will not start trials but will stop trials. + The current running trial will not be handled, + as the trialrunner will be given control to handle it.""" + + result["hyperband_info"] = {} + bracket, _ = self._trial_info[trial] + bracket.update_trial_stats(trial, result) + + if bracket.continue_trial(trial): + return TrialScheduler.CONTINUE + + result["hyperband_info"]["budget"] = bracket._cumul_r + + # MAIN CHANGE HERE! + statuses = [(t, t.status) for t in bracket._live_trials] + if not bracket.filled() or any(status != Trial.PAUSED + for t, status in statuses + if t is not trial): + trial_runner._search_alg.on_pause(trial.trial_id) + return TrialScheduler.PAUSE + action = self._process_bracket(trial_runner, bracket) + return action + + def _unpause_trial(self, trial_runner, trial): + trial_runner.trial_executor.unpause_trial(trial) + trial_runner._search_alg.on_unpause(trial.trial_id) + + def choose_trial_to_run(self, trial_runner): + """Fair scheduling within iteration by completion percentage. + + List of trials not used since all trials are tracked as state + of scheduler. If iteration is occupied (ie, no trials to run), + then look into next iteration. + """ + + for hyperband in self._hyperbands: + # band will have None entries if no resources + # are to be allocated to that bracket. + scrubbed = [b for b in hyperband if b is not None] + for bracket in scrubbed: + for trial in bracket.current_trials(): + if (trial.status == Trial.PENDING + and trial_runner.has_resources(trial.resources)): + return trial + # MAIN CHANGE HERE! + if not any(t.status == Trial.RUNNING + for t in trial_runner.get_trials()): + for hyperband in self._hyperbands: + for bracket in hyperband: + if bracket and any(trial.status == Trial.PAUSED + for trial in bracket.current_trials()): + # This will change the trial state and let the + # trial runner retry. + self._process_bracket(trial_runner, bracket) + # MAIN CHANGE HERE! + return None diff --git a/python/ray/tune/schedulers/hyperband.py b/python/ray/tune/schedulers/hyperband.py index b1ca1deec11b..064ae09aa8bd 100644 --- a/python/ray/tune/schedulers/hyperband.py +++ b/python/ray/tune/schedulers/hyperband.py @@ -8,6 +8,7 @@ from ray.tune.schedulers.trial_scheduler import FIFOScheduler, TrialScheduler from ray.tune.trial import Trial +from ray.tune.error import TuneError logger = logging.getLogger(__name__) @@ -72,6 +73,8 @@ class HyperBandScheduler(FIFOScheduler): The scheduler will terminate trials after this time has passed. Note that this is different from the semantics of `max_t` as mentioned in the original HyperBand paper. + reduction_factor (float): Same as `eta`. Determines how sharp + the difference is between bracket space-time allocation ratios. """ def __init__(self, @@ -79,7 +82,8 @@ def __init__(self, reward_attr=None, metric="episode_reward_mean", mode="max", - max_t=81): + max_t=81, + reduction_factor=3): assert max_t > 0, "Max (time_attr) not valid!" assert mode in ["min", "max"], "`mode` must be 'min' or 'max'!" @@ -92,8 +96,9 @@ def __init__(self, "Setting `metric={}` and `mode=max`.".format(reward_attr)) FIFOScheduler.__init__(self) - self._eta = 3 - self._s_max_1 = 5 + self._eta = reduction_factor + self._s_max_1 = int( + np.round(np.log(max_t) / np.log(reduction_factor))) + 1 self._max_t_attr = max_t # bracket max trials self._get_n0 = lambda s: int( @@ -173,10 +178,10 @@ def on_trial_result(self, trial_runner, trial, result): if bracket.continue_trial(trial): return TrialScheduler.CONTINUE - action = self._process_bracket(trial_runner, bracket, trial) + action = self._process_bracket(trial_runner, bracket) return action - def _process_bracket(self, trial_runner, bracket, trial): + def _process_bracket(self, trial_runner, bracket): """This is called whenever a trial makes progress. When all live trials in the bracket have no more iterations left, @@ -202,15 +207,15 @@ def _process_bracket(self, trial_runner, bracket, trial): bracket.cleanup_trial(t) action = TrialScheduler.STOP else: - raise Exception("Trial with unexpected status encountered") + raise TuneError("Trial with unexpected status encountered") # ready the good trials - if trial is too far ahead, don't continue for t in good: if t.status not in [Trial.PAUSED, Trial.RUNNING]: - raise Exception("Trial with unexpected status encountered") + raise TuneError("Trial with unexpected status encountered") if bracket.continue_trial(t): if t.status == Trial.PAUSED: - trial_runner.trial_executor.unpause_trial(t) + self._unpause_trial(trial_runner, t) elif t.status == Trial.RUNNING: action = TrialScheduler.CONTINUE return action @@ -223,7 +228,7 @@ def on_trial_remove(self, trial_runner, trial): bracket, _ = self._trial_info[trial] bracket.cleanup_trial(trial) if not bracket.finished(): - self._process_bracket(trial_runner, bracket, trial) + self._process_bracket(trial_runner, bracket) def on_trial_complete(self, trial_runner, trial, result): """Cleans up trial info from bracket if trial completed early.""" @@ -279,6 +284,15 @@ def debug_string(self): out += "\n {}".format(bracket) return out + def state(self): + return { + "num_brackets": sum(len(band) for band in self._hyperbands), + "num_stopped": self._num_stopped + } + + def _unpause_trial(self, trial_runner, trial): + trial_runner.trial_executor.unpause_trial(trial) + class Bracket(): """Logical object for tracking Hyperband bracket progress. Keeps track @@ -349,7 +363,7 @@ def successive_halving(self, metric, metric_op): self._r *= self._eta self._r = int(min(self._r, self._max_t_attr - self._cumul_r)) - self._cumul_r += self._r + self._cumul_r = self._r sorted_trials = sorted( self._live_trials, key=lambda t: metric_op * self._live_trials[t][metric]) diff --git a/python/ray/tune/suggest/__init__.py b/python/ray/tune/suggest/__init__.py index 69f2897207ae..a182f6b1ae43 100644 --- a/python/ray/tune/suggest/__init__.py +++ b/python/ray/tune/suggest/__init__.py @@ -2,10 +2,11 @@ from ray.tune.suggest.basic_variant import BasicVariantGenerator from ray.tune.suggest.suggestion import SuggestionAlgorithm from ray.tune.suggest.variant_generator import grid_search +from ray.tune.suggest.bohb import TuneBOHB __all__ = [ "SearchAlgorithm", "BasicVariantGenerator", "SuggestionAlgorithm", - "grid_search" + "grid_search", "TuneBOHB" ] diff --git a/python/ray/tune/suggest/bohb.py b/python/ray/tune/suggest/bohb.py new file mode 100644 index 000000000000..3fe5d5877f29 --- /dev/null +++ b/python/ray/tune/suggest/bohb.py @@ -0,0 +1,128 @@ +"""BOHB (Bayesian Optimization with HyperBand)""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import copy +import logging + +from ray.tune.suggest import SuggestionAlgorithm + +logger = logging.getLogger(__name__) + + +class _BOHBJobWrapper(): + """Mock object for HpBandSter to process.""" + + def __init__(self, loss, budget, config): + self.result = {"loss": loss} + self.kwargs = {"budget": budget, "config": config.copy()} + self.exception = None + + +class TuneBOHB(SuggestionAlgorithm): + """BOHB suggestion component. + + + Requires HpBandSter and ConfigSpace to be installed. You can install + HpBandSter and ConfigSpace with: `pip install hpbandster ConfigSpace`. + + This should be used in conjunction with HyperBandForBOHB. + + Args: + space (ConfigurationSpace): Continuous ConfigSpace search space. + Parameters will be sampled from this space which will be used + to run trials. + bohb_config (dict): configuration for HpBandSter BOHB algorithm + max_concurrent (int): Number of maximum concurrent trials. Defaults + to 10. + metric (str): The training result objective value attribute. + mode (str): One of {min, max}. Determines whether objective is + minimizing or maximizing the metric attribute. + + Example: + >>> import ConfigSpace as CS + >>> config_space = CS.ConfigurationSpace() + >>> config_space.add_hyperparameter( + CS.UniformFloatHyperparameter('width', lower=0, upper=20)) + >>> config_space.add_hyperparameter( + CS.UniformFloatHyperparameter('height', lower=-100, upper=100)) + >>> config_space.add_hyperparameter( + CS.CategoricalHyperparameter( + name='activation', choices=['relu', 'tanh'])) + >>> algo = TuneBOHB( + config_space, max_concurrent=4, metric='mean_loss', mode='min') + >>> bohb = HyperBandForBOHB( + time_attr='training_iteration', + metric='mean_loss', + mode='min', + max_t=100) + >>> run(MyTrainableClass, scheduler=bohb, search_alg=algo) + + """ + + def __init__(self, + space, + bohb_config=None, + max_concurrent=10, + metric="neg_mean_loss", + mode="max"): + from hpbandster.optimizers.config_generators.bohb import BOHB + assert BOHB is not None, "HpBandSter must be installed!" + assert mode in ["min", "max"], "`mode` must be 'min' or 'max'!" + self._max_concurrent = max_concurrent + self.trial_to_params = {} + self.running = set() + self.paused = set() + self.metric = metric + if mode == "max": + self._metric_op = -1. + elif mode == "min": + self._metric_op = 1. + bohb_config = bohb_config or {} + self.bohber = BOHB(space, **bohb_config) + super(TuneBOHB, self).__init__() + + def _suggest(self, trial_id): + if len(self.running) < self._max_concurrent: + # This parameter is not used in hpbandster implementation. + config, info = self.bohber.get_config(None) + self.trial_to_params[trial_id] = copy.deepcopy(config) + self.running.add(trial_id) + return config + return None + + def on_trial_result(self, trial_id, result): + if trial_id not in self.paused: + self.running.add(trial_id) + if "hyperband_info" not in result: + logger.warning("BOHB Info not detected in result. Are you using " + "HyperBandForBOHB as a scheduler?") + elif "budget" in result.get("hyperband_info", {}): + hbs_wrapper = self.to_wrapper(trial_id, result) + self.bohber.new_result(hbs_wrapper) + + def on_trial_complete(self, + trial_id, + result=None, + error=False, + early_terminated=False): + del self.trial_to_params[trial_id] + if trial_id in self.paused: + self.paused.remove(trial_id) + if trial_id in self.running: + self.running.remove(trial_id) + + def to_wrapper(self, trial_id, result): + return _BOHBJobWrapper(self._metric_op * result[self.metric], + result["hyperband_info"]["budget"], + self.trial_to_params[trial_id]) + + def on_pause(self, trial_id): + self.paused.add(trial_id) + self.running.remove(trial_id) + + def on_unpause(self, trial_id): + self.paused.remove(trial_id) + self.running.add(trial_id) diff --git a/python/ray/tune/tests/test_trial_scheduler.py b/python/ray/tune/tests/test_trial_scheduler.py index 2a0282c5a975..4d9eb8d072fa 100644 --- a/python/ray/tune/tests/test_trial_scheduler.py +++ b/python/ray/tune/tests/test_trial_scheduler.py @@ -15,7 +15,7 @@ from ray.tune.result import TRAINING_ITERATION from ray.tune.schedulers import (HyperBandScheduler, AsyncHyperBandScheduler, PopulationBasedTraining, MedianStoppingRule, - TrialScheduler) + TrialScheduler, HyperBandForBOHB) from ray.tune.schedulers.pbt import explore from ray.tune.trial import Trial, Checkpoint @@ -236,7 +236,7 @@ def _launch_trial(self, trial): class HyperbandSuite(unittest.TestCase): def setUp(self): - ray.init() + ray.init(object_store_memory=int(1e8)) def tearDown(self): ray.shutdown() @@ -319,17 +319,19 @@ def testConfigSameEta(self): self.assertEqual(sched._hyperbands[0][-1]._n, 81) self.assertEqual(sched._hyperbands[0][-1]._r, 1) - sched = HyperBandScheduler(max_t=810) + reduction_factor = 10 + sched = HyperBandScheduler( + max_t=1000, reduction_factor=reduction_factor) i = 0 while not sched._cur_band_filled(): t = Trial("__fake") sched.on_trial_add(None, t) i += 1 - self.assertEqual(len(sched._hyperbands[0]), 5) - self.assertEqual(sched._hyperbands[0][0]._n, 5) - self.assertEqual(sched._hyperbands[0][0]._r, 810) - self.assertEqual(sched._hyperbands[0][-1]._n, 81) - self.assertEqual(sched._hyperbands[0][-1]._r, 10) + self.assertEqual(len(sched._hyperbands[0]), 4) + self.assertEqual(sched._hyperbands[0][0]._n, 4) + self.assertEqual(sched._hyperbands[0][0]._r, 1000) + self.assertEqual(sched._hyperbands[0][-1]._n, 1000) + self.assertEqual(sched._hyperbands[0][-1]._r, 1) def testConfigSameEtaSmall(self): sched = HyperBandScheduler(max_t=1) @@ -338,8 +340,7 @@ def testConfigSameEtaSmall(self): t = Trial("__fake") sched.on_trial_add(None, t) i += 1 - self.assertEqual(len(sched._hyperbands[0]), 5) - self.assertTrue(all(v is None for v in sched._hyperbands[0][1:])) + self.assertEqual(len(sched._hyperbands[0]), 1) def testSuccessiveHalving(self): """Setup full band, then iterate through last bracket (n=81) @@ -367,7 +368,7 @@ def testSuccessiveHalving(self): self.assertEqual(action, TrialScheduler.CONTINUE) new_length = len(big_bracket.current_trials()) self.assertEqual(new_length, self.downscale(current_length, sched)) - cur_units += int(cur_units * sched._eta) + cur_units = int(cur_units * sched._eta) self.assertEqual(len(big_bracket.current_trials()), 1) def testHalvingStop(self): @@ -603,6 +604,76 @@ def testFilterNoneBracket(self): self.assertIsNotNone(trial) +class BOHBSuite(unittest.TestCase): + def setUp(self): + ray.init(object_store_memory=int(1e8)) + + def tearDown(self): + ray.shutdown() + _register_all() # re-register the evicted objects + + def testLargestBracketFirst(self): + sched = HyperBandForBOHB(max_t=3, reduction_factor=3) + runner = _MockTrialRunner(sched) + for i in range(3): + t = Trial("__fake") + sched.on_trial_add(runner, t) + runner._launch_trial(t) + + self.assertEqual(sched.state()["num_brackets"], 1) + sched.on_trial_add(runner, Trial("__fake")) + self.assertEqual(sched.state()["num_brackets"], 2) + + def testCheckTrialInfoUpdate(self): + def result(score, ts): + return {"episode_reward_mean": score, TRAINING_ITERATION: ts} + + sched = HyperBandForBOHB(max_t=3, reduction_factor=3) + runner = _MockTrialRunner(sched) + runner._search_alg = MagicMock() + trials = [Trial("__fake") for i in range(3)] + for t in trials: + runner.add_trial(t) + runner._launch_trial(t) + + for trial, trial_result in zip(trials, [result(1, 1), result(2, 1)]): + decision = sched.on_trial_result(runner, trial, trial_result) + self.assertEqual(decision, TrialScheduler.PAUSE) + runner._pause_trial(trial) + spy_result = result(0, 1) + decision = sched.on_trial_result(runner, trials[-1], spy_result) + self.assertEqual(decision, TrialScheduler.STOP) + sched.choose_trial_to_run(runner) + self.assertEqual(runner._search_alg.on_pause.call_count, 2) + self.assertEqual(runner._search_alg.on_unpause.call_count, 1) + self.assertTrue("hyperband_info" in spy_result) + self.assertEquals(spy_result["hyperband_info"]["budget"], 1) + + def testCheckTrialInfoUpdateMin(self): + def result(score, ts): + return {"episode_reward_mean": score, TRAINING_ITERATION: ts} + + sched = HyperBandForBOHB(max_t=3, reduction_factor=3, mode="min") + runner = _MockTrialRunner(sched) + runner._search_alg = MagicMock() + trials = [Trial("__fake") for i in range(3)] + for t in trials: + runner.add_trial(t) + runner._launch_trial(t) + + for trial, trial_result in zip(trials, [result(1, 1), result(2, 1)]): + decision = sched.on_trial_result(runner, trial, trial_result) + self.assertEqual(decision, TrialScheduler.PAUSE) + runner._pause_trial(trial) + spy_result = result(0, 1) + decision = sched.on_trial_result(runner, trials[-1], spy_result) + self.assertEqual(decision, TrialScheduler.CONTINUE) + sched.choose_trial_to_run(runner) + self.assertEqual(runner._search_alg.on_pause.call_count, 2) + self.assertTrue("hyperband_info" in spy_result) + self.assertEquals(spy_result["hyperband_info"]["budget"], 1) + + class _MockTrial(Trial): def __init__(self, i, config): self.trainable_name = "trial_{}".format(i)