diff --git a/ax/models/torch/botorch.py b/ax/models/torch/botorch.py index 054c8a7abe8..7cc8327cb1a 100644 --- a/ax/models/torch/botorch.py +++ b/ax/models/torch/botorch.py @@ -41,7 +41,7 @@ from botorch.models import ModelList from botorch.models.model import Model from botorch.utils.datasets import SupervisedDataset -from botorch.utils.transforms import is_fully_bayesian +from botorch.utils.transforms import is_ensemble from torch import Tensor from torch.nn import ModuleList # @manual @@ -572,7 +572,7 @@ def get_feature_importances_from_botorch_model( ) if ls.ndim == 2: ls = ls.unsqueeze(0) - if is_fully_bayesian(m): # Take the median over the MCMC samples + if is_ensemble(m): # Take the median over the model batch dimension ls = torch.quantile(ls, q=0.5, dim=0, keepdim=True) lengthscales.append(ls) lengthscales = torch.cat(lengthscales, dim=0) diff --git a/ax/models/torch/utils.py b/ax/models/torch/utils.py index 5e0af0038e3..a8a58f38e46 100644 --- a/ax/models/torch/utils.py +++ b/ax/models/torch/utils.py @@ -51,7 +51,7 @@ from botorch.acquisition.utils import get_infeasible_cost from botorch.models import ModelListGP, SingleTaskGP from botorch.models.model import Model -from botorch.posteriors.fully_bayesian import FullyBayesianPosterior +from botorch.posteriors.fully_bayesian import GaussianMixturePosterior from botorch.posteriors.gpytorch import GPyTorchPosterior from botorch.posteriors.posterior_list import PosteriorList from botorch.sampling.normal import IIDNormalSampler, SobolQMCNormalSampler @@ -627,7 +627,7 @@ def predict_from_model(model: Model, X: Tensor) -> Tuple[Tensor, Tensor]: with torch.no_grad(): # TODO: Allow Posterior to (optionally) return the full covariance matrix posterior = model.posterior(X) - if isinstance(posterior, FullyBayesianPosterior): + if isinstance(posterior, GaussianMixturePosterior): mean = posterior.mixture_mean.cpu().detach() var = posterior.mixture_variance.cpu().detach().clamp_min(0) elif isinstance(posterior, (GPyTorchPosterior, PosteriorList)):