Skip to content

Commit

Permalink
Even more fixes to unused kwargs (pytorch#1985)
Browse files Browse the repository at this point in the history
Summary:
X-link: facebook/Ax#1782

Pull Request resolved: pytorch#1985

See previous diff

Reviewed By: lena-kashtelyan

Differential Revision: D48338443

fbshipit-source-id: 7cacc33ead6d5471855364eff910031640bc4708
  • Loading branch information
esantorella authored and facebook-github-bot committed Aug 16, 2023
1 parent e7974cc commit 0b04b96
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 37 deletions.
59 changes: 31 additions & 28 deletions botorch/acquisition/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,16 @@ def get_acquisition_function(
eta: Optional[Union[Tensor, float]] = 1e-3,
mc_samples: int = 512,
seed: Optional[int] = None,
**kwargs,
*,
# optional parameters that are only needed for certain acquisition functions
tau: float = 1e-3,
prune_baseline: bool = True,
marginalize_dim: Optional[int] = None,
cache_root: bool = True,
beta: Optional[float] = None,
ref_point: Union[None, List[float], Tensor] = None,
Y: Optional[Tensor] = None,
alpha: float = 0.0,
) -> monte_carlo.MCAcquisitionFunction:
r"""Convenience function for initializing botorch acquisition functions.
Expand Down Expand Up @@ -149,7 +158,7 @@ def get_acquisition_function(
objective=objective,
posterior_transform=posterior_transform,
X_pending=X_pending,
tau=kwargs.get("tau", 1e-3),
tau=tau,
constraints=constraints,
eta=eta,
)
Expand All @@ -161,9 +170,9 @@ def get_acquisition_function(
objective=objective,
posterior_transform=posterior_transform,
X_pending=X_pending,
prune_baseline=kwargs.get("prune_baseline", True),
marginalize_dim=kwargs.get("marginalize_dim"),
cache_root=kwargs.get("cache_root", True),
prune_baseline=prune_baseline,
marginalize_dim=marginalize_dim,
cache_root=cache_root,
constraints=constraints,
eta=eta,
)
Expand All @@ -177,9 +186,9 @@ def get_acquisition_function(
objective=objective,
posterior_transform=posterior_transform,
X_pending=X_pending,
prune_baseline=kwargs.get("prune_baseline", True),
marginalize_dim=kwargs.get("marginalize_dim"),
cache_root=kwargs.get("cache_root", True),
prune_baseline=prune_baseline,
marginalize_dim=marginalize_dim,
cache_root=cache_root,
constraints=constraints,
eta=eta,
)
Expand All @@ -192,32 +201,26 @@ def get_acquisition_function(
X_pending=X_pending,
)
elif acquisition_function_name == "qUCB":
if "beta" not in kwargs:
raise ValueError("`beta` must be specified in kwargs for qUCB.")
if beta is None:
raise ValueError("`beta` must be not be None for qUCB.")
return monte_carlo.qUpperConfidenceBound(
model=model,
beta=kwargs["beta"],
beta=beta,
sampler=sampler,
objective=objective,
posterior_transform=posterior_transform,
X_pending=X_pending,
)
elif acquisition_function_name == "qEHVI":
# pyre-fixme [16]: `Model` has no attribute `train_targets`
try:
ref_point = kwargs["ref_point"]
except KeyError:
raise ValueError("`ref_point` must be specified in kwargs for qEHVI")
try:
Y = kwargs["Y"]
except KeyError:
raise ValueError("`Y` must be specified in kwargs for qEHVI")
if Y is None:
raise ValueError("`Y` must not be None for qEHVI")
if ref_point is None:
raise ValueError("`ref_point` must not be None for qEHVI")
# get feasible points
if constraints is not None:
feas = torch.stack([c(Y) <= 0 for c in constraints], dim=-1).all(dim=-1)
Y = Y[feas]
obj = objective(Y)
alpha = kwargs.get("alpha", 0.0)
if alpha > 0:
partitioning = NondominatedPartitioning(
ref_point=torch.as_tensor(ref_point, dtype=Y.dtype, device=Y.device),
Expand All @@ -240,21 +243,21 @@ def get_acquisition_function(
X_pending=X_pending,
)
elif acquisition_function_name == "qNEHVI":
if "ref_point" not in kwargs:
raise ValueError("`ref_point` must be specified in kwargs for qNEHVI")
if ref_point is None:
raise ValueError("`ref_point` must not be None for qNEHVI")
return moo_monte_carlo.qNoisyExpectedHypervolumeImprovement(
model=model,
ref_point=kwargs["ref_point"],
ref_point=ref_point,
X_baseline=X_observed,
sampler=sampler,
objective=objective,
constraints=constraints,
eta=eta,
prune_baseline=kwargs.get("prune_baseline", True),
alpha=kwargs.get("alpha", 0.0),
prune_baseline=prune_baseline,
alpha=alpha,
X_pending=X_pending,
marginalize_dim=kwargs.get("marginalize_dim"),
cache_root=kwargs.get("cache_root", True),
marginalize_dim=marginalize_dim,
cache_root=cache_root,
)
raise NotImplementedError(
f"Unknown acquisition function {acquisition_function_name}"
Expand Down
12 changes: 3 additions & 9 deletions test/acquisition/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,6 @@
MCMultiOutputObjective,
monte_carlo as moo_monte_carlo,
)
from botorch.acquisition.multi_objective.monte_carlo import (
qExpectedHypervolumeImprovement,
)
from botorch.acquisition.objective import (
GenericMCObjective,
MCAcquisitionObjective,
Expand Down Expand Up @@ -643,11 +640,8 @@ def test_GetQUCB(self, mock_acqf):
self.assertEqual(sampler.seed, 2)
self.assertTrue(torch.equal(kwargs["X_pending"], self.X_pending))

@mock.patch(
f"{moo_monte_carlo.__name__}.qExpectedHypervolumeImprovement",
wraps=qExpectedHypervolumeImprovement,
)
def test_GetQEHVI(self, mock_acqf) -> None:
@mock.patch(f"{moo_monte_carlo.__name__}.qExpectedHypervolumeImprovement")
def test_GetQEHVI(self, mock_acqf):
# make sure ref_point is specified
with self.assertRaises(ValueError):
acqf = get_acquisition_function(
Expand Down Expand Up @@ -696,7 +690,7 @@ def test_GetQEHVI(self, mock_acqf) -> None:
ref_point=self.ref_point,
Y=self.Y,
)
self.assertIsInstance(acqf, qExpectedHypervolumeImprovement)
self.assertEqual(acqf, mock_acqf.return_value)
mock_acqf.assert_called_once_with(
constraints=None,
eta=1e-3,
Expand Down

0 comments on commit 0b04b96

Please sign in to comment.