diff --git a/ax/models/tests/test_botorch_model.py b/ax/models/tests/test_botorch_model.py index f5a75551598..9e42efd6d14 100644 --- a/ax/models/tests/test_botorch_model.py +++ b/ax/models/tests/test_botorch_model.py @@ -12,7 +12,6 @@ from unittest import mock import numpy as np - import torch from ax.core.search_space import SearchSpaceDigest from ax.exceptions.core import DataRequiredError @@ -29,6 +28,7 @@ from ax.models.torch.utils import sample_simplex from ax.models.torch_base import TorchOptConfig from ax.utils.common.testutils import TestCase +from ax.utils.common.typeutils import not_none from ax.utils.testing.mock import fast_botorch_optimize from ax.utils.testing.torch_stubs import get_torch_test_data from botorch.acquisition.utils import get_infeasible_cost @@ -120,7 +120,7 @@ def test_fixed_prior_BotorchModel( Xs2, Ys2, Yvars2, _, _, _, _ = get_torch_test_data( dtype=dtype, cuda=cuda, constant_noise=True ) - kwargs = { + kwargs: Dict[str, Any] = { "prior": { "covar_module_prior": { "lengthscale_prior": GammaPrior(6.0, 3.0), @@ -131,7 +131,7 @@ def test_fixed_prior_BotorchModel( "eta": 0.6, } } - model = BotorchModel(**kwargs) # pyre-ignore [6] + model = BotorchModel(**kwargs) datasets = [ SupervisedDataset( X=Xs1[0], @@ -329,8 +329,7 @@ def test_BotorchModel( if not use_input_warping: expected_train_inputs = expected_train_inputs.unsqueeze(0).expand( - torch.Size([2]) - + Xs1[0].shape # pyre-fixme[58]: Unsupported operand + 2, *Xs1[0].shape ) expected_train_targets = torch.cat(Ys1 + Ys2, dim=-1).permute(1, 0) else: @@ -475,10 +474,8 @@ def test_BotorchModel( ) # test get_rounding_func - dummy_rounding = get_rounding_func(rounding_func=dummy_func) + dummy_rounding = not_none(get_rounding_func(rounding_func=dummy_func)) X_temp = torch.rand(1, 2, 3, 4) - # pyre-fixme[29]: `Optional[typing.Callable[[torch._tensor.Tensor], - # torch._tensor.Tensor]]` is not a function. self.assertTrue(torch.equal(X_temp, dummy_rounding(X_temp))) # Check best point selection @@ -773,18 +770,18 @@ def test_botorchmodel_raises_when_no_data(self) -> None: ) def test_get_feature_importances_from_botorch_model(self) -> None: - tkwargs = {"dtype": torch.double} - train_X = torch.rand(5, 3, **tkwargs) # pyre-ignore [6] + tkwargs: Dict[str, Any] = {"dtype": torch.double} + train_X = torch.rand(5, 3, **tkwargs) train_Y = train_X.sum(dim=-1, keepdim=True) simple_gp = SingleTaskGP(train_X=train_X, train_Y=train_Y) simple_gp.covar_module.base_kernel.lengthscale = torch.tensor( - [1, 3, 5], **tkwargs # pyre-ignore [6] + [1, 3, 5], **tkwargs ) importances = get_feature_importances_from_botorch_model(simple_gp) self.assertTrue(np.allclose(importances, np.array([15 / 23, 5 / 23, 3 / 23]))) self.assertEqual(importances.shape, (1, 1, 3)) # Model with no base kernel - simple_gp.covar_module.base_kernel = None # pyre-ignore [16] + simple_gp.covar_module.base_kernel = None with self.assertRaisesRegex( NotImplementedError, "Failed to extract lengthscales from `m.covar_module.base_kernel`",