From 95f4ff3a176adc93c2f476126528903386f17231 Mon Sep 17 00:00:00 2001 From: Elizabeth Santorella Date: Thu, 4 Apr 2024 16:01:03 -0700 Subject: [PATCH] Tests for sequential=True as default for models in ModelBridge factory (#2324) Summary: Ax #1782 removed usages of sequential=True where it was silently ignored. This didn't have an effect, or at least as of today it doesn't, because sequential is set to True by default. This PR adds tests to make sure "sequential" is being passed through appropriately. The exception is cases where optimize_acqf_list is used; optimize_acqf_list is inherently sequential and doesn't accept that argument. Differential Revision: D55772746 --- ax/modelbridge/tests/test_factory.py | 62 ++++++++++++++++++++-------- 1 file changed, 44 insertions(+), 18 deletions(-) diff --git a/ax/modelbridge/tests/test_factory.py b/ax/modelbridge/tests/test_factory.py index 56024045ab6..737a340be5c 100644 --- a/ax/modelbridge/tests/test_factory.py +++ b/ax/modelbridge/tests/test_factory.py @@ -48,6 +48,7 @@ ) from ax.utils.testing.mock import fast_botorch_optimize from botorch.models.multitask import MultiTaskGP +from botorch.optim.optimize import optimize_acqf, optimize_acqf_list # pyre-fixme[3]: Return type must be annotated. @@ -191,9 +192,8 @@ def test_uniform(self) -> None: class ModelBridgeFactoryTestMultiObjective(TestCase): - # pyre-fixme[3]: Return type must be annotated. # pyre-fixme[2]: Parameter must be annotated. - def test_single_objective_error(self, factory_fn=get_MOO_RS): + def test_single_objective_error(self, factory_fn=get_MOO_RS) -> None: single_obj_exp = get_branin_experiment(with_batch=True) with self.assertRaises(ValueError): factory_fn( @@ -201,9 +201,8 @@ def test_single_objective_error(self, factory_fn=get_MOO_RS): data=single_obj_exp.fetch_data(), ) - # pyre-fixme[3]: Return type must be annotated. # pyre-fixme[2]: Parameter must be annotated. - def test_data_error_and_get_multi_obj_exp(self, factory_fn=get_MOO_RS): + def test_data_error_and_get_multi_obj_exp(self, factory_fn=get_MOO_RS) -> None: multi_obj_exp = get_branin_experiment_with_multi_objective(with_batch=True) with self.assertRaises(ValueError): factory_fn(experiment=multi_obj_exp, data=multi_obj_exp.fetch_data()) @@ -221,11 +220,16 @@ def test_MOO_RS(self) -> None: { "acquisition_function_kwargs": { "random_scalarization": True, - } + }, }, moo_rs._default_model_gen_options, ) - moo_rs_run = moo_rs.gen(n=2) + with mock.patch( + "ax.models.torch.botorch_moo_defaults.optimize_acqf_list", + wraps=optimize_acqf_list, + ) as mock_optimize_acqf_list: + moo_rs_run = moo_rs.gen(n=2) + mock_optimize_acqf_list.assert_called() self.assertEqual(len(moo_rs_run.arms), 2) @fast_botorch_optimize @@ -255,8 +259,10 @@ def test_MOO_PAREGO(self) -> None: def test_MOO_EHVI(self) -> None: self.test_single_objective_error(get_MOO_EHVI) multi_obj_exp, optimization_config = get_multi_obj_exp_and_opt_config() - # ValueError: MultiObjectiveOptimization requires non-empty data. - with self.assertRaises(ValueError): + with self.assertRaisesRegex( + ValueError, + "MultiObjectiveOptimization requires non-empty data.", + ): get_MOO_EHVI( experiment=multi_obj_exp, data=multi_obj_exp.fetch_data(), @@ -270,8 +276,13 @@ def test_MOO_EHVI(self) -> None: optimization_config=optimization_config, ) self.assertIsInstance(moo_ehvi, TorchModelBridge) - moo_ehvi_run = moo_ehvi.gen(n=1) + with mock.patch( + "ax.models.torch.botorch_defaults.optimize_acqf", wraps=optimize_acqf + ) as mock_optimize_acqf: + moo_ehvi_run = moo_ehvi.gen(n=1) self.assertEqual(len(moo_ehvi_run.arms), 1) + mock_optimize_acqf.assert_called_once() + self.assertTrue(mock_optimize_acqf.call_args.kwargs["sequential"]) @fast_botorch_optimize def test_MTGP_PAREGO(self) -> None: @@ -303,11 +314,16 @@ def test_MTGP_PAREGO(self) -> None: self.assertIsInstance(mt_ehvi.model.model.models[0], MultiTaskGP) task_covar_factor = mt_ehvi.model.model.models[0].task_covar_module.covar_factor self.assertEqual(task_covar_factor.shape, torch.Size([2, 2])) - mt_ehvi_run = mt_ehvi.gen( - n=1, - fixed_features=ObservationFeatures(parameters={}, trial_index=1), - ) + with mock.patch( + "ax.models.torch.botorch_moo_defaults.optimize_acqf_list", + wraps=optimize_acqf_list, + ) as mock_optimize_acqf_list: + mt_ehvi_run = mt_ehvi.gen( + n=1, + fixed_features=ObservationFeatures(parameters={}, trial_index=1), + ) self.assertEqual(len(mt_ehvi_run.arms), 1) + mock_optimize_acqf_list.assert_called_once() # Bad index given with self.assertRaises(ValueError): @@ -347,8 +363,13 @@ def test_MOO_NEHVI(self) -> None: optimization_config=optimization_config, ) self.assertIsInstance(moo_ehvi, TorchModelBridge) - moo_ehvi_run = moo_ehvi.gen(n=1) + with mock.patch( + "ax.models.torch.botorch_defaults.optimize_acqf", wraps=optimize_acqf + ) as mock_optimize_acqf: + moo_ehvi_run = moo_ehvi.gen(n=1) self.assertEqual(len(moo_ehvi_run.arms), 1) + mock_optimize_acqf.assert_called_once() + self.assertTrue(mock_optimize_acqf.call_args.kwargs["sequential"]) @fast_botorch_optimize def test_MOO_with_more_outcomes_than_thresholds(self) -> None: @@ -443,10 +464,15 @@ def test_MTGP_NEHVI(self) -> None: self.assertIsInstance(mt_ehvi.model.model.models[0], MultiTaskGP) task_covar_factor = mt_ehvi.model.model.models[0].task_covar_module.covar_factor self.assertEqual(task_covar_factor.shape, torch.Size([2, 2])) - mt_ehvi_run = mt_ehvi.gen( - n=1, - fixed_features=ObservationFeatures(parameters={}, trial_index=1), - ) + with mock.patch( + "ax.models.torch.botorch_defaults.optimize_acqf", wraps=optimize_acqf + ) as mock_optimize_acqf: + mt_ehvi_run = mt_ehvi.gen( + n=1, + fixed_features=ObservationFeatures(parameters={}, trial_index=1), + ) + mock_optimize_acqf.assert_called_once() + self.assertTrue(mock_optimize_acqf.call_args.kwargs["sequential"]) self.assertEqual(len(mt_ehvi_run.arms), 1) # Bad index given