Skip to content

Commit

Permalink
Remove Models.ST_MTGP_LEGACY & _NEHVI (facebook#3031)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: facebook#3031

Cleaning up legacy model entries and updating them to point to the MBM variant of `Models.ST_MTGP` to retain backwards compatibility.

Also cleaned up a few references of `Models.FULLYBAYESIAN...` while at it. These all point to corresponding MBM models already.

Reviewed By: Balandat

Differential Revision: D65568608

fbshipit-source-id: e706e6cdf3b6865be7654f1b759b59c7a727ca9d
  • Loading branch information
saitcakmak authored and facebook-github-bot committed Nov 7, 2024
1 parent 4047204 commit 22bd1b2
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 70 deletions.
29 changes: 14 additions & 15 deletions ax/modelbridge/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@
from ax.models.torch.botorch import BotorchModel
from ax.models.torch.botorch_modular.model import BoTorchModel as ModularBoTorchModel
from ax.models.torch.botorch_modular.surrogate import SurrogateSpec
from ax.models.torch.botorch_moo import MultiObjectiveBotorchModel
from ax.models.torch.cbo_sac import SACBO
from ax.utils.common.kwargs import (
consolidate_kwargs,
Expand Down Expand Up @@ -188,12 +187,6 @@ class ModelSetup(NamedTuple):
model_class=UniformGenerator,
transforms=Cont_X_trans,
),
"ST_MTGP_LEGACY": ModelSetup(
bridge_class=TorchModelBridge,
model_class=BotorchModel,
transforms=ST_MTGP_trans,
standard_bridge_kwargs=STANDARD_TORCH_BRIDGE_KWARGS,
),
"ST_MTGP": ModelSetup(
bridge_class=TorchModelBridge,
model_class=ModularBoTorchModel,
Expand Down Expand Up @@ -228,12 +221,6 @@ class ModelSetup(NamedTuple):
},
standard_bridge_kwargs=STANDARD_TORCH_BRIDGE_KWARGS,
),
"ST_MTGP_NEHVI": ModelSetup(
bridge_class=TorchModelBridge,
model_class=MultiObjectiveBotorchModel,
transforms=ST_MTGP_trans,
standard_bridge_kwargs=STANDARD_TORCH_BRIDGE_KWARGS,
),
"Contextual_SACBO": ModelSetup(
bridge_class=TorchModelBridge,
model_class=SACBO,
Expand Down Expand Up @@ -427,12 +414,24 @@ class Models(ModelRegistryBase):
BOTORCH_MODULAR = "BoTorch"
EMPIRICAL_BAYES_THOMPSON = "EB"
UNIFORM = "Uniform"
ST_MTGP_LEGACY = "ST_MTGP_LEGACY"
ST_MTGP = "ST_MTGP"
BO_MIXED = "BO_MIXED"
ST_MTGP_NEHVI = "ST_MTGP_NEHVI"
CONTEXT_SACBO = "Contextual_SACBO"

@classmethod
@property
def ST_MTGP_LEGACY(cls) -> Models:
return _deprecated_model_with_warning(
old_model_str="ST_MTGP_LEGACY", new_model=cls.ST_MTGP
)

@classmethod
@property
def ST_MTGP_NEHVI(cls) -> Models:
return _deprecated_model_with_warning(
old_model_str="ST_MTGP_NEHVI", new_model=cls.ST_MTGP
)

@classmethod
@property
def MOO(cls) -> Models:
Expand Down
57 changes: 2 additions & 55 deletions ax/modelbridge/tests/test_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
from ax.models.torch.botorch_modular.kernels import ScaleMaternKernel
from ax.models.torch.botorch_modular.model import BoTorchModel
from ax.models.torch.botorch_modular.surrogate import Surrogate, SurrogateSpec
from ax.models.torch.botorch_moo import MultiObjectiveBotorchModel
from ax.utils.common.kwargs import get_function_argument_names
from ax.utils.common.testutils import TestCase
from ax.utils.testing.core_stubs import (
Expand Down Expand Up @@ -364,60 +363,6 @@ def test_ModelSetups_do_not_share_kwargs(self) -> None:
# Intersection of two sets should be empty
self.assertEqual(model_args & bridge_args, set())

@mock_botorch_optimize
def test_ST_MTGP_LEGACY(self) -> None:
"""Tests single type MTGP instantiation."""
# Test Single-type MTGP
exp, status_quo_features = get_branin_experiment_with_status_quo_trials()
mtgp = Models.ST_MTGP_LEGACY(
experiment=exp,
data=exp.fetch_data(),
status_quo_features=status_quo_features,
)
self.assertIsInstance(mtgp, TorchModelBridge)
# Test that it can generate.
mtgp_run = mtgp.gen(
n=1,
fixed_features=ObservationFeatures(parameters={}, trial_index=1),
)
self.assertEqual(len(mtgp_run.arms), 1)

exp, status_quo_features = get_branin_experiment_with_status_quo_trials(
num_sobol_trials=1
)
with self.assertRaisesRegex(ValueError, "TrialAsTask transform expects"):
Models.ST_MTGP_LEGACY(
experiment=exp,
data=exp.fetch_data(),
status_quo_features=status_quo_features,
)

@mock_botorch_optimize
def test_ST_MTGP_NEHVI(self) -> None:
"""Tests single type MTGP NEHVI instantiation."""
exp, status_quo_features = get_branin_experiment_with_status_quo_trials(
num_sobol_trials=2, multi_objective=True
)
mtgp = Models.ST_MTGP_NEHVI(
experiment=exp,
data=exp.fetch_data(),
status_quo_features=status_quo_features,
optimization_config=exp.optimization_config,
)
self.assertIsInstance(mtgp, TorchModelBridge)
self.assertIsInstance(mtgp.model, MultiObjectiveBotorchModel)

# test it can generate
mtgp_run = mtgp.gen(
n=1,
fixed_features=ObservationFeatures(parameters={}, trial_index=1),
)
self.assertEqual(len(mtgp_run.arms), 1)
# test a generated trial can be completed
t = exp.new_batch_trial().add_generator_run(mtgp_run)
t.set_status_quo_with_weight(status_quo=t.arms[0], weight=0.5)
t.run().mark_completed()

@mock_botorch_optimize
def test_ST_MTGP(self, use_saas: bool = False) -> None:
"""Tests single type MTGP via Modular BoTorch instantiation
Expand Down Expand Up @@ -505,6 +450,8 @@ def test_deprecated_models(self) -> None:
same check in a couple different ways.
"""
for old_model_str, new_model in [
("ST_MTGP_NEHVI", Models.ST_MTGP),
("ST_MTGP_LEGACY", Models.ST_MTGP),
("MOO", Models.BOTORCH_MODULAR),
("GPEI", Models.BOTORCH_MODULAR),
("FULLYBAYESIAN", Models.SAASBO),
Expand Down

0 comments on commit 22bd1b2

Please sign in to comment.