diff --git a/ax/modelbridge/registry.py b/ax/modelbridge/registry.py index 27f0ea653a8..0ed0b14732c 100644 --- a/ax/modelbridge/registry.py +++ b/ax/modelbridge/registry.py @@ -188,12 +188,6 @@ class ModelSetup(NamedTuple): model_class=UniformGenerator, transforms=Cont_X_trans, ), - "MOO": ModelSetup( - bridge_class=TorchModelBridge, - model_class=MultiObjectiveBotorchModel, - transforms=Cont_X_trans + Y_trans, - standard_bridge_kwargs=STANDARD_TORCH_BRIDGE_KWARGS, - ), "ST_MTGP_LEGACY": ModelSetup( bridge_class=TorchModelBridge, model_class=BotorchModel, @@ -433,13 +427,19 @@ class Models(ModelRegistryBase): BOTORCH_MODULAR = "BoTorch" EMPIRICAL_BAYES_THOMPSON = "EB" UNIFORM = "Uniform" - MOO = "MOO" ST_MTGP_LEGACY = "ST_MTGP_LEGACY" ST_MTGP = "ST_MTGP" BO_MIXED = "BO_MIXED" ST_MTGP_NEHVI = "ST_MTGP_NEHVI" CONTEXT_SACBO = "Contextual_SACBO" + @classmethod + @property + def MOO(cls) -> Models: + return _deprecated_model_with_warning( + old_model_str="MOO", new_model=cls.BOTORCH_MODULAR + ) + @classmethod @property def GPEI(cls) -> Models: diff --git a/ax/modelbridge/tests/test_registry.py b/ax/modelbridge/tests/test_registry.py index 20f17d7ce70..c685d11f5a0 100644 --- a/ax/modelbridge/tests/test_registry.py +++ b/ax/modelbridge/tests/test_registry.py @@ -505,6 +505,7 @@ def test_deprecated_models(self) -> None: same check in a couple different ways. """ for old_model_str, new_model in [ + ("MOO", Models.BOTORCH_MODULAR), ("GPEI", Models.BOTORCH_MODULAR), ("FULLYBAYESIAN", Models.SAASBO), ("FULLYBAYESIANMOO", Models.SAASBO), diff --git a/ax/service/tests/test_ax_client.py b/ax/service/tests/test_ax_client.py index 133ad0903e4..2e9bdfff2c5 100644 --- a/ax/service/tests/test_ax_client.py +++ b/ax/service/tests/test_ax_client.py @@ -220,7 +220,7 @@ def get_client_with_simple_discrete_moo_problem( gs = GenerationStrategy( steps=[ GenerationStep(model=Models.SOBOL, num_trials=3), - GenerationStep(model=Models.MOO, num_trials=-1), + GenerationStep(model=Models.BOTORCH_MODULAR, num_trials=-1), ] ) diff --git a/ax/service/tests/test_report_utils.py b/ax/service/tests/test_report_utils.py index c6ffa3bebff..109ef1476fe 100644 --- a/ax/service/tests/test_report_utils.py +++ b/ax/service/tests/test_report_utils.py @@ -442,29 +442,21 @@ def test_get_standard_plots_moo(self) -> None: # https://bugs.python.org/issue41943 for more information. with self.assertLogs(logger="ax", level=INFO) as log: plots = get_standard_plots( - experiment=exp, model=Models.MOO(experiment=exp, data=exp.fetch_data()) + experiment=exp, + model=Models.BOTORCH_MODULAR(experiment=exp, data=exp.fetch_data()), ) - self.assertEqual(len(log.output), 5) + self.assertEqual(len(log.output), 3) self.assertIn( "Pareto plotting not supported for experiments with relative objective " "thresholds.", log.output[0], ) - self.assertIn( - "Failed to compute signed global feature sensitivities", - log.output[1], - ) - self.assertIn( - "Failed to compute unsigned feature sensitivities:", - log.output[2], - ) - created_plots_logs = set(log.output[2:]) for metric_suffix in ("a", "b"): expected_msg = ( "Created contour plots for metric branin_" f"{metric_suffix} and parameters ['x2', 'x1']" ) - self.assertTrue(any(expected_msg in msg for msg in created_plots_logs)) + self.assertTrue(any(expected_msg in msg for msg in log.output[1:])) self.assertEqual(len(plots), 6) @mock_botorch_optimize @@ -489,7 +481,8 @@ def test_get_standard_plots_moo_relative_constraints(self) -> None: )._objective_thresholds: ot.relative = False plots = get_standard_plots( - experiment=exp, model=Models.MOO(experiment=exp, data=exp.fetch_data()) + experiment=exp, + model=Models.BOTORCH_MODULAR(experiment=exp, data=exp.fetch_data()), ) self.assertEqual(len(plots), 8) @@ -500,7 +493,8 @@ def test_get_standard_plots_moo_no_objective_thresholds(self) -> None: exp.optimization_config.objective.objectives[1].minimize = True exp.trials[0].run() plots = get_standard_plots( - experiment=exp, model=Models.MOO(experiment=exp, data=exp.fetch_data()) + experiment=exp, + model=Models.BOTORCH_MODULAR(experiment=exp, data=exp.fetch_data()), ) self.assertEqual(len(plots), 8) diff --git a/ax/service/utils/best_point.py b/ax/service/utils/best_point.py index 7dd4d9fcb93..4b8f4858cb8 100644 --- a/ax/service/utils/best_point.py +++ b/ax/service/utils/best_point.py @@ -558,7 +558,7 @@ def get_pareto_optimal_parameters( if is_moo_modelbridge: generation_strategy._fit_current_model(data=None) else: - modelbridge = Models.MOO( + modelbridge = Models.BOTORCH_MODULAR( experiment=experiment, data=checked_cast( Data, experiment.lookup_data(trial_indices=trial_indices)