Skip to content

Commit

Permalink
Add functional_samples method to FunctionalLaplace
Browse files Browse the repository at this point in the history
  • Loading branch information
wiseodd committed Sep 14, 2024
1 parent 1294dec commit ed84592
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 1 deletion.
45 changes: 44 additions & 1 deletion laplace/baselaplace.py
Original file line number Diff line number Diff line change
Expand Up @@ -2468,6 +2468,48 @@ def __call__(
x, likelihood, joint, link_approx, n_samples, diagonal_output
)

def functional_samples(
self,
x: torch.Tensor | MutableMapping[str, torch.Tensor | Any],
pred_type: PredType | str = PredType.GLM,
n_samples: int = 100,
diagonal_output: bool = False,
generator: torch.Generator | None = None,
) -> torch.Tensor:
"""Sample from the functional posterior on input data `x`.
Can be used, for example, for Thompson sampling.
Parameters
----------
x : torch.Tensor or MutableMapping
input data `(batch_size, input_shape)`
pred_type : {'glm'}, default='glm'
type of posterior predictive, linearized GLM predictive.
n_samples : int
number of samples
diagonal_output : bool
whether to use a diagonalized glm posterior predictive on the outputs.
Only applies when `pred_type='glm'`.
generator : torch.Generator, optional
random number generator to control the samples (if sampling used)
Returns
-------
samples : torch.Tensor
samples `(n_samples, batch_size, output_shape)`
"""
if pred_type not in PredType.__members__.values():
raise ValueError("Only glm supported as prediction type.")

f_mu, f_var = self._glm_predictive_distribution(x)
return self._glm_functional_samples(
f_mu, f_var, n_samples, diagonal_output, generator
)

def predictive_samples(
self,
x: torch.Tensor | MutableMapping[str, torch.Tensor | Any],
Expand All @@ -2477,7 +2519,8 @@ def predictive_samples(
generator: torch.Generator | None = None,
) -> torch.Tensor:
"""Sample from the posterior predictive on input data `x`.
Can be used, for example, for Thompson sampling.
I.e., the corresponding inverse-link function is applied on top of the
functional sample. Can be used, for example, for Thompson sampling.
Parameters
----------
Expand Down
34 changes: 34 additions & 0 deletions tests/test_functional_laplace_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,3 +286,37 @@ def mock_jacobians(self, x):
expected_block_diagonal_kernel,
block_diag_kernel.to(expected_block_diagonal_kernel.dtype),
)


def test_functional_samples(model, reg_loader):
lap = FunctionalLaplace(model, "regression", n_subset=5)
lap.fit(reg_loader)
X, y = reg_loader.dataset.tensors
f = model(X)

generator = torch.Generator()

fsamples_reg_glm = lap.functional_samples(
X, pred_type="glm", n_samples=100, generator=generator.manual_seed(123)
)
assert fsamples_reg_glm.shape == torch.Size([100, f.shape[0], f.shape[1]])

fsamples_reg_nn = lap.functional_samples(
X, pred_type="nn", n_samples=100, generator=generator.manual_seed(123)
)
assert fsamples_reg_nn.shape == torch.Size([100, f.shape[0], f.shape[1]])

# The samples should not be affected by the likelihood
lap.likelihood = "classification"

fsamples_clf_glm = lap.functional_samples(
X, pred_type="glm", n_samples=100, generator=generator.manual_seed(123)
)
assert fsamples_clf_glm.shape == torch.Size([100, f.shape[0], f.shape[1]])
assert torch.allclose(fsamples_clf_glm, fsamples_reg_glm)

fsamples_clf_nn = lap.functional_samples(
X, pred_type="nn", n_samples=100, generator=generator.manual_seed(123)
)
assert fsamples_clf_nn.shape == torch.Size([100, f.shape[0], f.shape[1]])
assert torch.allclose(fsamples_clf_nn, fsamples_reg_nn)

0 comments on commit ed84592

Please sign in to comment.