From ea11fd512f40120e95a08349d2ff129492779185 Mon Sep 17 00:00:00 2001 From: Justin Yu Date: Wed, 21 Dec 2022 21:51:29 -0800 Subject: [PATCH] [Tune] Fix AxSearch save and nan/inf result handling (#31147) This PR fixes AxSearch saving and handles trials that produce nan/inf metrics properly. Signed-off-by: Justin Yu --- python/ray/tune/search/ax/ax_search.py | 24 +- python/ray/tune/tests/test_searchers.py | 296 +++++++++++++----------- 2 files changed, 179 insertions(+), 141 deletions(-) diff --git a/python/ray/tune/search/ax/ax_search.py b/python/ray/tune/search/ax/ax_search.py index 6661b13a4dce..ce64619c4951 100644 --- a/python/ray/tune/search/ax/ax_search.py +++ b/python/ray/tune/search/ax/ax_search.py @@ -1,7 +1,8 @@ import copy -import pickle +import numpy as np from typing import Dict, List, Optional, Union +from ray import cloudpickle from ray.tune.result import DEFAULT_METRIC from ray.tune.search.sample import ( Categorical, @@ -151,7 +152,7 @@ def __init__( parameter_constraints: Optional[List] = None, outcome_constraints: Optional[List] = None, ax_client: Optional[AxClient] = None, - **ax_kwargs + **ax_kwargs, ): assert ( ax is not None @@ -324,12 +325,21 @@ def on_trial_complete(self, trial_id, result=None, error=False): def _process_result(self, trial_id, result): ax_trial_index = self._live_trial_mapping[trial_id] - metric_dict = {self._metric: (result[self._metric], None)} - outcome_names = [ + metrics_to_include = [self._metric] + [ oc.metric.name for oc in self._ax.experiment.optimization_config.outcome_constraints ] - metric_dict.update({on: (result[on], None) for on in outcome_names}) + metric_dict = {} + for key in metrics_to_include: + val = result[key] + if np.isnan(val) or np.isinf(val): + # Don't report trials with NaN metrics to Ax + self._ax.abandon_trial( + trial_index=ax_trial_index, + reason=f"nan/inf metrics reported by {trial_id}", + ) + return + metric_dict[key] = (val, None) self._ax.complete_trial(trial_index=ax_trial_index, raw_data=metric_dict) @staticmethod @@ -415,9 +425,9 @@ def resolve_value(par, domain): def save(self, checkpoint_path: str): save_object = self.__dict__ with open(checkpoint_path, "wb") as outputFile: - pickle.dump(save_object, outputFile) + cloudpickle.dump(save_object, outputFile) def restore(self, checkpoint_path: str): with open(checkpoint_path, "rb") as inputFile: - save_object = pickle.load(inputFile) + save_object = cloudpickle.load(inputFile) self.__dict__.update(save_object) diff --git a/python/ray/tune/tests/test_searchers.py b/python/ray/tune/tests/test_searchers.py index 74f0808c69a7..790abc889245 100644 --- a/python/ray/tune/tests/test_searchers.py +++ b/python/ray/tune/tests/test_searchers.py @@ -1,14 +1,16 @@ -import unittest -import tempfile -import shutil -import os +import contextlib from copy import deepcopy - import numpy as np +import os +import shutil +import tempfile +import unittest +from unittest.mock import patch import ray from ray import tune from ray.tune.result import TRAINING_ITERATION +from ray.tune.search import ConcurrencyLimiter def _invalid_objective(config): @@ -38,6 +40,8 @@ class InvalidValuesTest(unittest.TestCase): Test searcher handling of invalid values (NaN, -inf, inf). Implicitly tests automatic config conversion and default (anonymous) mode handling. + Also tests that searcher save doesn't throw any errors during + experiment checkpointing. """ def setUp(self): @@ -62,6 +66,19 @@ def assertCorrectExperimentOutput(self, analysis): self.assertIn(best_trial.config["list"], ([1, 2, 3], (1, 2, 3))) self.assertEqual(best_trial.config["num"], 4) + @contextlib.contextmanager + def check_searcher_checkpoint_errors_scope(self): + buffer = [] + from ray.tune.execution.trial_runner import logger + + with patch.object(logger, "warning", lambda x: buffer.append(x)): + yield + + assert not any( + "Trial Runner checkpointing failed: Can't pickle local object" in x + for x in buffer + ), "Searcher checkpointing failed (unable to serialize)." + def testAxManualSetup(self): from ray.tune.search.ax import AxSearch from ax.service.ax_client import AxClient @@ -93,67 +110,74 @@ def testAxManualSetup(self): def testAx(self): from ray.tune.search.ax import AxSearch - searcher = AxSearch(random_seed=4321) + searcher = ConcurrencyLimiter(AxSearch(random_seed=4321), max_concurrent=2) + + with self.check_searcher_checkpoint_errors_scope(): + # Make sure enough samples are used so that Ax actually fits a model + # for config suggestion + out = tune.run( + _invalid_objective, + search_alg=searcher, + metric="_metric", + mode="max", + num_samples=16, + reuse_actors=False, + config=self.config, + ) - out = tune.run( - _invalid_objective, - search_alg=searcher, - metric="_metric", - mode="max", - num_samples=4, - reuse_actors=False, - config=self.config, - ) self.assertCorrectExperimentOutput(out) def testBayesOpt(self): from ray.tune.search.bayesopt import BayesOptSearch - out = tune.run( - _invalid_objective, - # At least one nan, inf, -inf and float - search_alg=BayesOptSearch(random_state=1234), - config=self.config, - metric="_metric", - mode="max", - num_samples=8, - reuse_actors=False, - ) + with self.check_searcher_checkpoint_errors_scope(): + out = tune.run( + _invalid_objective, + # At least one nan, inf, -inf and float + search_alg=BayesOptSearch(random_state=1234), + config=self.config, + metric="_metric", + mode="max", + num_samples=8, + reuse_actors=False, + ) self.assertCorrectExperimentOutput(out) def testBlendSearch(self): from ray.tune.search.flaml import BlendSearch - out = tune.run( - _invalid_objective, - search_alg=BlendSearch( - points_to_evaluate=[ - {"report": 1.0}, - {"report": 2.1}, - {"report": 3.1}, - {"report": 4.1}, - ] - ), - config=self.config, - metric="_metric", - mode="max", - num_samples=16, - reuse_actors=False, - ) + with self.check_searcher_checkpoint_errors_scope(): + out = tune.run( + _invalid_objective, + search_alg=BlendSearch( + points_to_evaluate=[ + {"report": 1.0}, + {"report": 2.1}, + {"report": 3.1}, + {"report": 4.1}, + ] + ), + config=self.config, + metric="_metric", + mode="max", + num_samples=16, + reuse_actors=False, + ) self.assertCorrectExperimentOutput(out) def testBOHB(self): from ray.tune.search.bohb import TuneBOHB - out = tune.run( - _invalid_objective, - search_alg=TuneBOHB(seed=1000), - config=self.config, - metric="_metric", - mode="max", - num_samples=8, - reuse_actors=False, - ) + with self.check_searcher_checkpoint_errors_scope(): + out = tune.run( + _invalid_objective, + search_alg=TuneBOHB(seed=1000), + config=self.config, + metric="_metric", + mode="max", + num_samples=8, + reuse_actors=False, + ) self.assertCorrectExperimentOutput(out) def testCFO(self): @@ -163,22 +187,23 @@ def testCFO(self): ) from ray.tune.search.flaml import CFO - out = tune.run( - _invalid_objective, - search_alg=CFO( - points_to_evaluate=[ - {"report": 1.0}, - {"report": 2.1}, - {"report": 3.1}, - {"report": 4.1}, - ] - ), - config=self.config, - metric="_metric", - mode="max", - num_samples=16, - reuse_actors=False, - ) + with self.check_searcher_checkpoint_errors_scope(): + out = tune.run( + _invalid_objective, + search_alg=CFO( + points_to_evaluate=[ + {"report": 1.0}, + {"report": 2.1}, + {"report": 3.1}, + {"report": 4.1}, + ] + ), + config=self.config, + metric="_metric", + mode="max", + num_samples=16, + reuse_actors=False, + ) self.assertCorrectExperimentOutput(out) def testDragonfly(self): @@ -186,45 +211,48 @@ def testDragonfly(self): np.random.seed(1000) # At least one nan, inf, -inf and float - out = tune.run( - _invalid_objective, - search_alg=DragonflySearch(domain="euclidean", optimizer="random"), - config=self.config, - metric="_metric", - mode="max", - num_samples=8, - reuse_actors=False, - ) + with self.check_searcher_checkpoint_errors_scope(): + out = tune.run( + _invalid_objective, + search_alg=DragonflySearch(domain="euclidean", optimizer="random"), + config=self.config, + metric="_metric", + mode="max", + num_samples=8, + reuse_actors=False, + ) self.assertCorrectExperimentOutput(out) def testHEBO(self): from ray.tune.search.hebo import HEBOSearch - out = tune.run( - _invalid_objective, - # At least one nan, inf, -inf and float - search_alg=HEBOSearch(random_state_seed=123), - config=self.config, - metric="_metric", - mode="max", - num_samples=8, - reuse_actors=False, - ) + with self.check_searcher_checkpoint_errors_scope(): + out = tune.run( + _invalid_objective, + # At least one nan, inf, -inf and float + search_alg=HEBOSearch(random_state_seed=123), + config=self.config, + metric="_metric", + mode="max", + num_samples=8, + reuse_actors=False, + ) self.assertCorrectExperimentOutput(out) def testHyperopt(self): from ray.tune.search.hyperopt import HyperOptSearch - out = tune.run( - _invalid_objective, - # At least one nan, inf, -inf and float - search_alg=HyperOptSearch(random_state_seed=1234), - config=self.config, - metric="_metric", - mode="max", - num_samples=8, - reuse_actors=False, - ) + with self.check_searcher_checkpoint_errors_scope(): + out = tune.run( + _invalid_objective, + # At least one nan, inf, -inf and float + search_alg=HyperOptSearch(random_state_seed=1234), + config=self.config, + metric="_metric", + mode="max", + num_samples=8, + reuse_actors=False, + ) self.assertCorrectExperimentOutput(out) def testNevergrad(self): @@ -233,14 +261,15 @@ def testNevergrad(self): np.random.seed(2020) # At least one nan, inf, -inf and float - out = tune.run( - _invalid_objective, - search_alg=NevergradSearch(optimizer=ng.optimizers.RandomSearch), - config=self.config, - mode="max", - num_samples=16, - reuse_actors=False, - ) + with self.check_searcher_checkpoint_errors_scope(): + out = tune.run( + _invalid_objective, + search_alg=NevergradSearch(optimizer=ng.optimizers.RandomSearch), + config=self.config, + mode="max", + num_samples=16, + reuse_actors=False, + ) self.assertCorrectExperimentOutput(out) def testOptuna(self): @@ -249,15 +278,16 @@ def testOptuna(self): np.random.seed(1000) # At least one nan, inf, -inf and float - out = tune.run( - _invalid_objective, - search_alg=OptunaSearch(sampler=RandomSampler(seed=1234)), - config=self.config, - metric="_metric", - mode="max", - num_samples=8, - reuse_actors=False, - ) + with self.check_searcher_checkpoint_errors_scope(): + out = tune.run( + _invalid_objective, + search_alg=OptunaSearch(sampler=RandomSampler(seed=1234)), + config=self.config, + metric="_metric", + mode="max", + num_samples=8, + reuse_actors=False, + ) self.assertCorrectExperimentOutput(out) def testOptunaReportTooOften(self): @@ -284,35 +314,33 @@ def testSkopt(self): np.random.seed(1234) # At least one nan, inf, -inf and float - out = tune.run( - _invalid_objective, - search_alg=SkOptSearch(), - config=self.config, - metric="_metric", - mode="max", - num_samples=8, - reuse_actors=False, - ) + with self.check_searcher_checkpoint_errors_scope(): + out = tune.run( + _invalid_objective, + search_alg=SkOptSearch(), + config=self.config, + metric="_metric", + mode="max", + num_samples=8, + reuse_actors=False, + ) self.assertCorrectExperimentOutput(out) def testZOOpt(self): - self.skipTest( - "Recent ZOOpt versions fail handling invalid values gracefully. " - "Skipping until we or they found a workaround. " - ) from ray.tune.search.zoopt import ZOOptSearch np.random.seed(1000) # At least one nan, inf, -inf and float - out = tune.run( - _invalid_objective, - search_alg=ZOOptSearch(budget=100, parallel_num=4), - config=self.config, - metric="_metric", - mode="max", - num_samples=8, - reuse_actors=False, - ) + with self.check_searcher_checkpoint_errors_scope(): + out = tune.run( + _invalid_objective, + search_alg=ZOOptSearch(budget=100, parallel_num=4), + config=self.config, + metric="_metric", + mode="max", + num_samples=8, + reuse_actors=False, + ) self.assertCorrectExperimentOutput(out)