Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update the default SingleTaskGP prior #2610

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions ax/models/tests/test_botorch_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

# pyre-strict

import math
from copy import deepcopy
from unittest import mock
from unittest.mock import Mock
Expand Down Expand Up @@ -66,9 +67,9 @@ def test_get_model(self) -> None:
self.assertIsInstance(model, SingleTaskGP)
self.assertIsInstance(model.likelihood, FixedNoiseGaussianLikelihood)
self.assertEqual(
model.covar_module.base_kernel.lengthscale_prior.concentration, 3.0
model.covar_module.lengthscale_prior.loc, math.log(2.0) / 2 + 2**0.5
)
self.assertEqual(model.covar_module.base_kernel.lengthscale_prior.rate, 6.0)
self.assertEqual(model.covar_module.lengthscale_prior.scale, 3**0.5)
model = _get_model(X=x, Y=y, Yvar=unknown_var, task_feature=1)
self.assertIs(type(model), MultiTaskGP) # Don't accept subclasses.
self.assertIsInstance(model.likelihood, GaussianLikelihood)
Expand Down
34 changes: 13 additions & 21 deletions ax/models/tests/test_botorch_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from botorch.models.transforms.input import Warp
from botorch.utils.datasets import SupervisedDataset
from botorch.utils.objective import get_objective_weights_transform
from gpytorch.kernels.constant_kernel import ConstantKernel
from gpytorch.likelihoods import _GaussianLikelihoodBase
from gpytorch.likelihoods.gaussian_likelihood import FixedNoiseGaussianLikelihood
from gpytorch.mlls import ExactMarginalLogLikelihood, LeaveOneOutPseudoLikelihood
Expand Down Expand Up @@ -558,19 +559,12 @@ def test_BotorchModel(

# Test loading state dict
true_state_dict = {
"mean_module.raw_constant": 3.5004,
"covar_module.raw_outputscale": 2.2438,
"covar_module.base_kernel.raw_lengthscale": [
[-0.9274, -0.9274, -0.9274]
],
"covar_module.base_kernel.raw_lengthscale_constraint.lower_bound": 0.1,
"covar_module.base_kernel.raw_lengthscale_constraint.upper_bound": 2.5,
"covar_module.base_kernel.lengthscale_prior.concentration": 3.0,
"covar_module.base_kernel.lengthscale_prior.rate": 6.0,
"covar_module.raw_outputscale_constraint.lower_bound": 0.2,
"covar_module.raw_outputscale_constraint.upper_bound": 2.6,
"covar_module.outputscale_prior.concentration": 2.0,
"covar_module.outputscale_prior.rate": 0.15,
"mean_module.raw_constant": 1.0,
"covar_module.raw_lengthscale": [[0.3548, 0.3548, 0.3548]],
"covar_module.lengthscale_prior._transformed_loc": 1.9635,
"covar_module.lengthscale_prior._transformed_scale": 1.7321,
"covar_module.raw_lengthscale_constraint.lower_bound": 0.0250,
"covar_module.raw_lengthscale_constraint.upper_bound": float("inf"),
}
true_state_dict = {
key: torch.tensor(val, **tkwargs)
Expand All @@ -591,8 +585,7 @@ def test_BotorchModel(

# Test for some change in model parameters & buffer for refit_model=True
true_state_dict["mean_module.raw_constant"] += 0.1
true_state_dict["covar_module.raw_outputscale"] += 0.1
true_state_dict["covar_module.base_kernel.raw_lengthscale"] += 0.1
true_state_dict["covar_module.raw_lengthscale"] += 0.1
model = get_and_fit_model(
Xs=Xs1,
Ys=Ys1,
Expand Down Expand Up @@ -774,17 +767,16 @@ def test_get_feature_importances_from_botorch_model(self) -> None:
train_X = torch.rand(5, 3, **tkwargs)
train_Y = train_X.sum(dim=-1, keepdim=True)
simple_gp = SingleTaskGP(train_X=train_X, train_Y=train_Y)
simple_gp.covar_module.base_kernel.lengthscale = torch.tensor(
[1, 3, 5], **tkwargs
)
simple_gp.covar_module.lengthscale = torch.tensor([1, 3, 5], **tkwargs)
importances = get_feature_importances_from_botorch_model(simple_gp)
self.assertTrue(np.allclose(importances, np.array([15 / 23, 5 / 23, 3 / 23])))
self.assertEqual(importances.shape, (1, 1, 3))
# Model with no base kernel
simple_gp.covar_module.base_kernel = None
# Model with kernel that has no lengthscales
simple_gp.covar_module = ConstantKernel()
with self.assertRaisesRegex(
NotImplementedError,
"Failed to extract lengthscales from `m.covar_module.base_kernel`",
"Failed to extract lengthscales from `m.covar_module` and "
"`m.covar_module.base_kernel`",
):
get_feature_importances_from_botorch_model(simple_gp)

Expand Down
10 changes: 8 additions & 2 deletions ax/models/torch/botorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,15 +562,21 @@ def get_feature_importances_from_botorch_model(
lengthscales = []
for m in models:
try:
ls = m.covar_module.base_kernel.lengthscale
# this can be a ModelList of a SAAS and STGP, so this is a necessary way
# to get the lengthscale
if hasattr(m.covar_module, "base_kernel"):
ls = m.covar_module.base_kernel.lengthscale
else:
ls = m.covar_module.lengthscale
except AttributeError:
ls = None
if ls is None or ls.shape[-1] != m.train_inputs[0].shape[-1]:
# TODO: We could potentially set the feature importances to NaN in this
# case, but this require knowing the batch dimension of this model.
# Consider supporting in the future.
raise NotImplementedError(
"Failed to extract lengthscales from `m.covar_module.base_kernel`"
"Failed to extract lengthscales from `m.covar_module` "
"and `m.covar_module.base_kernel`"
)
if ls.ndim == 2:
ls = ls.unsqueeze(0)
Expand Down
9 changes: 5 additions & 4 deletions ax/models/torch/tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,8 +634,8 @@ def test_feature_importances(self) -> None:
self.assertEqual(importances.shape, (1, 1, 3))
saas_model = deepcopy(model.surrogate.model)
else:
model.surrogate.model.covar_module.base_kernel.lengthscale = (
torch.tensor([1, 2, 3], **self.tkwargs)
model.surrogate.model.covar_module.lengthscale = torch.tensor(
[1, 2, 3], **self.tkwargs
)
importances = model.feature_importances()
self.assertTrue(
Expand All @@ -658,11 +658,12 @@ def test_feature_importances(self) -> None:
)
self.assertEqual(importances.shape, (2, 1, 3))
# Add model we don't support
vanilla_model.covar_module.base_kernel = None
vanilla_model.covar_module = None
model.surrogate._model = vanilla_model # pyre-ignore
with self.assertRaisesRegex(
NotImplementedError,
"Failed to extract lengthscales from `m.covar_module.base_kernel`",
"Failed to extract lengthscales from `m.covar_module` "
"and `m.covar_module.base_kernel`",
):
model.feature_importances()
# Test model is None
Expand Down
5 changes: 4 additions & 1 deletion ax/plot/tests/test_feature_importances.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,10 @@ def get_sensitivity_values(ax_model: ModelBridge) -> Dict:

Returns map {'metric_name': {'parameter_name': sensitivity_value}}
"""
ls = ax_model.model.model.covar_module.base_kernel.lengthscale.squeeze()
if hasattr(ax_model.model.model.covar_module, "outputscale"):
ls = ax_model.model.model.covar_module.base_kernel.lengthscale.squeeze()
else:
ls = ax_model.model.model.covar_module.lengthscale.squeeze()
if len(ls.shape) > 1:
ls = ls.mean(dim=0)
# pyre-fixme[16]: `float` has no attribute `detach`.
Expand Down
16 changes: 12 additions & 4 deletions ax/utils/sensitivity/derivative_gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,12 @@ def get_KxX_dx(gp: Model, x: Tensor, kernel_type: str = "rbf") -> Tensor:
D = X.shape[1]
N = X.shape[0]
n = x.shape[0]
lengthscale = gp.covar_module.base_kernel.lengthscale.detach()
if hasattr(gp.covar_module, "outputscale"):
lengthscale = gp.covar_module.base_kernel.lengthscale.detach()
sigma_f = gp.covar_module.outputscale.detach()
else:
lengthscale = gp.covar_module.lengthscale.detach()
sigma_f = 1.0
if kernel_type == "rbf":
K_xX = gp.covar_module(x, X).evaluate()
part1 = -torch.eye(D, device=x.device, dtype=x.dtype) / lengthscale**2
Expand All @@ -52,7 +57,6 @@ def get_KxX_dx(gp: Model, x: Tensor, kernel_type: str = "rbf") -> Tensor:
constant_component = (-5.0 / 3.0) * distance - (5.0 * math.sqrt(5.0) / 3.0) * (
distance**2
)
sigma_f = gp.covar_module.outputscale.detach()
part1 = torch.eye(D, device=lengthscale.device) / lengthscale
part2 = (x1_.view(n, 1, D) - x2_.view(1, N, D)) / distance.unsqueeze(2)
total_k = sigma_f * constant_component * exp_component
Expand All @@ -70,8 +74,12 @@ def get_Kxx_dx2(gp: Model, kernel_type: str = "rbf") -> Tensor:
"""
X = gp.train_inputs[0]
D = X.shape[1]
lengthscale = gp.covar_module.base_kernel.lengthscale.detach()
sigma_f = gp.covar_module.outputscale.detach()
if hasattr(gp.covar_module, "outputscale"):
lengthscale = gp.covar_module.base_kernel.lengthscale.detach()
sigma_f = gp.covar_module.outputscale.detach()
else:
lengthscale = gp.covar_module.lengthscale.detach()
sigma_f = 1.0
res = (torch.eye(D, device=lengthscale.device) / lengthscale**2) * sigma_f
if kernel_type == "rbf":
return res
Expand Down
Loading