From 3ac429ea44e1419a3ed7359e77964fc5416c05b4 Mon Sep 17 00:00:00 2001 From: Sam Daulton Date: Tue, 2 Apr 2024 16:19:45 -0700 Subject: [PATCH] fix skip_runners_and_metrics for metrics on generator runs with mutable multi-objective optimization config Summary: This fixes load_experiment with `skip_runners_and_metricsmetrics` by setting properties SQAMetrics on generator runs that have multi-objective intent, for experiments with mutable optimization configs Reviewed By: Balandat Differential Revision: D55663714 --- ax/storage/sqa_store/decoder.py | 1 - ax/storage/sqa_store/load.py | 30 +++++++++++++----- ax/storage/sqa_store/tests/test_sqa_store.py | 33 ++++++++++++++++---- ax/utils/testing/core_stubs.py | 7 ++++- 4 files changed, 55 insertions(+), 16 deletions(-) diff --git a/ax/storage/sqa_store/decoder.py b/ax/storage/sqa_store/decoder.py index f23e7b3de0a..caaa2dab008 100644 --- a/ax/storage/sqa_store/decoder.py +++ b/ax/storage/sqa_store/decoder.py @@ -635,7 +635,6 @@ def generator_run_from_sqa( for arm_sqa in generator_run_sqa.arms: arms.append(self.arm_from_sqa(arm_sqa=arm_sqa)) weights.append(arm_sqa.weight) - if not reduced_state and not immutable_search_space_and_opt_config: ( opt_config, diff --git a/ax/storage/sqa_store/load.py b/ax/storage/sqa_store/load.py index b203d1f7a1a..b8382c521c6 100644 --- a/ax/storage/sqa_store/load.py +++ b/ax/storage/sqa_store/load.py @@ -25,6 +25,7 @@ SQAExperiment, SQAGenerationStrategy, SQAGeneratorRun, + SQAMetric, SQATrial, ) from ax.storage.sqa_store.sqa_config import SQAConfig @@ -139,13 +140,9 @@ def _load_experiment( if skip_runners_and_metrics: base_metric_type_int = decoder.config.metric_registry[Metric] for sqa_metric in experiment_sqa.metrics: - sqa_metric.metric_type = base_metric_type_int - # Handle multi-objective metrics that are not directly attached to - # the experiment - if sqa_metric.intent == MetricIntent.MULTI_OBJECTIVE: - if sqa_metric.properties is None: - sqa_metric.properties = {} - sqa_metric.properties["skip_runners_and_metrics"] = True + _set_sqa_metric_to_base_type( + sqa_metric, base_metric_type_int=base_metric_type_int + ) assign_metric_on_gr = not reduced_state and not imm_OC_and_SS if assign_metric_on_gr: @@ -153,7 +150,9 @@ def _load_experiment( for sqa_trial in experiment_sqa.trials: for sqa_generator_run in sqa_trial.generator_runs: for sqa_metric in sqa_generator_run.metrics: - sqa_metric.metric_type = base_metric_type_int + _set_sqa_metric_to_base_type( + sqa_metric, base_metric_type_int=base_metric_type_int + ) except DetachedInstanceError as e: raise DetachedInstanceError( "Unable to retrieve metric from SQA generator run, possibly due " @@ -250,6 +249,21 @@ def _get_trials_sqa( return sqa_trials +def _set_sqa_metric_to_base_type( + sqa_metric: SQAMetric, base_metric_type_int: int +) -> None: + """Sets metric type to base type, since we don't want to load + the metric class from the DB. + """ + sqa_metric.metric_type = base_metric_type_int + # Handle multi-objective metrics that are not directly attached to + # the experiment + if sqa_metric.intent == MetricIntent.MULTI_OBJECTIVE: + if sqa_metric.properties is None: + sqa_metric.properties = {} + sqa_metric.properties["skip_runners_and_metrics"] = True + + def _get_experiment_sqa_reduced_state( experiment_name: str, exp_sqa_class: Type[SQAExperiment], diff --git a/ax/storage/sqa_store/tests/test_sqa_store.py b/ax/storage/sqa_store/tests/test_sqa_store.py index d7760e5807b..ad87ab8c9df 100644 --- a/ax/storage/sqa_store/tests/test_sqa_store.py +++ b/ax/storage/sqa_store/tests/test_sqa_store.py @@ -16,7 +16,7 @@ from ax.core.batch_trial import BatchTrial, LifecycleStage from ax.core.generator_run import GeneratorRun from ax.core.metric import Metric -from ax.core.objective import Objective +from ax.core.objective import MultiObjective, Objective from ax.core.outcome_constraint import OutcomeConstraint from ax.core.parameter import ParameterType, RangeParameter from ax.core.runner import Runner @@ -248,11 +248,18 @@ def test_LoadExperimentSkipMetricsAndRunners(self) -> None: for immutable in [True, False]: for multi_objective in [True, False]: + custom_metric_names = ["custom_test_metric"] experiment = get_experiment_with_custom_runner_and_metric( constrain_search_space=False, immutable=immutable, multi_objective=multi_objective, ) + if multi_objective: + custom_metric_names.extend(["m1", "m3"]) + for metric_name in custom_metric_names: + self.assertEqual( + experiment.metrics[metric_name].__class__, CustomTestMetric + ) # Save the experiment to db using the updated registries. save_experiment(experiment, config=sqa_config) @@ -272,13 +279,27 @@ def test_LoadExperimentSkipMetricsAndRunners(self) -> None: # - the runner is not loaded # - the metric is loaded as a base Metric class, not CustomTestMetric self.assertIs(loaded_experiment.runner, None) - self.assertTrue("custom_test_metric" in loaded_experiment.metrics) - self.assertEqual( - loaded_experiment.metrics["custom_test_metric"].__class__, Metric - ) + + for metric_name in custom_metric_names: + self.assertTrue(metric_name in loaded_experiment.metrics) + self.assertEqual( + loaded_experiment.metrics["custom_test_metric"].__class__, + Metric, + ) self.assertEqual(len(loaded_experiment.trials), 1) - self.assertIs(loaded_experiment.trials[0].runner, None) + trial = loaded_experiment.trials[0] + self.assertIs(trial.runner, None) delete_experiment(exp_name=experiment.name) + # check generator runs + gr = trial.generator_runs[0] + if multi_objective and not immutable: + objectives = checked_cast( + MultiObjective, not_none(gr.optimization_config).objective + ).objectives + for i, objective in enumerate(objectives): + metric = objective.metric + self.assertEqual(metric.name, f"m{1 + 2 * i}") + self.assertEqual(metric.__class__, Metric) @patch( f"{Decoder.__module__}.Decoder.generator_run_from_sqa", diff --git a/ax/utils/testing/core_stubs.py b/ax/utils/testing/core_stubs.py index 06a83417338..b6e69af956a 100644 --- a/ax/utils/testing/core_stubs.py +++ b/ax/utils/testing/core_stubs.py @@ -181,11 +181,16 @@ def get_experiment_with_custom_runner_and_metric( # Create a trial, set its runner and complete it. sobol_generator = get_sobol(search_space=experiment.search_space) - sobol_run = sobol_generator.gen(n=1) + sobol_run = sobol_generator.gen( + n=1, + optimization_config=experiment.optimization_config if not immutable else None, + ) trial = experiment.new_trial(generator_run=sobol_run) trial.runner = experiment.runner trial.mark_running() experiment.attach_data(get_data(metric_name="custom_test_metric")) + experiment.attach_data(get_data(metric_name="m1")) + experiment.attach_data(get_data(metric_name="m3")) trial.mark_completed() if immutable: