Skip to content

Commit

Permalink
fix skip_runners_and_metrics for metrics on generator runs with mutab…
Browse files Browse the repository at this point in the history
…le multi-objective optimization config

Summary: This fixes load_experiment with `skip_runners_and_metricsmetrics` by setting properties SQAMetrics on  generator runs that have multi-objective intent, for experiments with mutable optimization configs

Reviewed By: Balandat

Differential Revision: D55663714
  • Loading branch information
sdaulton authored and facebook-github-bot committed Apr 2, 2024
1 parent 0014c75 commit 3ac429e
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 16 deletions.
1 change: 0 additions & 1 deletion ax/storage/sqa_store/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -635,7 +635,6 @@ def generator_run_from_sqa(
for arm_sqa in generator_run_sqa.arms:
arms.append(self.arm_from_sqa(arm_sqa=arm_sqa))
weights.append(arm_sqa.weight)

if not reduced_state and not immutable_search_space_and_opt_config:
(
opt_config,
Expand Down
30 changes: 22 additions & 8 deletions ax/storage/sqa_store/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
SQAExperiment,
SQAGenerationStrategy,
SQAGeneratorRun,
SQAMetric,
SQATrial,
)
from ax.storage.sqa_store.sqa_config import SQAConfig
Expand Down Expand Up @@ -139,21 +140,19 @@ def _load_experiment(
if skip_runners_and_metrics:
base_metric_type_int = decoder.config.metric_registry[Metric]
for sqa_metric in experiment_sqa.metrics:
sqa_metric.metric_type = base_metric_type_int
# Handle multi-objective metrics that are not directly attached to
# the experiment
if sqa_metric.intent == MetricIntent.MULTI_OBJECTIVE:
if sqa_metric.properties is None:
sqa_metric.properties = {}
sqa_metric.properties["skip_runners_and_metrics"] = True
_set_sqa_metric_to_base_type(
sqa_metric, base_metric_type_int=base_metric_type_int
)

assign_metric_on_gr = not reduced_state and not imm_OC_and_SS
if assign_metric_on_gr:
try:
for sqa_trial in experiment_sqa.trials:
for sqa_generator_run in sqa_trial.generator_runs:
for sqa_metric in sqa_generator_run.metrics:
sqa_metric.metric_type = base_metric_type_int
_set_sqa_metric_to_base_type(
sqa_metric, base_metric_type_int=base_metric_type_int
)
except DetachedInstanceError as e:
raise DetachedInstanceError(
"Unable to retrieve metric from SQA generator run, possibly due "
Expand Down Expand Up @@ -250,6 +249,21 @@ def _get_trials_sqa(
return sqa_trials


def _set_sqa_metric_to_base_type(
sqa_metric: SQAMetric, base_metric_type_int: int
) -> None:
"""Sets metric type to base type, since we don't want to load
the metric class from the DB.
"""
sqa_metric.metric_type = base_metric_type_int
# Handle multi-objective metrics that are not directly attached to
# the experiment
if sqa_metric.intent == MetricIntent.MULTI_OBJECTIVE:
if sqa_metric.properties is None:
sqa_metric.properties = {}
sqa_metric.properties["skip_runners_and_metrics"] = True


def _get_experiment_sqa_reduced_state(
experiment_name: str,
exp_sqa_class: Type[SQAExperiment],
Expand Down
33 changes: 27 additions & 6 deletions ax/storage/sqa_store/tests/test_sqa_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from ax.core.batch_trial import BatchTrial, LifecycleStage
from ax.core.generator_run import GeneratorRun
from ax.core.metric import Metric
from ax.core.objective import Objective
from ax.core.objective import MultiObjective, Objective
from ax.core.outcome_constraint import OutcomeConstraint
from ax.core.parameter import ParameterType, RangeParameter
from ax.core.runner import Runner
Expand Down Expand Up @@ -248,11 +248,18 @@ def test_LoadExperimentSkipMetricsAndRunners(self) -> None:

for immutable in [True, False]:
for multi_objective in [True, False]:
custom_metric_names = ["custom_test_metric"]
experiment = get_experiment_with_custom_runner_and_metric(
constrain_search_space=False,
immutable=immutable,
multi_objective=multi_objective,
)
if multi_objective:
custom_metric_names.extend(["m1", "m3"])
for metric_name in custom_metric_names:
self.assertEqual(
experiment.metrics[metric_name].__class__, CustomTestMetric
)

# Save the experiment to db using the updated registries.
save_experiment(experiment, config=sqa_config)
Expand All @@ -272,13 +279,27 @@ def test_LoadExperimentSkipMetricsAndRunners(self) -> None:
# - the runner is not loaded
# - the metric is loaded as a base Metric class, not CustomTestMetric
self.assertIs(loaded_experiment.runner, None)
self.assertTrue("custom_test_metric" in loaded_experiment.metrics)
self.assertEqual(
loaded_experiment.metrics["custom_test_metric"].__class__, Metric
)

for metric_name in custom_metric_names:
self.assertTrue(metric_name in loaded_experiment.metrics)
self.assertEqual(
loaded_experiment.metrics["custom_test_metric"].__class__,
Metric,
)
self.assertEqual(len(loaded_experiment.trials), 1)
self.assertIs(loaded_experiment.trials[0].runner, None)
trial = loaded_experiment.trials[0]
self.assertIs(trial.runner, None)
delete_experiment(exp_name=experiment.name)
# check generator runs
gr = trial.generator_runs[0]
if multi_objective and not immutable:
objectives = checked_cast(
MultiObjective, not_none(gr.optimization_config).objective
).objectives
for i, objective in enumerate(objectives):
metric = objective.metric
self.assertEqual(metric.name, f"m{1 + 2 * i}")
self.assertEqual(metric.__class__, Metric)

@patch(
f"{Decoder.__module__}.Decoder.generator_run_from_sqa",
Expand Down
7 changes: 6 additions & 1 deletion ax/utils/testing/core_stubs.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,11 +181,16 @@ def get_experiment_with_custom_runner_and_metric(

# Create a trial, set its runner and complete it.
sobol_generator = get_sobol(search_space=experiment.search_space)
sobol_run = sobol_generator.gen(n=1)
sobol_run = sobol_generator.gen(
n=1,
optimization_config=experiment.optimization_config if not immutable else None,
)
trial = experiment.new_trial(generator_run=sobol_run)
trial.runner = experiment.runner
trial.mark_running()
experiment.attach_data(get_data(metric_name="custom_test_metric"))
experiment.attach_data(get_data(metric_name="m1"))
experiment.attach_data(get_data(metric_name="m3"))
trial.mark_completed()

if immutable:
Expand Down

0 comments on commit 3ac429e

Please sign in to comment.