diff --git a/laplace/baselaplace.py b/laplace/baselaplace.py index 25fdf12..47f8a9f 100644 --- a/laplace/baselaplace.py +++ b/laplace/baselaplace.py @@ -2468,6 +2468,48 @@ def __call__( x, likelihood, joint, link_approx, n_samples, diagonal_output ) + def functional_samples( + self, + x: torch.Tensor | MutableMapping[str, torch.Tensor | Any], + pred_type: PredType | str = PredType.GLM, + n_samples: int = 100, + diagonal_output: bool = False, + generator: torch.Generator | None = None, + ) -> torch.Tensor: + """Sample from the functional posterior on input data `x`. + Can be used, for example, for Thompson sampling. + + Parameters + ---------- + x : torch.Tensor or MutableMapping + input data `(batch_size, input_shape)` + + pred_type : {'glm'}, default='glm' + type of posterior predictive, linearized GLM predictive. + + n_samples : int + number of samples + + diagonal_output : bool + whether to use a diagonalized glm posterior predictive on the outputs. + Only applies when `pred_type='glm'`. + + generator : torch.Generator, optional + random number generator to control the samples (if sampling used) + + Returns + ------- + samples : torch.Tensor + samples `(n_samples, batch_size, output_shape)` + """ + if pred_type not in PredType.__members__.values(): + raise ValueError("Only glm supported as prediction type.") + + f_mu, f_var = self._glm_predictive_distribution(x) + return self._glm_functional_samples( + f_mu, f_var, n_samples, diagonal_output, generator + ) + def predictive_samples( self, x: torch.Tensor | MutableMapping[str, torch.Tensor | Any], @@ -2477,7 +2519,8 @@ def predictive_samples( generator: torch.Generator | None = None, ) -> torch.Tensor: """Sample from the posterior predictive on input data `x`. - Can be used, for example, for Thompson sampling. + I.e., the corresponding inverse-link function is applied on top of the + functional sample. Can be used, for example, for Thompson sampling. Parameters ---------- diff --git a/tests/test_functional_laplace_unit.py b/tests/test_functional_laplace_unit.py index 296aebb..c8f3586 100644 --- a/tests/test_functional_laplace_unit.py +++ b/tests/test_functional_laplace_unit.py @@ -286,3 +286,37 @@ def mock_jacobians(self, x): expected_block_diagonal_kernel, block_diag_kernel.to(expected_block_diagonal_kernel.dtype), ) + + +def test_functional_samples(model, reg_loader): + lap = FunctionalLaplace(model, "regression", n_subset=5) + lap.fit(reg_loader) + X, y = reg_loader.dataset.tensors + f = model(X) + + generator = torch.Generator() + + fsamples_reg_glm = lap.functional_samples( + X, pred_type="glm", n_samples=100, generator=generator.manual_seed(123) + ) + assert fsamples_reg_glm.shape == torch.Size([100, f.shape[0], f.shape[1]]) + + fsamples_reg_nn = lap.functional_samples( + X, pred_type="nn", n_samples=100, generator=generator.manual_seed(123) + ) + assert fsamples_reg_nn.shape == torch.Size([100, f.shape[0], f.shape[1]]) + + # The samples should not be affected by the likelihood + lap.likelihood = "classification" + + fsamples_clf_glm = lap.functional_samples( + X, pred_type="glm", n_samples=100, generator=generator.manual_seed(123) + ) + assert fsamples_clf_glm.shape == torch.Size([100, f.shape[0], f.shape[1]]) + assert torch.allclose(fsamples_clf_glm, fsamples_reg_glm) + + fsamples_clf_nn = lap.functional_samples( + X, pred_type="nn", n_samples=100, generator=generator.manual_seed(123) + ) + assert fsamples_clf_nn.shape == torch.Size([100, f.shape[0], f.shape[1]]) + assert torch.allclose(fsamples_clf_nn, fsamples_reg_nn)