From 49ac816ee8ac156670392d79bff45134602ecab6 Mon Sep 17 00:00:00 2001 From: Bernie Beckerman Date: Tue, 20 Jun 2023 17:20:31 -0700 Subject: [PATCH] Load Experiment without runners and metrics in the case where search space and optimization config are immutable (#1656) Summary: Pull Request resolved: https://github.com/facebook/Ax/pull/1656 `load_experiment` was previously failing when `skip_runners_and_metrics=True` for experiments with immutable search space and optimization config. See [Lena's comment](https://www.internalfb.com/diff/D46595953?dst_version_fbid=263061986396665&transaction_fbid=639853311380999) for more detail. Added in V4: uses the metric type corresponding to the decoder's entry for Ax base Metric, instead of 0, which is not always the correct entry. Differential Revision: D46595953 fbshipit-source-id: 1f9e24735c57580ff85a8d500dde027640daad65 --- ax/storage/sqa_store/load.py | 22 +- ax/storage/sqa_store/tests/test_sqa_store.py | 203 +++++++++++-------- ax/utils/testing/core_stubs.py | 5 + 3 files changed, 136 insertions(+), 94 deletions(-) diff --git a/ax/storage/sqa_store/load.py b/ax/storage/sqa_store/load.py index d6cd5b92469..17e108d6184 100644 --- a/ax/storage/sqa_store/load.py +++ b/ax/storage/sqa_store/load.py @@ -9,6 +9,7 @@ from ax.core.experiment import Experiment from ax.core.generator_run import GeneratorRun +from ax.core.metric import Metric from ax.core.trial import Trial from ax.exceptions.core import ExperimentNotFoundError, ObjectNotFoundError from ax.modelbridge.generation_strategy import GenerationStrategy @@ -28,6 +29,7 @@ from ax.utils.common.constants import Keys from ax.utils.common.typeutils import not_none from sqlalchemy.orm import defaultload, lazyload, noload +from sqlalchemy.orm.exc import DetachedInstanceError # ---------------------------- Loading `Experiment`. --------------------------- @@ -132,13 +134,21 @@ def _load_experiment( # "lower_is_better" or any other attribute we want to include. This can be # implemented in the future if we need to. if skip_runners_and_metrics: + base_metric_type_int = decoder.config.metric_registry[Metric] for sqa_metric in experiment_sqa.metrics: - sqa_metric.metric_type = 0 - - for sqa_trial in experiment_sqa.trials: - for sqa_generator_run in sqa_trial.generator_runs: - for sqa_metric in sqa_generator_run.metrics: - sqa_metric.metric_type = 0 + sqa_metric.metric_type = base_metric_type_int + imm_OC_and_SS = _get_experiment_immutable_opt_config_and_search_space( + experiment_name=experiment_name, exp_sqa_class=exp_sqa_class + ) + assign_metric_on_gr = not reduced_state and not imm_OC_and_SS + if assign_metric_on_gr: + try: + for sqa_trial in experiment_sqa.trials: + for sqa_generator_run in sqa_trial.generator_runs: + for sqa_metric in sqa_generator_run.metrics: + sqa_metric.metric_type = base_metric_type_int + except DetachedInstanceError as e: + raise e return decoder.experiment_from_sqa( experiment_sqa=experiment_sqa, diff --git a/ax/storage/sqa_store/tests/test_sqa_store.py b/ax/storage/sqa_store/tests/test_sqa_store.py index 3ad9235ede9..bb93f804a67 100644 --- a/ax/storage/sqa_store/tests/test_sqa_store.py +++ b/ax/storage/sqa_store/tests/test_sqa_store.py @@ -237,30 +237,37 @@ def testLoadExperimentSkipMetricsAndRunners(self) -> None: runner_registry=runner_registry, ) - # Save the experiment to db using the updated registries. - save_experiment(experiment, config=sqa_config) + for immutable in [True, False]: + experiment = get_experiment_with_custom_runner_and_metric( + constrain_search_space=False, + immutable=immutable, + ) - # At this point try to load the experiment back without specifying - # updated registries. Confirm that this attempt fails. - with self.assertRaises(SQADecodeError): - loaded_experiment = load_experiment(self.experiment.name) + # Save the experiment to db using the updated registries. + save_experiment(experiment, config=sqa_config) - # Now load it with the skip_runners_and_metrics argument set. - # The experiment should load (i.e. no exceptions raised) - loaded_experiment = load_experiment( - self.experiment.name, skip_runners_and_metrics=True - ) + # At this point try to load the experiment back without specifying + # updated registries. Confirm that this attempt fails. + with self.assertRaises(SQADecodeError): + loaded_experiment = load_experiment(experiment.name) - # Validate that: - # - the runner is not loaded - # - the metric is loaded as a base Metric class (i.e. not CustomTestMetric) - self.assertIs(loaded_experiment.runner, None) - self.assertTrue("custom_test_metric" in loaded_experiment.metrics) - self.assertEqual( - loaded_experiment.metrics["custom_test_metric"].__class__, Metric - ) - self.assertEqual(len(loaded_experiment.trials), 1) - self.assertIs(loaded_experiment.trials[0].runner, None) + # Now load it with the skip_runners_and_metrics argument set. + # The experiment should load (i.e. no exceptions raised) + loaded_experiment = load_experiment( + experiment.name, skip_runners_and_metrics=True + ) + + # Validate that: + # - the runner is not loaded + # - the metric is loaded as a base Metric class, not CustomTestMetric + self.assertIs(loaded_experiment.runner, None) + self.assertTrue("custom_test_metric" in loaded_experiment.metrics) + self.assertEqual( + loaded_experiment.metrics["custom_test_metric"].__class__, Metric + ) + self.assertEqual(len(loaded_experiment.trials), 1) + self.assertIs(loaded_experiment.trials[0].runner, None) + delete_experiment(exp_name=experiment.name) @patch( f"{Decoder.__module__}.Decoder.generator_run_from_sqa", @@ -278,74 +285,94 @@ def testLoadExperimentSkipMetricsAndRunners(self) -> None: def test_ExperimentSaveAndLoadReducedState( self, _mock_exp_from_sqa, _mock_trial_from_sqa, _mock_gr_from_sqa ) -> None: - # 1. No abandoned arms + no trials case, reduced state should be the - # same as non-reduced state. - exp = get_experiment_with_multi_objective() - save_experiment(exp) - loaded_experiment = load_experiment(exp.name, reduced_state=True) - self.assertEqual(loaded_experiment, exp) - # Make sure decoder function was called with `reduced_state=True`. - self.assertTrue(_mock_exp_from_sqa.call_args[1].get("reduced_state")) - _mock_exp_from_sqa.reset_mock() - - # 2. Try case with abandoned arms. - exp = get_experiment_with_batch_trial(constrain_search_space=False) - save_experiment(exp) - loaded_experiment = load_experiment(exp.name, reduced_state=True) - # Experiments are not the same, because one has abandoned arms info. - self.assertNotEqual(loaded_experiment, exp) - # Remove all abandoned arms and check that all else is equal as expected. - t = checked_cast(BatchTrial, exp.trials[0]) - t._abandoned_arms_metadata = {} - self.assertEqual(loaded_experiment, exp) - # Make sure that all relevant decoding functions were called with - # `reduced_state=True` and correct number of times. - self.assertTrue(_mock_exp_from_sqa.call_args[1].get("reduced_state")) - self.assertTrue(_mock_trial_from_sqa.call_args[1].get("reduced_state")) - # 2 generator runs + regular and status quo. - self.assertTrue(_mock_gr_from_sqa.call_args[1].get("reduced_state")) - _mock_exp_from_sqa.reset_mock() - _mock_trial_from_sqa.reset_mock() - _mock_gr_from_sqa.reset_mock() + for skip_runners_and_metrics in [False, True]: + # 1. No abandoned arms + no trials case, reduced state should be the + # same as non-reduced state. + exp = get_experiment_with_multi_objective() + save_experiment(exp) + loaded_experiment = load_experiment( + exp.name, + reduced_state=True, + skip_runners_and_metrics=skip_runners_and_metrics, + ) + loaded_experiment.runner = exp.runner + self.assertEqual(loaded_experiment, exp) + # Make sure decoder function was called with `reduced_state=True`. + self.assertTrue(_mock_exp_from_sqa.call_args[1].get("reduced_state")) + _mock_exp_from_sqa.reset_mock() + delete_experiment(exp_name=exp.name) - # 3. Try case with model state and search space + opt.config on a - # generator run in the experiment. - gr = Models.SOBOL(experiment=exp).gen(1) - # Expecting model kwargs to have 6 fields (seed, deduplicate, init_position, - # scramble, generated_points, fallback_to_sample_polytope) - # and the rest of model-state info on generator run to have values too. - mkw = gr._model_kwargs - self.assertIsNotNone(mkw) - self.assertEqual(len(mkw), 6) - bkw = gr._bridge_kwargs - self.assertIsNotNone(bkw) - self.assertEqual(len(bkw), 8) - ms = gr._model_state_after_gen - self.assertIsNotNone(ms) - self.assertEqual(len(ms), 2) - gm = gr._gen_metadata - self.assertIsNotNone(gm) - self.assertEqual(len(gm), 0) - self.assertIsNotNone(gr._search_space, gr.optimization_config) - exp.new_trial(generator_run=gr) - save_experiment(exp) - # Make sure that all relevant decoding functions were called with - # `reduced_state=True` and correct number of times. - loaded_experiment = load_experiment(exp.name, reduced_state=True) - self.assertTrue(_mock_exp_from_sqa.call_args[1].get("reduced_state")) - self.assertTrue(_mock_trial_from_sqa.call_args[1].get("reduced_state")) - # 2 generator runs from trial #0 + 1 from trial #1. - self.assertTrue(_mock_gr_from_sqa.call_args[1].get("reduced_state")) - self.assertNotEqual(loaded_experiment, exp) - # Remove all fields that are not part of the reduced state and - # check that everything else is equal as expected. - exp.trials.get(1).generator_run._model_kwargs = None - exp.trials.get(1).generator_run._bridge_kwargs = None - exp.trials.get(1).generator_run._gen_metadata = None - exp.trials.get(1).generator_run._model_state_after_gen = None - exp.trials.get(1).generator_run._search_space = None - exp.trials.get(1).generator_run._optimization_config = None - self.assertEqual(loaded_experiment, exp) + # 2. Try case with abandoned arms. + exp = get_experiment_with_batch_trial(constrain_search_space=False) + save_experiment(exp) + loaded_experiment = load_experiment( + exp.name, + reduced_state=True, + skip_runners_and_metrics=skip_runners_and_metrics, + ) + # Experiments are not the same, because one has abandoned arms info. + self.assertNotEqual(loaded_experiment, exp) + # Remove all abandoned arms and check that all else is equal as expected. + t = checked_cast(BatchTrial, exp.trials[0]) + t._abandoned_arms_metadata = {} + loaded_experiment.runner = exp.runner + loaded_experiment._trials[0]._runner = exp._trials[0]._runner + self.assertEqual(loaded_experiment, exp) + # Make sure that all relevant decoding functions were called with + # `reduced_state=True` and correct number of times. + self.assertTrue(_mock_exp_from_sqa.call_args[1].get("reduced_state")) + self.assertTrue(_mock_trial_from_sqa.call_args[1].get("reduced_state")) + # 2 generator runs + regular and status quo. + self.assertTrue(_mock_gr_from_sqa.call_args[1].get("reduced_state")) + _mock_exp_from_sqa.reset_mock() + _mock_trial_from_sqa.reset_mock() + _mock_gr_from_sqa.reset_mock() + + # 3. Try case with model state and search space + opt.config on a + # generator run in the experiment. + gr = Models.SOBOL(experiment=exp).gen(1) + # Expecting model kwargs to have 6 fields (seed, deduplicate, init_position, + # scramble, generated_points, fallback_to_sample_polytope) + # and the rest of model-state info on generator run to have values too. + mkw = gr._model_kwargs + self.assertIsNotNone(mkw) + self.assertEqual(len(mkw), 6) + bkw = gr._bridge_kwargs + self.assertIsNotNone(bkw) + self.assertEqual(len(bkw), 8) + ms = gr._model_state_after_gen + self.assertIsNotNone(ms) + self.assertEqual(len(ms), 2) + gm = gr._gen_metadata + self.assertIsNotNone(gm) + self.assertEqual(len(gm), 0) + self.assertIsNotNone(gr._search_space, gr.optimization_config) + exp.new_trial(generator_run=gr) + save_experiment(exp) + # Make sure that all relevant decoding functions were called with + # `reduced_state=True` and correct number of times. + loaded_experiment = load_experiment( + exp.name, + reduced_state=True, + skip_runners_and_metrics=skip_runners_and_metrics, + ) + loaded_experiment.runner = exp.runner + loaded_experiment._trials[0]._runner = exp._trials[0]._runner + self.assertTrue(_mock_exp_from_sqa.call_args[1].get("reduced_state")) + self.assertTrue(_mock_trial_from_sqa.call_args[1].get("reduced_state")) + # 2 generator runs from trial #0 + 1 from trial #1. + self.assertTrue(_mock_gr_from_sqa.call_args[1].get("reduced_state")) + self.assertNotEqual(loaded_experiment, exp) + # Remove all fields that are not part of the reduced state and + # check that everything else is equal as expected. + exp.trials.get(1).generator_run._model_kwargs = None + exp.trials.get(1).generator_run._bridge_kwargs = None + exp.trials.get(1).generator_run._gen_metadata = None + exp.trials.get(1).generator_run._model_state_after_gen = None + exp.trials.get(1).generator_run._search_space = None + exp.trials.get(1).generator_run._optimization_config = None + self.assertEqual(loaded_experiment, exp) + delete_experiment(exp_name=exp.name) def testMTExperimentSaveAndLoad(self) -> None: experiment = get_multi_type_experiment(add_trials=True) diff --git a/ax/utils/testing/core_stubs.py b/ax/utils/testing/core_stubs.py index 093ecf8a74a..1ae869287f6 100644 --- a/ax/utils/testing/core_stubs.py +++ b/ax/utils/testing/core_stubs.py @@ -95,6 +95,7 @@ from ax.models.winsorization_config import WinsorizationConfig from ax.runners.synthetic import SyntheticRunner from ax.service.utils.scheduler_options import SchedulerOptions, TrialType +from ax.utils.common.constants import Keys from ax.utils.common.logger import get_logger from ax.utils.common.typeutils import checked_cast, not_none from ax.utils.measurement.synthetic_functions import branin @@ -146,6 +147,7 @@ def get_experiment_with_map_data_type() -> Experiment: def get_experiment_with_custom_runner_and_metric( constrain_search_space: bool = True, + immutable: bool = False, ) -> Experiment: # Create experiment with custom runner and metric @@ -172,6 +174,9 @@ def get_experiment_with_custom_runner_and_metric( experiment.attach_data(get_data(metric_name="custom_test_metric")) trial.mark_completed() + if immutable: + experiment._properties = {Keys.IMMUTABLE_SEARCH_SPACE_AND_OPT_CONF: True} + return experiment