Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove unnecessary pyre-fixme #2513

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 9 additions & 12 deletions ax/models/tests/test_botorch_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from unittest import mock

import numpy as np

import torch
from ax.core.search_space import SearchSpaceDigest
from ax.exceptions.core import DataRequiredError
Expand All @@ -29,6 +28,7 @@
from ax.models.torch.utils import sample_simplex
from ax.models.torch_base import TorchOptConfig
from ax.utils.common.testutils import TestCase
from ax.utils.common.typeutils import not_none
from ax.utils.testing.mock import fast_botorch_optimize
from ax.utils.testing.torch_stubs import get_torch_test_data
from botorch.acquisition.utils import get_infeasible_cost
Expand Down Expand Up @@ -120,7 +120,7 @@ def test_fixed_prior_BotorchModel(
Xs2, Ys2, Yvars2, _, _, _, _ = get_torch_test_data(
dtype=dtype, cuda=cuda, constant_noise=True
)
kwargs = {
kwargs: Dict[str, Any] = {
"prior": {
"covar_module_prior": {
"lengthscale_prior": GammaPrior(6.0, 3.0),
Expand All @@ -131,7 +131,7 @@ def test_fixed_prior_BotorchModel(
"eta": 0.6,
}
}
model = BotorchModel(**kwargs) # pyre-ignore [6]
model = BotorchModel(**kwargs)
datasets = [
SupervisedDataset(
X=Xs1[0],
Expand Down Expand Up @@ -329,8 +329,7 @@ def test_BotorchModel(

if not use_input_warping:
expected_train_inputs = expected_train_inputs.unsqueeze(0).expand(
torch.Size([2])
+ Xs1[0].shape # pyre-fixme[58]: Unsupported operand
2, *Xs1[0].shape
)
expected_train_targets = torch.cat(Ys1 + Ys2, dim=-1).permute(1, 0)
else:
Expand Down Expand Up @@ -475,10 +474,8 @@ def test_BotorchModel(
)

# test get_rounding_func
dummy_rounding = get_rounding_func(rounding_func=dummy_func)
dummy_rounding = not_none(get_rounding_func(rounding_func=dummy_func))
X_temp = torch.rand(1, 2, 3, 4)
# pyre-fixme[29]: `Optional[typing.Callable[[torch._tensor.Tensor],
# torch._tensor.Tensor]]` is not a function.
self.assertTrue(torch.equal(X_temp, dummy_rounding(X_temp)))

# Check best point selection
Expand Down Expand Up @@ -773,18 +770,18 @@ def test_botorchmodel_raises_when_no_data(self) -> None:
)

def test_get_feature_importances_from_botorch_model(self) -> None:
tkwargs = {"dtype": torch.double}
train_X = torch.rand(5, 3, **tkwargs) # pyre-ignore [6]
tkwargs: Dict[str, Any] = {"dtype": torch.double}
train_X = torch.rand(5, 3, **tkwargs)
train_Y = train_X.sum(dim=-1, keepdim=True)
simple_gp = SingleTaskGP(train_X=train_X, train_Y=train_Y)
simple_gp.covar_module.base_kernel.lengthscale = torch.tensor(
[1, 3, 5], **tkwargs # pyre-ignore [6]
[1, 3, 5], **tkwargs
)
importances = get_feature_importances_from_botorch_model(simple_gp)
self.assertTrue(np.allclose(importances, np.array([15 / 23, 5 / 23, 3 / 23])))
self.assertEqual(importances.shape, (1, 1, 3))
# Model with no base kernel
simple_gp.covar_module.base_kernel = None # pyre-ignore [16]
simple_gp.covar_module.base_kernel = None
with self.assertRaisesRegex(
NotImplementedError,
"Failed to extract lengthscales from `m.covar_module.base_kernel`",
Expand Down