From 9f8e4bc2d0bff90bd9b66f35b4430a7cf325e3d4 Mon Sep 17 00:00:00 2001 From: Sam Daulton Date: Thu, 13 Jun 2024 16:39:21 -0700 Subject: [PATCH] add option for using posterior predictive in cross-validation (#2517) Summary: Pull Request resolved: https://github.com/facebook/Ax/pull/2517 see title. This change is particularly important for model selection using the NLL if we have noisy observations. Using the posterior over the true function and not the noisy observations gives quite misleading results about model calibration. I also think that predicted vs actual plots from LOOCV are insightful when using the posterior predictive when the observations are noisy. We may want to consider adding observation_noise to `predict`, but we can do that in a follow-up. Reviewed By: Balandat Differential Revision: D58227612 --- ax/modelbridge/base.py | 6 ++++ ax/modelbridge/cross_validation.py | 23 +++++++++++++-- ax/modelbridge/discrete.py | 7 ++++- ax/modelbridge/map_torch.py | 2 ++ ax/modelbridge/random.py | 1 + ax/modelbridge/tests/test_base_modelbridge.py | 15 ++++++++++ ax/modelbridge/tests/test_cross_validation.py | 28 ++++++++++++++++++- ax/modelbridge/torch.py | 2 ++ ax/models/discrete_base.py | 4 +++ ax/models/torch/botorch.py | 8 ++++-- ax/models/torch/botorch_modular/model.py | 14 ++++++++-- ax/models/torch/botorch_modular/surrogate.py | 11 ++++++-- ax/models/torch/randomforest.py | 1 + ax/models/torch/tests/test_model.py | 14 ++++++---- ax/models/torch/tests/test_surrogate.py | 15 ++++++++-- ax/models/torch/utils.py | 9 ++++-- ax/models/torch_base.py | 4 +++ 17 files changed, 142 insertions(+), 22 deletions(-) diff --git a/ax/modelbridge/base.py b/ax/modelbridge/base.py index b9bfa26aa62..43feb6de6b0 100644 --- a/ax/modelbridge/base.py +++ b/ax/modelbridge/base.py @@ -907,12 +907,16 @@ def cross_validate( self, cv_training_data: List[Observation], cv_test_points: List[ObservationFeatures], + use_posterior_predictive: bool = False, ) -> List[ObservationData]: """Make a set of cross-validation predictions. Args: cv_training_data: The training data to use for cross validation. cv_test_points: The test points at which predictions will be made. + use_posterior_predictive: A boolean indicating if the predictions + should be from the posterior predictive (i.e. including + observation noise). Returns: A list of predictions at the test points. @@ -936,6 +940,7 @@ def cross_validate( search_space=search_space, cv_training_data=cv_training_data, cv_test_points=cv_test_points, + use_posterior_predictive=use_posterior_predictive, ) # Apply reverse transforms, in reverse order cv_test_observations = [ @@ -952,6 +957,7 @@ def _cross_validate( search_space: SearchSpace, cv_training_data: List[Observation], cv_test_points: List[ObservationFeatures], + use_posterior_predictive: bool = False, ) -> List[ObservationData]: """Apply the terminal transform, make predictions on the test points, and reverse terminal transform on the results. diff --git a/ax/modelbridge/cross_validation.py b/ax/modelbridge/cross_validation.py index a07ec1475f7..b5f9bcc820e 100644 --- a/ax/modelbridge/cross_validation.py +++ b/ax/modelbridge/cross_validation.py @@ -78,6 +78,7 @@ def cross_validate( # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. test_selector: Optional[Callable] = None, untransform: bool = True, + use_posterior_predictive: bool = False, ) -> List[CVResult]: """Cross validation for model predictions. @@ -112,6 +113,12 @@ def cross_validate( of the original data in regions where outliers have been removed, we have found it to better reflect the how good the model used for candidate generation actually is. + use_posterior_predictive: A boolean indicating if the predictions + should be from the posterior predictive (i.e. including + observation noise). Note: we should reconsider how we compute + cross-validation and model fit metrics where there is non- + Gaussian noise. + Returns: A CVResult for each observation in the training data. """ @@ -162,7 +169,9 @@ def cross_validate( # Make the prediction if untransform: cv_test_predictions = model.cross_validate( - cv_training_data=cv_training_data, cv_test_points=cv_test_points + cv_training_data=cv_training_data, + cv_test_points=cv_test_points, + use_posterior_predictive=use_posterior_predictive, ) else: # Get test predictions in transformed space @@ -186,6 +195,7 @@ def cross_validate( search_space=search_space, cv_training_data=cv_training_data, cv_test_points=cv_test_points, + use_posterior_predictive=use_posterior_predictive, ) # Get test observations in transformed space cv_test_data = deepcopy(cv_test_data) @@ -197,7 +207,9 @@ def cross_validate( return result -def cross_validate_by_trial(model: ModelBridge, trial: int = -1) -> List[CVResult]: +def cross_validate_by_trial( + model: ModelBridge, trial: int = -1, use_posterior_predictive: bool = False +) -> List[CVResult]: """Cross validation for model predictions on a particular trial. Uses all of the data up until the specified trial to predict each of the @@ -206,6 +218,9 @@ def cross_validate_by_trial(model: ModelBridge, trial: int = -1) -> List[CVResul Args: model: Fitted model (ModelBridge) to cross validate. trial: Trial for which predictions are evaluated. + use_posterior_predictive: A boolean indicating if the predictions + should be from the posterior predictive (i.e. including + observation noise). Returns: A CVResult for each observation in the training data. @@ -241,7 +256,9 @@ def cross_validate_by_trial(model: ModelBridge, trial: int = -1) -> List[CVResul cv_test_data.append(obs) # Make the prediction cv_test_predictions = model.cross_validate( - cv_training_data=cv_training_data, cv_test_points=cv_test_points + cv_training_data=cv_training_data, + cv_test_points=cv_test_points, + use_posterior_predictive=use_posterior_predictive, ) # Form CVResult objects result = [ diff --git a/ax/modelbridge/discrete.py b/ax/modelbridge/discrete.py index 3e4436d8e66..e3015a87f6e 100644 --- a/ax/modelbridge/discrete.py +++ b/ax/modelbridge/discrete.py @@ -191,6 +191,7 @@ def _cross_validate( search_space: SearchSpace, cv_training_data: List[Observation], cv_test_points: List[ObservationFeatures], + use_posterior_predictive: bool = False, ) -> List[ObservationData]: """Make predictions at cv_test_points using only the data in obs_feats and obs_data. @@ -208,7 +209,11 @@ def _cross_validate( ] # Use the model to do the cross validation f_test, cov_test = self.model.cross_validate( - Xs_train=Xs_train, Ys_train=Ys_train, Yvars_train=Yvars_train, X_test=X_test + Xs_train=Xs_train, + Ys_train=Ys_train, + Yvars_train=Yvars_train, + X_test=X_test, + use_posterior_predictive=use_posterior_predictive, ) # Convert array back to ObservationData return array_to_observation_data(f=f_test, cov=cov_test, outcomes=self.outcomes) diff --git a/ax/modelbridge/map_torch.py b/ax/modelbridge/map_torch.py index 61cd56b9834..da51c5c18e0 100644 --- a/ax/modelbridge/map_torch.py +++ b/ax/modelbridge/map_torch.py @@ -279,6 +279,7 @@ def _cross_validate( cv_training_data: List[Observation], cv_test_points: List[ObservationFeatures], parameters: Optional[List[str]] = None, + use_posterior_predictive: bool = False, **kwargs: Any, ) -> List[ObservationData]: """Make predictions at cv_test_points using only the data in obs_feats @@ -294,6 +295,7 @@ def _cross_validate( cv_training_data=cv_training_data, cv_test_points=cv_test_points, parameters=parameters, # we pass the map_keys too by default + use_posterior_predictive=use_posterior_predictive, **kwargs, ) observation_features, observation_data = separate_observations(cv_training_data) diff --git a/ax/modelbridge/random.py b/ax/modelbridge/random.py index ca1dc4b7215..3d617047814 100644 --- a/ax/modelbridge/random.py +++ b/ax/modelbridge/random.py @@ -100,6 +100,7 @@ def _cross_validate( search_space: SearchSpace, cv_training_data: List[Observation], cv_test_points: List[ObservationFeatures], + use_posterior_predictive: bool = False, ) -> List[ObservationData]: raise NotImplementedError diff --git a/ax/modelbridge/tests/test_base_modelbridge.py b/ax/modelbridge/tests/test_base_modelbridge.py index fe7b50eb0de..7f858dd0581 100644 --- a/ax/modelbridge/tests/test_base_modelbridge.py +++ b/ax/modelbridge/tests/test_base_modelbridge.py @@ -267,9 +267,24 @@ def warn_and_return_mock_obs( search_space=SearchSpace([FixedParameter("x", ParameterType.FLOAT, 8.0)]), cv_training_data=[get_observation2trans()], cv_test_points=[get_observation1().features], # untransformed after + use_posterior_predictive=False, ) self.assertTrue(cv_predictions == [get_observation1().data]) + # Test use_posterior_predictive in CV + modelbridge.cross_validate( + cv_training_data=cv_training_data, + cv_test_points=cv_test_points, + use_posterior_predictive=True, + ) + + modelbridge._cross_validate.assert_called_with( + search_space=SearchSpace([FixedParameter("x", ParameterType.FLOAT, 8.0)]), + cv_training_data=[get_observation2trans()], + cv_test_points=[get_observation1().features], # untransformed after + use_posterior_predictive=True, + ) + # Test stored training data obs = modelbridge.get_training_data() self.assertTrue(obs == [get_observation1(), get_observation2()]) diff --git a/ax/modelbridge/tests/test_cross_validation.py b/ax/modelbridge/tests/test_cross_validation.py index 2d0ef3715fc..1484d48da20 100644 --- a/ax/modelbridge/tests/test_cross_validation.py +++ b/ax/modelbridge/tests/test_cross_validation.py @@ -201,7 +201,9 @@ def test_CrossValidate(self) -> None: # Test ModelBridge._cross_validate was called correctly. z = ma._cross_validate.mock_calls self.assertEqual(len(z), 3) - ma._cross_validate.assert_called_with(**self.transformed_cv_input_dict) + ma._cross_validate.assert_called_with( + **self.transformed_cv_input_dict, use_posterior_predictive=False + ) # Test selector @@ -219,6 +221,21 @@ def test_selector(obs: Observation) -> bool: ) self.assertTrue(np.array_equal(sorted(all_test), np.array([2.0, 2.0, 3.0]))) + # test observation noise + for untransform in (True, False): + result = cross_validate( + model=ma, + folds=-1, + use_posterior_predictive=True, + untransform=untransform, + ) + if untransform: + mock_cv = ma.cross_validate + else: + mock_cv = ma._cross_validate + call_kwargs = mock_cv.mock_calls[-1].kwargs + self.assertTrue(call_kwargs["use_posterior_predictive"]) + def test_CrossValidateByTrial(self) -> None: # With only 1 trial ma = mock.MagicMock() @@ -261,6 +278,15 @@ def test_CrossValidateByTrial(self) -> None: self.assertEqual(len(result), 1) self.assertEqual(result[0].observed.features.trial_index, 2) + mock_cv = ma.cross_validate + call_kwargs = mock_cv.mock_calls[-1].kwargs + self.assertFalse(call_kwargs["use_posterior_predictive"]) + + # test observation noise + result = cross_validate_by_trial(model=ma, use_posterior_predictive=True) + call_kwargs = mock_cv.mock_calls[-1].kwargs + self.assertTrue(call_kwargs["use_posterior_predictive"]) + def test_cross_validate_gives_a_useful_error_for_model_with_no_data(self) -> None: exp = get_branin_experiment() sobol = Models.SOBOL(experiment=exp, search_space=exp.search_space) diff --git a/ax/modelbridge/torch.py b/ax/modelbridge/torch.py index f72d75464f6..ccc7acb9f9c 100644 --- a/ax/modelbridge/torch.py +++ b/ax/modelbridge/torch.py @@ -428,6 +428,7 @@ def _cross_validate( cv_training_data: List[Observation], cv_test_points: List[ObservationFeatures], parameters: Optional[List[str]] = None, + use_posterior_predictive: bool = False, **kwargs: Any, ) -> List[ObservationData]: """Make predictions at cv_test_points using only the data in obs_feats @@ -453,6 +454,7 @@ def _cross_validate( datasets=datasets, X_test=torch.as_tensor(X_test, dtype=self.dtype, device=self.device), search_space_digest=search_space_digest, + use_posterior_predictive=use_posterior_predictive, **kwargs, ) # Convert array back to ObservationData diff --git a/ax/models/discrete_base.py b/ax/models/discrete_base.py index fd007e4972d..6cda9d46b81 100644 --- a/ax/models/discrete_base.py +++ b/ax/models/discrete_base.py @@ -102,6 +102,7 @@ def cross_validate( Ys_train: List[List[float]], Yvars_train: List[List[float]], X_test: List[TParamValueList], + use_posterior_predictive: bool = False, ) -> Tuple[np.ndarray, np.ndarray]: """Do cross validation with the given training and test sets. @@ -116,6 +117,9 @@ def cross_validate( each outcome. Yvars_train: The variances of each entry in Ys, same shape. X_test: List of the j parameterizations at which to make predictions. + use_posterior_predictive: A boolean indicating if the predictions + should be from the posterior predictive (i.e. including + observation noise). Returns: 2-element tuple containing diff --git a/ax/models/torch/botorch.py b/ax/models/torch/botorch.py index 5fc9113f488..11329ab8b40 100644 --- a/ax/models/torch/botorch.py +++ b/ax/models/torch/botorch.py @@ -64,7 +64,7 @@ ], Model, ] -TModelPredictor = Callable[[Model, Tensor], Tuple[Tensor, Tensor]] +TModelPredictor = Callable[[Model, Tensor, bool], Tuple[Tensor, Tensor]] # pyre-fixme[33]: Aliased annotation cannot contain `Any`. @@ -466,6 +466,7 @@ def cross_validate( # pyre-ignore [14]: `search_space_digest` arg not needed he self, datasets: List[SupervisedDataset], X_test: Tensor, + use_posterior_predictive: bool = False, **kwargs: Any, ) -> Tuple[Tensor, Tensor]: if self._model is None: @@ -488,7 +489,10 @@ def cross_validate( # pyre-ignore [14]: `search_space_digest` arg not needed he use_loocv_pseudo_likelihood=self.use_loocv_pseudo_likelihood, **self._kwargs, ) - return self.model_predictor(model=model, X=X_test) # pyre-ignore: [28] + # pyre-ignore: [28] + return self.model_predictor( + model=model, X=X_test, use_posterior_predictive=use_posterior_predictive + ) def feature_importances(self) -> np.ndarray: return get_feature_importances_from_botorch_model(model=self._model) diff --git a/ax/models/torch/botorch_modular/model.py b/ax/models/torch/botorch_modular/model.py index 2021c513363..992e2ff1964 100644 --- a/ax/models/torch/botorch_modular/model.py +++ b/ax/models/torch/botorch_modular/model.py @@ -396,10 +396,15 @@ def predict(self, X: Tensor) -> Tuple[Tensor, Tensor]: return f, cov def predict_from_surrogate( - self, surrogate_label: str, X: Tensor + self, + surrogate_label: str, + X: Tensor, + use_posterior_predictive: bool = False, ) -> Tuple[Tensor, Tensor]: """Predict from the Surrogate with the given label.""" - return self.surrogates[surrogate_label].predict(X=X) + return self.surrogates[surrogate_label].predict( + X=X, use_posterior_predictive=use_posterior_predictive + ) @copy_doc(TorchModel.gen) def gen( @@ -504,6 +509,7 @@ def cross_validate( datasets: Sequence[SupervisedDataset], X_test: Tensor, search_space_digest: SearchSpaceDigest, + use_posterior_predictive: bool = False, **additional_model_inputs: Any, ) -> Tuple[Tensor, Tensor]: # Will fail if metric_names exist across multiple models @@ -561,7 +567,9 @@ def cross_validate( **additional_model_inputs, ) X_test_prediction = self.predict_from_surrogate( - surrogate_label=surrogate_label, X=X_test + surrogate_label=surrogate_label, + X=X_test, + use_posterior_predictive=use_posterior_predictive, ) finally: # Reset the surrogates back to this model's surrogate, make diff --git a/ax/models/torch/botorch_modular/surrogate.py b/ax/models/torch/botorch_modular/surrogate.py index 83c7608f04a..568e99b2c11 100644 --- a/ax/models/torch/botorch_modular/surrogate.py +++ b/ax/models/torch/botorch_modular/surrogate.py @@ -594,17 +594,24 @@ def _discard_cached_model_and_data_if_search_space_digest_changed( self._last_datasets = {} self._last_search_space_digest = search_space_digest - def predict(self, X: Tensor) -> Tuple[Tensor, Tensor]: + def predict( + self, X: Tensor, use_posterior_predictive: bool = False + ) -> Tuple[Tensor, Tensor]: """Predicts outcomes given an input tensor. Args: X: A ``n x d`` tensor of input parameters. + use_posterior_predictive: A boolean indicating if the predictions + should be from the posterior predictive (i.e. including + observation noise). Returns: Tensor: The predicted posterior mean as an ``n x o``-dim tensor. Tensor: The predicted posterior covariance as a ``n x o x o``-dim tensor. """ - return predict_from_model(model=self.model, X=X) + return predict_from_model( + model=self.model, X=X, use_posterior_predictive=use_posterior_predictive + ) def best_in_sample_point( self, diff --git a/ax/models/torch/randomforest.py b/ax/models/torch/randomforest.py index 8b584712962..dd4e936c0fb 100644 --- a/ax/models/torch/randomforest.py +++ b/ax/models/torch/randomforest.py @@ -73,6 +73,7 @@ def cross_validate( # pyre-ignore [14]: not using metric_names or ssd self, datasets: List[SupervisedDataset], X_test: Tensor, + use_posterior_predictive: bool = False, ) -> Tuple[Tensor, Tensor]: Xs, Ys, Yvars = _datasets_to_legacy_inputs(datasets=datasets) cv_models: List[RandomForestRegressor] = [] diff --git a/ax/models/torch/tests/test_model.py b/ax/models/torch/tests/test_model.py index 74eee241416..3496c975c1f 100644 --- a/ax/models/torch/tests/test_model.py +++ b/ax/models/torch/tests/test_model.py @@ -9,6 +9,7 @@ import dataclasses from collections import OrderedDict from copy import deepcopy +from itertools import product from typing import Dict, Type from unittest import mock from unittest.mock import Mock @@ -397,11 +398,9 @@ def test_cross_validate(self, mock_fit: Mock) -> None: old_surrogate._model = mock.MagicMock() old_surrogate._model.state_dict.return_value = OrderedDict({"key": "val"}) - for refit_on_cv, warm_start_refit in [ - (True, True), - (True, False), - (False, True), - ]: + for refit_on_cv, warm_start_refit, use_posterior_predictive in product( + (True, False), (True, False), (True, False) + ): self.model.refit_on_cv = refit_on_cv self.model.warm_start_refit = warm_start_refit with mock.patch( @@ -412,6 +411,7 @@ def test_cross_validate(self, mock_fit: Mock) -> None: datasets=self.block_design_training_data, X_test=self.X_test, search_space_digest=self.mf_search_space_digest, + use_posterior_predictive=use_posterior_predictive, ) # Check that `predict` is called on the cloned surrogate, not # on the original one. @@ -425,6 +425,10 @@ def test_cross_validate(self, mock_fit: Mock) -> None: self.X_test, ), ) + self.assertIs( + mock_predict.call_args_list[-1][1]["use_posterior_predictive"], + use_posterior_predictive, + ) # Check that surrogate is reset back to `old_surrogate` at the # end of cross-validation. diff --git a/ax/models/torch/tests/test_surrogate.py b/ax/models/torch/tests/test_surrogate.py index 98fd4c73788..5689c550380 100644 --- a/ax/models/torch/tests/test_surrogate.py +++ b/ax/models/torch/tests/test_surrogate.py @@ -9,6 +9,7 @@ import dataclasses import math from collections import OrderedDict +from itertools import product from typing import Any, Dict, Tuple, Type from unittest.mock import MagicMock, Mock, patch @@ -520,14 +521,22 @@ def test_construct_custom_model(self) -> None: @fast_botorch_optimize @patch(f"{SURROGATE_PATH}.predict_from_model") def test_predict(self, mock_predict: Mock) -> None: - for botorch_model_class in [SaasFullyBayesianSingleTaskGP, SingleTaskGP]: + for botorch_model_class, use_posterior_predictive in product( + (SaasFullyBayesianSingleTaskGP, SingleTaskGP), (True, False) + ): surrogate, _ = self._get_surrogate(botorch_model_class=botorch_model_class) surrogate.fit( datasets=self.training_data, search_space_digest=self.search_space_digest, ) - surrogate.predict(X=self.Xs[0]) - mock_predict.assert_called_with(model=surrogate.model, X=self.Xs[0]) + surrogate.predict( + X=self.Xs[0], use_posterior_predictive=use_posterior_predictive + ) + mock_predict.assert_called_with( + model=surrogate.model, + X=self.Xs[0], + use_posterior_predictive=use_posterior_predictive, + ) @fast_botorch_optimize def test_best_in_sample_point(self) -> None: diff --git a/ax/models/torch/utils.py b/ax/models/torch/utils.py index 414539aaa87..c63c3c4180d 100644 --- a/ax/models/torch/utils.py +++ b/ax/models/torch/utils.py @@ -624,7 +624,9 @@ def pick_best_out_of_sample_point_acqf_class( return cast(Type[AcquisitionFunction], acqf_class), acqf_options -def predict_from_model(model: Model, X: Tensor) -> Tuple[Tensor, Tensor]: +def predict_from_model( + model: Model, X: Tensor, use_posterior_predictive: bool = False +) -> Tuple[Tensor, Tensor]: r"""Predicts outcomes given a model and input tensor. For a `GaussianMixturePosterior` we currently use a Gaussian approximation where we @@ -634,6 +636,9 @@ def predict_from_model(model: Model, X: Tensor) -> Tuple[Tensor, Tensor]: Args: model: A botorch Model. X: A `n x d` tensor of input parameters. + use_posterior_predictive: A boolean indicating if the predictions + should be from the posterior predictive (i.e. including + observation noise). Returns: Tensor: The predicted posterior mean as an `n x o`-dim tensor. @@ -641,7 +646,7 @@ def predict_from_model(model: Model, X: Tensor) -> Tuple[Tensor, Tensor]: """ with torch.no_grad(): # TODO: Allow Posterior to (optionally) return the full covariance matrix - posterior = model.posterior(X) + posterior = model.posterior(X, observation_noise=use_posterior_predictive) if isinstance(posterior, GaussianMixturePosterior): mean = posterior.mixture_mean.cpu().detach() var = posterior.mixture_variance.cpu().detach().clamp_min(0) diff --git a/ax/models/torch_base.py b/ax/models/torch_base.py index 7822a62d920..86969430129 100644 --- a/ax/models/torch_base.py +++ b/ax/models/torch_base.py @@ -200,6 +200,7 @@ def cross_validate( datasets: List[SupervisedDataset], X_test: Tensor, search_space_digest: SearchSpaceDigest, + use_posterior_predictive: bool = False, ) -> Tuple[Tensor, Tensor]: """Do cross validation with the given training and test sets. @@ -212,6 +213,9 @@ def cross_validate( X_test: (j x d) tensor of the j points at which to make predictions. search_space_digest: A SearchSpaceDigest object containing metadata on the features in X. + use_posterior_predictive: A boolean indicating if the predictions + should be from the posterior predictive (i.e. including + observation noise). Returns: 2-element tuple containing