-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[tune] Add ASHA (promotion-based scheduler) #4401
Comments
This is already implemented in https://ray.readthedocs.io/en/latest/tune-schedulers.html#asynchronous-hyperband Should we rename it? |
Hmm I think we should have something called ASHA in the docs / doc comments for search purposes. Not sure about naming the class ASHA :) |
Ah awesome, thanks a lot! For me it would have definitely been easier to spot if it would have been called AsynchronousSuccessiveHalvingScheduler, but I am not sure if this is the case for everyone else. ;-) |
OK; feel free to push a fix to our documentation or docstrings! Closing this for now. |
Hello,
|
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import logging
import numpy as np
from ray.tune.trial import Trial
from ray.tune.schedulers import (
FIFOScheduler, TrialScheduler, AsyncHyperBandScheduler)
logger = logging.getLogger(__name__)
class ASHAv2(FIFOScheduler):
"""Implements the Async Successive Halving with better termination."""
def __init__(self,
time_attr="training_iteration",
reward_attr=None,
metric="episode_reward_mean",
mode="max",
max_t=100,
grace_period=1,
reduction_factor=4,
brackets=1):
assert max_t > 0, "Max (time_attr) not valid!"
assert max_t >= grace_period, "grace_period must be <= max_t!"
assert grace_period > 0, "grace_period must be positive!"
assert reduction_factor > 1, "Reduction Factor not valid!"
assert brackets > 0, "brackets must be positive!"
assert mode in ["min", "max"], "`mode` must be 'min' or 'max'!"
if reward_attr is not None:
mode = "max"
metric = reward_attr
logger.warning(
"`reward_attr` is deprecated and will be removed in a future "
"version of Tune. "
"Setting `metric={}` and `mode=max`.".format(reward_attr))
FIFOScheduler.__init__(self)
self._reduction_factor = reduction_factor
self._max_t = max_t
# Tracks state for new trial add
self._brackets = [
_Bracket(grace_period, max_t, reduction_factor, s)
for s in range(brackets)
]
self._counter = 0 # for
self._num_stopped = 0
self._metric = metric
if mode == "max":
self._metric_op = 1.
elif mode == "min":
self._metric_op = -1.
self._time_attr = time_attr
self._num_paused = 0
def on_trial_result(self, trial_runner, trial, result):
action = TrialScheduler.CONTINUE
if self._time_attr not in result or self._metric not in result:
return action
if result[self._time_attr] >= self._max_t:
action = TrialScheduler.STOP
else:
bracket = self._brackets[0]
action = bracket.on_result(trial, result[self._time_attr],
self._metric_op * result[self._metric])
if action == TrialScheduler.STOP:
self._num_stopped += 1
if action == TrialScheduler.PAUSE:
self._num_paused += 1
return action
def on_trial_complete(self, trial_runner, trial, result):
if self._time_attr not in result or self._metric not in result:
return
bracket = self._brackets[0]
bracket.on_result(trial, result[self._time_attr],
self._metric_op * result[self._metric],
complete=True)
def choose_trial_to_run(self, trial_runner):
for bracket in self._brackets:
for trial in bracket.promotable_trials():
if trial and trial_runner.has_resources(trial.resources):
assert trial.status == Trial.PAUSED
logger.warning(f"Promoting trial [{trial.config}].")
bracket.unpause_trial(trial)
return trial
trial = FIFOScheduler.choose_trial_to_run(self, trial_runner)
if trial:
self._brackets[0].unpause_trial(trial)
logger.info(f"Choosing trial {trial.config} to run from trialrunner.")
return trial
def debug_string(self):
out = "Using ASHAv2: num_stopped={}".format(self._num_stopped)
out += "\n" + "\n".join([b.debug_str() for b in self._brackets])
return out
class _Bracket():
"""Bookkeeping system to track the cutoffs.
Rungs are created in reversed order so that we can more easily find
the correct rung corresponding to the current iteration of the result.
Example:
>>> b = _Bracket(1, 10, 2, 3)
>>> b.on_result(trial1, 1, 2) # CONTINUE
>>> b.on_result(trial2, 1, 4) # CONTINUE
>>> b.cutoff(b._rungs[-1][1]) == 3.0 # rungs are reversed
>>> b.on_result(trial3, 1, 1) # STOP
>>> b.cutoff(b._rungs[0][1]) == 2.0
"""
def __init__(self, min_t, max_t, reduction_factor, s):
self.rf = reduction_factor
MAX_RUNGS = int(np.log(max_t / min_t) / np.log(self.rf) - s + 1)
self._rungs = [(min_t * self.rf**(k + s), {}, [])
for k in reversed(range(MAX_RUNGS))]
def cutoff(self, recorded):
if len(recorded) < self.rf:
return None
return np.percentile(list(recorded.values()), (1 - 1 / self.rf) * 100)
def top_k_ids(self, recorded):
entries = list(recorded.items())
k = int(len(entries) / self.rf)
top_rung = sorted(entries, key=lambda kv: kv[1], reverse=True)[0:k]
print("TOP RUNG:", top_rung)
return [tid for tid, value in top_rung]
def on_result(self, trial, cur_iter, cur_rew, complete=False):
action = TrialScheduler.CONTINUE
if cur_rew is None:
logger.warning("Reward attribute is None! Consider"
" reporting using a different field.")
return action
for milestone, recorded, paused in self._rungs:
if cur_iter < milestone or trial.trial_id in recorded:
continue
else:
recorded[trial.trial_id] = cur_rew
top_k_trial_ids = self.top_k_ids(recorded)
if complete or trial.status != Trial.RUNNING:
break
if trial.trial_id not in top_k_trial_ids:
action = TrialScheduler.PAUSE
paused += [trial]
break
if action == TrialScheduler.PAUSE:
print(trial, cur_iter)
return action
def debug_str(self):
iters = " | ".join([
"Iter {:.3f}: {} [{} paused]".format(
milestone, self.cutoff(recorded), len(paused))
for milestone, recorded, paused in self._rungs
])
return "Bracket: " + iters
def promotable_trials(self):
for _, recorded, paused in self._rungs:
for tid in self.top_k_ids(recorded):
paused_trials = {p.trial_id: p for p in paused}
if tid in paused_trials:
yield paused_trials[tid]
def unpause_trial(self, trial):
for _, _, paused in self._rungs:
if trial in paused:
paused.pop(paused.index(trial))
assert trial not in paused Should be the implementation you're looking for, @mseeger? |
Thanks, great.
Maybe some other naming should be found? Both variants may be useful to
have. One could be called promotion based, the other stopping based.
Richard Liaw <[email protected]> schrieb am Do., 26. Sep. 2019,
00:41:
… from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import logging
import numpy as np
from ray.tune.trial import Trial
from ray.tune.schedulers import (
FIFOScheduler, TrialScheduler, AsyncHyperBandScheduler)
logger = logging.getLogger(__name__)
class ASHAv2(FIFOScheduler):
"""Implements the Async Successive Halving with better termination."""
def __init__(self,
time_attr="training_iteration",
reward_attr=None,
metric="episode_reward_mean",
mode="max",
max_t=100,
grace_period=1,
reduction_factor=4,
brackets=1):
assert max_t > 0, "Max (time_attr) not valid!"
assert max_t >= grace_period, "grace_period must be <= max_t!"
assert grace_period > 0, "grace_period must be positive!"
assert reduction_factor > 1, "Reduction Factor not valid!"
assert brackets > 0, "brackets must be positive!"
assert mode in ["min", "max"], "`mode` must be 'min' or 'max'!"
if reward_attr is not None:
mode = "max"
metric = reward_attr
logger.warning(
"`reward_attr` is deprecated and will be removed in a future "
"version of Tune. "
"Setting `metric={}` and `mode=max`.".format(reward_attr))
FIFOScheduler.__init__(self)
self._reduction_factor = reduction_factor
self._max_t = max_t
# Tracks state for new trial add
self._brackets = [
_Bracket(grace_period, max_t, reduction_factor, s)
for s in range(brackets)
]
self._counter = 0 # for
self._num_stopped = 0
self._metric = metric
if mode == "max":
self._metric_op = 1.
elif mode == "min":
self._metric_op = -1.
self._time_attr = time_attr
self._num_paused = 0
def on_trial_result(self, trial_runner, trial, result):
action = TrialScheduler.CONTINUE
if self._time_attr not in result or self._metric not in result:
return action
if result[self._time_attr] >= self._max_t:
action = TrialScheduler.STOP
else:
bracket = self._brackets[0]
action = bracket.on_result(trial, result[self._time_attr],
self._metric_op * result[self._metric])
if action == TrialScheduler.STOP:
self._num_stopped += 1
if action == TrialScheduler.PAUSE:
self._num_paused += 1
return action
def on_trial_complete(self, trial_runner, trial, result):
if self._time_attr not in result or self._metric not in result:
return
bracket = self._brackets[0]
bracket.on_result(trial, result[self._time_attr],
self._metric_op * result[self._metric],
complete=True)
def choose_trial_to_run(self, trial_runner):
for bracket in self._brackets:
for trial in bracket.promotable_trials():
if trial and trial_runner.has_resources(trial.resources):
assert trial.status == Trial.PAUSED
logger.warning(f"Promoting trial [{trial.config}].")
bracket.unpause_trial(trial)
return trial
trial = FIFOScheduler.choose_trial_to_run(self, trial_runner)
if trial:
self._brackets[0].unpause_trial(trial)
logger.info(f"Choosing trial {trial.config} to run from trialrunner.")
return trial
def debug_string(self):
out = "Using ASHAv2: num_stopped={}".format(self._num_stopped)
out += "\n" + "\n".join([b.debug_str() for b in self._brackets])
return out
class _Bracket():
"""Bookkeeping system to track the cutoffs.
Rungs are created in reversed order so that we can more easily find
the correct rung corresponding to the current iteration of the result.
Example:
>>> b = _Bracket(1, 10, 2, 3)
>>> b.on_result(trial1, 1, 2) # CONTINUE
>>> b.on_result(trial2, 1, 4) # CONTINUE
>>> b.cutoff(b._rungs[-1][1]) == 3.0 # rungs are reversed
>>> b.on_result(trial3, 1, 1) # STOP
>>> b.cutoff(b._rungs[0][1]) == 2.0
"""
def __init__(self, min_t, max_t, reduction_factor, s):
self.rf = reduction_factor
MAX_RUNGS = int(np.log(max_t / min_t) / np.log(self.rf) - s + 1)
self._rungs = [(min_t * self.rf**(k + s), {}, [])
for k in reversed(range(MAX_RUNGS))]
def cutoff(self, recorded):
if len(recorded) < self.rf:
return None
return np.percentile(list(recorded.values()), (1 - 1 / self.rf) * 100)
def top_k_ids(self, recorded):
entries = list(recorded.items())
k = int(len(entries) / self.rf)
top_rung = sorted(entries, key=lambda kv: kv[1], reverse=True)[0:k]
print("TOP RUNG:", top_rung)
return [tid for tid, value in top_rung]
def on_result(self, trial, cur_iter, cur_rew, complete=False):
action = TrialScheduler.CONTINUE
if cur_rew is None:
logger.warning("Reward attribute is None! Consider"
" reporting using a different field.")
return action
for milestone, recorded, paused in self._rungs:
if cur_iter < milestone or trial.trial_id in recorded:
continue
else:
recorded[trial.trial_id] = cur_rew
top_k_trial_ids = self.top_k_ids(recorded)
if complete or trial.status != Trial.RUNNING:
break
if trial.trial_id not in top_k_trial_ids:
action = TrialScheduler.PAUSE
paused += [trial]
break
if action == TrialScheduler.PAUSE:
print(trial, cur_iter)
return action
def debug_str(self):
iters = " | ".join([
"Iter {:.3f}: {} [{} paused]".format(
milestone, self.cutoff(recorded), len(paused))
for milestone, recorded, paused in self._rungs
])
return "Bracket: " + iters
def promotable_trials(self):
for _, recorded, paused in self._rungs:
for tid in self.top_k_ids(recorded):
paused_trials = {p.trial_id: p for p in paused}
if tid in paused_trials:
yield paused_trials[tid]
def unpause_trial(self, trial):
for _, _, paused in self._rungs:
if trial in paused:
paused.pop(paused.index(trial))
assert trial not in paused
Should be the implementation you're looking for, @mseeger
<https://github.com/mseeger>?
—
You are receiving this because you were mentioned.
Reply to this email directly, view it on GitHub
<#4401?email_source=notifications&email_token=ABRVDIQTGFXGQF2CZPLT44DQLPSH7A5CNFSM4G7IB7QKYY3PNVWWK3TUL52HS4DFVREXG43VMVBW63LNMVXHJKTDN5WW2ZLOORPWSZGOD7TUW6Y#issuecomment-535251835>,
or mute the thread
<https://github.com/notifications/unsubscribe-auth/ABRVDIVL23JNCFO3DHM7Z73QLPSH7ANCNFSM4G7IB7QA>
.
|
Hi @richardliaw . Did you have some observations that |
It does not perform better, though we can add it to master. I'll reopen this issue in case anyone is interested. |
This probably can be closed. |
Is support for the following items that @mseeger mentioned available in 2.2.0 or any other version?
I am particularly interested in a trial being preempted when it reaches a milestone. This allows the comparison of all trials at the same milestone before making a decision about which configs should be promoted. |
@MahdiNazemi Can you make a new issue? |
I read the title of this paper from CMU and immediately had to think of ray.
I am not sure how hard this would be for you guys to add to tune, but on a very high-level it seems like a rather simple and intuitive algorithm, that apparently is competitive to PBT.
Here’s a blog post explaining the idea behind ASHA (Asynchronous Succesive Halving Algorithm):
https://blog.ml.cmu.edu/2018/12/12/massively-parallel-hyperparameter-optimization/
Would be awesome if this could be added to tune. :)
The text was updated successfully, but these errors were encountered: