diff --git a/ax/storage/sqa_store/load.py b/ax/storage/sqa_store/load.py index d6cd5b92469..b4461e562d5 100644 --- a/ax/storage/sqa_store/load.py +++ b/ax/storage/sqa_store/load.py @@ -28,6 +28,7 @@ from ax.utils.common.constants import Keys from ax.utils.common.typeutils import not_none from sqlalchemy.orm import defaultload, lazyload, noload +from sqlalchemy.orm.exc import DetachedInstanceError # ---------------------------- Loading `Experiment`. --------------------------- @@ -135,10 +136,13 @@ def _load_experiment( for sqa_metric in experiment_sqa.metrics: sqa_metric.metric_type = 0 - for sqa_trial in experiment_sqa.trials: - for sqa_generator_run in sqa_trial.generator_runs: - for sqa_metric in sqa_generator_run.metrics: - sqa_metric.metric_type = 0 + try: + for sqa_trial in experiment_sqa.trials: + for sqa_generator_run in sqa_trial.generator_runs: + for sqa_metric in sqa_generator_run.metrics: + sqa_metric.metric_type = 0 + except DetachedInstanceError: + pass return decoder.experiment_from_sqa( experiment_sqa=experiment_sqa, diff --git a/ax/storage/sqa_store/tests/test_sqa_store.py b/ax/storage/sqa_store/tests/test_sqa_store.py index 2cbbeda376c..4f448c6767f 100644 --- a/ax/storage/sqa_store/tests/test_sqa_store.py +++ b/ax/storage/sqa_store/tests/test_sqa_store.py @@ -237,30 +237,40 @@ def testLoadExperimentSkipMetricsAndRunners(self) -> None: runner_registry=runner_registry, ) - # Save the experiment to db using the updated registries. - save_experiment(experiment, config=sqa_config) - - # At this point try to load the experiment back without specifying - # updated registries. Confirm that this attempt fails. - with self.assertRaises(SQADecodeError): - loaded_experiment = load_experiment(self.experiment.name) - - # Now load it with the skip_runners_and_metrics argument set. - # The experiment should load (i.e. no exceptions raised) - loaded_experiment = load_experiment( - self.experiment.name, skip_runners_and_metrics=True - ) + for immutable in [True, False]: + experiment = get_experiment_with_custom_runner_and_metric( + constrain_search_space=False + ) + if immutable: + experiment._properties = { + Keys.IMMUTABLE_SEARCH_SPACE_AND_OPT_CONF: True + } + + # Save the experiment to db using the updated registries. + save_experiment(experiment, config=sqa_config) + + # At this point try to load the experiment back without specifying + # updated registries. Confirm that this attempt fails. + with self.assertRaises(SQADecodeError): + loaded_experiment = load_experiment(experiment.name) + + # Now load it with the skip_runners_and_metrics argument set. + # The experiment should load (i.e. no exceptions raised) + loaded_experiment = load_experiment( + experiment.name, skip_runners_and_metrics=True + ) - # Validate that: - # - the runner is not loaded - # - the metric is loaded as a base Metric class (i.e. not CustomTestMetric) - self.assertIs(loaded_experiment.runner, None) - self.assertTrue("custom_test_metric" in loaded_experiment.metrics) - self.assertEqual( - loaded_experiment.metrics["custom_test_metric"].__class__, Metric - ) - self.assertEqual(len(loaded_experiment.trials), 1) - self.assertIs(loaded_experiment.trials[0].runner, None) + # Validate that: + # - the runner is not loaded + # - the metric is loaded as a base Metric class, not CustomTestMetric + self.assertIs(loaded_experiment.runner, None) + self.assertTrue("custom_test_metric" in loaded_experiment.metrics) + self.assertEqual( + loaded_experiment.metrics["custom_test_metric"].__class__, Metric + ) + self.assertEqual(len(loaded_experiment.trials), 1) + self.assertIs(loaded_experiment.trials[0].runner, None) + delete_experiment(exp_name=experiment.name) @patch( f"{Decoder.__module__}.Decoder.generator_run_from_sqa",