Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Load Experiment without runners and metrics in the case where search space and optimization config are immutable #1656

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 20 additions & 10 deletions ax/storage/sqa_store/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from ax.core.experiment import Experiment
from ax.core.generator_run import GeneratorRun
from ax.core.metric import Metric
from ax.core.trial import Trial
from ax.exceptions.core import ExperimentNotFoundError, ObjectNotFoundError
from ax.modelbridge.generation_strategy import GenerationStrategy
Expand All @@ -28,6 +29,7 @@
from ax.utils.common.constants import Keys
from ax.utils.common.typeutils import not_none
from sqlalchemy.orm import defaultload, lazyload, noload
from sqlalchemy.orm.exc import DetachedInstanceError


# ---------------------------- Loading `Experiment`. ---------------------------
Expand Down Expand Up @@ -99,15 +101,14 @@ def _load_experiment(
# pyre-ignore Incompatible variable type [9]: trial_sqa_class is decl. to have type
# `Type[SQATrial]` but is used as type `Type[ax.storage.sqa_store.db.SQABase]`
trial_sqa_class: Type[SQATrial] = decoder.config.class_to_sqa_class[Trial]
imm_OC_and_SS = _get_experiment_immutable_opt_config_and_search_space(
experiment_name=experiment_name, exp_sqa_class=exp_sqa_class
)

if reduced_state:
_get_experiment_sqa_func = _get_experiment_sqa_reduced_state

else:
imm_OC_and_SS = _get_experiment_immutable_opt_config_and_search_space(
experiment_name=experiment_name, exp_sqa_class=exp_sqa_class
)

_get_experiment_sqa_func = (
_get_experiment_sqa_immutable_opt_config_and_search_space
if imm_OC_and_SS
Expand All @@ -132,13 +133,22 @@ def _load_experiment(
# "lower_is_better" or any other attribute we want to include. This can be
# implemented in the future if we need to.
if skip_runners_and_metrics:
base_metric_type_int = decoder.config.metric_registry[Metric]
for sqa_metric in experiment_sqa.metrics:
sqa_metric.metric_type = 0

for sqa_trial in experiment_sqa.trials:
for sqa_generator_run in sqa_trial.generator_runs:
for sqa_metric in sqa_generator_run.metrics:
sqa_metric.metric_type = 0
sqa_metric.metric_type = base_metric_type_int
assign_metric_on_gr = not reduced_state and not imm_OC_and_SS
if assign_metric_on_gr:
try:
for sqa_trial in experiment_sqa.trials:
for sqa_generator_run in sqa_trial.generator_runs:
for sqa_metric in sqa_generator_run.metrics:
sqa_metric.metric_type = base_metric_type_int
except DetachedInstanceError as e:
raise DetachedInstanceError(
"Unable to retrieve metric from SQA generator run, possibly due "
"to parts of the experiment being lazy-loaded. This is not "
f"expected state, please contact Ax support. Original error: {e}"
)

return decoder.experiment_from_sqa(
experiment_sqa=experiment_sqa,
Expand Down
222 changes: 133 additions & 89 deletions ax/storage/sqa_store/tests/test_sqa_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from unittest.mock import MagicMock, Mock, patch

from ax.core.arm import Arm
from ax.core.batch_trial import LifecycleStage
from ax.core.batch_trial import BatchTrial, LifecycleStage
from ax.core.generator_run import GeneratorRun
from ax.core.metric import Metric
from ax.core.objective import Objective
Expand Down Expand Up @@ -76,7 +76,7 @@
from ax.utils.common.logger import get_logger
from ax.utils.common.serialization import serialize_init_args
from ax.utils.common.testutils import TestCase
from ax.utils.common.typeutils import not_none
from ax.utils.common.typeutils import checked_cast, not_none
from ax.utils.testing.core_stubs import (
CustomTestMetric,
CustomTestRunner,
Expand Down Expand Up @@ -210,7 +210,9 @@ def test_LoadExperimentTrialsInBatches(self) -> None:
# load experiments with custom runners and metrics without a decoder.
def test_LoadExperimentSkipMetricsAndRunners(self) -> None:
# Create a test experiment with a custom metric and runner.
experiment = get_experiment_with_custom_runner_and_metric()
experiment = get_experiment_with_custom_runner_and_metric(
constrain_search_space=False
)

# Note that the experiment is created outside of the test code.
# Confirm that it uses the custom runner and metric
Expand Down Expand Up @@ -241,30 +243,37 @@ def test_LoadExperimentSkipMetricsAndRunners(self) -> None:
runner_registry=runner_registry,
)

# Save the experiment to db using the updated registries.
save_experiment(experiment, config=sqa_config)
for immutable in [True, False]:
experiment = get_experiment_with_custom_runner_and_metric(
constrain_search_space=False,
immutable=immutable,
)

# At this point try to load the experiment back without specifying
# updated registries. Confirm that this attempt fails.
with self.assertRaises(SQADecodeError):
loaded_experiment = load_experiment(self.experiment.name)
# Save the experiment to db using the updated registries.
save_experiment(experiment, config=sqa_config)

# Now load it with the skip_runners_and_metrics argument set.
# The experiment should load (i.e. no exceptions raised)
loaded_experiment = load_experiment(
self.experiment.name, skip_runners_and_metrics=True
)
# At this point try to load the experiment back without specifying
# updated registries. Confirm that this attempt fails.
with self.assertRaises(SQADecodeError):
loaded_experiment = load_experiment(experiment.name)

# Validate that:
# - the runner is not loaded
# - the metric is loaded as a base Metric class (i.e. not CustomTestMetric)
self.assertIs(loaded_experiment.runner, None)
self.assertTrue("custom_test_metric" in loaded_experiment.metrics)
self.assertEqual(
loaded_experiment.metrics["custom_test_metric"].__class__, Metric
)
self.assertEqual(len(loaded_experiment.trials), 1)
self.assertIs(loaded_experiment.trials[0].runner, None)
# Now load it with the skip_runners_and_metrics argument set.
# The experiment should load (i.e. no exceptions raised)
loaded_experiment = load_experiment(
experiment.name, skip_runners_and_metrics=True
)

# Validate that:
# - the runner is not loaded
# - the metric is loaded as a base Metric class, not CustomTestMetric
self.assertIs(loaded_experiment.runner, None)
self.assertTrue("custom_test_metric" in loaded_experiment.metrics)
self.assertEqual(
loaded_experiment.metrics["custom_test_metric"].__class__, Metric
)
self.assertEqual(len(loaded_experiment.trials), 1)
self.assertIs(loaded_experiment.trials[0].runner, None)
delete_experiment(exp_name=experiment.name)

@patch(
f"{Decoder.__module__}.Decoder.generator_run_from_sqa",
Expand All @@ -282,73 +291,108 @@ def test_LoadExperimentSkipMetricsAndRunners(self) -> None:
def test_ExperimentSaveAndLoadReducedState(
self, _mock_exp_from_sqa, _mock_trial_from_sqa, _mock_gr_from_sqa
) -> None:
# 1. No abandoned arms + no trials case, reduced state should be the
# same as non-reduced state.
exp = get_experiment_with_multi_objective()
save_experiment(exp)
loaded_experiment = load_experiment(exp.name, reduced_state=True)
self.assertEqual(loaded_experiment, exp)
# Make sure decoder function was called with `reduced_state=True`.
self.assertTrue(_mock_exp_from_sqa.call_args[1].get("reduced_state"))
_mock_exp_from_sqa.reset_mock()

# 2. Try case with abandoned arms.
exp = self.experiment
save_experiment(exp)
loaded_experiment = load_experiment(exp.name, reduced_state=True)
# Experiments are not the same, because one has abandoned arms info.
self.assertNotEqual(loaded_experiment, exp)
# Remove all abandoned arms and check that all else is equal as expected.
exp.trials.get(0)._abandoned_arms_metadata = {}
self.assertEqual(loaded_experiment, exp)
# Make sure that all relevant decoding functions were called with
# `reduced_state=True` and correct number of times.
self.assertTrue(_mock_exp_from_sqa.call_args[1].get("reduced_state"))
self.assertTrue(_mock_trial_from_sqa.call_args[1].get("reduced_state"))
# 2 generator runs + regular and status quo.
self.assertTrue(_mock_gr_from_sqa.call_args[1].get("reduced_state"))
_mock_exp_from_sqa.reset_mock()
_mock_trial_from_sqa.reset_mock()
_mock_gr_from_sqa.reset_mock()
for skip_runners_and_metrics in [False, True]:
# 1. No abandoned arms + no trials case, reduced state should be the
# same as non-reduced state.
exp = get_experiment_with_multi_objective()
save_experiment(exp)
loaded_experiment = load_experiment(
exp.name,
reduced_state=True,
skip_runners_and_metrics=skip_runners_and_metrics,
)
loaded_experiment.runner = exp.runner
self.assertEqual(loaded_experiment, exp)
# Make sure decoder function was called with `reduced_state=True`.
self.assertTrue(_mock_exp_from_sqa.call_args[1].get("reduced_state"))
_mock_exp_from_sqa.reset_mock()
delete_experiment(exp_name=exp.name)

# 3. Try case with model state and search space + opt.config on a
# generator run in the experiment.
gr = Models.SOBOL(experiment=exp).gen(1)
# Expecting model kwargs to have 6 fields (seed, deduplicate, init_position,
# scramble, generated_points, fallback_to_sample_polytope)
# and the rest of model-state info on generator run to have values too.
mkw = gr._model_kwargs
self.assertIsNotNone(mkw)
self.assertEqual(len(mkw), 6)
bkw = gr._bridge_kwargs
self.assertIsNotNone(bkw)
self.assertEqual(len(bkw), 9)
ms = gr._model_state_after_gen
self.assertIsNotNone(ms)
self.assertEqual(len(ms), 2)
gm = gr._gen_metadata
self.assertIsNotNone(gm)
self.assertEqual(len(gm), 0)
self.assertIsNotNone(gr._search_space, gr.optimization_config)
# 2. Try case with abandoned arms.
exp = get_experiment_with_batch_trial(constrain_search_space=False)
save_experiment(exp)
loaded_experiment = load_experiment(
exp.name,
reduced_state=True,
skip_runners_and_metrics=skip_runners_and_metrics,
)
# Experiments are not the same, because one has abandoned arms info.
self.assertNotEqual(loaded_experiment, exp)
# Remove all abandoned arms and check that all else is equal as expected.
t = checked_cast(BatchTrial, exp.trials[0])
t._abandoned_arms_metadata = {}
loaded_experiment.runner = exp.runner
loaded_experiment._trials[0]._runner = exp._trials[0]._runner
self.assertEqual(loaded_experiment, exp)
# Make sure that all relevant decoding functions were called with
# `reduced_state=True` and correct number of times.
self.assertTrue(_mock_exp_from_sqa.call_args[1].get("reduced_state"))
self.assertTrue(_mock_trial_from_sqa.call_args[1].get("reduced_state"))
# 2 generator runs + regular and status quo.
self.assertTrue(_mock_gr_from_sqa.call_args[1].get("reduced_state"))
_mock_exp_from_sqa.reset_mock()
_mock_trial_from_sqa.reset_mock()
_mock_gr_from_sqa.reset_mock()

# 3. Try case with model state and search space + opt.config on a
# generator run in the experiment.
gr = Models.SOBOL(experiment=exp).gen(1)
# Expecting model kwargs to have 6 fields (seed, deduplicate, init_position,
# scramble, generated_points, fallback_to_sample_polytope)
# and the rest of model-state info on generator run to have values too.
mkw = gr._model_kwargs
self.assertIsNotNone(mkw)
self.assertEqual(len(mkw), 6)
bkw = gr._bridge_kwargs
self.assertIsNotNone(bkw)
self.assertEqual(len(bkw), 9)
ms = gr._model_state_after_gen
self.assertIsNotNone(ms)
self.assertEqual(len(ms), 2)
gm = gr._gen_metadata
self.assertIsNotNone(gm)
self.assertEqual(len(gm), 0)
self.assertIsNotNone(gr._search_space, gr.optimization_config)
exp.new_trial(generator_run=gr)
save_experiment(exp)
# Make sure that all relevant decoding functions were called with
# `reduced_state=True` and correct number of times.
loaded_experiment = load_experiment(
exp.name,
reduced_state=True,
skip_runners_and_metrics=skip_runners_and_metrics,
)
loaded_experiment.runner = exp.runner
loaded_experiment._trials[0]._runner = exp._trials[0]._runner
self.assertTrue(_mock_exp_from_sqa.call_args[1].get("reduced_state"))
self.assertTrue(_mock_trial_from_sqa.call_args[1].get("reduced_state"))
# 2 generator runs from trial #0 + 1 from trial #1.
self.assertTrue(_mock_gr_from_sqa.call_args[1].get("reduced_state"))
self.assertNotEqual(loaded_experiment, exp)
# Remove all fields that are not part of the reduced state and
# check that everything else is equal as expected.
exp.trials.get(1).generator_run._model_kwargs = None
exp.trials.get(1).generator_run._bridge_kwargs = None
exp.trials.get(1).generator_run._gen_metadata = None
exp.trials.get(1).generator_run._model_state_after_gen = None
exp.trials.get(1).generator_run._search_space = None
exp.trials.get(1).generator_run._optimization_config = None
self.assertEqual(loaded_experiment, exp)
delete_experiment(exp_name=exp.name)

def test_ExperimentSaveAndLoadGRWithOptConfig(self) -> None:
exp = get_experiment_with_batch_trial(constrain_search_space=False)
gr = Models.SOBOL(experiment=exp).gen(
n=1, optimization_config=exp.optimization_config
)
exp.new_trial(generator_run=gr)
save_experiment(exp)
# Make sure that all relevant decoding functions were called with
# `reduced_state=True` and correct number of times.
loaded_experiment = load_experiment(exp.name, reduced_state=True)
self.assertTrue(_mock_exp_from_sqa.call_args[1].get("reduced_state"))
self.assertTrue(_mock_trial_from_sqa.call_args[1].get("reduced_state"))
# 2 generator runs from trial #0 + 1 from trial #1.
self.assertTrue(_mock_gr_from_sqa.call_args[1].get("reduced_state"))
self.assertNotEqual(loaded_experiment, exp)
# Remove all fields that are not part of the reduced state and
# check that everything else is equal as expected.
exp.trials.get(1).generator_run._model_kwargs = None
exp.trials.get(1).generator_run._bridge_kwargs = None
exp.trials.get(1).generator_run._gen_metadata = None
exp.trials.get(1).generator_run._model_state_after_gen = None
exp.trials.get(1).generator_run._search_space = None
exp.trials.get(1).generator_run._optimization_config = None
self.assertEqual(loaded_experiment, exp)
loaded_experiment = load_experiment(
exp.name,
reduced_state=False,
skip_runners_and_metrics=True,
)
self.assertEqual(loaded_experiment.trials[1], exp.trials[1])

def test_MTExperimentSaveAndLoad(self) -> None:
experiment = get_multi_type_experiment(add_trials=True)
Expand Down Expand Up @@ -1615,7 +1659,7 @@ def test_ImmutableSearchSpaceAndOptConfigLoading(
_mock_get_gs_sqa_imm_oc_ss,
_mock_gr_from_sqa,
) -> None:
experiment = get_experiment_with_batch_trial()
experiment = get_experiment_with_batch_trial(constrain_search_space=False)
experiment._properties = {Keys.IMMUTABLE_SEARCH_SPACE_AND_OPT_CONF: True}
save_experiment(experiment)

Expand Down
Loading