diff --git a/python/ray/tune/schedulers/async_hyperband.py b/python/ray/tune/schedulers/async_hyperband.py index 6c056934ac60..cd785af5beb8 100644 --- a/python/ray/tune/schedulers/async_hyperband.py +++ b/python/ray/tune/schedulers/async_hyperband.py @@ -29,7 +29,7 @@ class AsyncHyperBandScheduler(FIFOScheduler): with `time_attr`, this may refer to any objective value. Stopping procedures will use this attribute. mode (str): One of {min, max}. Determines whether objective is minimizing - maximizing the metric attribute + or maximizing the metric attribute max_t (float): max time units per trial. Trials will be stopped after max_t time units (determined by time_attr) have passed. grace_period (float): Only stop trials at least this old in time. @@ -42,6 +42,7 @@ class AsyncHyperBandScheduler(FIFOScheduler): def __init__(self, time_attr="training_iteration", + reward_attr=None, metric="episode_reward_mean", mode="max", max_t=100, @@ -54,6 +55,13 @@ def __init__(self, assert reduction_factor > 1, "Reduction Factor not valid!" assert brackets > 0, "brackets must be positive!" assert mode in ["min", "max"], "mode must be 'min' or 'max'!" + + if reward_attr is not None: + mode = "max" + metric = reward_attr + logger.warning("`reward_attr` will be depreciated!" + "Consider using `metric` and `mode`.") + FIFOScheduler.__init__(self) self._reduction_factor = reduction_factor self._max_t = max_t diff --git a/python/ray/tune/schedulers/hyperband.py b/python/ray/tune/schedulers/hyperband.py index 25348828f8fa..7d49142e2d25 100644 --- a/python/ray/tune/schedulers/hyperband.py +++ b/python/ray/tune/schedulers/hyperband.py @@ -67,7 +67,7 @@ class HyperBandScheduler(FIFOScheduler): with `time_attr`, this may refer to any objective value. Stopping procedures will use this attribute. mode (str): One of {min, max}. Determines whether objective is minimizing - maximizing the metric attribute + or maximizing the metric attribute max_t (int): max time units per trial. Trials will be stopped after max_t time units (determined by time_attr) have passed. The scheduler will terminate trials after this time has passed. @@ -77,11 +77,19 @@ class HyperBandScheduler(FIFOScheduler): def __init__(self, time_attr="training_iteration", + reward_attr=None, metric="episode_reward_mean", mode="max", max_t=81): assert max_t > 0, "Max (time_attr) not valid!" assert mode in ["min", "max"], "mode must be 'min' or 'max'!" + + if reward_attr is not None: + mode = "max" + metric = reward_attr + logger.warning("`reward_attr` will be depreciated!" + "Consider using `metric` and `mode`.") + FIFOScheduler.__init__(self) self._eta = 3 self._s_max_1 = 5 diff --git a/python/ray/tune/schedulers/median_stopping_rule.py b/python/ray/tune/schedulers/median_stopping_rule.py index 93cf5de6079c..5a83cf04f12e 100644 --- a/python/ray/tune/schedulers/median_stopping_rule.py +++ b/python/ray/tune/schedulers/median_stopping_rule.py @@ -26,7 +26,7 @@ class MedianStoppingRule(FIFOScheduler): with `time_attr`, this may refer to any objective value. Stopping procedures will use this attribute. mode (str): One of {min, max}. Determines whether objective is minimizing - maximizing the metric attribute + or maximizing the metric attribute grace_period (float): Only stop trials at least this old in time. The units are the same as the attribute named by `time_attr`. min_samples_required (int): Min samples to compute median over. @@ -39,12 +39,21 @@ class MedianStoppingRule(FIFOScheduler): def __init__(self, time_attr="time_total_s", + reward_attr=None, metric="episode_reward_mean", mode="max", grace_period=60.0, min_samples_required=3, hard_stop=True, verbose=True): + assert mode in ["min", "max"], "mode must be 'min' or 'max'!" + + if reward_attr is not None: + mode = "max" + metric = reward_attr + logger.warning("`reward_attr` will be depreciated!" + "Consider using `metric` and `mode`.") + FIFOScheduler.__init__(self) self._stopped_trials = set() self._completed_trials = set() diff --git a/python/ray/tune/schedulers/pbt.py b/python/ray/tune/schedulers/pbt.py index cb094d21d1e2..c91b5e96e877 100644 --- a/python/ray/tune/schedulers/pbt.py +++ b/python/ray/tune/schedulers/pbt.py @@ -124,7 +124,7 @@ class PopulationBasedTraining(FIFOScheduler): with `time_attr`, this may refer to any objective value. Stopping procedures will use this attribute. mode (str): One of {min, max}. Determines whether objective is minimizing - maximizing the metric attribute + or maximizing the metric attribute perturbation_interval (float): Models will be considered for perturbation at this interval of `time_attr`. Note that perturbation incurs checkpoint overhead, so you shouldn't set this @@ -168,6 +168,7 @@ class PopulationBasedTraining(FIFOScheduler): def __init__(self, time_attr="time_total_s", + reward_attr=None, metric="episode_reward_mean", mode="max", perturbation_interval=60.0, @@ -179,6 +180,15 @@ def __init__(self, raise TuneError( "You must specify at least one of `hyperparam_mutations` or " "`custom_explore_fn` to use PBT.") + + assert mode in ["min", "max"], "mode must be 'min' or 'max'!" + + if reward_attr is not None: + mode = "max" + metric = reward_attr + logger.warning("`reward_attr` will be depreciated!" + "Consider using `metric` and `mode`.") + FIFOScheduler.__init__(self) self._metric = metric self._time_attr = time_attr diff --git a/python/ray/tune/tests/test_trial_scheduler.py b/python/ray/tune/tests/test_trial_scheduler.py index 553477c150b2..929eedf4c6c4 100644 --- a/python/ray/tune/tests/test_trial_scheduler.py +++ b/python/ray/tune/tests/test_trial_scheduler.py @@ -163,6 +163,34 @@ def result2(t, rew): rule.on_trial_result(None, t2, result2(6, 0)), TrialScheduler.CONTINUE) + def testAlternateMetricsMin(self): + def result2(t, rew): + return dict(training_iteration=t, neg_mean_loss=rew) + + rule = MedianStoppingRule( + grace_period=0, + min_samples_required=1, + time_attr="training_iteration", + metric="mean_loss", + mode="min") + t1 = Trial("PPO") # mean is 450, max 900, t_max=10 + t2 = Trial("PPO") # mean is 450, max 450, t_max=5 + for i in range(10): + self.assertEqual( + rule.on_trial_result(None, t1, result2(i, i * 100)), + TrialScheduler.CONTINUE) + for i in range(5): + self.assertEqual( + rule.on_trial_result(None, t2, result2(i, 450)), + TrialScheduler.CONTINUE) + rule.on_trial_complete(None, t1, result2(10, 1000)) + self.assertEqual( + rule.on_trial_result(None, t2, result2(5, 450)), + TrialScheduler.CONTINUE) + self.assertEqual( + rule.on_trial_result(None, t2, result2(6, 0)), + TrialScheduler.CONTINUE) + class _MockTrialExecutor(TrialExecutor): def start_trial(self, trial, checkpoint_obj=None): @@ -526,6 +554,36 @@ def result2(t, rew): self.assertEqual(action, TrialScheduler.CONTINUE) self.assertEqual(new_length, self.downscale(current_length, sched)) + def testAlternateMetricsMin(self): + """Checking that alternate metrics will pass.""" + + def result2(t, rew): + return dict(time_total_s=t, neg_mean_loss=rew) + + sched = HyperBandScheduler( + time_attr="time_total_s", metric="mean_loss", mode="min") + stats = self.default_statistics() + + for i in range(stats["max_trials"]): + t = Trial("__fake") + sched.on_trial_add(None, t) + runner = _MockTrialRunner(sched) + + big_bracket = sched._hyperbands[0][-1] + + for trl in big_bracket.current_trials(): + runner._launch_trial(trl) + current_length = len(big_bracket.current_trials()) + + # Provides results from 0 to 8 in order, keeping the last one running + for i, trl in enumerate(big_bracket.current_trials()): + action = sched.on_trial_result(runner, trl, result2(1, i)) + runner.process_action(trl, action) + + new_length = len(big_bracket.current_trials()) + self.assertEqual(action, TrialScheduler.CONTINUE) + self.assertEqual(new_length, self.downscale(current_length, sched)) + def testJumpingTime(self): sched, mock_runner = self.schedulerSetup(81) big_bracket = sched._hyperbands[0][-1] @@ -1046,6 +1104,36 @@ def result2(t, rew): scheduler.on_trial_result(None, t2, result2(6, 0)), TrialScheduler.CONTINUE) + def testAlternateMetricsMin(self): + def result2(t, rew): + return dict(training_iteration=t, neg_mean_loss=rew) + + scheduler = AsyncHyperBandScheduler( + grace_period=1, + time_attr="training_iteration", + metric="mean_loss", + mode="min", + brackets=1) + t1 = Trial("PPO") # mean is 450, max 900, t_max=10 + t2 = Trial("PPO") # mean is 450, max 450, t_max=5 + scheduler.on_trial_add(None, t1) + scheduler.on_trial_add(None, t2) + for i in range(10): + self.assertEqual( + scheduler.on_trial_result(None, t1, result2(i, i * 100)), + TrialScheduler.CONTINUE) + for i in range(5): + self.assertEqual( + scheduler.on_trial_result(None, t2, result2(i, 450)), + TrialScheduler.CONTINUE) + scheduler.on_trial_complete(None, t1, result2(10, 1000)) + self.assertEqual( + scheduler.on_trial_result(None, t2, result2(5, 450)), + TrialScheduler.CONTINUE) + self.assertEqual( + scheduler.on_trial_result(None, t2, result2(6, 0)), + TrialScheduler.CONTINUE) + if __name__ == "__main__": unittest.main(verbosity=2)