From 22bd1b2740bb19ec6f8e6e85509baf4f5ca30a25 Mon Sep 17 00:00:00 2001 From: Sait Cakmak Date: Thu, 7 Nov 2024 06:06:24 -0800 Subject: [PATCH] Remove Models.ST_MTGP_LEGACY & _NEHVI (#3031) Summary: Pull Request resolved: https://github.com/facebook/Ax/pull/3031 Cleaning up legacy model entries and updating them to point to the MBM variant of `Models.ST_MTGP` to retain backwards compatibility. Also cleaned up a few references of `Models.FULLYBAYESIAN...` while at it. These all point to corresponding MBM models already. Reviewed By: Balandat Differential Revision: D65568608 fbshipit-source-id: e706e6cdf3b6865be7654f1b759b59c7a727ca9d --- ax/modelbridge/registry.py | 29 +++++++------- ax/modelbridge/tests/test_registry.py | 57 +-------------------------- 2 files changed, 16 insertions(+), 70 deletions(-) diff --git a/ax/modelbridge/registry.py b/ax/modelbridge/registry.py index 0ed0b14732c..9317ee82f7c 100644 --- a/ax/modelbridge/registry.py +++ b/ax/modelbridge/registry.py @@ -61,7 +61,6 @@ from ax.models.torch.botorch import BotorchModel from ax.models.torch.botorch_modular.model import BoTorchModel as ModularBoTorchModel from ax.models.torch.botorch_modular.surrogate import SurrogateSpec -from ax.models.torch.botorch_moo import MultiObjectiveBotorchModel from ax.models.torch.cbo_sac import SACBO from ax.utils.common.kwargs import ( consolidate_kwargs, @@ -188,12 +187,6 @@ class ModelSetup(NamedTuple): model_class=UniformGenerator, transforms=Cont_X_trans, ), - "ST_MTGP_LEGACY": ModelSetup( - bridge_class=TorchModelBridge, - model_class=BotorchModel, - transforms=ST_MTGP_trans, - standard_bridge_kwargs=STANDARD_TORCH_BRIDGE_KWARGS, - ), "ST_MTGP": ModelSetup( bridge_class=TorchModelBridge, model_class=ModularBoTorchModel, @@ -228,12 +221,6 @@ class ModelSetup(NamedTuple): }, standard_bridge_kwargs=STANDARD_TORCH_BRIDGE_KWARGS, ), - "ST_MTGP_NEHVI": ModelSetup( - bridge_class=TorchModelBridge, - model_class=MultiObjectiveBotorchModel, - transforms=ST_MTGP_trans, - standard_bridge_kwargs=STANDARD_TORCH_BRIDGE_KWARGS, - ), "Contextual_SACBO": ModelSetup( bridge_class=TorchModelBridge, model_class=SACBO, @@ -427,12 +414,24 @@ class Models(ModelRegistryBase): BOTORCH_MODULAR = "BoTorch" EMPIRICAL_BAYES_THOMPSON = "EB" UNIFORM = "Uniform" - ST_MTGP_LEGACY = "ST_MTGP_LEGACY" ST_MTGP = "ST_MTGP" BO_MIXED = "BO_MIXED" - ST_MTGP_NEHVI = "ST_MTGP_NEHVI" CONTEXT_SACBO = "Contextual_SACBO" + @classmethod + @property + def ST_MTGP_LEGACY(cls) -> Models: + return _deprecated_model_with_warning( + old_model_str="ST_MTGP_LEGACY", new_model=cls.ST_MTGP + ) + + @classmethod + @property + def ST_MTGP_NEHVI(cls) -> Models: + return _deprecated_model_with_warning( + old_model_str="ST_MTGP_NEHVI", new_model=cls.ST_MTGP + ) + @classmethod @property def MOO(cls) -> Models: diff --git a/ax/modelbridge/tests/test_registry.py b/ax/modelbridge/tests/test_registry.py index c685d11f5a0..8c30434d202 100644 --- a/ax/modelbridge/tests/test_registry.py +++ b/ax/modelbridge/tests/test_registry.py @@ -30,7 +30,6 @@ from ax.models.torch.botorch_modular.kernels import ScaleMaternKernel from ax.models.torch.botorch_modular.model import BoTorchModel from ax.models.torch.botorch_modular.surrogate import Surrogate, SurrogateSpec -from ax.models.torch.botorch_moo import MultiObjectiveBotorchModel from ax.utils.common.kwargs import get_function_argument_names from ax.utils.common.testutils import TestCase from ax.utils.testing.core_stubs import ( @@ -364,60 +363,6 @@ def test_ModelSetups_do_not_share_kwargs(self) -> None: # Intersection of two sets should be empty self.assertEqual(model_args & bridge_args, set()) - @mock_botorch_optimize - def test_ST_MTGP_LEGACY(self) -> None: - """Tests single type MTGP instantiation.""" - # Test Single-type MTGP - exp, status_quo_features = get_branin_experiment_with_status_quo_trials() - mtgp = Models.ST_MTGP_LEGACY( - experiment=exp, - data=exp.fetch_data(), - status_quo_features=status_quo_features, - ) - self.assertIsInstance(mtgp, TorchModelBridge) - # Test that it can generate. - mtgp_run = mtgp.gen( - n=1, - fixed_features=ObservationFeatures(parameters={}, trial_index=1), - ) - self.assertEqual(len(mtgp_run.arms), 1) - - exp, status_quo_features = get_branin_experiment_with_status_quo_trials( - num_sobol_trials=1 - ) - with self.assertRaisesRegex(ValueError, "TrialAsTask transform expects"): - Models.ST_MTGP_LEGACY( - experiment=exp, - data=exp.fetch_data(), - status_quo_features=status_quo_features, - ) - - @mock_botorch_optimize - def test_ST_MTGP_NEHVI(self) -> None: - """Tests single type MTGP NEHVI instantiation.""" - exp, status_quo_features = get_branin_experiment_with_status_quo_trials( - num_sobol_trials=2, multi_objective=True - ) - mtgp = Models.ST_MTGP_NEHVI( - experiment=exp, - data=exp.fetch_data(), - status_quo_features=status_quo_features, - optimization_config=exp.optimization_config, - ) - self.assertIsInstance(mtgp, TorchModelBridge) - self.assertIsInstance(mtgp.model, MultiObjectiveBotorchModel) - - # test it can generate - mtgp_run = mtgp.gen( - n=1, - fixed_features=ObservationFeatures(parameters={}, trial_index=1), - ) - self.assertEqual(len(mtgp_run.arms), 1) - # test a generated trial can be completed - t = exp.new_batch_trial().add_generator_run(mtgp_run) - t.set_status_quo_with_weight(status_quo=t.arms[0], weight=0.5) - t.run().mark_completed() - @mock_botorch_optimize def test_ST_MTGP(self, use_saas: bool = False) -> None: """Tests single type MTGP via Modular BoTorch instantiation @@ -505,6 +450,8 @@ def test_deprecated_models(self) -> None: same check in a couple different ways. """ for old_model_str, new_model in [ + ("ST_MTGP_NEHVI", Models.ST_MTGP), + ("ST_MTGP_LEGACY", Models.ST_MTGP), ("MOO", Models.BOTORCH_MODULAR), ("GPEI", Models.BOTORCH_MODULAR), ("FULLYBAYESIAN", Models.SAASBO),