Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add fast computation of functional_variance for DiagLLLaplace and KronLLLaplace #145

Merged
merged 10 commits into from
Jun 30, 2024
16 changes: 13 additions & 3 deletions laplace/baselaplace.py
Original file line number Diff line number Diff line change
Expand Up @@ -854,7 +854,10 @@ def __call__(

diagonal_output : bool
whether to use a diagonalized posterior predictive on the outputs.
Only works for `pred_type='glm'` and `link_approx='mc'`.
Only works for `pred_type='glm'` when `joint=False` in regression.
In the case of last-layer Laplace with a diagonal or Kron Hessian,
setting this to `True` makes computation much(!) faster for large
number of outputs.

generator : torch.Generator, optional
random number generator to control the samples (if sampling used).
Expand Down Expand Up @@ -898,7 +901,10 @@ def __call__(

if pred_type == PredType.GLM:
f_mu, f_var = self._glm_predictive_distribution(
x, joint=joint and likelihood == Likelihood.REGRESSION
x,
joint=joint and likelihood == Likelihood.REGRESSION,
diagonal_output=diagonal_output
and self.likelihood == Likelihood.REGRESSION,
)

if likelihood == Likelihood.REGRESSION:
Expand Down Expand Up @@ -1015,6 +1021,7 @@ def _glm_predictive_distribution(
self,
X: torch.Tensor | MutableMapping[str, torch.Tensor | Any],
joint: bool = False,
diagonal_output=False,
) -> tuple[torch.Tensor, torch.Tensor]:
backend_name = self._backend_cls.__name__.lower()
if self.enable_backprop and (
Expand All @@ -1031,7 +1038,10 @@ def _glm_predictive_distribution(
f_mu = f_mu.flatten() # (batch*out)
f_var = self.functional_covariance(Js) # (batch*out, batch*out)
else:
f_var = self.functional_variance(Js)
f_var = self.functional_variance(Js) # (batch, out, out)

if diagonal_output:
f_var = torch.diagonal(f_var, dim1=-2, dim2=-1)

return (
(f_mu.detach(), f_var.detach())
Expand Down
93 changes: 90 additions & 3 deletions laplace/lllaplace.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,14 +198,24 @@ def fit(
self.mean = self.mean.detach()

def _glm_predictive_distribution(
self, X: torch.Tensor | MutableMapping, joint: bool = False
self,
X: torch.Tensor | MutableMapping,
joint: bool = False,
diagonal_output: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
Js, f_mu = self.backend.last_layer_jacobians(X)

if joint:
Js, f_mu = self.backend.last_layer_jacobians(X)
f_mu = f_mu.flatten() # (batch*out)
f_var = self.functional_covariance(Js) # (batch*out, batch*out)
elif diagonal_output:
try:
f_mu, f_var = self.functional_variance_fast(X)
except NotImplementedError:
# WARN: Fallback if not implemented
Js, f_mu = self.backend.last_layer_jacobians(X)
f_var = self.functional_variance(Js).diagonal(dim1=-2, dim2=-1)
else:
Js, f_mu = self.backend.last_layer_jacobians(X)
f_var = self.functional_variance(Js)

return (
Expand All @@ -214,6 +224,24 @@ def _glm_predictive_distribution(
else (f_mu, f_var)
)

def functional_variance_fast(self, X):
"""
Should be overriden if there exists a trick to make this fast!

Parameters
----------
X: torch.Tensor of shape (batch_size, input_dim)

Returns
-------
f_var_diag: torch.Tensor of shape (batch_size, num_outputs)
Corresponding to the diagonal of the covariance matrix of the outputs
"""
Js, f_mu = self.backend.last_layer_jacobians(X)
f_cov = self.functional_variance(Js) # No trick possible for Full Laplace
f_var = torch.diagonal(f_cov, dim1=-2, dim2=-1)
return f_mu, f_var

def _nn_predictive_samples(
self,
X: torch.Tensor | MutableMapping[str, torch.Tensor | Any],
Expand Down Expand Up @@ -395,6 +423,46 @@ def __init__(
def _init_H(self) -> None:
self.H = Kron.init_from_model(self.model.last_layer, self._device)

def functional_variance_fast(self, X):
raise NotImplementedError

# TODO: @Alex wants to revise this implementation
f_mu, phi = self.model.forward_with_features(X)
num_classes = f_mu.shape[-1]

# Contribution from the weights
# -----------------------------
eig_U, eig_V = self.posterior_precision.eigenvalues[0]
vec_U, vec_V = self.posterior_precision.eigenvectors[0]
delta = self.posterior_precision.deltas[0].sqrt()
inv_U_eig, inv_V_eig = (
torch.pow(eig_U + delta, -1),
torch.pow(eig_V + delta, -1),
)

# Matrix form of the kron factors
U = torch.einsum("ik,k,jk->ij", vec_U, inv_U_eig, vec_U)
V = torch.einsum("ik,k,jk->ij", vec_V, inv_V_eig, vec_V)

# Using the identity of the Matrix Gaussian distribution
# phi is (batch_size, embd_dim), V is (embd_dim, embd_dim), U is (num_classes, num_classes)
# phiVphi is (batch_size,)
phiVphi = torch.einsum("bi,ij,bj->b", phi, V, phi)
f_var = torch.einsum("b,ii->bi", phiVphi, U) # (batch_size, num_classes)

if self.model.last_layer.bias is not None:
# Contribution from the biases
# ----------------------------
eig = self.posterior_precision.eigenvalues[1][0]
vec = self.posterior_precision.eigenvectors[1][0]
delta = self.posterior_precision.deltas[1].sqrt()
inv_eig = torch.pow(eig + delta, -1)

Sigma_bias = torch.einsum("ik,k,ik->i", vec, inv_eig, vec) # (num_classes)
f_var += Sigma_bias.reshape(1, num_classes)

return f_mu, f_var


class DiagLLLaplace(LLLaplace, DiagLaplace):
"""Last-layer Laplace approximation with diagonal log likelihood Hessian approximation
Expand All @@ -405,3 +473,22 @@ class DiagLLLaplace(LLLaplace, DiagLaplace):

# key to map to correct subclass of BaseLaplace, (subset of weights, Hessian structure)
_key = ("last_layer", "diag")

def functional_variance_fast(self, X):
f_mu, phi = self.model.forward_with_features(X)
k = f_mu.shape[-1] # num_classes
b, d = phi.shape # batch_size, embd_dim

# Here, we exploit the fact that J Sigma J.T is (batch) diagonal
# We notice that the param variance is [vars_weight, vars_biases] and
# each functional variance phi^2*var_weight + var_bias
f_var = torch.einsum(
"bd,kd,bd->bk", phi, self.posterior_variance[: d * k].reshape(k, d), phi
)

if self.model.last_layer.bias is not None:
# Add the last num_classes variances, corresponding to the biases' variances
# (b,k) + (1,k) = (b,k)
f_var += self.posterior_variance[-k:].reshape(1, k)

return f_mu, f_var
2 changes: 1 addition & 1 deletion laplace/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def validate(
fitting=True,
)

if type(out) == tuple:
if type(out) is tuple:
if is_offline:
output_means.append(out[0])
output_vars.append(out[1])
Expand Down
42 changes: 42 additions & 0 deletions tests/test_baselaplace.py
Original file line number Diff line number Diff line change
Expand Up @@ -741,6 +741,48 @@ def test_backprop_nn(laplace, model, reg_loader, backend):
assert False


@pytest.mark.parametrize("laplace", [FullLaplace, KronLaplace, DiagLaplace])
def test_reg_glm_predictive_correct_behavior(laplace, model, reg_loader):
X, y = reg_loader.dataset.tensors
n_batch = X.shape[0]
n_outputs = y.shape[-1]

lap = laplace(model, "regression")
lap.fit(reg_loader)

# Joint predictive ignores diagonal_output
f_mean, f_var = lap(X, pred_type="glm", joint=True, diagonal_output=True)
assert f_var.shape == (n_batch * n_outputs, n_batch * n_outputs)

f_mean, f_var = lap(X, pred_type="glm", joint=True, diagonal_output=False)
assert f_var.shape == (n_batch * n_outputs, n_batch * n_outputs)

# diagonal_output affects non-joint
f_mean, f_var = lap(X, pred_type="glm", joint=False, diagonal_output=True)
assert f_var.shape == (n_batch, n_outputs)

f_mean, f_var = lap(X, pred_type="glm", joint=False, diagonal_output=False)
assert f_var.shape == (n_batch, n_outputs, n_outputs)


@pytest.mark.parametrize(
"likelihood,custom_loader",
[
("classification", "custom_loader_clf"),
("regression", "custom_loader_reg"),
("reward_modeling", "custom_loader_clf"),
],
)
def test_dict_data_diagEF_curvlinops_fails(
custom_model, custom_loader, likelihood, request
):
custom_loader = request.getfixturevalue(custom_loader)
lap = DiagLaplace(custom_model, likelihood=likelihood, backend=CurvlinopsEF)

with pytest.raises(ValueError):
lap.fit(custom_loader)


@pytest.mark.parametrize(
"likelihood", ["classification", "regression", "reward_modeling"]
)
Expand Down
62 changes: 62 additions & 0 deletions tests/test_lllaplace.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,13 @@ def model():
return model


@pytest.fixture
def model_no_output_bias():
model = torch.nn.Sequential(nn.Linear(3, 20), nn.Linear(20, 2, bias=False))
setattr(model, "output_size", 2)
return model


@pytest.fixture
def model_with_reduction():
class Model(nn.Module):
Expand Down Expand Up @@ -187,6 +194,7 @@ def test_laplace_init_precision(laplace, model):
setattr(model, "n_params", len(parameters_to_vector(model_params)))
# float
precision = 10.6

laplace(
model, likelihood="regression", prior_precision=precision, last_layer_name="1"
)
Expand Down Expand Up @@ -563,6 +571,36 @@ def test_classification_predictive_samples(laplace, model, class_loader):
assert np.allclose(fsamples.sum().item(), len(f) * 100) # sum up to 1


@pytest.mark.parametrize("laplace", [FullLLLaplace, DiagLLLaplace, KronLLLaplace])
def test_functional_variance_fast(laplace, model, reg_loader):
if laplace == KronLLLaplace:
# TODO still!
return

X, y = reg_loader.dataset.tensors
X.requires_grad = True

lap = laplace(model, "regression", enable_backprop=True)
lap.fit(reg_loader)
f_mu, f_var = lap.functional_variance_fast(X)

assert f_mu.shape == (X.shape[0], y.shape[-1])
assert f_var.shape == (X.shape[0], y.shape[-1])

Js, f_naive = lap.backend.last_layer_jacobians(X)

if laplace == DiagLLLaplace:
f_var_naive = torch.einsum("ncp,p,ncp->nc", Js, lap.posterior_variance, Js)
elif laplace == KronLLLaplace:
f_var_naive = lap.posterior_precision.inv_square_form(Js)
f_var_naive = torch.diagonal(f_var_naive, dim1=-2, dim2=-1)
else: # FullLLaplace
f_var_naive = torch.einsum("ncp,pq,ncq->nc", Js, lap.posterior_covariance, Js)

assert torch.allclose(f_mu, f_naive)
assert torch.allclose(f_var, f_var_naive)


@pytest.mark.parametrize("laplace", flavors)
def test_backprop_glm(laplace, model, reg_loader):
X, y = reg_loader.dataset.tensors
Expand Down Expand Up @@ -637,3 +675,27 @@ def test_backprop_nn(laplace, model, reg_loader):
assert grad_X_var.shape == X.shape
except ValueError:
assert False


@pytest.mark.parametrize("laplace", [FullLLLaplace, KronLLLaplace, DiagLLLaplace])
def test_reg_glm_predictive_correct_behavior(laplace, model, reg_loader):
X, y = reg_loader.dataset.tensors
n_batch = X.shape[0]
n_outputs = y.shape[-1]

lap = laplace(model, "regression")
lap.fit(reg_loader)

# Joint predictive ignores diagonal_output
f_mean, f_var = lap(X, pred_type="glm", joint=True, diagonal_output=True)
assert f_var.shape == (n_batch * n_outputs, n_batch * n_outputs)

f_mean, f_var = lap(X, pred_type="glm", joint=True, diagonal_output=False)
assert f_var.shape == (n_batch * n_outputs, n_batch * n_outputs)

# diagonal_output affects non-joint
f_mean, f_var = lap(X, pred_type="glm", joint=False, diagonal_output=True)
assert f_var.shape == (n_batch, n_outputs)

f_mean, f_var = lap(X, pred_type="glm", joint=False, diagonal_output=False)
assert f_var.shape == (n_batch, n_outputs, n_outputs)
Loading