Skip to content

Commit

Permalink
Remove unnecessary pyre-fixme
Browse files Browse the repository at this point in the history
Summary: The previous diff clarified some types, so now this pyre-ignore is unnecessary.

Differential Revision: D58417021
  • Loading branch information
saitcakmak authored and facebook-github-bot committed Jun 11, 2024
1 parent 9eaf705 commit c16bddb
Showing 1 changed file with 9 additions and 12 deletions.
21 changes: 9 additions & 12 deletions ax/models/tests/test_botorch_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from unittest import mock

import numpy as np

import torch
from ax.core.search_space import SearchSpaceDigest
from ax.exceptions.core import DataRequiredError
Expand All @@ -29,6 +28,7 @@
from ax.models.torch.utils import sample_simplex
from ax.models.torch_base import TorchOptConfig
from ax.utils.common.testutils import TestCase
from ax.utils.common.typeutils import not_none
from ax.utils.testing.mock import fast_botorch_optimize
from ax.utils.testing.torch_stubs import get_torch_test_data
from botorch.acquisition.utils import get_infeasible_cost
Expand Down Expand Up @@ -120,7 +120,7 @@ def test_fixed_prior_BotorchModel(
Xs2, Ys2, Yvars2, _, _, _, _ = get_torch_test_data(
dtype=dtype, cuda=cuda, constant_noise=True
)
kwargs = {
kwargs: Dict[str, Any] = {
"prior": {
"covar_module_prior": {
"lengthscale_prior": GammaPrior(6.0, 3.0),
Expand All @@ -131,7 +131,7 @@ def test_fixed_prior_BotorchModel(
"eta": 0.6,
}
}
model = BotorchModel(**kwargs) # pyre-ignore [6]
model = BotorchModel(**kwargs)
datasets = [
SupervisedDataset(
X=Xs1[0],
Expand Down Expand Up @@ -329,8 +329,7 @@ def test_BotorchModel(

if not use_input_warping:
expected_train_inputs = expected_train_inputs.unsqueeze(0).expand(
torch.Size([2])
+ Xs1[0].shape # pyre-fixme[58]: Unsupported operand
2, *Xs1[0].shape
)
expected_train_targets = torch.cat(Ys1 + Ys2, dim=-1).permute(1, 0)
else:
Expand Down Expand Up @@ -475,10 +474,8 @@ def test_BotorchModel(
)

# test get_rounding_func
dummy_rounding = get_rounding_func(rounding_func=dummy_func)
dummy_rounding = not_none(get_rounding_func(rounding_func=dummy_func))
X_temp = torch.rand(1, 2, 3, 4)
# pyre-fixme[29]: `Optional[typing.Callable[[torch._tensor.Tensor],
# torch._tensor.Tensor]]` is not a function.
self.assertTrue(torch.equal(X_temp, dummy_rounding(X_temp)))

# Check best point selection
Expand Down Expand Up @@ -773,18 +770,18 @@ def test_botorchmodel_raises_when_no_data(self) -> None:
)

def test_get_feature_importances_from_botorch_model(self) -> None:
tkwargs = {"dtype": torch.double}
train_X = torch.rand(5, 3, **tkwargs) # pyre-ignore [6]
tkwargs: Dict[str, Any] = {"dtype": torch.double}
train_X = torch.rand(5, 3, **tkwargs)
train_Y = train_X.sum(dim=-1, keepdim=True)
simple_gp = SingleTaskGP(train_X=train_X, train_Y=train_Y)
simple_gp.covar_module.base_kernel.lengthscale = torch.tensor(
[1, 3, 5], **tkwargs # pyre-ignore [6]
[1, 3, 5], **tkwargs
)
importances = get_feature_importances_from_botorch_model(simple_gp)
self.assertTrue(np.allclose(importances, np.array([15 / 23, 5 / 23, 3 / 23])))
self.assertEqual(importances.shape, (1, 1, 3))
# Model with no base kernel
simple_gp.covar_module.base_kernel = None # pyre-ignore [16]
simple_gp.covar_module.base_kernel = None
with self.assertRaisesRegex(
NotImplementedError,
"Failed to extract lengthscales from `m.covar_module.base_kernel`",
Expand Down

0 comments on commit c16bddb

Please sign in to comment.