Skip to content

Commit

Permalink
Allow batch trial to be constructed with a list of GRs (#1995)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1995

This makes creating trials from `GS.gen_for_multiple_trial_from_multiple()` models easier in D51307866 and is closer to the way `BatchTrial`s actually work.

Reviewed By: lena-kashtelyan

Differential Revision: D51211147

fbshipit-source-id: eb57c2c77c06959bdeb3bbaa37bb0c98b8c85699
  • Loading branch information
Daniel Cohen authored and facebook-github-bot committed Nov 22, 2023
1 parent 46fef44 commit 50790cc
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 1 deletion.
14 changes: 13 additions & 1 deletion ax/core/batch_trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
TEvaluationOutcome,
validate_evaluation_outcome,
)
from ax.exceptions.core import AxError, UserInputError
from ax.exceptions.core import AxError, UnsupportedError, UserInputError
from ax.utils.common.base import SortableBase
from ax.utils.common.docutils import copy_doc
from ax.utils.common.equality import datetime_equals, equality_typechecker
Expand Down Expand Up @@ -117,6 +117,10 @@ class BatchTrial(BaseTrial):
generator_run: GeneratorRun, associated with this trial. This can a
also be set later through `add_arm` or `add_generator_run`, but a
trial's associated generator run is immutable once set.
generator_runs: GeneratorRuns, associated with this trial. This can a
also be set later through `add_arm` or `add_generator_run`, but a
trial's associated generator run is immutable once set. This cannot
be combined with the `generator_run` argument.
trial_type: Type of this trial, if used in MultiTypeExperiment.
optimize_for_power: Whether to optimize the weights of arms in this
trial such that the experiment's power to detect effects of
Expand All @@ -140,6 +144,7 @@ def __init__(
self,
experiment: core.experiment.Experiment,
generator_run: Optional[GeneratorRun] = None,
generator_runs: Optional[List[GeneratorRun]] = None,
trial_type: Optional[str] = None,
optimize_for_power: Optional[bool] = False,
ttl_seconds: Optional[int] = None,
Expand All @@ -158,7 +163,14 @@ def __init__(
self._status_quo: Optional[Arm] = None
self._status_quo_weight_override: Optional[float] = None
if generator_run is not None:
if generator_runs is not None:
raise UnsupportedError(
"Cannot specify both `generator_run` and `generator_runs`."
)
self.add_generator_run(generator_run=generator_run)
elif generator_runs is not None:
for gr in generator_runs:
self.add_generator_run(generator_run=gr)

self.optimize_for_power = optimize_for_power
status_quo = experiment.status_quo
Expand Down
6 changes: 6 additions & 0 deletions ax/core/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -1059,6 +1059,7 @@ def new_trial(
def new_batch_trial(
self,
generator_run: Optional[GeneratorRun] = None,
generator_runs: Optional[List[GeneratorRun]] = None,
trial_type: Optional[str] = None,
optimize_for_power: Optional[bool] = False,
ttl_seconds: Optional[int] = None,
Expand All @@ -1070,6 +1071,10 @@ def new_batch_trial(
generator_run: GeneratorRun, associated with this trial. This can a
also be set later through `add_arm` or `add_generator_run`, but a
trial's associated generator run is immutable once set.
generator_runs: GeneratorRuns, associated with this trial. This can a
also be set later through `add_arm` or `add_generator_run`, but a
trial's associated generator run is immutable once set. This cannot
be combined with the `generator_run` argument.
trial_type: Type of this trial, if used in MultiTypeExperiment.
optimize_for_power: Whether to optimize the weights of arms in this
trial such that the experiment's power to detect effects of
Expand All @@ -1090,6 +1095,7 @@ def new_batch_trial(
experiment=self,
trial_type=trial_type,
generator_run=generator_run,
generator_runs=generator_runs,
optimize_for_power=optimize_for_power,
ttl_seconds=ttl_seconds,
lifecycle_stage=lifecycle_stage,
Expand Down
36 changes: 36 additions & 0 deletions ax/core/tests/test_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
get_status_quo,
get_test_map_data_experiment,
)
from ax.utils.testing.mock import fast_botorch_optimize

DUMMY_RUN_METADATA_KEY = "test_run_metadata_key"
DUMMY_RUN_METADATA_VALUE = "test_run_metadata_value"
Expand Down Expand Up @@ -1254,3 +1255,38 @@ def test_WarmStartMapData(self) -> None:
old_df.drop(["arm_name", "trial_index"], axis=1),
new_df.drop(["arm_name", "trial_index"], axis=1),
)

@fast_botorch_optimize
def test_batch_with_multiple_generator_runs(self) -> None:
exp = get_branin_experiment()
sobol = Models.SOBOL(experiment=exp, search_space=exp.search_space)
exp.new_batch_trial(generator_runs=[sobol.gen(n=7)]).run().complete()

data = exp.fetch_data()
gp = Models.BOTORCH_MODULAR(
experiment=exp, search_space=exp.search_space, data=data
)
ts = Models.EMPIRICAL_BAYES_THOMPSON(
experiment=exp, search_space=exp.search_space, data=data
)
exp.new_batch_trial(generator_runs=[gp.gen(n=3), ts.gen(n=1)]).run().complete()

self.assertEqual(len(exp.trials), 2)
self.assertEqual(len(exp.trials[0].generator_runs), 1)
self.assertEqual(len(exp.trials[0].arms), 7)
self.assertEqual(len(exp.trials[1].generator_runs), 2)
self.assertEqual(len(exp.trials[1].arms), 4)

def test_it_does_not_take_both_single_and_multiple_gr_ars(self) -> None:
exp = get_branin_experiment()
sobol = Models.SOBOL(experiment=exp, search_space=exp.search_space)
gr1 = sobol.gen(n=7)
gr2 = sobol.gen(n=7)
with self.assertRaisesRegex(
UnsupportedError,
"Cannot specify both `generator_run` and `generator_runs`.",
):
exp.new_batch_trial(
generator_run=gr1,
generator_runs=[gr2],
)

0 comments on commit 50790cc

Please sign in to comment.