Skip to content

Commit

Permalink
Merge pull request #145 from aleximmer/speedup_llla
Browse files Browse the repository at this point in the history
Add fast computation of functional_variance for DiagLLLaplace and KronLLLaplace
  • Loading branch information
wiseodd authored Jun 30, 2024
2 parents 392fec0 + 1030ce5 commit e3ca2c6
Show file tree
Hide file tree
Showing 5 changed files with 208 additions and 7 deletions.
16 changes: 13 additions & 3 deletions laplace/baselaplace.py
Original file line number Diff line number Diff line change
Expand Up @@ -854,7 +854,10 @@ def __call__(
diagonal_output : bool
whether to use a diagonalized posterior predictive on the outputs.
Only works for `pred_type='glm'` and `link_approx='mc'`.
Only works for `pred_type='glm'` when `joint=False` in regression.
In the case of last-layer Laplace with a diagonal or Kron Hessian,
setting this to `True` makes computation much(!) faster for large
number of outputs.
generator : torch.Generator, optional
random number generator to control the samples (if sampling used).
Expand Down Expand Up @@ -898,7 +901,10 @@ def __call__(

if pred_type == PredType.GLM:
f_mu, f_var = self._glm_predictive_distribution(
x, joint=joint and likelihood == Likelihood.REGRESSION
x,
joint=joint and likelihood == Likelihood.REGRESSION,
diagonal_output=diagonal_output
and self.likelihood == Likelihood.REGRESSION,
)

if likelihood == Likelihood.REGRESSION:
Expand Down Expand Up @@ -1015,6 +1021,7 @@ def _glm_predictive_distribution(
self,
X: torch.Tensor | MutableMapping[str, torch.Tensor | Any],
joint: bool = False,
diagonal_output=False,
) -> tuple[torch.Tensor, torch.Tensor]:
backend_name = self._backend_cls.__name__.lower()
if self.enable_backprop and (
Expand All @@ -1031,7 +1038,10 @@ def _glm_predictive_distribution(
f_mu = f_mu.flatten() # (batch*out)
f_var = self.functional_covariance(Js) # (batch*out, batch*out)
else:
f_var = self.functional_variance(Js)
f_var = self.functional_variance(Js) # (batch, out, out)

if diagonal_output:
f_var = torch.diagonal(f_var, dim1=-2, dim2=-1)

return (
(f_mu.detach(), f_var.detach())
Expand Down
93 changes: 90 additions & 3 deletions laplace/lllaplace.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,14 +198,24 @@ def fit(
self.mean = self.mean.detach()

def _glm_predictive_distribution(
self, X: torch.Tensor | MutableMapping, joint: bool = False
self,
X: torch.Tensor | MutableMapping,
joint: bool = False,
diagonal_output: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
Js, f_mu = self.backend.last_layer_jacobians(X)

if joint:
Js, f_mu = self.backend.last_layer_jacobians(X)
f_mu = f_mu.flatten() # (batch*out)
f_var = self.functional_covariance(Js) # (batch*out, batch*out)
elif diagonal_output:
try:
f_mu, f_var = self.functional_variance_fast(X)
except NotImplementedError:
# WARN: Fallback if not implemented
Js, f_mu = self.backend.last_layer_jacobians(X)
f_var = self.functional_variance(Js).diagonal(dim1=-2, dim2=-1)
else:
Js, f_mu = self.backend.last_layer_jacobians(X)
f_var = self.functional_variance(Js)

return (
Expand All @@ -214,6 +224,24 @@ def _glm_predictive_distribution(
else (f_mu, f_var)
)

def functional_variance_fast(self, X):
"""
Should be overriden if there exists a trick to make this fast!
Parameters
----------
X: torch.Tensor of shape (batch_size, input_dim)
Returns
-------
f_var_diag: torch.Tensor of shape (batch_size, num_outputs)
Corresponding to the diagonal of the covariance matrix of the outputs
"""
Js, f_mu = self.backend.last_layer_jacobians(X)
f_cov = self.functional_variance(Js) # No trick possible for Full Laplace
f_var = torch.diagonal(f_cov, dim1=-2, dim2=-1)
return f_mu, f_var

def _nn_predictive_samples(
self,
X: torch.Tensor | MutableMapping[str, torch.Tensor | Any],
Expand Down Expand Up @@ -395,6 +423,46 @@ def __init__(
def _init_H(self) -> None:
self.H = Kron.init_from_model(self.model.last_layer, self._device)

def functional_variance_fast(self, X):
raise NotImplementedError

# TODO: @Alex wants to revise this implementation
f_mu, phi = self.model.forward_with_features(X)
num_classes = f_mu.shape[-1]

# Contribution from the weights
# -----------------------------
eig_U, eig_V = self.posterior_precision.eigenvalues[0]
vec_U, vec_V = self.posterior_precision.eigenvectors[0]
delta = self.posterior_precision.deltas[0].sqrt()
inv_U_eig, inv_V_eig = (
torch.pow(eig_U + delta, -1),
torch.pow(eig_V + delta, -1),
)

# Matrix form of the kron factors
U = torch.einsum("ik,k,jk->ij", vec_U, inv_U_eig, vec_U)
V = torch.einsum("ik,k,jk->ij", vec_V, inv_V_eig, vec_V)

# Using the identity of the Matrix Gaussian distribution
# phi is (batch_size, embd_dim), V is (embd_dim, embd_dim), U is (num_classes, num_classes)
# phiVphi is (batch_size,)
phiVphi = torch.einsum("bi,ij,bj->b", phi, V, phi)
f_var = torch.einsum("b,ii->bi", phiVphi, U) # (batch_size, num_classes)

if self.model.last_layer.bias is not None:
# Contribution from the biases
# ----------------------------
eig = self.posterior_precision.eigenvalues[1][0]
vec = self.posterior_precision.eigenvectors[1][0]
delta = self.posterior_precision.deltas[1].sqrt()
inv_eig = torch.pow(eig + delta, -1)

Sigma_bias = torch.einsum("ik,k,ik->i", vec, inv_eig, vec) # (num_classes)
f_var += Sigma_bias.reshape(1, num_classes)

return f_mu, f_var


class DiagLLLaplace(LLLaplace, DiagLaplace):
"""Last-layer Laplace approximation with diagonal log likelihood Hessian approximation
Expand All @@ -405,3 +473,22 @@ class DiagLLLaplace(LLLaplace, DiagLaplace):

# key to map to correct subclass of BaseLaplace, (subset of weights, Hessian structure)
_key = ("last_layer", "diag")

def functional_variance_fast(self, X):
f_mu, phi = self.model.forward_with_features(X)
k = f_mu.shape[-1] # num_classes
b, d = phi.shape # batch_size, embd_dim

# Here, we exploit the fact that J Sigma J.T is (batch) diagonal
# We notice that the param variance is [vars_weight, vars_biases] and
# each functional variance phi^2*var_weight + var_bias
f_var = torch.einsum(
"bd,kd,bd->bk", phi, self.posterior_variance[: d * k].reshape(k, d), phi
)

if self.model.last_layer.bias is not None:
# Add the last num_classes variances, corresponding to the biases' variances
# (b,k) + (1,k) = (b,k)
f_var += self.posterior_variance[-k:].reshape(1, k)

return f_mu, f_var
2 changes: 1 addition & 1 deletion laplace/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def validate(
fitting=True,
)

if type(out) == tuple:
if type(out) is tuple:
if is_offline:
output_means.append(out[0])
output_vars.append(out[1])
Expand Down
42 changes: 42 additions & 0 deletions tests/test_baselaplace.py
Original file line number Diff line number Diff line change
Expand Up @@ -741,6 +741,48 @@ def test_backprop_nn(laplace, model, reg_loader, backend):
assert False


@pytest.mark.parametrize("laplace", [FullLaplace, KronLaplace, DiagLaplace])
def test_reg_glm_predictive_correct_behavior(laplace, model, reg_loader):
X, y = reg_loader.dataset.tensors
n_batch = X.shape[0]
n_outputs = y.shape[-1]

lap = laplace(model, "regression")
lap.fit(reg_loader)

# Joint predictive ignores diagonal_output
f_mean, f_var = lap(X, pred_type="glm", joint=True, diagonal_output=True)
assert f_var.shape == (n_batch * n_outputs, n_batch * n_outputs)

f_mean, f_var = lap(X, pred_type="glm", joint=True, diagonal_output=False)
assert f_var.shape == (n_batch * n_outputs, n_batch * n_outputs)

# diagonal_output affects non-joint
f_mean, f_var = lap(X, pred_type="glm", joint=False, diagonal_output=True)
assert f_var.shape == (n_batch, n_outputs)

f_mean, f_var = lap(X, pred_type="glm", joint=False, diagonal_output=False)
assert f_var.shape == (n_batch, n_outputs, n_outputs)


@pytest.mark.parametrize(
"likelihood,custom_loader",
[
("classification", "custom_loader_clf"),
("regression", "custom_loader_reg"),
("reward_modeling", "custom_loader_clf"),
],
)
def test_dict_data_diagEF_curvlinops_fails(
custom_model, custom_loader, likelihood, request
):
custom_loader = request.getfixturevalue(custom_loader)
lap = DiagLaplace(custom_model, likelihood=likelihood, backend=CurvlinopsEF)

with pytest.raises(ValueError):
lap.fit(custom_loader)


@pytest.mark.parametrize(
"likelihood", ["classification", "regression", "reward_modeling"]
)
Expand Down
62 changes: 62 additions & 0 deletions tests/test_lllaplace.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,13 @@ def model():
return model


@pytest.fixture
def model_no_output_bias():
model = torch.nn.Sequential(nn.Linear(3, 20), nn.Linear(20, 2, bias=False))
setattr(model, "output_size", 2)
return model


@pytest.fixture
def model_with_reduction():
class Model(nn.Module):
Expand Down Expand Up @@ -187,6 +194,7 @@ def test_laplace_init_precision(laplace, model):
setattr(model, "n_params", len(parameters_to_vector(model_params)))
# float
precision = 10.6

laplace(
model, likelihood="regression", prior_precision=precision, last_layer_name="1"
)
Expand Down Expand Up @@ -563,6 +571,36 @@ def test_classification_predictive_samples(laplace, model, class_loader):
assert np.allclose(fsamples.sum().item(), len(f) * 100) # sum up to 1


@pytest.mark.parametrize("laplace", [FullLLLaplace, DiagLLLaplace, KronLLLaplace])
def test_functional_variance_fast(laplace, model, reg_loader):
if laplace == KronLLLaplace:
# TODO still!
return

X, y = reg_loader.dataset.tensors
X.requires_grad = True

lap = laplace(model, "regression", enable_backprop=True)
lap.fit(reg_loader)
f_mu, f_var = lap.functional_variance_fast(X)

assert f_mu.shape == (X.shape[0], y.shape[-1])
assert f_var.shape == (X.shape[0], y.shape[-1])

Js, f_naive = lap.backend.last_layer_jacobians(X)

if laplace == DiagLLLaplace:
f_var_naive = torch.einsum("ncp,p,ncp->nc", Js, lap.posterior_variance, Js)
elif laplace == KronLLLaplace:
f_var_naive = lap.posterior_precision.inv_square_form(Js)
f_var_naive = torch.diagonal(f_var_naive, dim1=-2, dim2=-1)
else: # FullLLaplace
f_var_naive = torch.einsum("ncp,pq,ncq->nc", Js, lap.posterior_covariance, Js)

assert torch.allclose(f_mu, f_naive)
assert torch.allclose(f_var, f_var_naive)


@pytest.mark.parametrize("laplace", flavors)
def test_backprop_glm(laplace, model, reg_loader):
X, y = reg_loader.dataset.tensors
Expand Down Expand Up @@ -637,3 +675,27 @@ def test_backprop_nn(laplace, model, reg_loader):
assert grad_X_var.shape == X.shape
except ValueError:
assert False


@pytest.mark.parametrize("laplace", [FullLLLaplace, KronLLLaplace, DiagLLLaplace])
def test_reg_glm_predictive_correct_behavior(laplace, model, reg_loader):
X, y = reg_loader.dataset.tensors
n_batch = X.shape[0]
n_outputs = y.shape[-1]

lap = laplace(model, "regression")
lap.fit(reg_loader)

# Joint predictive ignores diagonal_output
f_mean, f_var = lap(X, pred_type="glm", joint=True, diagonal_output=True)
assert f_var.shape == (n_batch * n_outputs, n_batch * n_outputs)

f_mean, f_var = lap(X, pred_type="glm", joint=True, diagonal_output=False)
assert f_var.shape == (n_batch * n_outputs, n_batch * n_outputs)

# diagonal_output affects non-joint
f_mean, f_var = lap(X, pred_type="glm", joint=False, diagonal_output=True)
assert f_var.shape == (n_batch, n_outputs)

f_mean, f_var = lap(X, pred_type="glm", joint=False, diagonal_output=False)
assert f_var.shape == (n_batch, n_outputs, n_outputs)

0 comments on commit e3ca2c6

Please sign in to comment.