Skip to content

Commit

Permalink
addressed PR comments (ray-project#4120)
Browse files Browse the repository at this point in the history
  • Loading branch information
hershg committed May 27, 2019
1 parent b0910ee commit df94ade
Show file tree
Hide file tree
Showing 5 changed files with 127 additions and 4 deletions.
10 changes: 9 additions & 1 deletion python/ray/tune/schedulers/async_hyperband.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class AsyncHyperBandScheduler(FIFOScheduler):
with `time_attr`, this may refer to any objective value. Stopping
procedures will use this attribute.
mode (str): One of {min, max}. Determines whether objective is minimizing
maximizing the metric attribute
or maximizing the metric attribute
max_t (float): max time units per trial. Trials will be stopped after
max_t time units (determined by time_attr) have passed.
grace_period (float): Only stop trials at least this old in time.
Expand All @@ -42,6 +42,7 @@ class AsyncHyperBandScheduler(FIFOScheduler):

def __init__(self,
time_attr="training_iteration",
reward_attr=None,
metric="episode_reward_mean",
mode="max",
max_t=100,
Expand All @@ -54,6 +55,13 @@ def __init__(self,
assert reduction_factor > 1, "Reduction Factor not valid!"
assert brackets > 0, "brackets must be positive!"
assert mode in ["min", "max"], "mode must be 'min' or 'max'!"

if reward_attr is not None:
mode = "max"
metric = reward_attr
logger.warning("`reward_attr` will be depreciated!"
"Consider using `metric` and `mode`.")

FIFOScheduler.__init__(self)
self._reduction_factor = reduction_factor
self._max_t = max_t
Expand Down
10 changes: 9 additions & 1 deletion python/ray/tune/schedulers/hyperband.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ class HyperBandScheduler(FIFOScheduler):
with `time_attr`, this may refer to any objective value. Stopping
procedures will use this attribute.
mode (str): One of {min, max}. Determines whether objective is minimizing
maximizing the metric attribute
or maximizing the metric attribute
max_t (int): max time units per trial. Trials will be stopped after
max_t time units (determined by time_attr) have passed.
The scheduler will terminate trials after this time has passed.
Expand All @@ -77,11 +77,19 @@ class HyperBandScheduler(FIFOScheduler):

def __init__(self,
time_attr="training_iteration",
reward_attr=None,
metric="episode_reward_mean",
mode="max",
max_t=81):
assert max_t > 0, "Max (time_attr) not valid!"
assert mode in ["min", "max"], "mode must be 'min' or 'max'!"

if reward_attr is not None:
mode = "max"
metric = reward_attr
logger.warning("`reward_attr` will be depreciated!"
"Consider using `metric` and `mode`.")

FIFOScheduler.__init__(self)
self._eta = 3
self._s_max_1 = 5
Expand Down
11 changes: 10 additions & 1 deletion python/ray/tune/schedulers/median_stopping_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class MedianStoppingRule(FIFOScheduler):
with `time_attr`, this may refer to any objective value. Stopping
procedures will use this attribute.
mode (str): One of {min, max}. Determines whether objective is minimizing
maximizing the metric attribute
or maximizing the metric attribute
grace_period (float): Only stop trials at least this old in time.
The units are the same as the attribute named by `time_attr`.
min_samples_required (int): Min samples to compute median over.
Expand All @@ -39,12 +39,21 @@ class MedianStoppingRule(FIFOScheduler):

def __init__(self,
time_attr="time_total_s",
reward_attr=None,
metric="episode_reward_mean",
mode="max",
grace_period=60.0,
min_samples_required=3,
hard_stop=True,
verbose=True):
assert mode in ["min", "max"], "mode must be 'min' or 'max'!"

if reward_attr is not None:
mode = "max"
metric = reward_attr
logger.warning("`reward_attr` will be depreciated!"
"Consider using `metric` and `mode`.")

FIFOScheduler.__init__(self)
self._stopped_trials = set()
self._completed_trials = set()
Expand Down
12 changes: 11 additions & 1 deletion python/ray/tune/schedulers/pbt.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ class PopulationBasedTraining(FIFOScheduler):
with `time_attr`, this may refer to any objective value. Stopping
procedures will use this attribute.
mode (str): One of {min, max}. Determines whether objective is minimizing
maximizing the metric attribute
or maximizing the metric attribute
perturbation_interval (float): Models will be considered for
perturbation at this interval of `time_attr`. Note that
perturbation incurs checkpoint overhead, so you shouldn't set this
Expand Down Expand Up @@ -168,6 +168,7 @@ class PopulationBasedTraining(FIFOScheduler):

def __init__(self,
time_attr="time_total_s",
reward_attr=None,
metric="episode_reward_mean",
mode="max",
perturbation_interval=60.0,
Expand All @@ -179,6 +180,15 @@ def __init__(self,
raise TuneError(
"You must specify at least one of `hyperparam_mutations` or "
"`custom_explore_fn` to use PBT.")

assert mode in ["min", "max"], "mode must be 'min' or 'max'!"

if reward_attr is not None:
mode = "max"
metric = reward_attr
logger.warning("`reward_attr` will be depreciated!"
"Consider using `metric` and `mode`.")

FIFOScheduler.__init__(self)
self._metric = metric
self._time_attr = time_attr
Expand Down
88 changes: 88 additions & 0 deletions python/ray/tune/tests/test_trial_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,34 @@ def result2(t, rew):
rule.on_trial_result(None, t2, result2(6, 0)),
TrialScheduler.CONTINUE)

def testAlternateMetricsMin(self):
def result2(t, rew):
return dict(training_iteration=t, neg_mean_loss=rew)

rule = MedianStoppingRule(
grace_period=0,
min_samples_required=1,
time_attr="training_iteration",
metric="mean_loss",
mode="min")
t1 = Trial("PPO") # mean is 450, max 900, t_max=10
t2 = Trial("PPO") # mean is 450, max 450, t_max=5
for i in range(10):
self.assertEqual(
rule.on_trial_result(None, t1, result2(i, i * 100)),
TrialScheduler.CONTINUE)
for i in range(5):
self.assertEqual(
rule.on_trial_result(None, t2, result2(i, 450)),
TrialScheduler.CONTINUE)
rule.on_trial_complete(None, t1, result2(10, 1000))
self.assertEqual(
rule.on_trial_result(None, t2, result2(5, 450)),
TrialScheduler.CONTINUE)
self.assertEqual(
rule.on_trial_result(None, t2, result2(6, 0)),
TrialScheduler.CONTINUE)


class _MockTrialExecutor(TrialExecutor):
def start_trial(self, trial, checkpoint_obj=None):
Expand Down Expand Up @@ -526,6 +554,36 @@ def result2(t, rew):
self.assertEqual(action, TrialScheduler.CONTINUE)
self.assertEqual(new_length, self.downscale(current_length, sched))

def testAlternateMetricsMin(self):
"""Checking that alternate metrics will pass."""

def result2(t, rew):
return dict(time_total_s=t, neg_mean_loss=rew)

sched = HyperBandScheduler(
time_attr="time_total_s", metric="mean_loss", mode="min")
stats = self.default_statistics()

for i in range(stats["max_trials"]):
t = Trial("__fake")
sched.on_trial_add(None, t)
runner = _MockTrialRunner(sched)

big_bracket = sched._hyperbands[0][-1]

for trl in big_bracket.current_trials():
runner._launch_trial(trl)
current_length = len(big_bracket.current_trials())

# Provides results from 0 to 8 in order, keeping the last one running
for i, trl in enumerate(big_bracket.current_trials()):
action = sched.on_trial_result(runner, trl, result2(1, i))
runner.process_action(trl, action)

new_length = len(big_bracket.current_trials())
self.assertEqual(action, TrialScheduler.CONTINUE)
self.assertEqual(new_length, self.downscale(current_length, sched))

def testJumpingTime(self):
sched, mock_runner = self.schedulerSetup(81)
big_bracket = sched._hyperbands[0][-1]
Expand Down Expand Up @@ -1046,6 +1104,36 @@ def result2(t, rew):
scheduler.on_trial_result(None, t2, result2(6, 0)),
TrialScheduler.CONTINUE)

def testAlternateMetricsMin(self):
def result2(t, rew):
return dict(training_iteration=t, neg_mean_loss=rew)

scheduler = AsyncHyperBandScheduler(
grace_period=1,
time_attr="training_iteration",
metric="mean_loss",
mode="min",
brackets=1)
t1 = Trial("PPO") # mean is 450, max 900, t_max=10
t2 = Trial("PPO") # mean is 450, max 450, t_max=5
scheduler.on_trial_add(None, t1)
scheduler.on_trial_add(None, t2)
for i in range(10):
self.assertEqual(
scheduler.on_trial_result(None, t1, result2(i, i * 100)),
TrialScheduler.CONTINUE)
for i in range(5):
self.assertEqual(
scheduler.on_trial_result(None, t2, result2(i, 450)),
TrialScheduler.CONTINUE)
scheduler.on_trial_complete(None, t1, result2(10, 1000))
self.assertEqual(
scheduler.on_trial_result(None, t2, result2(5, 450)),
TrialScheduler.CONTINUE)
self.assertEqual(
scheduler.on_trial_result(None, t2, result2(6, 0)),
TrialScheduler.CONTINUE)


if __name__ == "__main__":
unittest.main(verbosity=2)

0 comments on commit df94ade

Please sign in to comment.