Skip to content

Commit

Permalink
Merge a6d9d59 into aada4cf
Browse files Browse the repository at this point in the history
  • Loading branch information
bernardbeckerman authored Nov 27, 2023
2 parents aada4cf + a6d9d59 commit 7f24c14
Show file tree
Hide file tree
Showing 3 changed files with 184 additions and 112 deletions.
30 changes: 20 additions & 10 deletions ax/storage/sqa_store/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from ax.core.experiment import Experiment
from ax.core.generator_run import GeneratorRun
from ax.core.metric import Metric
from ax.core.trial import Trial
from ax.exceptions.core import ExperimentNotFoundError, ObjectNotFoundError
from ax.modelbridge.generation_strategy import GenerationStrategy
Expand All @@ -28,6 +29,7 @@
from ax.utils.common.constants import Keys
from ax.utils.common.typeutils import not_none
from sqlalchemy.orm import defaultload, lazyload, noload
from sqlalchemy.orm.exc import DetachedInstanceError


# ---------------------------- Loading `Experiment`. ---------------------------
Expand Down Expand Up @@ -99,15 +101,14 @@ def _load_experiment(
# pyre-ignore Incompatible variable type [9]: trial_sqa_class is decl. to have type
# `Type[SQATrial]` but is used as type `Type[ax.storage.sqa_store.db.SQABase]`
trial_sqa_class: Type[SQATrial] = decoder.config.class_to_sqa_class[Trial]
imm_OC_and_SS = _get_experiment_immutable_opt_config_and_search_space(
experiment_name=experiment_name, exp_sqa_class=exp_sqa_class
)

if reduced_state:
_get_experiment_sqa_func = _get_experiment_sqa_reduced_state

else:
imm_OC_and_SS = _get_experiment_immutable_opt_config_and_search_space(
experiment_name=experiment_name, exp_sqa_class=exp_sqa_class
)

_get_experiment_sqa_func = (
_get_experiment_sqa_immutable_opt_config_and_search_space
if imm_OC_and_SS
Expand All @@ -132,13 +133,22 @@ def _load_experiment(
# "lower_is_better" or any other attribute we want to include. This can be
# implemented in the future if we need to.
if skip_runners_and_metrics:
base_metric_type_int = decoder.config.metric_registry[Metric]
for sqa_metric in experiment_sqa.metrics:
sqa_metric.metric_type = 0

for sqa_trial in experiment_sqa.trials:
for sqa_generator_run in sqa_trial.generator_runs:
for sqa_metric in sqa_generator_run.metrics:
sqa_metric.metric_type = 0
sqa_metric.metric_type = base_metric_type_int
assign_metric_on_gr = not reduced_state and not imm_OC_and_SS
if assign_metric_on_gr:
try:
for sqa_trial in experiment_sqa.trials:
for sqa_generator_run in sqa_trial.generator_runs:
for sqa_metric in sqa_generator_run.metrics:
sqa_metric.metric_type = base_metric_type_int
except DetachedInstanceError as e:
raise DetachedInstanceError(
"Unable to retrieve metric from SQA generator run, possibly due "
"to parts of the experiment being lazy-loaded. This is not "
f"expected state, please contact Ax support. Original error: {e}"
)

return decoder.experiment_from_sqa(
experiment_sqa=experiment_sqa,
Expand Down
222 changes: 133 additions & 89 deletions ax/storage/sqa_store/tests/test_sqa_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from unittest.mock import MagicMock, Mock, patch

from ax.core.arm import Arm
from ax.core.batch_trial import LifecycleStage
from ax.core.batch_trial import BatchTrial, LifecycleStage
from ax.core.generator_run import GeneratorRun
from ax.core.metric import Metric
from ax.core.objective import Objective
Expand Down Expand Up @@ -76,7 +76,7 @@
from ax.utils.common.logger import get_logger
from ax.utils.common.serialization import serialize_init_args
from ax.utils.common.testutils import TestCase
from ax.utils.common.typeutils import not_none
from ax.utils.common.typeutils import checked_cast, not_none
from ax.utils.testing.core_stubs import (
CustomTestMetric,
CustomTestRunner,
Expand Down Expand Up @@ -210,7 +210,9 @@ def test_LoadExperimentTrialsInBatches(self) -> None:
# load experiments with custom runners and metrics without a decoder.
def test_LoadExperimentSkipMetricsAndRunners(self) -> None:
# Create a test experiment with a custom metric and runner.
experiment = get_experiment_with_custom_runner_and_metric()
experiment = get_experiment_with_custom_runner_and_metric(
constrain_search_space=False
)

# Note that the experiment is created outside of the test code.
# Confirm that it uses the custom runner and metric
Expand Down Expand Up @@ -241,30 +243,37 @@ def test_LoadExperimentSkipMetricsAndRunners(self) -> None:
runner_registry=runner_registry,
)

# Save the experiment to db using the updated registries.
save_experiment(experiment, config=sqa_config)
for immutable in [True, False]:
experiment = get_experiment_with_custom_runner_and_metric(
constrain_search_space=False,
immutable=immutable,
)

# At this point try to load the experiment back without specifying
# updated registries. Confirm that this attempt fails.
with self.assertRaises(SQADecodeError):
loaded_experiment = load_experiment(self.experiment.name)
# Save the experiment to db using the updated registries.
save_experiment(experiment, config=sqa_config)

# Now load it with the skip_runners_and_metrics argument set.
# The experiment should load (i.e. no exceptions raised)
loaded_experiment = load_experiment(
self.experiment.name, skip_runners_and_metrics=True
)
# At this point try to load the experiment back without specifying
# updated registries. Confirm that this attempt fails.
with self.assertRaises(SQADecodeError):
loaded_experiment = load_experiment(experiment.name)

# Validate that:
# - the runner is not loaded
# - the metric is loaded as a base Metric class (i.e. not CustomTestMetric)
self.assertIs(loaded_experiment.runner, None)
self.assertTrue("custom_test_metric" in loaded_experiment.metrics)
self.assertEqual(
loaded_experiment.metrics["custom_test_metric"].__class__, Metric
)
self.assertEqual(len(loaded_experiment.trials), 1)
self.assertIs(loaded_experiment.trials[0].runner, None)
# Now load it with the skip_runners_and_metrics argument set.
# The experiment should load (i.e. no exceptions raised)
loaded_experiment = load_experiment(
experiment.name, skip_runners_and_metrics=True
)

# Validate that:
# - the runner is not loaded
# - the metric is loaded as a base Metric class, not CustomTestMetric
self.assertIs(loaded_experiment.runner, None)
self.assertTrue("custom_test_metric" in loaded_experiment.metrics)
self.assertEqual(
loaded_experiment.metrics["custom_test_metric"].__class__, Metric
)
self.assertEqual(len(loaded_experiment.trials), 1)
self.assertIs(loaded_experiment.trials[0].runner, None)
delete_experiment(exp_name=experiment.name)

@patch(
f"{Decoder.__module__}.Decoder.generator_run_from_sqa",
Expand All @@ -282,73 +291,108 @@ def test_LoadExperimentSkipMetricsAndRunners(self) -> None:
def test_ExperimentSaveAndLoadReducedState(
self, _mock_exp_from_sqa, _mock_trial_from_sqa, _mock_gr_from_sqa
) -> None:
# 1. No abandoned arms + no trials case, reduced state should be the
# same as non-reduced state.
exp = get_experiment_with_multi_objective()
save_experiment(exp)
loaded_experiment = load_experiment(exp.name, reduced_state=True)
self.assertEqual(loaded_experiment, exp)
# Make sure decoder function was called with `reduced_state=True`.
self.assertTrue(_mock_exp_from_sqa.call_args[1].get("reduced_state"))
_mock_exp_from_sqa.reset_mock()

# 2. Try case with abandoned arms.
exp = self.experiment
save_experiment(exp)
loaded_experiment = load_experiment(exp.name, reduced_state=True)
# Experiments are not the same, because one has abandoned arms info.
self.assertNotEqual(loaded_experiment, exp)
# Remove all abandoned arms and check that all else is equal as expected.
exp.trials.get(0)._abandoned_arms_metadata = {}
self.assertEqual(loaded_experiment, exp)
# Make sure that all relevant decoding functions were called with
# `reduced_state=True` and correct number of times.
self.assertTrue(_mock_exp_from_sqa.call_args[1].get("reduced_state"))
self.assertTrue(_mock_trial_from_sqa.call_args[1].get("reduced_state"))
# 2 generator runs + regular and status quo.
self.assertTrue(_mock_gr_from_sqa.call_args[1].get("reduced_state"))
_mock_exp_from_sqa.reset_mock()
_mock_trial_from_sqa.reset_mock()
_mock_gr_from_sqa.reset_mock()
for skip_runners_and_metrics in [False, True]:
# 1. No abandoned arms + no trials case, reduced state should be the
# same as non-reduced state.
exp = get_experiment_with_multi_objective()
save_experiment(exp)
loaded_experiment = load_experiment(
exp.name,
reduced_state=True,
skip_runners_and_metrics=skip_runners_and_metrics,
)
loaded_experiment.runner = exp.runner
self.assertEqual(loaded_experiment, exp)
# Make sure decoder function was called with `reduced_state=True`.
self.assertTrue(_mock_exp_from_sqa.call_args[1].get("reduced_state"))
_mock_exp_from_sqa.reset_mock()
delete_experiment(exp_name=exp.name)

# 3. Try case with model state and search space + opt.config on a
# generator run in the experiment.
gr = Models.SOBOL(experiment=exp).gen(1)
# Expecting model kwargs to have 6 fields (seed, deduplicate, init_position,
# scramble, generated_points, fallback_to_sample_polytope)
# and the rest of model-state info on generator run to have values too.
mkw = gr._model_kwargs
self.assertIsNotNone(mkw)
self.assertEqual(len(mkw), 6)
bkw = gr._bridge_kwargs
self.assertIsNotNone(bkw)
self.assertEqual(len(bkw), 9)
ms = gr._model_state_after_gen
self.assertIsNotNone(ms)
self.assertEqual(len(ms), 2)
gm = gr._gen_metadata
self.assertIsNotNone(gm)
self.assertEqual(len(gm), 0)
self.assertIsNotNone(gr._search_space, gr.optimization_config)
# 2. Try case with abandoned arms.
exp = get_experiment_with_batch_trial(constrain_search_space=False)
save_experiment(exp)
loaded_experiment = load_experiment(
exp.name,
reduced_state=True,
skip_runners_and_metrics=skip_runners_and_metrics,
)
# Experiments are not the same, because one has abandoned arms info.
self.assertNotEqual(loaded_experiment, exp)
# Remove all abandoned arms and check that all else is equal as expected.
t = checked_cast(BatchTrial, exp.trials[0])
t._abandoned_arms_metadata = {}
loaded_experiment.runner = exp.runner
loaded_experiment._trials[0]._runner = exp._trials[0]._runner
self.assertEqual(loaded_experiment, exp)
# Make sure that all relevant decoding functions were called with
# `reduced_state=True` and correct number of times.
self.assertTrue(_mock_exp_from_sqa.call_args[1].get("reduced_state"))
self.assertTrue(_mock_trial_from_sqa.call_args[1].get("reduced_state"))
# 2 generator runs + regular and status quo.
self.assertTrue(_mock_gr_from_sqa.call_args[1].get("reduced_state"))
_mock_exp_from_sqa.reset_mock()
_mock_trial_from_sqa.reset_mock()
_mock_gr_from_sqa.reset_mock()

# 3. Try case with model state and search space + opt.config on a
# generator run in the experiment.
gr = Models.SOBOL(experiment=exp).gen(1)
# Expecting model kwargs to have 6 fields (seed, deduplicate, init_position,
# scramble, generated_points, fallback_to_sample_polytope)
# and the rest of model-state info on generator run to have values too.
mkw = gr._model_kwargs
self.assertIsNotNone(mkw)
self.assertEqual(len(mkw), 6)
bkw = gr._bridge_kwargs
self.assertIsNotNone(bkw)
self.assertEqual(len(bkw), 9)
ms = gr._model_state_after_gen
self.assertIsNotNone(ms)
self.assertEqual(len(ms), 2)
gm = gr._gen_metadata
self.assertIsNotNone(gm)
self.assertEqual(len(gm), 0)
self.assertIsNotNone(gr._search_space, gr.optimization_config)
exp.new_trial(generator_run=gr)
save_experiment(exp)
# Make sure that all relevant decoding functions were called with
# `reduced_state=True` and correct number of times.
loaded_experiment = load_experiment(
exp.name,
reduced_state=True,
skip_runners_and_metrics=skip_runners_and_metrics,
)
loaded_experiment.runner = exp.runner
loaded_experiment._trials[0]._runner = exp._trials[0]._runner
self.assertTrue(_mock_exp_from_sqa.call_args[1].get("reduced_state"))
self.assertTrue(_mock_trial_from_sqa.call_args[1].get("reduced_state"))
# 2 generator runs from trial #0 + 1 from trial #1.
self.assertTrue(_mock_gr_from_sqa.call_args[1].get("reduced_state"))
self.assertNotEqual(loaded_experiment, exp)
# Remove all fields that are not part of the reduced state and
# check that everything else is equal as expected.
exp.trials.get(1).generator_run._model_kwargs = None
exp.trials.get(1).generator_run._bridge_kwargs = None
exp.trials.get(1).generator_run._gen_metadata = None
exp.trials.get(1).generator_run._model_state_after_gen = None
exp.trials.get(1).generator_run._search_space = None
exp.trials.get(1).generator_run._optimization_config = None
self.assertEqual(loaded_experiment, exp)
delete_experiment(exp_name=exp.name)

def test_ExperimentSaveAndLoadGRWithOptConfig(self) -> None:
exp = get_experiment_with_batch_trial(constrain_search_space=False)
gr = Models.SOBOL(experiment=exp).gen(
n=1, optimization_config=exp.optimization_config
)
exp.new_trial(generator_run=gr)
save_experiment(exp)
# Make sure that all relevant decoding functions were called with
# `reduced_state=True` and correct number of times.
loaded_experiment = load_experiment(exp.name, reduced_state=True)
self.assertTrue(_mock_exp_from_sqa.call_args[1].get("reduced_state"))
self.assertTrue(_mock_trial_from_sqa.call_args[1].get("reduced_state"))
# 2 generator runs from trial #0 + 1 from trial #1.
self.assertTrue(_mock_gr_from_sqa.call_args[1].get("reduced_state"))
self.assertNotEqual(loaded_experiment, exp)
# Remove all fields that are not part of the reduced state and
# check that everything else is equal as expected.
exp.trials.get(1).generator_run._model_kwargs = None
exp.trials.get(1).generator_run._bridge_kwargs = None
exp.trials.get(1).generator_run._gen_metadata = None
exp.trials.get(1).generator_run._model_state_after_gen = None
exp.trials.get(1).generator_run._search_space = None
exp.trials.get(1).generator_run._optimization_config = None
self.assertEqual(loaded_experiment, exp)
loaded_experiment = load_experiment(
exp.name,
reduced_state=False,
skip_runners_and_metrics=True,
)
self.assertEqual(loaded_experiment.trials[1], exp.trials[1])

def test_MTExperimentSaveAndLoad(self) -> None:
experiment = get_multi_type_experiment(add_trials=True)
Expand Down Expand Up @@ -1615,7 +1659,7 @@ def test_ImmutableSearchSpaceAndOptConfigLoading(
_mock_get_gs_sqa_imm_oc_ss,
_mock_gr_from_sqa,
) -> None:
experiment = get_experiment_with_batch_trial()
experiment = get_experiment_with_batch_trial(constrain_search_space=False)
experiment._properties = {Keys.IMMUTABLE_SEARCH_SPACE_AND_OPT_CONF: True}
save_experiment(experiment)

Expand Down
Loading

0 comments on commit 7f24c14

Please sign in to comment.