Skip to content

Commit

Permalink
Tests for sequential=True as default for models in ModelBridge factory (
Browse files Browse the repository at this point in the history
#2324)

Summary:

Ax #1782 removed usages of sequential=True where it was silently ignored. This didn't have an effect, or at least as of today it doesn't, because sequential is set to True by default. This PR adds tests to make sure "sequential" is being passed through appropriately. The exception is cases where optimize_acqf_list is used; optimize_acqf_list is inherently sequential and doesn't accept that argument.

Differential Revision: D55772746
  • Loading branch information
esantorella authored and facebook-github-bot committed Apr 4, 2024
1 parent 2b2f84a commit 95f4ff3
Showing 1 changed file with 44 additions and 18 deletions.
62 changes: 44 additions & 18 deletions ax/modelbridge/tests/test_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
)
from ax.utils.testing.mock import fast_botorch_optimize
from botorch.models.multitask import MultiTaskGP
from botorch.optim.optimize import optimize_acqf, optimize_acqf_list


# pyre-fixme[3]: Return type must be annotated.
Expand Down Expand Up @@ -191,19 +192,17 @@ def test_uniform(self) -> None:


class ModelBridgeFactoryTestMultiObjective(TestCase):
# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[2]: Parameter must be annotated.
def test_single_objective_error(self, factory_fn=get_MOO_RS):
def test_single_objective_error(self, factory_fn=get_MOO_RS) -> None:
single_obj_exp = get_branin_experiment(with_batch=True)
with self.assertRaises(ValueError):
factory_fn(
experiment=single_obj_exp,
data=single_obj_exp.fetch_data(),
)

# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[2]: Parameter must be annotated.
def test_data_error_and_get_multi_obj_exp(self, factory_fn=get_MOO_RS):
def test_data_error_and_get_multi_obj_exp(self, factory_fn=get_MOO_RS) -> None:
multi_obj_exp = get_branin_experiment_with_multi_objective(with_batch=True)
with self.assertRaises(ValueError):
factory_fn(experiment=multi_obj_exp, data=multi_obj_exp.fetch_data())
Expand All @@ -221,11 +220,16 @@ def test_MOO_RS(self) -> None:
{
"acquisition_function_kwargs": {
"random_scalarization": True,
}
},
},
moo_rs._default_model_gen_options,
)
moo_rs_run = moo_rs.gen(n=2)
with mock.patch(
"ax.models.torch.botorch_moo_defaults.optimize_acqf_list",
wraps=optimize_acqf_list,
) as mock_optimize_acqf_list:
moo_rs_run = moo_rs.gen(n=2)
mock_optimize_acqf_list.assert_called()
self.assertEqual(len(moo_rs_run.arms), 2)

@fast_botorch_optimize
Expand Down Expand Up @@ -255,8 +259,10 @@ def test_MOO_PAREGO(self) -> None:
def test_MOO_EHVI(self) -> None:
self.test_single_objective_error(get_MOO_EHVI)
multi_obj_exp, optimization_config = get_multi_obj_exp_and_opt_config()
# ValueError: MultiObjectiveOptimization requires non-empty data.
with self.assertRaises(ValueError):
with self.assertRaisesRegex(
ValueError,
"MultiObjectiveOptimization requires non-empty data.",
):
get_MOO_EHVI(
experiment=multi_obj_exp,
data=multi_obj_exp.fetch_data(),
Expand All @@ -270,8 +276,13 @@ def test_MOO_EHVI(self) -> None:
optimization_config=optimization_config,
)
self.assertIsInstance(moo_ehvi, TorchModelBridge)
moo_ehvi_run = moo_ehvi.gen(n=1)
with mock.patch(
"ax.models.torch.botorch_defaults.optimize_acqf", wraps=optimize_acqf
) as mock_optimize_acqf:
moo_ehvi_run = moo_ehvi.gen(n=1)
self.assertEqual(len(moo_ehvi_run.arms), 1)
mock_optimize_acqf.assert_called_once()
self.assertTrue(mock_optimize_acqf.call_args.kwargs["sequential"])

@fast_botorch_optimize
def test_MTGP_PAREGO(self) -> None:
Expand Down Expand Up @@ -303,11 +314,16 @@ def test_MTGP_PAREGO(self) -> None:
self.assertIsInstance(mt_ehvi.model.model.models[0], MultiTaskGP)
task_covar_factor = mt_ehvi.model.model.models[0].task_covar_module.covar_factor
self.assertEqual(task_covar_factor.shape, torch.Size([2, 2]))
mt_ehvi_run = mt_ehvi.gen(
n=1,
fixed_features=ObservationFeatures(parameters={}, trial_index=1),
)
with mock.patch(
"ax.models.torch.botorch_moo_defaults.optimize_acqf_list",
wraps=optimize_acqf_list,
) as mock_optimize_acqf_list:
mt_ehvi_run = mt_ehvi.gen(
n=1,
fixed_features=ObservationFeatures(parameters={}, trial_index=1),
)
self.assertEqual(len(mt_ehvi_run.arms), 1)
mock_optimize_acqf_list.assert_called_once()

# Bad index given
with self.assertRaises(ValueError):
Expand Down Expand Up @@ -347,8 +363,13 @@ def test_MOO_NEHVI(self) -> None:
optimization_config=optimization_config,
)
self.assertIsInstance(moo_ehvi, TorchModelBridge)
moo_ehvi_run = moo_ehvi.gen(n=1)
with mock.patch(
"ax.models.torch.botorch_defaults.optimize_acqf", wraps=optimize_acqf
) as mock_optimize_acqf:
moo_ehvi_run = moo_ehvi.gen(n=1)
self.assertEqual(len(moo_ehvi_run.arms), 1)
mock_optimize_acqf.assert_called_once()
self.assertTrue(mock_optimize_acqf.call_args.kwargs["sequential"])

@fast_botorch_optimize
def test_MOO_with_more_outcomes_than_thresholds(self) -> None:
Expand Down Expand Up @@ -443,10 +464,15 @@ def test_MTGP_NEHVI(self) -> None:
self.assertIsInstance(mt_ehvi.model.model.models[0], MultiTaskGP)
task_covar_factor = mt_ehvi.model.model.models[0].task_covar_module.covar_factor
self.assertEqual(task_covar_factor.shape, torch.Size([2, 2]))
mt_ehvi_run = mt_ehvi.gen(
n=1,
fixed_features=ObservationFeatures(parameters={}, trial_index=1),
)
with mock.patch(
"ax.models.torch.botorch_defaults.optimize_acqf", wraps=optimize_acqf
) as mock_optimize_acqf:
mt_ehvi_run = mt_ehvi.gen(
n=1,
fixed_features=ObservationFeatures(parameters={}, trial_index=1),
)
mock_optimize_acqf.assert_called_once()
self.assertTrue(mock_optimize_acqf.call_args.kwargs["sequential"])
self.assertEqual(len(mt_ehvi_run.arms), 1)

# Bad index given
Expand Down

0 comments on commit 95f4ff3

Please sign in to comment.