Skip to content

Commit

Permalink
Merge pull request #240 from aleximmer/target-dim
Browse files Browse the repository at this point in the history
Transform `y` into 2D tensor when `y.ndim == 1` and `likelihood == REGRESSION`
  • Loading branch information
wiseodd authored Sep 14, 2024
2 parents 26b1d19 + 99c1ebd commit 6b0618a
Show file tree
Hide file tree
Showing 8 changed files with 115 additions and 22 deletions.
52 changes: 36 additions & 16 deletions laplace/baselaplace.py
Original file line number Diff line number Diff line change
Expand Up @@ -846,6 +846,13 @@ def fit(
else:
X, y = data
X, y = X.to(self._device), y.to(self._device)

if self.likelihood == Likelihood.REGRESSION and y.ndim != out.ndim:
raise ValueError(
f"The model's output has {out.ndim} dims but "
f"the target has {y.ndim} dims."
)

self.model.zero_grad()
loss_batch, H_batch = self._curv_closure(X, y, N=N)
self.loss += loss_batch
Expand Down Expand Up @@ -1761,12 +1768,19 @@ def fit(
if not self.enable_backprop:
self.mean = self.mean.detach()

X, _ = next(iter(train_loader))
X, y = next(iter(train_loader))
with torch.no_grad():
try:
out = self.model(X[:1].to(self._device))
except (TypeError, AttributeError):
out = self.model(X.to(self._device))

if self.likelihood == Likelihood.REGRESSION and y.ndim != out.ndim:
raise ValueError(
f"The model's output has {out.ndim} dims but "
f"the target has {y.ndim} dims."
)

self.n_outputs = out.shape[-1]
setattr(self.model, "output_size", self.n_outputs)

Expand Down Expand Up @@ -1930,7 +1944,7 @@ class FunctionalLaplace(BaseLaplace):
See [Improving predictions of Bayesian neural nets via local linearization (Immer et al., 2021)](https://arxiv.org/abs/2008.08400)
for more details.
Note that for `likelihood='classification'`, we approximate \( L_{NN} \\) with a diagonal matrix
Note that for `likelihood='classification'`, we approximate \\( L_{NN} \\) with a diagonal matrix
( \\( L_{NN} \\) is a block-diagonal matrix, where blocks represent Hessians of per-data-point log-likelihood w.r.t.
neural network output \\( f \\), See Appendix [A.2.1](https://arxiv.org/abs/2008.08400) for exact definition). We
resort to such an approximation because of the (possible) errors found in Laplace approximation for
Expand Down Expand Up @@ -2023,9 +2037,9 @@ def _check_prior_precision(prior_precision: float | torch.Tensor):

def _init_K_MM(self):
"""Allocates memory for the kernel matrix evaluated at the subset of the training
data points. If the subset is of size \(M\) and the problem has \(C\) outputs,
this is a list of C \((M,M\)) tensors for diagonal kernel and \((M x C, M x C)\)
otherwise.
data points. If the subset is of size \\(M\\) and the problem has \\(C\\) outputs,
this is a list of C \\((M,M\\)) tensors for diagonal kernel and
\\((M \\times C, M \\times C)\\) otherwise.
"""
if self.independent_outputs:
self.K_MM = [
Expand All @@ -2040,9 +2054,9 @@ def _init_K_MM(self):

def _init_Sigma_inv(self):
"""Allocates memory for the cholesky decomposition of
\[
K_{MM} + \Lambda_{MM}^{-1}.
\]
\\[
K_{MM} + \\Lambda_{MM}^{-1}.
\\]
See See [Improving predictions of Bayesian neural nets via local linearization (Immer et al., 2021)](https://arxiv.org/abs/2008.08400)
Equation 15 for more information.
"""
Expand Down Expand Up @@ -2115,13 +2129,13 @@ class for more details.

def _build_Sigma_inv(self):
"""Computes the cholesky decomposition of
\[
K_{MM} + \Lambda_{MM}^{-1}.
\]
\\[
K_{MM} + \\Lambda_{MM}^{-1}.
\\]
See See [Improving predictions of Bayesian neural nets via local linearization (Immer et al., 2021)](https://arxiv.org/abs/2008.08400)
Equation 15 for more information.
As the diagonal approximation is performed with \Lambda_{MM} (which is stored in self.L),
As the diagonal approximation is performed with \\(\\Lambda_{MM}\\) (which is stored in self.L),
the code is greatly simplified.
"""
if self.independent_outputs:
Expand Down Expand Up @@ -2231,10 +2245,16 @@ def fit(

Js_batch, f_batch = self._jacobians(X, enable_backprop=False)

if self.likelihood == Likelihood.REGRESSION and y.ndim != out.ndim:
raise ValueError(
f"The model's output has {out.ndim} dims but "
f"the target has {y.ndim} dims."
)

with torch.no_grad():
loss_batch = self.backend.factor * self.backend.lossfunc(f_batch, y)

if self.likelihood == "regression":
if self.likelihood == Likelihood.REGRESSION:
b, C = f_batch.shape
lambdas_batch = torch.unsqueeze(torch.eye(C), 0).repeat(b, 1, 1)
else:
Expand Down Expand Up @@ -2552,11 +2572,11 @@ def log_det_ratio(self) -> torch.Tensor:
[GP book R&W 2006](http://www.gaussianprocess.org/gpml/chapters/) with
(note that we always use diagonal approximation \\(D\\) of the Hessian of log likelihood w.r.t. \\(f\\)):
log determinant term := \\( \log | I + D^{1/2}K D^{1/2} | \\)
log determinant term := \\( \\log | I + D^{1/2}K D^{1/2} | \\)
For `regression`, we use ["standard" GP marginal likelihood](https://stats.stackexchange.com/questions/280105/log-marginal-likelihood-for-gaussian-process):
log determinant term := \\( \log | K + \\sigma_2 I | \\)
log determinant term := \\( \\log | K + \\sigma_2 I | \\)
"""
if self.likelihood == Likelihood.REGRESSION:
if self.independent_outputs:
Expand Down Expand Up @@ -2596,7 +2616,7 @@ def scatter(self, eps: float = 0.00001) -> torch.Tensor:
"""Compute scatter term in GP log marginal likelihood.
For `classification` we use eq. (3.44) from Chapter 3.5 from
[GP book R&W 2006](http://www.gaussianprocess.org/gpml/chapters/) with \\(\hat{f} = f \\):
[GP book R&W 2006](http://www.gaussianprocess.org/gpml/chapters/) with \\(\\hat{f} = f \\):
scatter term := \\( f K^{-1} f^{T} \\)
Expand Down
39 changes: 38 additions & 1 deletion tests/test_baselaplace.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from tests.utils import ListDataset, dict_data_collator, jacobians_naive

torch.manual_seed(240)
torch.set_default_tensor_type(torch.DoubleTensor)
torch.set_default_dtype(torch.double)

flavors = [FullLaplace, KronLaplace, DiagLaplace]
if find_spec("asdfghjkl") is not None:
Expand All @@ -43,6 +43,16 @@ def model():
return model


@pytest.fixture
def model_1d():
model = torch.nn.Sequential(nn.Linear(3, 20), nn.Linear(20, 1))
setattr(model, "output_size", 1)
model_params = list(model.parameters())
setattr(model, "n_layers", len(model_params)) # number of parameter groups
setattr(model, "n_params", len(parameters_to_vector(model_params)))
return model


@pytest.fixture
def large_model():
model = wide_resnet50_2()
Expand Down Expand Up @@ -113,6 +123,22 @@ def reg_loader():
return DataLoader(TensorDataset(X, y), batch_size=3)


@pytest.fixture
def reg_loader_1d():
torch.manual_seed(9999)
X = torch.randn(10, 3)
y = torch.randn(10, 1)
return DataLoader(TensorDataset(X, y), batch_size=3)


@pytest.fixture
def reg_loader_1d_flat():
torch.manual_seed(9999)
X = torch.randn(10, 3)
y = torch.randn((10,))
return DataLoader(TensorDataset(X, y), batch_size=3)


@pytest.fixture
def custom_loader_clf():
data = []
Expand Down Expand Up @@ -818,3 +844,14 @@ def test_gridsearch(model, likelihood, prior_prec_type, reg_loader, class_loader

# Should not raise an error
lap.optimize_prior_precision(method="gridsearch", val_loader=dataloader, n_steps=10)


@pytest.mark.parametrize("laplace", flavors)
def test_parametric_fit_y_shape(model_1d, reg_loader_1d, reg_loader_1d_flat, laplace):
lap = laplace(model_1d, likelihood="regression")
lap.fit(reg_loader_1d) # OK

lap2 = laplace(model_1d, likelihood="regression")

with pytest.raises(ValueError):
lap2.fit(reg_loader_1d_flat)
36 changes: 36 additions & 0 deletions tests/test_functional_laplace_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,22 @@ def reg_loader():
return DataLoader(TensorDataset(X, y), batch_size=3)


@pytest.fixture
def reg_loader_1d():
torch.manual_seed(9999)
X = torch.randn(10, 3)
y = torch.randn(10, 1)
return DataLoader(TensorDataset(X, y), batch_size=3)


@pytest.fixture
def reg_loader_1d_flat():
torch.manual_seed(9999)
X = torch.randn(10, 3)
y = torch.randn((10,))
return DataLoader(TensorDataset(X, y), batch_size=3)


@pytest.fixture
def model():
model = torch.nn.Sequential(nn.Linear(3, 20), nn.Linear(20, 2))
Expand All @@ -24,6 +40,16 @@ def model():
return model


@pytest.fixture
def model_1d():
model = torch.nn.Sequential(nn.Linear(3, 20), nn.Linear(20, 1))
setattr(model, "output_size", 1)
model_params = list(model.parameters())
setattr(model, "n_layers", len(model_params)) # number of parameter groups
setattr(model, "n_params", len(parameters_to_vector(model_params)))
return model


@pytest.fixture
def reg_Xy():
torch.manual_seed(711)
Expand Down Expand Up @@ -286,3 +312,13 @@ def mock_jacobians(self, x):
expected_block_diagonal_kernel,
block_diag_kernel.to(expected_block_diagonal_kernel.dtype),
)


def test_functional_fit_y_shape(model_1d, reg_loader_1d, reg_loader_1d_flat):
la = FunctionalLaplace(model_1d, "regression", 10, independent_outputs=False)
la.fit(reg_loader_1d)

la2 = FunctionalLaplace(model_1d, "regression", 10, independent_outputs=False)

with pytest.raises(ValueError):
la2.fit(reg_loader_1d_flat)
2 changes: 1 addition & 1 deletion tests/test_laplace.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from laplace.lllaplace import DiagLLLaplace, FullLLLaplace, KronLLLaplace

torch.manual_seed(240)
torch.set_default_tensor_type(torch.DoubleTensor)
torch.set_default_dtype(torch.double)
flavors = [
FullLaplace,
KronLaplace,
Expand Down
2 changes: 1 addition & 1 deletion tests/test_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from laplace.utils import kron as kron_prod
from tests.utils import get_diag_psd_matrix, get_psd_matrix, jacobians_naive

torch.set_default_tensor_type(torch.DoubleTensor)
torch.set_default_dtype(torch.double)


@pytest.fixture
Expand Down
2 changes: 1 addition & 1 deletion tests/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
)

torch.manual_seed(240)
torch.set_default_tensor_type(torch.DoubleTensor)
torch.set_default_dtype(torch.double)

lrlaplace_param = pytest.param(
LowRankLaplace, marks=pytest.mark.xfail(reason="Unimplemented in the new ASDL")
Expand Down
2 changes: 1 addition & 1 deletion tests/test_subnetlaplace.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
)

torch.manual_seed(240)
torch.set_default_tensor_type(torch.DoubleTensor)
torch.set_default_dtype(torch.double)
score_based_subnet_masks = [
RandomSubnetMask,
LargestMagnitudeSubnetMask,
Expand Down
2 changes: 1 addition & 1 deletion tests/test_subset_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from laplace.curvature.curvlinops import CurvlinopsEF, CurvlinopsGGN, CurvlinopsHessian

torch.manual_seed(240)
torch.set_default_tensor_type(torch.DoubleTensor)
torch.set_default_dtype(torch.double)
flavors = [KronLaplace, DiagLaplace, FullLaplace]
valid_backends = [CurvlinopsGGN, CurvlinopsEF, AsdlGGN, AsdlEF]

Expand Down

0 comments on commit 6b0618a

Please sign in to comment.