Skip to content

Commit

Permalink
Load Experiment without runners and metrics in the case where search …
Browse files Browse the repository at this point in the history
…space and optimization config are immutable (facebook#1656)

Summary:
Pull Request resolved: facebook#1656

`load_experiment` was previously failing when `skip_runners_and_metrics=True` for experiments with immutable search space and optimization config. See [Lena's comment](https://www.internalfb.com/diff/D46595953?dst_version_fbid=263061986396665&transaction_fbid=639853311380999) for more detail.

Differential Revision: D46595953

fbshipit-source-id: 89618fbeb19c47392f30709856730721d01ffbd9
  • Loading branch information
Bernie Beckerman authored and facebook-github-bot committed Jun 13, 2023
1 parent d9ec05b commit 062b975
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 27 deletions.
12 changes: 8 additions & 4 deletions ax/storage/sqa_store/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from ax.utils.common.constants import Keys
from ax.utils.common.typeutils import not_none
from sqlalchemy.orm import defaultload, lazyload, noload
from sqlalchemy.orm.exc import DetachedInstanceError


# ---------------------------- Loading `Experiment`. ---------------------------
Expand Down Expand Up @@ -135,10 +136,13 @@ def _load_experiment(
for sqa_metric in experiment_sqa.metrics:
sqa_metric.metric_type = 0

for sqa_trial in experiment_sqa.trials:
for sqa_generator_run in sqa_trial.generator_runs:
for sqa_metric in sqa_generator_run.metrics:
sqa_metric.metric_type = 0
try:
for sqa_trial in experiment_sqa.trials:
for sqa_generator_run in sqa_trial.generator_runs:
for sqa_metric in sqa_generator_run.metrics:
sqa_metric.metric_type = 0
except DetachedInstanceError:
pass

return decoder.experiment_from_sqa(
experiment_sqa=experiment_sqa,
Expand Down
56 changes: 33 additions & 23 deletions ax/storage/sqa_store/tests/test_sqa_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,30 +237,40 @@ def testLoadExperimentSkipMetricsAndRunners(self) -> None:
runner_registry=runner_registry,
)

# Save the experiment to db using the updated registries.
save_experiment(experiment, config=sqa_config)

# At this point try to load the experiment back without specifying
# updated registries. Confirm that this attempt fails.
with self.assertRaises(SQADecodeError):
loaded_experiment = load_experiment(self.experiment.name)

# Now load it with the skip_runners_and_metrics argument set.
# The experiment should load (i.e. no exceptions raised)
loaded_experiment = load_experiment(
self.experiment.name, skip_runners_and_metrics=True
)
for immutable in [True, False]:
experiment = get_experiment_with_custom_runner_and_metric(
constrain_search_space=False
)
if immutable:
experiment._properties = {
Keys.IMMUTABLE_SEARCH_SPACE_AND_OPT_CONF: True
}

# Save the experiment to db using the updated registries.
save_experiment(experiment, config=sqa_config)

# At this point try to load the experiment back without specifying
# updated registries. Confirm that this attempt fails.
with self.assertRaises(SQADecodeError):
loaded_experiment = load_experiment(experiment.name)

# Now load it with the skip_runners_and_metrics argument set.
# The experiment should load (i.e. no exceptions raised)
loaded_experiment = load_experiment(
experiment.name, skip_runners_and_metrics=True
)

# Validate that:
# - the runner is not loaded
# - the metric is loaded as a base Metric class (i.e. not CustomTestMetric)
self.assertIs(loaded_experiment.runner, None)
self.assertTrue("custom_test_metric" in loaded_experiment.metrics)
self.assertEqual(
loaded_experiment.metrics["custom_test_metric"].__class__, Metric
)
self.assertEqual(len(loaded_experiment.trials), 1)
self.assertIs(loaded_experiment.trials[0].runner, None)
# Validate that:
# - the runner is not loaded
# - the metric is loaded as a base Metric class, not CustomTestMetric
self.assertIs(loaded_experiment.runner, None)
self.assertTrue("custom_test_metric" in loaded_experiment.metrics)
self.assertEqual(
loaded_experiment.metrics["custom_test_metric"].__class__, Metric
)
self.assertEqual(len(loaded_experiment.trials), 1)
self.assertIs(loaded_experiment.trials[0].runner, None)
delete_experiment(exp_name=experiment.name)

@patch(
f"{Decoder.__module__}.Decoder.generator_run_from_sqa",
Expand Down

0 comments on commit 062b975

Please sign in to comment.