diff --git a/ax/modelbridge/base.py b/ax/modelbridge/base.py index b9bfa26aa62..43feb6de6b0 100644 --- a/ax/modelbridge/base.py +++ b/ax/modelbridge/base.py @@ -907,12 +907,16 @@ def cross_validate( self, cv_training_data: List[Observation], cv_test_points: List[ObservationFeatures], + use_posterior_predictive: bool = False, ) -> List[ObservationData]: """Make a set of cross-validation predictions. Args: cv_training_data: The training data to use for cross validation. cv_test_points: The test points at which predictions will be made. + use_posterior_predictive: A boolean indicating if the predictions + should be from the posterior predictive (i.e. including + observation noise). Returns: A list of predictions at the test points. @@ -936,6 +940,7 @@ def cross_validate( search_space=search_space, cv_training_data=cv_training_data, cv_test_points=cv_test_points, + use_posterior_predictive=use_posterior_predictive, ) # Apply reverse transforms, in reverse order cv_test_observations = [ @@ -952,6 +957,7 @@ def _cross_validate( search_space: SearchSpace, cv_training_data: List[Observation], cv_test_points: List[ObservationFeatures], + use_posterior_predictive: bool = False, ) -> List[ObservationData]: """Apply the terminal transform, make predictions on the test points, and reverse terminal transform on the results. diff --git a/ax/modelbridge/cross_validation.py b/ax/modelbridge/cross_validation.py index a07ec1475f7..b5f9bcc820e 100644 --- a/ax/modelbridge/cross_validation.py +++ b/ax/modelbridge/cross_validation.py @@ -78,6 +78,7 @@ def cross_validate( # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. test_selector: Optional[Callable] = None, untransform: bool = True, + use_posterior_predictive: bool = False, ) -> List[CVResult]: """Cross validation for model predictions. @@ -112,6 +113,12 @@ def cross_validate( of the original data in regions where outliers have been removed, we have found it to better reflect the how good the model used for candidate generation actually is. + use_posterior_predictive: A boolean indicating if the predictions + should be from the posterior predictive (i.e. including + observation noise). Note: we should reconsider how we compute + cross-validation and model fit metrics where there is non- + Gaussian noise. + Returns: A CVResult for each observation in the training data. """ @@ -162,7 +169,9 @@ def cross_validate( # Make the prediction if untransform: cv_test_predictions = model.cross_validate( - cv_training_data=cv_training_data, cv_test_points=cv_test_points + cv_training_data=cv_training_data, + cv_test_points=cv_test_points, + use_posterior_predictive=use_posterior_predictive, ) else: # Get test predictions in transformed space @@ -186,6 +195,7 @@ def cross_validate( search_space=search_space, cv_training_data=cv_training_data, cv_test_points=cv_test_points, + use_posterior_predictive=use_posterior_predictive, ) # Get test observations in transformed space cv_test_data = deepcopy(cv_test_data) @@ -197,7 +207,9 @@ def cross_validate( return result -def cross_validate_by_trial(model: ModelBridge, trial: int = -1) -> List[CVResult]: +def cross_validate_by_trial( + model: ModelBridge, trial: int = -1, use_posterior_predictive: bool = False +) -> List[CVResult]: """Cross validation for model predictions on a particular trial. Uses all of the data up until the specified trial to predict each of the @@ -206,6 +218,9 @@ def cross_validate_by_trial(model: ModelBridge, trial: int = -1) -> List[CVResul Args: model: Fitted model (ModelBridge) to cross validate. trial: Trial for which predictions are evaluated. + use_posterior_predictive: A boolean indicating if the predictions + should be from the posterior predictive (i.e. including + observation noise). Returns: A CVResult for each observation in the training data. @@ -241,7 +256,9 @@ def cross_validate_by_trial(model: ModelBridge, trial: int = -1) -> List[CVResul cv_test_data.append(obs) # Make the prediction cv_test_predictions = model.cross_validate( - cv_training_data=cv_training_data, cv_test_points=cv_test_points + cv_training_data=cv_training_data, + cv_test_points=cv_test_points, + use_posterior_predictive=use_posterior_predictive, ) # Form CVResult objects result = [ diff --git a/ax/modelbridge/discrete.py b/ax/modelbridge/discrete.py index 3e4436d8e66..e3015a87f6e 100644 --- a/ax/modelbridge/discrete.py +++ b/ax/modelbridge/discrete.py @@ -191,6 +191,7 @@ def _cross_validate( search_space: SearchSpace, cv_training_data: List[Observation], cv_test_points: List[ObservationFeatures], + use_posterior_predictive: bool = False, ) -> List[ObservationData]: """Make predictions at cv_test_points using only the data in obs_feats and obs_data. @@ -208,7 +209,11 @@ def _cross_validate( ] # Use the model to do the cross validation f_test, cov_test = self.model.cross_validate( - Xs_train=Xs_train, Ys_train=Ys_train, Yvars_train=Yvars_train, X_test=X_test + Xs_train=Xs_train, + Ys_train=Ys_train, + Yvars_train=Yvars_train, + X_test=X_test, + use_posterior_predictive=use_posterior_predictive, ) # Convert array back to ObservationData return array_to_observation_data(f=f_test, cov=cov_test, outcomes=self.outcomes) diff --git a/ax/modelbridge/map_torch.py b/ax/modelbridge/map_torch.py index 61cd56b9834..da51c5c18e0 100644 --- a/ax/modelbridge/map_torch.py +++ b/ax/modelbridge/map_torch.py @@ -279,6 +279,7 @@ def _cross_validate( cv_training_data: List[Observation], cv_test_points: List[ObservationFeatures], parameters: Optional[List[str]] = None, + use_posterior_predictive: bool = False, **kwargs: Any, ) -> List[ObservationData]: """Make predictions at cv_test_points using only the data in obs_feats @@ -294,6 +295,7 @@ def _cross_validate( cv_training_data=cv_training_data, cv_test_points=cv_test_points, parameters=parameters, # we pass the map_keys too by default + use_posterior_predictive=use_posterior_predictive, **kwargs, ) observation_features, observation_data = separate_observations(cv_training_data) diff --git a/ax/modelbridge/random.py b/ax/modelbridge/random.py index ca1dc4b7215..3d617047814 100644 --- a/ax/modelbridge/random.py +++ b/ax/modelbridge/random.py @@ -100,6 +100,7 @@ def _cross_validate( search_space: SearchSpace, cv_training_data: List[Observation], cv_test_points: List[ObservationFeatures], + use_posterior_predictive: bool = False, ) -> List[ObservationData]: raise NotImplementedError diff --git a/ax/modelbridge/tests/test_base_modelbridge.py b/ax/modelbridge/tests/test_base_modelbridge.py index fe7b50eb0de..7f858dd0581 100644 --- a/ax/modelbridge/tests/test_base_modelbridge.py +++ b/ax/modelbridge/tests/test_base_modelbridge.py @@ -267,9 +267,24 @@ def warn_and_return_mock_obs( search_space=SearchSpace([FixedParameter("x", ParameterType.FLOAT, 8.0)]), cv_training_data=[get_observation2trans()], cv_test_points=[get_observation1().features], # untransformed after + use_posterior_predictive=False, ) self.assertTrue(cv_predictions == [get_observation1().data]) + # Test use_posterior_predictive in CV + modelbridge.cross_validate( + cv_training_data=cv_training_data, + cv_test_points=cv_test_points, + use_posterior_predictive=True, + ) + + modelbridge._cross_validate.assert_called_with( + search_space=SearchSpace([FixedParameter("x", ParameterType.FLOAT, 8.0)]), + cv_training_data=[get_observation2trans()], + cv_test_points=[get_observation1().features], # untransformed after + use_posterior_predictive=True, + ) + # Test stored training data obs = modelbridge.get_training_data() self.assertTrue(obs == [get_observation1(), get_observation2()]) diff --git a/ax/modelbridge/tests/test_cross_validation.py b/ax/modelbridge/tests/test_cross_validation.py index 2d0ef3715fc..1484d48da20 100644 --- a/ax/modelbridge/tests/test_cross_validation.py +++ b/ax/modelbridge/tests/test_cross_validation.py @@ -201,7 +201,9 @@ def test_CrossValidate(self) -> None: # Test ModelBridge._cross_validate was called correctly. z = ma._cross_validate.mock_calls self.assertEqual(len(z), 3) - ma._cross_validate.assert_called_with(**self.transformed_cv_input_dict) + ma._cross_validate.assert_called_with( + **self.transformed_cv_input_dict, use_posterior_predictive=False + ) # Test selector @@ -219,6 +221,21 @@ def test_selector(obs: Observation) -> bool: ) self.assertTrue(np.array_equal(sorted(all_test), np.array([2.0, 2.0, 3.0]))) + # test observation noise + for untransform in (True, False): + result = cross_validate( + model=ma, + folds=-1, + use_posterior_predictive=True, + untransform=untransform, + ) + if untransform: + mock_cv = ma.cross_validate + else: + mock_cv = ma._cross_validate + call_kwargs = mock_cv.mock_calls[-1].kwargs + self.assertTrue(call_kwargs["use_posterior_predictive"]) + def test_CrossValidateByTrial(self) -> None: # With only 1 trial ma = mock.MagicMock() @@ -261,6 +278,15 @@ def test_CrossValidateByTrial(self) -> None: self.assertEqual(len(result), 1) self.assertEqual(result[0].observed.features.trial_index, 2) + mock_cv = ma.cross_validate + call_kwargs = mock_cv.mock_calls[-1].kwargs + self.assertFalse(call_kwargs["use_posterior_predictive"]) + + # test observation noise + result = cross_validate_by_trial(model=ma, use_posterior_predictive=True) + call_kwargs = mock_cv.mock_calls[-1].kwargs + self.assertTrue(call_kwargs["use_posterior_predictive"]) + def test_cross_validate_gives_a_useful_error_for_model_with_no_data(self) -> None: exp = get_branin_experiment() sobol = Models.SOBOL(experiment=exp, search_space=exp.search_space) diff --git a/ax/modelbridge/torch.py b/ax/modelbridge/torch.py index f72d75464f6..ccc7acb9f9c 100644 --- a/ax/modelbridge/torch.py +++ b/ax/modelbridge/torch.py @@ -428,6 +428,7 @@ def _cross_validate( cv_training_data: List[Observation], cv_test_points: List[ObservationFeatures], parameters: Optional[List[str]] = None, + use_posterior_predictive: bool = False, **kwargs: Any, ) -> List[ObservationData]: """Make predictions at cv_test_points using only the data in obs_feats @@ -453,6 +454,7 @@ def _cross_validate( datasets=datasets, X_test=torch.as_tensor(X_test, dtype=self.dtype, device=self.device), search_space_digest=search_space_digest, + use_posterior_predictive=use_posterior_predictive, **kwargs, ) # Convert array back to ObservationData diff --git a/ax/models/discrete_base.py b/ax/models/discrete_base.py index fd007e4972d..6cda9d46b81 100644 --- a/ax/models/discrete_base.py +++ b/ax/models/discrete_base.py @@ -102,6 +102,7 @@ def cross_validate( Ys_train: List[List[float]], Yvars_train: List[List[float]], X_test: List[TParamValueList], + use_posterior_predictive: bool = False, ) -> Tuple[np.ndarray, np.ndarray]: """Do cross validation with the given training and test sets. @@ -116,6 +117,9 @@ def cross_validate( each outcome. Yvars_train: The variances of each entry in Ys, same shape. X_test: List of the j parameterizations at which to make predictions. + use_posterior_predictive: A boolean indicating if the predictions + should be from the posterior predictive (i.e. including + observation noise). Returns: 2-element tuple containing diff --git a/ax/models/torch/botorch.py b/ax/models/torch/botorch.py index 5fc9113f488..11329ab8b40 100644 --- a/ax/models/torch/botorch.py +++ b/ax/models/torch/botorch.py @@ -64,7 +64,7 @@ ], Model, ] -TModelPredictor = Callable[[Model, Tensor], Tuple[Tensor, Tensor]] +TModelPredictor = Callable[[Model, Tensor, bool], Tuple[Tensor, Tensor]] # pyre-fixme[33]: Aliased annotation cannot contain `Any`. @@ -466,6 +466,7 @@ def cross_validate( # pyre-ignore [14]: `search_space_digest` arg not needed he self, datasets: List[SupervisedDataset], X_test: Tensor, + use_posterior_predictive: bool = False, **kwargs: Any, ) -> Tuple[Tensor, Tensor]: if self._model is None: @@ -488,7 +489,10 @@ def cross_validate( # pyre-ignore [14]: `search_space_digest` arg not needed he use_loocv_pseudo_likelihood=self.use_loocv_pseudo_likelihood, **self._kwargs, ) - return self.model_predictor(model=model, X=X_test) # pyre-ignore: [28] + # pyre-ignore: [28] + return self.model_predictor( + model=model, X=X_test, use_posterior_predictive=use_posterior_predictive + ) def feature_importances(self) -> np.ndarray: return get_feature_importances_from_botorch_model(model=self._model) diff --git a/ax/models/torch/botorch_modular/model.py b/ax/models/torch/botorch_modular/model.py index 2021c513363..992e2ff1964 100644 --- a/ax/models/torch/botorch_modular/model.py +++ b/ax/models/torch/botorch_modular/model.py @@ -396,10 +396,15 @@ def predict(self, X: Tensor) -> Tuple[Tensor, Tensor]: return f, cov def predict_from_surrogate( - self, surrogate_label: str, X: Tensor + self, + surrogate_label: str, + X: Tensor, + use_posterior_predictive: bool = False, ) -> Tuple[Tensor, Tensor]: """Predict from the Surrogate with the given label.""" - return self.surrogates[surrogate_label].predict(X=X) + return self.surrogates[surrogate_label].predict( + X=X, use_posterior_predictive=use_posterior_predictive + ) @copy_doc(TorchModel.gen) def gen( @@ -504,6 +509,7 @@ def cross_validate( datasets: Sequence[SupervisedDataset], X_test: Tensor, search_space_digest: SearchSpaceDigest, + use_posterior_predictive: bool = False, **additional_model_inputs: Any, ) -> Tuple[Tensor, Tensor]: # Will fail if metric_names exist across multiple models @@ -561,7 +567,9 @@ def cross_validate( **additional_model_inputs, ) X_test_prediction = self.predict_from_surrogate( - surrogate_label=surrogate_label, X=X_test + surrogate_label=surrogate_label, + X=X_test, + use_posterior_predictive=use_posterior_predictive, ) finally: # Reset the surrogates back to this model's surrogate, make diff --git a/ax/models/torch/botorch_modular/surrogate.py b/ax/models/torch/botorch_modular/surrogate.py index 83c7608f04a..568e99b2c11 100644 --- a/ax/models/torch/botorch_modular/surrogate.py +++ b/ax/models/torch/botorch_modular/surrogate.py @@ -594,17 +594,24 @@ def _discard_cached_model_and_data_if_search_space_digest_changed( self._last_datasets = {} self._last_search_space_digest = search_space_digest - def predict(self, X: Tensor) -> Tuple[Tensor, Tensor]: + def predict( + self, X: Tensor, use_posterior_predictive: bool = False + ) -> Tuple[Tensor, Tensor]: """Predicts outcomes given an input tensor. Args: X: A ``n x d`` tensor of input parameters. + use_posterior_predictive: A boolean indicating if the predictions + should be from the posterior predictive (i.e. including + observation noise). Returns: Tensor: The predicted posterior mean as an ``n x o``-dim tensor. Tensor: The predicted posterior covariance as a ``n x o x o``-dim tensor. """ - return predict_from_model(model=self.model, X=X) + return predict_from_model( + model=self.model, X=X, use_posterior_predictive=use_posterior_predictive + ) def best_in_sample_point( self, diff --git a/ax/models/torch/randomforest.py b/ax/models/torch/randomforest.py index 8b584712962..dd4e936c0fb 100644 --- a/ax/models/torch/randomforest.py +++ b/ax/models/torch/randomforest.py @@ -73,6 +73,7 @@ def cross_validate( # pyre-ignore [14]: not using metric_names or ssd self, datasets: List[SupervisedDataset], X_test: Tensor, + use_posterior_predictive: bool = False, ) -> Tuple[Tensor, Tensor]: Xs, Ys, Yvars = _datasets_to_legacy_inputs(datasets=datasets) cv_models: List[RandomForestRegressor] = [] diff --git a/ax/models/torch/tests/test_model.py b/ax/models/torch/tests/test_model.py index 74eee241416..3496c975c1f 100644 --- a/ax/models/torch/tests/test_model.py +++ b/ax/models/torch/tests/test_model.py @@ -9,6 +9,7 @@ import dataclasses from collections import OrderedDict from copy import deepcopy +from itertools import product from typing import Dict, Type from unittest import mock from unittest.mock import Mock @@ -397,11 +398,9 @@ def test_cross_validate(self, mock_fit: Mock) -> None: old_surrogate._model = mock.MagicMock() old_surrogate._model.state_dict.return_value = OrderedDict({"key": "val"}) - for refit_on_cv, warm_start_refit in [ - (True, True), - (True, False), - (False, True), - ]: + for refit_on_cv, warm_start_refit, use_posterior_predictive in product( + (True, False), (True, False), (True, False) + ): self.model.refit_on_cv = refit_on_cv self.model.warm_start_refit = warm_start_refit with mock.patch( @@ -412,6 +411,7 @@ def test_cross_validate(self, mock_fit: Mock) -> None: datasets=self.block_design_training_data, X_test=self.X_test, search_space_digest=self.mf_search_space_digest, + use_posterior_predictive=use_posterior_predictive, ) # Check that `predict` is called on the cloned surrogate, not # on the original one. @@ -425,6 +425,10 @@ def test_cross_validate(self, mock_fit: Mock) -> None: self.X_test, ), ) + self.assertIs( + mock_predict.call_args_list[-1][1]["use_posterior_predictive"], + use_posterior_predictive, + ) # Check that surrogate is reset back to `old_surrogate` at the # end of cross-validation. diff --git a/ax/models/torch/tests/test_surrogate.py b/ax/models/torch/tests/test_surrogate.py index 98fd4c73788..5689c550380 100644 --- a/ax/models/torch/tests/test_surrogate.py +++ b/ax/models/torch/tests/test_surrogate.py @@ -9,6 +9,7 @@ import dataclasses import math from collections import OrderedDict +from itertools import product from typing import Any, Dict, Tuple, Type from unittest.mock import MagicMock, Mock, patch @@ -520,14 +521,22 @@ def test_construct_custom_model(self) -> None: @fast_botorch_optimize @patch(f"{SURROGATE_PATH}.predict_from_model") def test_predict(self, mock_predict: Mock) -> None: - for botorch_model_class in [SaasFullyBayesianSingleTaskGP, SingleTaskGP]: + for botorch_model_class, use_posterior_predictive in product( + (SaasFullyBayesianSingleTaskGP, SingleTaskGP), (True, False) + ): surrogate, _ = self._get_surrogate(botorch_model_class=botorch_model_class) surrogate.fit( datasets=self.training_data, search_space_digest=self.search_space_digest, ) - surrogate.predict(X=self.Xs[0]) - mock_predict.assert_called_with(model=surrogate.model, X=self.Xs[0]) + surrogate.predict( + X=self.Xs[0], use_posterior_predictive=use_posterior_predictive + ) + mock_predict.assert_called_with( + model=surrogate.model, + X=self.Xs[0], + use_posterior_predictive=use_posterior_predictive, + ) @fast_botorch_optimize def test_best_in_sample_point(self) -> None: diff --git a/ax/models/torch/utils.py b/ax/models/torch/utils.py index 414539aaa87..c63c3c4180d 100644 --- a/ax/models/torch/utils.py +++ b/ax/models/torch/utils.py @@ -624,7 +624,9 @@ def pick_best_out_of_sample_point_acqf_class( return cast(Type[AcquisitionFunction], acqf_class), acqf_options -def predict_from_model(model: Model, X: Tensor) -> Tuple[Tensor, Tensor]: +def predict_from_model( + model: Model, X: Tensor, use_posterior_predictive: bool = False +) -> Tuple[Tensor, Tensor]: r"""Predicts outcomes given a model and input tensor. For a `GaussianMixturePosterior` we currently use a Gaussian approximation where we @@ -634,6 +636,9 @@ def predict_from_model(model: Model, X: Tensor) -> Tuple[Tensor, Tensor]: Args: model: A botorch Model. X: A `n x d` tensor of input parameters. + use_posterior_predictive: A boolean indicating if the predictions + should be from the posterior predictive (i.e. including + observation noise). Returns: Tensor: The predicted posterior mean as an `n x o`-dim tensor. @@ -641,7 +646,7 @@ def predict_from_model(model: Model, X: Tensor) -> Tuple[Tensor, Tensor]: """ with torch.no_grad(): # TODO: Allow Posterior to (optionally) return the full covariance matrix - posterior = model.posterior(X) + posterior = model.posterior(X, observation_noise=use_posterior_predictive) if isinstance(posterior, GaussianMixturePosterior): mean = posterior.mixture_mean.cpu().detach() var = posterior.mixture_variance.cpu().detach().clamp_min(0) diff --git a/ax/models/torch_base.py b/ax/models/torch_base.py index 7822a62d920..86969430129 100644 --- a/ax/models/torch_base.py +++ b/ax/models/torch_base.py @@ -200,6 +200,7 @@ def cross_validate( datasets: List[SupervisedDataset], X_test: Tensor, search_space_digest: SearchSpaceDigest, + use_posterior_predictive: bool = False, ) -> Tuple[Tensor, Tensor]: """Do cross validation with the given training and test sets. @@ -212,6 +213,9 @@ def cross_validate( X_test: (j x d) tensor of the j points at which to make predictions. search_space_digest: A SearchSpaceDigest object containing metadata on the features in X. + use_posterior_predictive: A boolean indicating if the predictions + should be from the posterior predictive (i.e. including + observation noise). Returns: 2-element tuple containing