Skip to content

Commit

Permalink
Load Experiment without runners and metrics in the case where search …
Browse files Browse the repository at this point in the history
…space and optimization config are immutable (facebook#1656)

Summary:
Pull Request resolved: facebook#1656

`load_experiment` was previously failing when `skip_runners_and_metrics=True` for experiments with immutable search space and optimization config. See [Lena's comment](https://www.internalfb.com/diff/D46595953?dst_version_fbid=263061986396665&transaction_fbid=639853311380999) for more detail.

Added in V4: uses the metric type corresponding to the decoder's entry for Ax base Metric, instead of 0, which is not always the correct entry.

Differential Revision: D46595953

fbshipit-source-id: 1f9e24735c57580ff85a8d500dde027640daad65
  • Loading branch information
Bernie Beckerman authored and facebook-github-bot committed Jun 21, 2023
1 parent 1330f0a commit 49ac816
Show file tree
Hide file tree
Showing 3 changed files with 136 additions and 94 deletions.
22 changes: 16 additions & 6 deletions ax/storage/sqa_store/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from ax.core.experiment import Experiment
from ax.core.generator_run import GeneratorRun
from ax.core.metric import Metric
from ax.core.trial import Trial
from ax.exceptions.core import ExperimentNotFoundError, ObjectNotFoundError
from ax.modelbridge.generation_strategy import GenerationStrategy
Expand All @@ -28,6 +29,7 @@
from ax.utils.common.constants import Keys
from ax.utils.common.typeutils import not_none
from sqlalchemy.orm import defaultload, lazyload, noload
from sqlalchemy.orm.exc import DetachedInstanceError


# ---------------------------- Loading `Experiment`. ---------------------------
Expand Down Expand Up @@ -132,13 +134,21 @@ def _load_experiment(
# "lower_is_better" or any other attribute we want to include. This can be
# implemented in the future if we need to.
if skip_runners_and_metrics:
base_metric_type_int = decoder.config.metric_registry[Metric]
for sqa_metric in experiment_sqa.metrics:
sqa_metric.metric_type = 0

for sqa_trial in experiment_sqa.trials:
for sqa_generator_run in sqa_trial.generator_runs:
for sqa_metric in sqa_generator_run.metrics:
sqa_metric.metric_type = 0
sqa_metric.metric_type = base_metric_type_int
imm_OC_and_SS = _get_experiment_immutable_opt_config_and_search_space(
experiment_name=experiment_name, exp_sqa_class=exp_sqa_class
)
assign_metric_on_gr = not reduced_state and not imm_OC_and_SS
if assign_metric_on_gr:
try:
for sqa_trial in experiment_sqa.trials:
for sqa_generator_run in sqa_trial.generator_runs:
for sqa_metric in sqa_generator_run.metrics:
sqa_metric.metric_type = base_metric_type_int
except DetachedInstanceError as e:
raise e

return decoder.experiment_from_sqa(
experiment_sqa=experiment_sqa,
Expand Down
203 changes: 115 additions & 88 deletions ax/storage/sqa_store/tests/test_sqa_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,30 +237,37 @@ def testLoadExperimentSkipMetricsAndRunners(self) -> None:
runner_registry=runner_registry,
)

# Save the experiment to db using the updated registries.
save_experiment(experiment, config=sqa_config)
for immutable in [True, False]:
experiment = get_experiment_with_custom_runner_and_metric(
constrain_search_space=False,
immutable=immutable,
)

# At this point try to load the experiment back without specifying
# updated registries. Confirm that this attempt fails.
with self.assertRaises(SQADecodeError):
loaded_experiment = load_experiment(self.experiment.name)
# Save the experiment to db using the updated registries.
save_experiment(experiment, config=sqa_config)

# Now load it with the skip_runners_and_metrics argument set.
# The experiment should load (i.e. no exceptions raised)
loaded_experiment = load_experiment(
self.experiment.name, skip_runners_and_metrics=True
)
# At this point try to load the experiment back without specifying
# updated registries. Confirm that this attempt fails.
with self.assertRaises(SQADecodeError):
loaded_experiment = load_experiment(experiment.name)

# Validate that:
# - the runner is not loaded
# - the metric is loaded as a base Metric class (i.e. not CustomTestMetric)
self.assertIs(loaded_experiment.runner, None)
self.assertTrue("custom_test_metric" in loaded_experiment.metrics)
self.assertEqual(
loaded_experiment.metrics["custom_test_metric"].__class__, Metric
)
self.assertEqual(len(loaded_experiment.trials), 1)
self.assertIs(loaded_experiment.trials[0].runner, None)
# Now load it with the skip_runners_and_metrics argument set.
# The experiment should load (i.e. no exceptions raised)
loaded_experiment = load_experiment(
experiment.name, skip_runners_and_metrics=True
)

# Validate that:
# - the runner is not loaded
# - the metric is loaded as a base Metric class, not CustomTestMetric
self.assertIs(loaded_experiment.runner, None)
self.assertTrue("custom_test_metric" in loaded_experiment.metrics)
self.assertEqual(
loaded_experiment.metrics["custom_test_metric"].__class__, Metric
)
self.assertEqual(len(loaded_experiment.trials), 1)
self.assertIs(loaded_experiment.trials[0].runner, None)
delete_experiment(exp_name=experiment.name)

@patch(
f"{Decoder.__module__}.Decoder.generator_run_from_sqa",
Expand All @@ -278,74 +285,94 @@ def testLoadExperimentSkipMetricsAndRunners(self) -> None:
def test_ExperimentSaveAndLoadReducedState(
self, _mock_exp_from_sqa, _mock_trial_from_sqa, _mock_gr_from_sqa
) -> None:
# 1. No abandoned arms + no trials case, reduced state should be the
# same as non-reduced state.
exp = get_experiment_with_multi_objective()
save_experiment(exp)
loaded_experiment = load_experiment(exp.name, reduced_state=True)
self.assertEqual(loaded_experiment, exp)
# Make sure decoder function was called with `reduced_state=True`.
self.assertTrue(_mock_exp_from_sqa.call_args[1].get("reduced_state"))
_mock_exp_from_sqa.reset_mock()

# 2. Try case with abandoned arms.
exp = get_experiment_with_batch_trial(constrain_search_space=False)
save_experiment(exp)
loaded_experiment = load_experiment(exp.name, reduced_state=True)
# Experiments are not the same, because one has abandoned arms info.
self.assertNotEqual(loaded_experiment, exp)
# Remove all abandoned arms and check that all else is equal as expected.
t = checked_cast(BatchTrial, exp.trials[0])
t._abandoned_arms_metadata = {}
self.assertEqual(loaded_experiment, exp)
# Make sure that all relevant decoding functions were called with
# `reduced_state=True` and correct number of times.
self.assertTrue(_mock_exp_from_sqa.call_args[1].get("reduced_state"))
self.assertTrue(_mock_trial_from_sqa.call_args[1].get("reduced_state"))
# 2 generator runs + regular and status quo.
self.assertTrue(_mock_gr_from_sqa.call_args[1].get("reduced_state"))
_mock_exp_from_sqa.reset_mock()
_mock_trial_from_sqa.reset_mock()
_mock_gr_from_sqa.reset_mock()
for skip_runners_and_metrics in [False, True]:
# 1. No abandoned arms + no trials case, reduced state should be the
# same as non-reduced state.
exp = get_experiment_with_multi_objective()
save_experiment(exp)
loaded_experiment = load_experiment(
exp.name,
reduced_state=True,
skip_runners_and_metrics=skip_runners_and_metrics,
)
loaded_experiment.runner = exp.runner
self.assertEqual(loaded_experiment, exp)
# Make sure decoder function was called with `reduced_state=True`.
self.assertTrue(_mock_exp_from_sqa.call_args[1].get("reduced_state"))
_mock_exp_from_sqa.reset_mock()
delete_experiment(exp_name=exp.name)

# 3. Try case with model state and search space + opt.config on a
# generator run in the experiment.
gr = Models.SOBOL(experiment=exp).gen(1)
# Expecting model kwargs to have 6 fields (seed, deduplicate, init_position,
# scramble, generated_points, fallback_to_sample_polytope)
# and the rest of model-state info on generator run to have values too.
mkw = gr._model_kwargs
self.assertIsNotNone(mkw)
self.assertEqual(len(mkw), 6)
bkw = gr._bridge_kwargs
self.assertIsNotNone(bkw)
self.assertEqual(len(bkw), 8)
ms = gr._model_state_after_gen
self.assertIsNotNone(ms)
self.assertEqual(len(ms), 2)
gm = gr._gen_metadata
self.assertIsNotNone(gm)
self.assertEqual(len(gm), 0)
self.assertIsNotNone(gr._search_space, gr.optimization_config)
exp.new_trial(generator_run=gr)
save_experiment(exp)
# Make sure that all relevant decoding functions were called with
# `reduced_state=True` and correct number of times.
loaded_experiment = load_experiment(exp.name, reduced_state=True)
self.assertTrue(_mock_exp_from_sqa.call_args[1].get("reduced_state"))
self.assertTrue(_mock_trial_from_sqa.call_args[1].get("reduced_state"))
# 2 generator runs from trial #0 + 1 from trial #1.
self.assertTrue(_mock_gr_from_sqa.call_args[1].get("reduced_state"))
self.assertNotEqual(loaded_experiment, exp)
# Remove all fields that are not part of the reduced state and
# check that everything else is equal as expected.
exp.trials.get(1).generator_run._model_kwargs = None
exp.trials.get(1).generator_run._bridge_kwargs = None
exp.trials.get(1).generator_run._gen_metadata = None
exp.trials.get(1).generator_run._model_state_after_gen = None
exp.trials.get(1).generator_run._search_space = None
exp.trials.get(1).generator_run._optimization_config = None
self.assertEqual(loaded_experiment, exp)
# 2. Try case with abandoned arms.
exp = get_experiment_with_batch_trial(constrain_search_space=False)
save_experiment(exp)
loaded_experiment = load_experiment(
exp.name,
reduced_state=True,
skip_runners_and_metrics=skip_runners_and_metrics,
)
# Experiments are not the same, because one has abandoned arms info.
self.assertNotEqual(loaded_experiment, exp)
# Remove all abandoned arms and check that all else is equal as expected.
t = checked_cast(BatchTrial, exp.trials[0])
t._abandoned_arms_metadata = {}
loaded_experiment.runner = exp.runner
loaded_experiment._trials[0]._runner = exp._trials[0]._runner
self.assertEqual(loaded_experiment, exp)
# Make sure that all relevant decoding functions were called with
# `reduced_state=True` and correct number of times.
self.assertTrue(_mock_exp_from_sqa.call_args[1].get("reduced_state"))
self.assertTrue(_mock_trial_from_sqa.call_args[1].get("reduced_state"))
# 2 generator runs + regular and status quo.
self.assertTrue(_mock_gr_from_sqa.call_args[1].get("reduced_state"))
_mock_exp_from_sqa.reset_mock()
_mock_trial_from_sqa.reset_mock()
_mock_gr_from_sqa.reset_mock()

# 3. Try case with model state and search space + opt.config on a
# generator run in the experiment.
gr = Models.SOBOL(experiment=exp).gen(1)
# Expecting model kwargs to have 6 fields (seed, deduplicate, init_position,
# scramble, generated_points, fallback_to_sample_polytope)
# and the rest of model-state info on generator run to have values too.
mkw = gr._model_kwargs
self.assertIsNotNone(mkw)
self.assertEqual(len(mkw), 6)
bkw = gr._bridge_kwargs
self.assertIsNotNone(bkw)
self.assertEqual(len(bkw), 8)
ms = gr._model_state_after_gen
self.assertIsNotNone(ms)
self.assertEqual(len(ms), 2)
gm = gr._gen_metadata
self.assertIsNotNone(gm)
self.assertEqual(len(gm), 0)
self.assertIsNotNone(gr._search_space, gr.optimization_config)
exp.new_trial(generator_run=gr)
save_experiment(exp)
# Make sure that all relevant decoding functions were called with
# `reduced_state=True` and correct number of times.
loaded_experiment = load_experiment(
exp.name,
reduced_state=True,
skip_runners_and_metrics=skip_runners_and_metrics,
)
loaded_experiment.runner = exp.runner
loaded_experiment._trials[0]._runner = exp._trials[0]._runner
self.assertTrue(_mock_exp_from_sqa.call_args[1].get("reduced_state"))
self.assertTrue(_mock_trial_from_sqa.call_args[1].get("reduced_state"))
# 2 generator runs from trial #0 + 1 from trial #1.
self.assertTrue(_mock_gr_from_sqa.call_args[1].get("reduced_state"))
self.assertNotEqual(loaded_experiment, exp)
# Remove all fields that are not part of the reduced state and
# check that everything else is equal as expected.
exp.trials.get(1).generator_run._model_kwargs = None
exp.trials.get(1).generator_run._bridge_kwargs = None
exp.trials.get(1).generator_run._gen_metadata = None
exp.trials.get(1).generator_run._model_state_after_gen = None
exp.trials.get(1).generator_run._search_space = None
exp.trials.get(1).generator_run._optimization_config = None
self.assertEqual(loaded_experiment, exp)
delete_experiment(exp_name=exp.name)

def testMTExperimentSaveAndLoad(self) -> None:
experiment = get_multi_type_experiment(add_trials=True)
Expand Down
5 changes: 5 additions & 0 deletions ax/utils/testing/core_stubs.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@
from ax.models.winsorization_config import WinsorizationConfig
from ax.runners.synthetic import SyntheticRunner
from ax.service.utils.scheduler_options import SchedulerOptions, TrialType
from ax.utils.common.constants import Keys
from ax.utils.common.logger import get_logger
from ax.utils.common.typeutils import checked_cast, not_none
from ax.utils.measurement.synthetic_functions import branin
Expand Down Expand Up @@ -146,6 +147,7 @@ def get_experiment_with_map_data_type() -> Experiment:

def get_experiment_with_custom_runner_and_metric(
constrain_search_space: bool = True,
immutable: bool = False,
) -> Experiment:

# Create experiment with custom runner and metric
Expand All @@ -172,6 +174,9 @@ def get_experiment_with_custom_runner_and_metric(
experiment.attach_data(get_data(metric_name="custom_test_metric"))
trial.mark_completed()

if immutable:
experiment._properties = {Keys.IMMUTABLE_SEARCH_SPACE_AND_OPT_CONF: True}

return experiment


Expand Down

0 comments on commit 49ac816

Please sign in to comment.