Skip to content

Commit

Permalink
Remove Models.MOO (facebook#3030)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: facebook#3030

Deprecates `Models.MOO` and points it to `Models.BOTORCH_MODULAR` for backwards compatibility.

Reviewed By: Balandat

Differential Revision: D65550955

fbshipit-source-id: 4b806ffd48992f965267cdcb8798ee6f5de6cd03
  • Loading branch information
saitcakmak authored and facebook-github-bot committed Nov 7, 2024
1 parent 6831902 commit 4047204
Show file tree
Hide file tree
Showing 5 changed files with 18 additions and 23 deletions.
14 changes: 7 additions & 7 deletions ax/modelbridge/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,12 +188,6 @@ class ModelSetup(NamedTuple):
model_class=UniformGenerator,
transforms=Cont_X_trans,
),
"MOO": ModelSetup(
bridge_class=TorchModelBridge,
model_class=MultiObjectiveBotorchModel,
transforms=Cont_X_trans + Y_trans,
standard_bridge_kwargs=STANDARD_TORCH_BRIDGE_KWARGS,
),
"ST_MTGP_LEGACY": ModelSetup(
bridge_class=TorchModelBridge,
model_class=BotorchModel,
Expand Down Expand Up @@ -433,13 +427,19 @@ class Models(ModelRegistryBase):
BOTORCH_MODULAR = "BoTorch"
EMPIRICAL_BAYES_THOMPSON = "EB"
UNIFORM = "Uniform"
MOO = "MOO"
ST_MTGP_LEGACY = "ST_MTGP_LEGACY"
ST_MTGP = "ST_MTGP"
BO_MIXED = "BO_MIXED"
ST_MTGP_NEHVI = "ST_MTGP_NEHVI"
CONTEXT_SACBO = "Contextual_SACBO"

@classmethod
@property
def MOO(cls) -> Models:
return _deprecated_model_with_warning(
old_model_str="MOO", new_model=cls.BOTORCH_MODULAR
)

@classmethod
@property
def GPEI(cls) -> Models:
Expand Down
1 change: 1 addition & 0 deletions ax/modelbridge/tests/test_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,6 +505,7 @@ def test_deprecated_models(self) -> None:
same check in a couple different ways.
"""
for old_model_str, new_model in [
("MOO", Models.BOTORCH_MODULAR),
("GPEI", Models.BOTORCH_MODULAR),
("FULLYBAYESIAN", Models.SAASBO),
("FULLYBAYESIANMOO", Models.SAASBO),
Expand Down
2 changes: 1 addition & 1 deletion ax/service/tests/test_ax_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ def get_client_with_simple_discrete_moo_problem(
gs = GenerationStrategy(
steps=[
GenerationStep(model=Models.SOBOL, num_trials=3),
GenerationStep(model=Models.MOO, num_trials=-1),
GenerationStep(model=Models.BOTORCH_MODULAR, num_trials=-1),
]
)

Expand Down
22 changes: 8 additions & 14 deletions ax/service/tests/test_report_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,29 +442,21 @@ def test_get_standard_plots_moo(self) -> None:
# https://bugs.python.org/issue41943 for more information.
with self.assertLogs(logger="ax", level=INFO) as log:
plots = get_standard_plots(
experiment=exp, model=Models.MOO(experiment=exp, data=exp.fetch_data())
experiment=exp,
model=Models.BOTORCH_MODULAR(experiment=exp, data=exp.fetch_data()),
)
self.assertEqual(len(log.output), 5)
self.assertEqual(len(log.output), 3)
self.assertIn(
"Pareto plotting not supported for experiments with relative objective "
"thresholds.",
log.output[0],
)
self.assertIn(
"Failed to compute signed global feature sensitivities",
log.output[1],
)
self.assertIn(
"Failed to compute unsigned feature sensitivities:",
log.output[2],
)
created_plots_logs = set(log.output[2:])
for metric_suffix in ("a", "b"):
expected_msg = (
"Created contour plots for metric branin_"
f"{metric_suffix} and parameters ['x2', 'x1']"
)
self.assertTrue(any(expected_msg in msg for msg in created_plots_logs))
self.assertTrue(any(expected_msg in msg for msg in log.output[1:]))
self.assertEqual(len(plots), 6)

@mock_botorch_optimize
Expand All @@ -489,7 +481,8 @@ def test_get_standard_plots_moo_relative_constraints(self) -> None:
)._objective_thresholds:
ot.relative = False
plots = get_standard_plots(
experiment=exp, model=Models.MOO(experiment=exp, data=exp.fetch_data())
experiment=exp,
model=Models.BOTORCH_MODULAR(experiment=exp, data=exp.fetch_data()),
)
self.assertEqual(len(plots), 8)

Expand All @@ -500,7 +493,8 @@ def test_get_standard_plots_moo_no_objective_thresholds(self) -> None:
exp.optimization_config.objective.objectives[1].minimize = True
exp.trials[0].run()
plots = get_standard_plots(
experiment=exp, model=Models.MOO(experiment=exp, data=exp.fetch_data())
experiment=exp,
model=Models.BOTORCH_MODULAR(experiment=exp, data=exp.fetch_data()),
)
self.assertEqual(len(plots), 8)

Expand Down
2 changes: 1 addition & 1 deletion ax/service/utils/best_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,7 +558,7 @@ def get_pareto_optimal_parameters(
if is_moo_modelbridge:
generation_strategy._fit_current_model(data=None)
else:
modelbridge = Models.MOO(
modelbridge = Models.BOTORCH_MODULAR(
experiment=experiment,
data=checked_cast(
Data, experiment.lookup_data(trial_indices=trial_indices)
Expand Down

0 comments on commit 4047204

Please sign in to comment.