Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[tune] Enable AsyncHyperband to continue training for last trials after max_t #24222

Merged
merged 2 commits into from
Apr 27, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 31 additions & 3 deletions python/ray/tune/schedulers/async_hyperband.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ class AsyncHyperBandScheduler(FIFOScheduler):
is simply a unit-less scalar.
brackets: Number of brackets. Each bracket has a different
halving rate, specified by the reduction factor.
stop_last_trials: Whether to terminate the trials after
reaching max_t. Defaults to True.
"""

def __init__(
Expand All @@ -51,6 +53,7 @@ def __init__(
grace_period: int = 1,
reduction_factor: float = 4,
brackets: int = 1,
stop_last_trials: bool = True,
):
assert max_t > 0, "Max (time_attr) not valid!"
assert max_t >= grace_period, "grace_period must be <= max_t!"
Expand All @@ -68,7 +71,14 @@ def __init__(

# Tracks state for new trial add
self._brackets = [
_Bracket(grace_period, max_t, reduction_factor, s) for s in range(brackets)
_Bracket(
grace_period,
max_t,
reduction_factor,
s,
stop_last_trials=stop_last_trials,
)
for s in range(brackets)
]
self._counter = 0 # for
self._num_stopped = 0
Expand All @@ -80,6 +90,7 @@ def __init__(
elif self._mode == "min":
self._metric_op = -1.0
self._time_attr = time_attr
self._stop_last_trials = stop_last_trials

def set_search_properties(
self, metric: Optional[str], mode: Optional[str], **spec
Expand Down Expand Up @@ -128,7 +139,7 @@ def on_trial_result(
action = TrialScheduler.CONTINUE
if self._time_attr not in result or self._metric not in result:
return action
if result[self._time_attr] >= self._max_t:
if result[self._time_attr] >= self._max_t and self._stop_last_trials:
action = TrialScheduler.STOP
else:
bracket = self._trial_info[trial.trial_id]
Expand Down Expand Up @@ -189,12 +200,20 @@ class _Bracket:
>>> b.cutoff(b._rungs[3][1]) == 2.0 # doctest: +SKIP
"""

def __init__(self, min_t: int, max_t: int, reduction_factor: float, s: int):
def __init__(
self,
min_t: int,
max_t: int,
reduction_factor: float,
s: int,
stop_last_trials: bool = True,
):
self.rf = reduction_factor
MAX_RUNGS = int(np.log(max_t / min_t) / np.log(self.rf) - s + 1)
self._rungs = [
(min_t * self.rf ** (k + s), {}) for k in reversed(range(MAX_RUNGS))
]
self._stop_last_trials = stop_last_trials

def cutoff(self, recorded) -> Optional[Union[int, float, complex, np.ndarray]]:
if not recorded:
Expand All @@ -204,6 +223,15 @@ def cutoff(self, recorded) -> Optional[Union[int, float, complex, np.ndarray]]:
def on_result(self, trial: Trial, cur_iter: int, cur_rew: Optional[float]) -> str:
action = TrialScheduler.CONTINUE
for milestone, recorded in self._rungs:
if (
cur_iter >= milestone
and trial.trial_id in recorded
and not self._stop_last_trials
):
# If our result has been recorded for this trial already, the
# decision to continue has already been made. Thus we can skip new
# cutoff calculation and just continue.
break
Yard1 marked this conversation as resolved.
Show resolved Hide resolved
if cur_iter < milestone or trial.trial_id in recorded:
continue
else:
Expand Down
51 changes: 51 additions & 0 deletions python/ray/tune/tests/test_trial_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2150,6 +2150,57 @@ def testAsyncHBSaveRestore(self):
TrialScheduler.STOP,
)

def testAsyncHBNonStopTrials(self):
trials = [Trial("PPO") for i in range(4)]
scheduler = AsyncHyperBandScheduler(
metric="metric",
mode="max",
grace_period=1,
max_t=3,
reduction_factor=2,
brackets=1,
stop_last_trials=False,
)
scheduler.on_trial_add(None, trials[0])
scheduler.on_trial_add(None, trials[1])
scheduler.on_trial_add(None, trials[2])
scheduler.on_trial_add(None, trials[3])

# Report one result
action = scheduler.on_trial_result(
None, trials[0], {"training_iteration": 2, "metric": 10}
)
assert action == TrialScheduler.CONTINUE
action = scheduler.on_trial_result(
None, trials[1], {"training_iteration": 2, "metric": 8}
)
assert action == TrialScheduler.STOP
action = scheduler.on_trial_result(
None, trials[2], {"training_iteration": 2, "metric": 6}
)
assert action == TrialScheduler.STOP
action = scheduler.on_trial_result(
None, trials[3], {"training_iteration": 2, "metric": 4}
)
assert action == TrialScheduler.STOP

# Report more. This will fail if `stop_last_trials=True`
action = scheduler.on_trial_result(
None, trials[0], {"training_iteration": 4, "metric": 10}
)
assert action == TrialScheduler.CONTINUE

action = scheduler.on_trial_result(
None, trials[0], {"training_iteration": 8, "metric": 10}
)
assert action == TrialScheduler.CONTINUE

# Also continue if we fall below the cutoff eventually
action = scheduler.on_trial_result(
None, trials[0], {"training_iteration": 14, "metric": 1}
)
assert action == TrialScheduler.CONTINUE

def testMedianStoppingNanInf(self):
scheduler = MedianStoppingRule(metric="episode_reward_mean", mode="max")

Expand Down