Skip to content

Commit

Permalink
Clean up is_fully_bayesian (facebook#1992)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: facebook#1992

X-link: pytorch/botorch#2108

This attempts to clean up the usage of `is_fully_bayesian` and also separately treat fully Bayesian models from ensemble models.

The main changes in diff are to:
- Add an `_is_fully_bayesian` attribute to `Model`. This is `True` for fully Bayesian models that rely on Pyro/NUTS to be fitted (they need some special handling for fitting and `state_dict` loading/saving.
- Add an `_is_ensemble` attribute to `Model`. This indicates whether the model is a collection of multiple models that are stored in an additional batch dimension. This is hopefully a better classification, but I'm open to a different name here.
- Rename `FullyBayesianPosterior` to `GaussianMixturePosterior` since that is more descriptive and plays better with the other changes.

Reviewed By: esantorella

Differential Revision: D50884342

fbshipit-source-id: 0ba603416c1823026c4fdf2e445cefdf8036cda8
  • Loading branch information
David Eriksson authored and facebook-github-bot committed Nov 16, 2023
1 parent b3ee6c9 commit 3985791
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
4 changes: 2 additions & 2 deletions ax/models/torch/botorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
from botorch.models import ModelList
from botorch.models.model import Model
from botorch.utils.datasets import SupervisedDataset
from botorch.utils.transforms import is_fully_bayesian
from botorch.utils.transforms import is_ensemble
from torch import Tensor
from torch.nn import ModuleList # @manual

Expand Down Expand Up @@ -572,7 +572,7 @@ def get_feature_importances_from_botorch_model(
)
if ls.ndim == 2:
ls = ls.unsqueeze(0)
if is_fully_bayesian(m): # Take the median over the MCMC samples
if is_ensemble(m): # Take the median over the model batch dimension
ls = torch.quantile(ls, q=0.5, dim=0, keepdim=True)
lengthscales.append(ls)
lengthscales = torch.cat(lengthscales, dim=0)
Expand Down
4 changes: 2 additions & 2 deletions ax/models/torch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
from botorch.acquisition.utils import get_infeasible_cost
from botorch.models import ModelListGP, SingleTaskGP
from botorch.models.model import Model
from botorch.posteriors.fully_bayesian import FullyBayesianPosterior
from botorch.posteriors.fully_bayesian import GaussianMixturePosterior
from botorch.posteriors.gpytorch import GPyTorchPosterior
from botorch.posteriors.posterior_list import PosteriorList
from botorch.sampling.normal import IIDNormalSampler, SobolQMCNormalSampler
Expand Down Expand Up @@ -627,7 +627,7 @@ def predict_from_model(model: Model, X: Tensor) -> Tuple[Tensor, Tensor]:
with torch.no_grad():
# TODO: Allow Posterior to (optionally) return the full covariance matrix
posterior = model.posterior(X)
if isinstance(posterior, FullyBayesianPosterior):
if isinstance(posterior, GaussianMixturePosterior):
mean = posterior.mixture_mean.cpu().detach()
var = posterior.mixture_variance.cpu().detach().clamp_min(0)
elif isinstance(posterior, (GPyTorchPosterior, PosteriorList)):
Expand Down

0 comments on commit 3985791

Please sign in to comment.