From 0b04b96a3cf902198b9bd266a0594b8ab8dae802 Mon Sep 17 00:00:00 2001 From: Elizabeth Santorella Date: Wed, 16 Aug 2023 14:43:28 -0700 Subject: [PATCH] Even more fixes to unused kwargs (#1985) Summary: X-link: https://github.com/facebook/Ax/pull/1782 Pull Request resolved: https://github.com/pytorch/botorch/pull/1985 See previous diff Reviewed By: lena-kashtelyan Differential Revision: D48338443 fbshipit-source-id: 7cacc33ead6d5471855364eff910031640bc4708 --- botorch/acquisition/utils.py | 59 ++++++++++++++++++---------------- test/acquisition/test_utils.py | 12 ++----- 2 files changed, 34 insertions(+), 37 deletions(-) diff --git a/botorch/acquisition/utils.py b/botorch/acquisition/utils.py index 5364ef5fd5..0dd34ce931 100644 --- a/botorch/acquisition/utils.py +++ b/botorch/acquisition/utils.py @@ -49,7 +49,16 @@ def get_acquisition_function( eta: Optional[Union[Tensor, float]] = 1e-3, mc_samples: int = 512, seed: Optional[int] = None, - **kwargs, + *, + # optional parameters that are only needed for certain acquisition functions + tau: float = 1e-3, + prune_baseline: bool = True, + marginalize_dim: Optional[int] = None, + cache_root: bool = True, + beta: Optional[float] = None, + ref_point: Union[None, List[float], Tensor] = None, + Y: Optional[Tensor] = None, + alpha: float = 0.0, ) -> monte_carlo.MCAcquisitionFunction: r"""Convenience function for initializing botorch acquisition functions. @@ -149,7 +158,7 @@ def get_acquisition_function( objective=objective, posterior_transform=posterior_transform, X_pending=X_pending, - tau=kwargs.get("tau", 1e-3), + tau=tau, constraints=constraints, eta=eta, ) @@ -161,9 +170,9 @@ def get_acquisition_function( objective=objective, posterior_transform=posterior_transform, X_pending=X_pending, - prune_baseline=kwargs.get("prune_baseline", True), - marginalize_dim=kwargs.get("marginalize_dim"), - cache_root=kwargs.get("cache_root", True), + prune_baseline=prune_baseline, + marginalize_dim=marginalize_dim, + cache_root=cache_root, constraints=constraints, eta=eta, ) @@ -177,9 +186,9 @@ def get_acquisition_function( objective=objective, posterior_transform=posterior_transform, X_pending=X_pending, - prune_baseline=kwargs.get("prune_baseline", True), - marginalize_dim=kwargs.get("marginalize_dim"), - cache_root=kwargs.get("cache_root", True), + prune_baseline=prune_baseline, + marginalize_dim=marginalize_dim, + cache_root=cache_root, constraints=constraints, eta=eta, ) @@ -192,32 +201,26 @@ def get_acquisition_function( X_pending=X_pending, ) elif acquisition_function_name == "qUCB": - if "beta" not in kwargs: - raise ValueError("`beta` must be specified in kwargs for qUCB.") + if beta is None: + raise ValueError("`beta` must be not be None for qUCB.") return monte_carlo.qUpperConfidenceBound( model=model, - beta=kwargs["beta"], + beta=beta, sampler=sampler, objective=objective, posterior_transform=posterior_transform, X_pending=X_pending, ) elif acquisition_function_name == "qEHVI": - # pyre-fixme [16]: `Model` has no attribute `train_targets` - try: - ref_point = kwargs["ref_point"] - except KeyError: - raise ValueError("`ref_point` must be specified in kwargs for qEHVI") - try: - Y = kwargs["Y"] - except KeyError: - raise ValueError("`Y` must be specified in kwargs for qEHVI") + if Y is None: + raise ValueError("`Y` must not be None for qEHVI") + if ref_point is None: + raise ValueError("`ref_point` must not be None for qEHVI") # get feasible points if constraints is not None: feas = torch.stack([c(Y) <= 0 for c in constraints], dim=-1).all(dim=-1) Y = Y[feas] obj = objective(Y) - alpha = kwargs.get("alpha", 0.0) if alpha > 0: partitioning = NondominatedPartitioning( ref_point=torch.as_tensor(ref_point, dtype=Y.dtype, device=Y.device), @@ -240,21 +243,21 @@ def get_acquisition_function( X_pending=X_pending, ) elif acquisition_function_name == "qNEHVI": - if "ref_point" not in kwargs: - raise ValueError("`ref_point` must be specified in kwargs for qNEHVI") + if ref_point is None: + raise ValueError("`ref_point` must not be None for qNEHVI") return moo_monte_carlo.qNoisyExpectedHypervolumeImprovement( model=model, - ref_point=kwargs["ref_point"], + ref_point=ref_point, X_baseline=X_observed, sampler=sampler, objective=objective, constraints=constraints, eta=eta, - prune_baseline=kwargs.get("prune_baseline", True), - alpha=kwargs.get("alpha", 0.0), + prune_baseline=prune_baseline, + alpha=alpha, X_pending=X_pending, - marginalize_dim=kwargs.get("marginalize_dim"), - cache_root=kwargs.get("cache_root", True), + marginalize_dim=marginalize_dim, + cache_root=cache_root, ) raise NotImplementedError( f"Unknown acquisition function {acquisition_function_name}" diff --git a/test/acquisition/test_utils.py b/test/acquisition/test_utils.py index eea92ad050..0ec4b748ee 100644 --- a/test/acquisition/test_utils.py +++ b/test/acquisition/test_utils.py @@ -14,9 +14,6 @@ MCMultiOutputObjective, monte_carlo as moo_monte_carlo, ) -from botorch.acquisition.multi_objective.monte_carlo import ( - qExpectedHypervolumeImprovement, -) from botorch.acquisition.objective import ( GenericMCObjective, MCAcquisitionObjective, @@ -643,11 +640,8 @@ def test_GetQUCB(self, mock_acqf): self.assertEqual(sampler.seed, 2) self.assertTrue(torch.equal(kwargs["X_pending"], self.X_pending)) - @mock.patch( - f"{moo_monte_carlo.__name__}.qExpectedHypervolumeImprovement", - wraps=qExpectedHypervolumeImprovement, - ) - def test_GetQEHVI(self, mock_acqf) -> None: + @mock.patch(f"{moo_monte_carlo.__name__}.qExpectedHypervolumeImprovement") + def test_GetQEHVI(self, mock_acqf): # make sure ref_point is specified with self.assertRaises(ValueError): acqf = get_acquisition_function( @@ -696,7 +690,7 @@ def test_GetQEHVI(self, mock_acqf) -> None: ref_point=self.ref_point, Y=self.Y, ) - self.assertIsInstance(acqf, qExpectedHypervolumeImprovement) + self.assertEqual(acqf, mock_acqf.return_value) mock_acqf.assert_called_once_with( constraints=None, eta=1e-3,