diff --git a/python/ray/tune/schedulers/async_hyperband.py b/python/ray/tune/schedulers/async_hyperband.py index 375ecf73037f..8272a22a2174 100644 --- a/python/ray/tune/schedulers/async_hyperband.py +++ b/python/ray/tune/schedulers/async_hyperband.py @@ -40,6 +40,8 @@ class AsyncHyperBandScheduler(FIFOScheduler): is simply a unit-less scalar. brackets: Number of brackets. Each bracket has a different halving rate, specified by the reduction factor. + stop_last_trials: Whether to terminate the trials after + reaching max_t. Defaults to True. """ def __init__( @@ -51,6 +53,7 @@ def __init__( grace_period: int = 1, reduction_factor: float = 4, brackets: int = 1, + stop_last_trials: bool = True, ): assert max_t > 0, "Max (time_attr) not valid!" assert max_t >= grace_period, "grace_period must be <= max_t!" @@ -68,7 +71,14 @@ def __init__( # Tracks state for new trial add self._brackets = [ - _Bracket(grace_period, max_t, reduction_factor, s) for s in range(brackets) + _Bracket( + grace_period, + max_t, + reduction_factor, + s, + stop_last_trials=stop_last_trials, + ) + for s in range(brackets) ] self._counter = 0 # for self._num_stopped = 0 @@ -80,6 +90,7 @@ def __init__( elif self._mode == "min": self._metric_op = -1.0 self._time_attr = time_attr + self._stop_last_trials = stop_last_trials def set_search_properties( self, metric: Optional[str], mode: Optional[str], **spec @@ -128,7 +139,7 @@ def on_trial_result( action = TrialScheduler.CONTINUE if self._time_attr not in result or self._metric not in result: return action - if result[self._time_attr] >= self._max_t: + if result[self._time_attr] >= self._max_t and self._stop_last_trials: action = TrialScheduler.STOP else: bracket = self._trial_info[trial.trial_id] @@ -189,12 +200,20 @@ class _Bracket: >>> b.cutoff(b._rungs[3][1]) == 2.0 # doctest: +SKIP """ - def __init__(self, min_t: int, max_t: int, reduction_factor: float, s: int): + def __init__( + self, + min_t: int, + max_t: int, + reduction_factor: float, + s: int, + stop_last_trials: bool = True, + ): self.rf = reduction_factor MAX_RUNGS = int(np.log(max_t / min_t) / np.log(self.rf) - s + 1) self._rungs = [ (min_t * self.rf ** (k + s), {}) for k in reversed(range(MAX_RUNGS)) ] + self._stop_last_trials = stop_last_trials def cutoff(self, recorded) -> Optional[Union[int, float, complex, np.ndarray]]: if not recorded: @@ -204,6 +223,16 @@ def cutoff(self, recorded) -> Optional[Union[int, float, complex, np.ndarray]]: def on_result(self, trial: Trial, cur_iter: int, cur_rew: Optional[float]) -> str: action = TrialScheduler.CONTINUE for milestone, recorded in self._rungs: + if ( + cur_iter >= milestone + and trial.trial_id in recorded + and not self._stop_last_trials + ): + # If our result has been recorded for this trial already, the + # decision to continue training has already been made. Thus we can + # skip new cutoff calculation and just continue training. + # We can also break as milestones are descending. + break if cur_iter < milestone or trial.trial_id in recorded: continue else: diff --git a/python/ray/tune/tests/test_trial_scheduler.py b/python/ray/tune/tests/test_trial_scheduler.py index 5cb7eb94c06b..df3b497626d4 100644 --- a/python/ray/tune/tests/test_trial_scheduler.py +++ b/python/ray/tune/tests/test_trial_scheduler.py @@ -2150,6 +2150,57 @@ def testAsyncHBSaveRestore(self): TrialScheduler.STOP, ) + def testAsyncHBNonStopTrials(self): + trials = [Trial("PPO") for i in range(4)] + scheduler = AsyncHyperBandScheduler( + metric="metric", + mode="max", + grace_period=1, + max_t=3, + reduction_factor=2, + brackets=1, + stop_last_trials=False, + ) + scheduler.on_trial_add(None, trials[0]) + scheduler.on_trial_add(None, trials[1]) + scheduler.on_trial_add(None, trials[2]) + scheduler.on_trial_add(None, trials[3]) + + # Report one result + action = scheduler.on_trial_result( + None, trials[0], {"training_iteration": 2, "metric": 10} + ) + assert action == TrialScheduler.CONTINUE + action = scheduler.on_trial_result( + None, trials[1], {"training_iteration": 2, "metric": 8} + ) + assert action == TrialScheduler.STOP + action = scheduler.on_trial_result( + None, trials[2], {"training_iteration": 2, "metric": 6} + ) + assert action == TrialScheduler.STOP + action = scheduler.on_trial_result( + None, trials[3], {"training_iteration": 2, "metric": 4} + ) + assert action == TrialScheduler.STOP + + # Report more. This will fail if `stop_last_trials=True` + action = scheduler.on_trial_result( + None, trials[0], {"training_iteration": 4, "metric": 10} + ) + assert action == TrialScheduler.CONTINUE + + action = scheduler.on_trial_result( + None, trials[0], {"training_iteration": 8, "metric": 10} + ) + assert action == TrialScheduler.CONTINUE + + # Also continue if we fall below the cutoff eventually + action = scheduler.on_trial_result( + None, trials[0], {"training_iteration": 14, "metric": 1} + ) + assert action == TrialScheduler.CONTINUE + def testMedianStoppingNanInf(self): scheduler = MedianStoppingRule(metric="episode_reward_mean", mode="max")