From c3a0002cb58504ae54e16934af653df516f9cb2e Mon Sep 17 00:00:00 2001 From: Agustinus Kristiadi Date: Mon, 2 Sep 2024 15:34:16 -0400 Subject: [PATCH 1/4] Transform `y` into 2D tensor when `y.ndim == 1` and `likelihood == REGRESSION` --- laplace/baselaplace.py | 9 +++++- tests/test_baselaplace.py | 42 +++++++++++++++++++++++++++ tests/test_functional_laplace_unit.py | 37 +++++++++++++++++++++++ 3 files changed, 87 insertions(+), 1 deletion(-) diff --git a/laplace/baselaplace.py b/laplace/baselaplace.py index c8909b1..321680a 100644 --- a/laplace/baselaplace.py +++ b/laplace/baselaplace.py @@ -846,6 +846,10 @@ def fit( else: X, y = data X, y = X.to(self._device), y.to(self._device) + + if self.likelihood == Likelihood.REGRESSION and y.ndim == 1: + y = y.unsqueeze(-1) + self.model.zero_grad() loss_batch, H_batch = self._curv_closure(X, y, N=N) self.loss += loss_batch @@ -2229,12 +2233,15 @@ def fit( X, y = data X, y = X.to(self._device), y.to(self._device) + if self.likelihood == Likelihood.REGRESSION and y.ndim == 1: + y = y.unsqueeze(1) + Js_batch, f_batch = self._jacobians(X, enable_backprop=False) with torch.no_grad(): loss_batch = self.backend.factor * self.backend.lossfunc(f_batch, y) - if self.likelihood == "regression": + if self.likelihood == Likelihood.REGRESSION: b, C = f_batch.shape lambdas_batch = torch.unsqueeze(torch.eye(C), 0).repeat(b, 1, 1) else: diff --git a/tests/test_baselaplace.py b/tests/test_baselaplace.py index 216dfba..f360ba9 100644 --- a/tests/test_baselaplace.py +++ b/tests/test_baselaplace.py @@ -43,6 +43,16 @@ def model(): return model +@pytest.fixture +def model_1d(): + model = torch.nn.Sequential(nn.Linear(3, 20), nn.Linear(20, 1)) + setattr(model, "output_size", 1) + model_params = list(model.parameters()) + setattr(model, "n_layers", len(model_params)) # number of parameter groups + setattr(model, "n_params", len(parameters_to_vector(model_params))) + return model + + @pytest.fixture def large_model(): model = wide_resnet50_2() @@ -113,6 +123,22 @@ def reg_loader(): return DataLoader(TensorDataset(X, y), batch_size=3) +@pytest.fixture +def reg_loader_1d(): + torch.manual_seed(9999) + X = torch.randn(10, 3) + y = torch.randn(10, 1) + return DataLoader(TensorDataset(X, y), batch_size=3) + + +@pytest.fixture +def reg_loader_1d_flat(): + torch.manual_seed(9999) + X = torch.randn(10, 3) + y = torch.randn((10,)) + return DataLoader(TensorDataset(X, y), batch_size=3) + + @pytest.fixture def custom_loader_clf(): data = [] @@ -818,3 +844,19 @@ def test_gridsearch(model, likelihood, prior_prec_type, reg_loader, class_loader # Should not raise an error lap.optimize_prior_precision(method="gridsearch", val_loader=dataloader, n_steps=10) + + +@pytest.mark.parametrize("laplace", flavors) +def test_prametric_fit_y_shape(model_1d, reg_loader_1d, reg_loader_1d_flat, laplace): + lap = laplace(model_1d, likelihood="regression") + lap.fit(reg_loader_1d) # OK + + lap2 = laplace(model_1d, likelihood="regression") + lap2.fit(reg_loader_1d_flat) # Also OK! + + H1, H2 = lap.H, lap2.H + + if isinstance(H1, KronDecomposed) and isinstance(H2, KronDecomposed): + H1, H2 = H1.to_matrix(), H2.to_matrix() + + assert torch.allclose(H1, H2) diff --git a/tests/test_functional_laplace_unit.py b/tests/test_functional_laplace_unit.py index 296aebb..4c46e0c 100644 --- a/tests/test_functional_laplace_unit.py +++ b/tests/test_functional_laplace_unit.py @@ -14,6 +14,22 @@ def reg_loader(): return DataLoader(TensorDataset(X, y), batch_size=3) +@pytest.fixture +def reg_loader_1d(): + torch.manual_seed(9999) + X = torch.randn(10, 3) + y = torch.randn(10, 1) + return DataLoader(TensorDataset(X, y), batch_size=3) + + +@pytest.fixture +def reg_loader_1d_flat(): + torch.manual_seed(9999) + X = torch.randn(10, 3) + y = torch.randn((10,)) + return DataLoader(TensorDataset(X, y), batch_size=3) + + @pytest.fixture def model(): model = torch.nn.Sequential(nn.Linear(3, 20), nn.Linear(20, 2)) @@ -24,6 +40,16 @@ def model(): return model +@pytest.fixture +def model_1d(): + model = torch.nn.Sequential(nn.Linear(3, 20), nn.Linear(20, 1)) + setattr(model, "output_size", 1) + model_params = list(model.parameters()) + setattr(model, "n_layers", len(model_params)) # number of parameter groups + setattr(model, "n_params", len(parameters_to_vector(model_params))) + return model + + @pytest.fixture def reg_Xy(): torch.manual_seed(711) @@ -286,3 +312,14 @@ def mock_jacobians(self, x): expected_block_diagonal_kernel, block_diag_kernel.to(expected_block_diagonal_kernel.dtype), ) + + +def test_prametric_fit_y_shape(model_1d, reg_loader_1d, reg_loader_1d_flat): + la = FunctionalLaplace(model_1d, "regression", 10, independent_outputs=False) + la.fit(reg_loader_1d) + + la2 = FunctionalLaplace(model_1d, "regression", 10, independent_outputs=False) + la2.fit(reg_loader_1d_flat) + + assert torch.allclose(la.mu, la2.mu) + assert torch.allclose(la.L, la2.L) From d9fe1a6f0376940ec48f09c4a0e75e55353d971c Mon Sep 17 00:00:00 2001 From: Agustinus Kristiadi Date: Mon, 2 Sep 2024 15:41:35 -0400 Subject: [PATCH 2/4] Transform y while supporting multiple leading dimension --- laplace/baselaplace.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/laplace/baselaplace.py b/laplace/baselaplace.py index 321680a..a2cd672 100644 --- a/laplace/baselaplace.py +++ b/laplace/baselaplace.py @@ -847,7 +847,7 @@ def fit( X, y = data X, y = X.to(self._device), y.to(self._device) - if self.likelihood == Likelihood.REGRESSION and y.ndim == 1: + if self.likelihood == Likelihood.REGRESSION and y.ndim == out.ndim - 1: y = y.unsqueeze(-1) self.model.zero_grad() @@ -2233,11 +2233,11 @@ def fit( X, y = data X, y = X.to(self._device), y.to(self._device) - if self.likelihood == Likelihood.REGRESSION and y.ndim == 1: - y = y.unsqueeze(1) - Js_batch, f_batch = self._jacobians(X, enable_backprop=False) + if self.likelihood == Likelihood.REGRESSION and y.ndim == f_batch.ndim - 1: + y = y.unsqueeze(1) + with torch.no_grad(): loss_batch = self.backend.factor * self.backend.lossfunc(f_batch, y) From 0331f05395b3f7d3f997a7e233bc1b4e54edc674 Mon Sep 17 00:00:00 2001 From: Agustinus Kristiadi Date: Thu, 12 Sep 2024 17:35:17 -0400 Subject: [PATCH 3/4] Raise an error if regression and output dim is different than target dim --- laplace/baselaplace.py | 42 +++++++++++++++------------ tests/test_baselaplace.py | 13 +++------ tests/test_functional_laplace_unit.py | 7 ++--- tests/test_laplace.py | 2 +- tests/test_matrix.py | 2 +- tests/test_serialization.py | 2 +- tests/test_subnetlaplace.py | 2 +- tests/test_subset_params.py | 2 +- 8 files changed, 36 insertions(+), 36 deletions(-) diff --git a/laplace/baselaplace.py b/laplace/baselaplace.py index a2cd672..b9aa330 100644 --- a/laplace/baselaplace.py +++ b/laplace/baselaplace.py @@ -847,8 +847,11 @@ def fit( X, y = data X, y = X.to(self._device), y.to(self._device) - if self.likelihood == Likelihood.REGRESSION and y.ndim == out.ndim - 1: - y = y.unsqueeze(-1) + if self.likelihood == Likelihood.REGRESSION and y.ndim != out.ndim: + raise ValueError( + f"The model's output is of shape {tuple(out.shape)} but " + f"the target has shape {tuple(y.shape)}." + ) self.model.zero_grad() loss_batch, H_batch = self._curv_closure(X, y, N=N) @@ -1934,7 +1937,7 @@ class FunctionalLaplace(BaseLaplace): See [Improving predictions of Bayesian neural nets via local linearization (Immer et al., 2021)](https://arxiv.org/abs/2008.08400) for more details. - Note that for `likelihood='classification'`, we approximate \( L_{NN} \\) with a diagonal matrix + Note that for `likelihood='classification'`, we approximate \\( L_{NN} \\) with a diagonal matrix ( \\( L_{NN} \\) is a block-diagonal matrix, where blocks represent Hessians of per-data-point log-likelihood w.r.t. neural network output \\( f \\), See Appendix [A.2.1](https://arxiv.org/abs/2008.08400) for exact definition). We resort to such an approximation because of the (possible) errors found in Laplace approximation for @@ -2027,9 +2030,9 @@ def _check_prior_precision(prior_precision: float | torch.Tensor): def _init_K_MM(self): """Allocates memory for the kernel matrix evaluated at the subset of the training - data points. If the subset is of size \(M\) and the problem has \(C\) outputs, - this is a list of C \((M,M\)) tensors for diagonal kernel and \((M x C, M x C)\) - otherwise. + data points. If the subset is of size \\(M\\) and the problem has \\(C\\) outputs, + this is a list of C \\((M,M\\)) tensors for diagonal kernel and + \\((M \\times C, M \\times C)\\) otherwise. """ if self.independent_outputs: self.K_MM = [ @@ -2044,9 +2047,9 @@ def _init_K_MM(self): def _init_Sigma_inv(self): """Allocates memory for the cholesky decomposition of - \[ - K_{MM} + \Lambda_{MM}^{-1}. - \] + \\[ + K_{MM} + \\Lambda_{MM}^{-1}. + \\] See See [Improving predictions of Bayesian neural nets via local linearization (Immer et al., 2021)](https://arxiv.org/abs/2008.08400) Equation 15 for more information. """ @@ -2119,13 +2122,13 @@ class for more details. def _build_Sigma_inv(self): """Computes the cholesky decomposition of - \[ - K_{MM} + \Lambda_{MM}^{-1}. - \] + \\[ + K_{MM} + \\Lambda_{MM}^{-1}. + \\] See See [Improving predictions of Bayesian neural nets via local linearization (Immer et al., 2021)](https://arxiv.org/abs/2008.08400) Equation 15 for more information. - As the diagonal approximation is performed with \Lambda_{MM} (which is stored in self.L), + As the diagonal approximation is performed with \\(\\Lambda_{MM}\\) (which is stored in self.L), the code is greatly simplified. """ if self.independent_outputs: @@ -2235,8 +2238,11 @@ def fit( Js_batch, f_batch = self._jacobians(X, enable_backprop=False) - if self.likelihood == Likelihood.REGRESSION and y.ndim == f_batch.ndim - 1: - y = y.unsqueeze(1) + if self.likelihood == Likelihood.REGRESSION and y.ndim != out.ndim: + raise ValueError( + f"The model's output is of shape {tuple(out.shape)} but " + f"the target has shape {tuple(y.shape)}." + ) with torch.no_grad(): loss_batch = self.backend.factor * self.backend.lossfunc(f_batch, y) @@ -2559,11 +2565,11 @@ def log_det_ratio(self) -> torch.Tensor: [GP book R&W 2006](http://www.gaussianprocess.org/gpml/chapters/) with (note that we always use diagonal approximation \\(D\\) of the Hessian of log likelihood w.r.t. \\(f\\)): - log determinant term := \\( \log | I + D^{1/2}K D^{1/2} | \\) + log determinant term := \\( \\log | I + D^{1/2}K D^{1/2} | \\) For `regression`, we use ["standard" GP marginal likelihood](https://stats.stackexchange.com/questions/280105/log-marginal-likelihood-for-gaussian-process): - log determinant term := \\( \log | K + \\sigma_2 I | \\) + log determinant term := \\( \\log | K + \\sigma_2 I | \\) """ if self.likelihood == Likelihood.REGRESSION: if self.independent_outputs: @@ -2603,7 +2609,7 @@ def scatter(self, eps: float = 0.00001) -> torch.Tensor: """Compute scatter term in GP log marginal likelihood. For `classification` we use eq. (3.44) from Chapter 3.5 from - [GP book R&W 2006](http://www.gaussianprocess.org/gpml/chapters/) with \\(\hat{f} = f \\): + [GP book R&W 2006](http://www.gaussianprocess.org/gpml/chapters/) with \\(\\hat{f} = f \\): scatter term := \\( f K^{-1} f^{T} \\) diff --git a/tests/test_baselaplace.py b/tests/test_baselaplace.py index f360ba9..5034cee 100644 --- a/tests/test_baselaplace.py +++ b/tests/test_baselaplace.py @@ -24,7 +24,7 @@ from tests.utils import ListDataset, dict_data_collator, jacobians_naive torch.manual_seed(240) -torch.set_default_tensor_type(torch.DoubleTensor) +torch.set_default_dtype(torch.double) flavors = [FullLaplace, KronLaplace, DiagLaplace] if find_spec("asdfghjkl") is not None: @@ -847,16 +847,11 @@ def test_gridsearch(model, likelihood, prior_prec_type, reg_loader, class_loader @pytest.mark.parametrize("laplace", flavors) -def test_prametric_fit_y_shape(model_1d, reg_loader_1d, reg_loader_1d_flat, laplace): +def test_parametric_fit_y_shape(model_1d, reg_loader_1d, reg_loader_1d_flat, laplace): lap = laplace(model_1d, likelihood="regression") lap.fit(reg_loader_1d) # OK lap2 = laplace(model_1d, likelihood="regression") - lap2.fit(reg_loader_1d_flat) # Also OK! - H1, H2 = lap.H, lap2.H - - if isinstance(H1, KronDecomposed) and isinstance(H2, KronDecomposed): - H1, H2 = H1.to_matrix(), H2.to_matrix() - - assert torch.allclose(H1, H2) + with pytest.raises(ValueError): + lap2.fit(reg_loader_1d_flat) diff --git a/tests/test_functional_laplace_unit.py b/tests/test_functional_laplace_unit.py index 4c46e0c..d950c59 100644 --- a/tests/test_functional_laplace_unit.py +++ b/tests/test_functional_laplace_unit.py @@ -314,12 +314,11 @@ def mock_jacobians(self, x): ) -def test_prametric_fit_y_shape(model_1d, reg_loader_1d, reg_loader_1d_flat): +def test_functional_fit_y_shape(model_1d, reg_loader_1d, reg_loader_1d_flat): la = FunctionalLaplace(model_1d, "regression", 10, independent_outputs=False) la.fit(reg_loader_1d) la2 = FunctionalLaplace(model_1d, "regression", 10, independent_outputs=False) - la2.fit(reg_loader_1d_flat) - assert torch.allclose(la.mu, la2.mu) - assert torch.allclose(la.L, la2.L) + with pytest.raises(ValueError): + la2.fit(reg_loader_1d_flat) diff --git a/tests/test_laplace.py b/tests/test_laplace.py index 1aed014..fe1d52e 100644 --- a/tests/test_laplace.py +++ b/tests/test_laplace.py @@ -8,7 +8,7 @@ from laplace.lllaplace import DiagLLLaplace, FullLLLaplace, KronLLLaplace torch.manual_seed(240) -torch.set_default_tensor_type(torch.DoubleTensor) +torch.set_default_dtype(torch.double) flavors = [ FullLaplace, KronLaplace, diff --git a/tests/test_matrix.py b/tests/test_matrix.py index e1701ae..0a4b5cb 100644 --- a/tests/test_matrix.py +++ b/tests/test_matrix.py @@ -9,7 +9,7 @@ from laplace.utils import kron as kron_prod from tests.utils import get_diag_psd_matrix, get_psd_matrix, jacobians_naive -torch.set_default_tensor_type(torch.DoubleTensor) +torch.set_default_dtype(torch.double) @pytest.fixture diff --git a/tests/test_serialization.py b/tests/test_serialization.py index 6151dbf..cbdee0f 100644 --- a/tests/test_serialization.py +++ b/tests/test_serialization.py @@ -26,7 +26,7 @@ ) torch.manual_seed(240) -torch.set_default_tensor_type(torch.DoubleTensor) +torch.set_default_dtype(torch.double) lrlaplace_param = pytest.param( LowRankLaplace, marks=pytest.mark.xfail(reason="Unimplemented in the new ASDL") diff --git a/tests/test_subnetlaplace.py b/tests/test_subnetlaplace.py index 31d15d4..9ad8c10 100644 --- a/tests/test_subnetlaplace.py +++ b/tests/test_subnetlaplace.py @@ -21,7 +21,7 @@ ) torch.manual_seed(240) -torch.set_default_tensor_type(torch.DoubleTensor) +torch.set_default_dtype(torch.double) score_based_subnet_masks = [ RandomSubnetMask, LargestMagnitudeSubnetMask, diff --git a/tests/test_subset_params.py b/tests/test_subset_params.py index ab8a496..2d4ea62 100644 --- a/tests/test_subset_params.py +++ b/tests/test_subset_params.py @@ -12,7 +12,7 @@ from laplace.curvature.curvlinops import CurvlinopsEF, CurvlinopsGGN, CurvlinopsHessian torch.manual_seed(240) -torch.set_default_tensor_type(torch.DoubleTensor) +torch.set_default_dtype(torch.double) flavors = [KronLaplace, DiagLaplace, FullLaplace] valid_backends = [CurvlinopsGGN, CurvlinopsEF, AsdlGGN, AsdlEF] From 99c1ebde76c913c2cc98654eea882f819bea1b57 Mon Sep 17 00:00:00 2001 From: Agustinus Kristiadi Date: Thu, 12 Sep 2024 17:53:23 -0400 Subject: [PATCH 4/4] Add output-dim check for low rank Laplace --- laplace/baselaplace.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/laplace/baselaplace.py b/laplace/baselaplace.py index b9aa330..5e7f366 100644 --- a/laplace/baselaplace.py +++ b/laplace/baselaplace.py @@ -849,8 +849,8 @@ def fit( if self.likelihood == Likelihood.REGRESSION and y.ndim != out.ndim: raise ValueError( - f"The model's output is of shape {tuple(out.shape)} but " - f"the target has shape {tuple(y.shape)}." + f"The model's output has {out.ndim} dims but " + f"the target has {y.ndim} dims." ) self.model.zero_grad() @@ -1768,12 +1768,19 @@ def fit( if not self.enable_backprop: self.mean = self.mean.detach() - X, _ = next(iter(train_loader)) + X, y = next(iter(train_loader)) with torch.no_grad(): try: out = self.model(X[:1].to(self._device)) except (TypeError, AttributeError): out = self.model(X.to(self._device)) + + if self.likelihood == Likelihood.REGRESSION and y.ndim != out.ndim: + raise ValueError( + f"The model's output has {out.ndim} dims but " + f"the target has {y.ndim} dims." + ) + self.n_outputs = out.shape[-1] setattr(self.model, "output_size", self.n_outputs) @@ -2240,8 +2247,8 @@ def fit( if self.likelihood == Likelihood.REGRESSION and y.ndim != out.ndim: raise ValueError( - f"The model's output is of shape {tuple(out.shape)} but " - f"the target has shape {tuple(y.shape)}." + f"The model's output has {out.ndim} dims but " + f"the target has {y.ndim} dims." ) with torch.no_grad():