From c653e1f4a43ce9bee09f3a5fc5c9345606b6da87 Mon Sep 17 00:00:00 2001 From: Agustinus Kristiadi Date: Sat, 27 Apr 2024 11:19:47 -0400 Subject: [PATCH 1/9] Reformat --- examples/bayesopt_example.py | 101 ++++--- examples/calibration_example.py | 13 +- examples/helper/dataloaders.py | 24 +- examples/helper/util.py | 23 +- examples/helper/wideresnet.py | 56 +++- examples/regression_example.py | 53 +++- laplace/__init__.py | 35 ++- laplace/lllaplace.py | 2 +- laplace/subnetlaplace.py | 76 +++-- laplace/utils/__init__.py | 64 +++- laplace/utils/feature_extractor.py | 17 +- laplace/utils/metrics.py | 7 +- laplace/utils/subnetmask.py | 119 ++++++-- laplace/utils/swag.py | 24 +- laplace/utils/utils.py | 70 +++-- setup.py | 2 +- tests/test_curv_backends_backpack.py | 4 +- tests/test_feature_extractor.py | 50 ++-- tests/test_laplace.py | 33 ++- tests/test_lllaplace.py | 205 +++++++++---- tests/test_matrix.py | 48 +-- tests/test_metrics.py | 10 +- tests/test_serialization.py | 56 ++-- tests/test_subnetlaplace.py | 423 ++++++++++++++++++++------- tests/test_subset_params.py | 13 +- tests/test_utils.py | 22 +- 26 files changed, 1101 insertions(+), 449 deletions(-) diff --git a/examples/bayesopt_example.py b/examples/bayesopt_example.py index efe5eecb..2559128c 100644 --- a/examples/bayesopt_example.py +++ b/examples/bayesopt_example.py @@ -1,4 +1,5 @@ import warnings + warnings.filterwarnings('ignore') import numpy as np @@ -10,7 +11,7 @@ from gpytorch import distributions as gdists import torch.utils.data as data_utils from botorch.models.model import Model -from botorch.posteriors.gpytorch import GPyTorchPosterior +from botorch.posteriors.gpytorch import GPyTorchPosterior import tqdm from laplace import Laplace @@ -45,13 +46,13 @@ class LaplaceBNN(Model): """ def __init__( - self, - train_X: torch.Tensor, - train_Y: torch.Tensor, - bnn: Laplace = None, - likelihood: str = 'regression', - batch_size: int = 1024): - + self, + train_X: torch.Tensor, + train_Y: torch.Tensor, + bnn: Laplace = None, + likelihood: str = 'regression', + batch_size: int = 1024, + ): super().__init__() self.train_X = train_X @@ -63,14 +64,13 @@ def __init__( nn.ReLU(), nn.Linear(50, 50), nn.ReLU(), - nn.Linear(50, train_Y.shape[-1]) + nn.Linear(50, train_Y.shape[-1]), ) self.bnn = bnn if self.bnn is None: self._train_model(self._get_train_loader()) - def posterior( self, X: torch.Tensor, @@ -84,7 +84,7 @@ def posterior( """ # Transform to `(batch_shape*q, d)` B, Q, D = X.shape - X = X.reshape(B*Q, D) + X = X.reshape(B * Q, D) # Posterior predictive distribution # mean_y is (batch_shape*q, k); cov_y is (batch_shape*q*k, batch_shape*q*k) @@ -92,13 +92,13 @@ def posterior( # Mean in `(batch_shape, q*k)` K = self.num_outputs - mean_y = mean_y.reshape(B, Q*K) + mean_y = mean_y.reshape(B, Q * K) # Cov is `(batch_shape, q*k, q*k)` - cov_y += 1e-4*torch.eye(B*Q*K) + cov_y += 1e-4 * torch.eye(B * Q * K) cov_y = cov_y.reshape(B, Q, K, B, Q, K) cov_y = torch.einsum('bqkbrl->bqkrl', cov_y) # (B, Q, K, Q, K) - cov_y = cov_y.reshape(B, Q*K, Q*K) + cov_y = cov_y.reshape(B, Q * K, Q * K) dist = gdists.MultivariateNormal(mean_y, covariance_matrix=cov_y) post_pred = GPyTorchPosterior(dist) @@ -108,8 +108,9 @@ def posterior( return post_pred - - def condition_on_observations(self, X: torch.Tensor, Y: torch.Tensor, **kwargs: Any) -> Model: + def condition_on_observations( + self, X: torch.Tensor, Y: torch.Tensor, **kwargs: Any + ) -> Model: self.train_X = torch.cat([self.train_X, X], dim=0) self.train_Y = torch.cat([self.train_Y, Y], dim=0) @@ -118,8 +119,11 @@ def condition_on_observations(self, X: torch.Tensor, Y: torch.Tensor, **kwargs: return LaplaceBNN( # Added dataset & retrained BNN - self.train_X, self.train_Y, self.bnn, - self.likelihood, self.batch_size, + self.train_X, + self.train_Y, + self.bnn, + self.likelihood, + self.batch_size, ) @property @@ -133,7 +137,9 @@ def _train_model(self, train_loader): """ n_epochs = 1000 optimizer = optim.Adam(self.nn.parameters(), lr=1e-1, weight_decay=1e-3) - scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, n_epochs*len(train_loader)) + scheduler = optim.lr_scheduler.CosineAnnealingLR( + optimizer, n_epochs * len(train_loader) + ) loss_func = nn.MSELoss() for i in range(n_epochs): @@ -148,14 +154,15 @@ def _train_model(self, train_loader): self.nn.eval() self.bnn = Laplace( - self.nn, self.likelihood, - subset_of_weights='all', hessian_structure='kron', - enable_backprop=True + self.nn, + self.likelihood, + subset_of_weights='all', + hessian_structure='kron', + enable_backprop=True, ) self.bnn.fit(train_loader) self.bnn.optimize_prior_precision(n_steps=50) - def _get_prediction(self, test_x: torch.Tensor, joint=True, use_test_loader=False): """ Batched Laplace prediction. @@ -175,7 +182,7 @@ def _get_prediction(self, test_x: torch.Tensor, joint=True, use_test_loader=Fals else: test_loader = data_utils.DataLoader( data_utils.TensorDataset(test_x, torch.zeros_like(test_x)), - batch_size=256 + batch_size=256, ) mean_y, cov_y = [], [] @@ -190,13 +197,14 @@ def _get_prediction(self, test_x: torch.Tensor, joint=True, use_test_loader=Fals return mean_y, cov_y - def _get_train_loader(self): return data_utils.DataLoader( data_utils.TensorDataset(self.train_X, self.train_Y), - batch_size=self.batch_size, shuffle=True + batch_size=self.batch_size, + shuffle=True, ) + if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--test_func', choices=['ackley', 'branin'], default='branin') @@ -227,25 +235,30 @@ def _get_train_loader(self): train_data_points = 20 - train_x = torch.cat([ - dists.Uniform(*bounds.T[i]).sample((train_data_points, 1)) - for i in range(2) # for each dimension - ], dim=1) + train_x = torch.cat( + [ + dists.Uniform(*bounds.T[i]).sample((train_data_points, 1)) + for i in range(2) # for each dimension + ], + dim=1, + ) train_y = true_f(train_x).reshape(-1, 1) - test_x = torch.cat([ - dists.Uniform(*bounds.T[i]).sample((10000, 1)) - for i in range(2) # for each dimension - ], dim=1) + test_x = torch.cat( + [ + dists.Uniform(*bounds.T[i]).sample((10000, 1)) + for i in range(2) # for each dimension + ], + dim=1, + ) test_y = true_f(test_x) models = { - 'RandomSearch': None, + 'RandomSearch': None, 'BNN-LA': LaplaceBNN(train_x, train_y), 'GP': SingleTaskGP(train_x, train_y), } - def evaluate_model(model_name, model): if model_name == 'GP': pred = model.posterior(test_x).mean.squeeze() @@ -256,7 +269,6 @@ def evaluate_model(model_name, model): return F.mse_loss(pred, test_y).squeeze().item() - for model_name, model in models.items(): np.random.seed(args.randseed) torch.set_default_dtype(torch.float64) @@ -271,10 +283,13 @@ def evaluate_model(model_name, model): for i in pbar: if model_name == 'RandomSearch': - new_x = torch.cat([ - dists.Uniform(*bounds.T[i]).sample((1, 1)) - for i in range(len(bounds)) # for each dimension - ], dim=1).squeeze() + new_x = torch.cat( + [ + dists.Uniform(*bounds.T[i]).sample((1, 1)) + for i in range(len(bounds)) # for each dimension + ], + dim=1, + ).squeeze() else: if args.acqf == 'EI': acq_f = ExpectedImprovement(model, best_f=best_y, maximize=False) @@ -292,12 +307,12 @@ def evaluate_model(model_name, model): bounds=bounds, q=1 if args.acqf not in ['qEI'] else 5, num_restarts=10, - raw_samples=20 + raw_samples=20, ) if len(new_x.shape) == 1: new_x = new_x.unsqueeze(0) - + # Evaluate the objective on the proposed x new_y = true_f(new_x).unsqueeze(-1) # (q, 1) diff --git a/examples/calibration_example.py b/examples/calibration_example.py index 36f6cd85..3fa7cb80 100644 --- a/examples/calibration_example.py +++ b/examples/calibration_example.py @@ -1,5 +1,6 @@ import warnings -warnings.simplefilter("ignore", UserWarning) + +warnings.simplefilter('ignore', UserWarning) import torch import torch.distributions as dists @@ -50,9 +51,9 @@ def predict(dataloader, model, laplace=False): print(f'[MAP] Acc.: {acc_map:.1%}; ECE: {ece_map:.1%}; NLL: {nll_map:.3}') # Laplace -la = Laplace(model, 'classification', - subset_of_weights='last_layer', - hessian_structure='kron') +la = Laplace( + model, 'classification', subset_of_weights='last_layer', hessian_structure='kron' +) la.fit(train_loader) la.optimize_prior_precision(method='marglik') @@ -61,4 +62,6 @@ def predict(dataloader, model, laplace=False): ece_laplace = ECE(bins=15).measure(probs_laplace.numpy(), targets.numpy()) nll_laplace = -dists.Categorical(probs_laplace).log_prob(targets).mean() -print(f'[Laplace] Acc.: {acc_laplace:.1%}; ECE: {ece_laplace:.1%}; NLL: {nll_laplace:.3}') +print( + f'[Laplace] Acc.: {acc_laplace:.1%}; ECE: {ece_laplace:.1%}; NLL: {nll_laplace:.3}' +) diff --git a/examples/helper/dataloaders.py b/examples/helper/dataloaders.py index 0b008fa9..e14b7c5d 100644 --- a/examples/helper/dataloaders.py +++ b/examples/helper/dataloaders.py @@ -13,22 +13,26 @@ def CIFAR10(train=True, batch_size=None, augm_flag=True): if batch_size == None: if train: - batch_size=train_batch_size + batch_size = train_batch_size else: - batch_size=test_batch_size + batch_size = test_batch_size transform_base = [transforms.ToTensor()] - transform_train = transforms.Compose([ - transforms.RandomHorizontalFlip(), - transforms.RandomCrop(32, padding=4, padding_mode='reflect'), - ] + transform_base) + transform_train = transforms.Compose( + [ + transforms.RandomHorizontalFlip(), + transforms.RandomCrop(32, padding=4, padding_mode='reflect'), + ] + + transform_base + ) transform_test = transforms.Compose(transform_base) transform_train = transforms.RandomChoice([transform_train, transform_test]) transform = transform_train if (augm_flag and train) else transform_test dataset = datasets.CIFAR10(path, train=train, transform=transform, download=True) - loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, - shuffle=train, num_workers=4) + loader = torch.utils.data.DataLoader( + dataset, batch_size=batch_size, shuffle=train, num_workers=4 + ) return loader @@ -38,9 +42,7 @@ def get_sinusoid_example(n_data=150, sigma_noise=0.3, batch_size=150): X_train = (torch.rand(n_data) * 8).unsqueeze(-1) y_train = torch.sin(X_train) + torch.randn_like(X_train) * sigma_noise train_loader = data_utils.DataLoader( - data_utils.TensorDataset(X_train, y_train), - batch_size=batch_size + data_utils.TensorDataset(X_train, y_train), batch_size=batch_size ) X_test = torch.linspace(-5, 13, 500).unsqueeze(-1) return X_train, y_train, train_loader, X_test - \ No newline at end of file diff --git a/examples/helper/util.py b/examples/helper/util.py index f156ce4b..431cc35e 100644 --- a/examples/helper/util.py +++ b/examples/helper/util.py @@ -9,13 +9,16 @@ def download_pretrained_model(): if not os.path.exists('./temp'): os.makedirs('./temp') - urllib.request.urlretrieve('https://nc.mlcloud.uni-tuebingen.de/index.php/s/2PBDYDsiotN76mq/download', './temp/CIFAR10_plain.pt') + urllib.request.urlretrieve( + 'https://nc.mlcloud.uni-tuebingen.de/index.php/s/2PBDYDsiotN76mq/download', + './temp/CIFAR10_plain.pt', + ) -def plot_regression(X_train, y_train, X_test, f_test, y_std, plot=True, - file_name='regression_example'): - fig, (ax1, ax2) = plt.subplots(nrows=1, ncols=2, sharey=True, - figsize=(4.5, 2.8)) +def plot_regression( + X_train, y_train, X_test, f_test, y_std, plot=True, file_name='regression_example' +): + fig, (ax1, ax2) = plt.subplots(nrows=1, ncols=2, sharey=True, figsize=(4.5, 2.8)) ax1.set_title('MAP') ax1.scatter(X_train.flatten(), y_train.flatten(), alpha=0.3, color='tab:orange') ax1.plot(X_test, f_test, color='black', label='$f_{MAP}$') @@ -24,8 +27,14 @@ def plot_regression(X_train, y_train, X_test, f_test, y_std, plot=True, ax2.set_title('LA') ax2.scatter(X_train.flatten(), y_train.flatten(), alpha=0.3, color='tab:orange') ax2.plot(X_test, f_test, label='$\mathbb{E}[f]$') - ax2.fill_between(X_test, f_test-y_std*2, f_test+y_std*2, - alpha=0.3, color='tab:blue', label='$2\sqrt{\mathbb{V}\,[y]}$') + ax2.fill_between( + X_test, + f_test - y_std * 2, + f_test + y_std * 2, + alpha=0.3, + color='tab:blue', + label='$2\sqrt{\mathbb{V}\,[y]}$', + ) ax2.legend() ax1.set_ylim([-4, 6]) ax1.set_xlim([X_test.min(), X_test.max()]) diff --git a/examples/helper/wideresnet.py b/examples/helper/wideresnet.py index 387d596e..a3946c5a 100644 --- a/examples/helper/wideresnet.py +++ b/examples/helper/wideresnet.py @@ -10,14 +10,29 @@ def __init__(self, in_planes, out_planes, stride, dropRate=0.0): self.bn1 = nn.BatchNorm2d(in_planes) self.relu1 = nn.ReLU(inplace=True) - self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) + self.conv1 = nn.Conv2d( + in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False + ) self.bn2 = nn.BatchNorm2d(out_planes) self.relu2 = nn.ReLU(inplace=True) - self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1, padding=1, bias=False) + self.conv2 = nn.Conv2d( + out_planes, out_planes, kernel_size=3, stride=1, padding=1, bias=False + ) self.droprate = dropRate - self.equalInOut = (in_planes == out_planes) - self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=0, bias=False) or None + self.equalInOut = in_planes == out_planes + self.convShortcut = ( + (not self.equalInOut) + and nn.Conv2d( + in_planes, + out_planes, + kernel_size=1, + stride=stride, + padding=0, + bias=False, + ) + or None + ) def forward(self, x): if not self.equalInOut: @@ -44,13 +59,22 @@ def forward(self, x): class NetworkBlock(nn.Module): def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0): super(NetworkBlock, self).__init__() - self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, dropRate) + self.layer = self._make_layer( + block, in_planes, out_planes, nb_layers, stride, dropRate + ) def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropRate): layers = [] for i in range(nb_layers): - layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropRate)) + layers.append( + block( + i == 0 and in_planes or out_planes, + out_planes, + i == 0 and stride or 1, + dropRate, + ) + ) return nn.Sequential(*layers) @@ -59,17 +83,26 @@ def forward(self, x): class WideResNet(nn.Module): - - def __init__(self, depth, widen_factor, num_classes, num_channel=3, dropRate=0.3, feature_extractor=False): + def __init__( + self, + depth, + widen_factor, + num_classes, + num_channel=3, + dropRate=0.3, + feature_extractor=False, + ): super(WideResNet, self).__init__() nChannels = [16, 16 * widen_factor, 32 * widen_factor, 64 * widen_factor] - assert ((depth - 4) % 6 == 0) + assert (depth - 4) % 6 == 0 n = (depth - 4) // 6 block = BasicBlock # 1st conv before any network block - self.conv1 = nn.Conv2d(num_channel, nChannels[0], kernel_size=3, stride=1, padding=1, bias=False) + self.conv1 = nn.Conv2d( + num_channel, nChannels[0], kernel_size=3, stride=1, padding=1, bias=False + ) self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate) self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropRate) self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropRate) @@ -84,7 +117,7 @@ def __init__(self, depth, widen_factor, num_classes, num_channel=3, dropRate=0.3 for m in self.modules(): if isinstance(m, nn.Conv2d): n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels - m.weight.data.normal_(0, math.sqrt(2. / n)) + m.weight.data.normal_(0, math.sqrt(2.0 / n)) elif isinstance(m, nn.BatchNorm2d): m.weight.data.fill_(1) m.bias.data.zero_() @@ -99,7 +132,6 @@ def forward(self, x): return self.fc(out) - def features(self, x): out = self.conv1(x) out = self.block1(out) diff --git a/examples/regression_example.py b/examples/regression_example.py index 2eb5e13e..e1932366 100644 --- a/examples/regression_example.py +++ b/examples/regression_example.py @@ -15,12 +15,15 @@ # create toy regression data X_train, y_train, train_loader, X_test = get_sinusoid_example(sigma_noise=0.3) + # construct single layer neural network def get_model(): torch.manual_seed(711) return torch.nn.Sequential( torch.nn.Linear(1, 50), torch.nn.Tanh(), torch.nn.Linear(50, 1) ) + + model = get_model() # train MAP @@ -35,11 +38,14 @@ def get_model(): la = Laplace(model, 'regression', subset_of_weights='all', hessian_structure='full') la.fit(train_loader) -log_prior, log_sigma = torch.ones(1, requires_grad=True), torch.ones(1, requires_grad=True) +log_prior, log_sigma = ( + torch.ones(1, requires_grad=True), + torch.ones(1, requires_grad=True), +) hyper_optimizer = torch.optim.Adam([log_prior, log_sigma], lr=1e-1) for i in range(n_epochs): hyper_optimizer.zero_grad() - neg_marglik = - la.log_marginal_likelihood(log_prior.exp(), log_sigma.exp()) + neg_marglik = -la.log_marginal_likelihood(log_prior.exp(), log_sigma.exp()) neg_marglik.backward() hyper_optimizer.step() @@ -51,8 +57,10 @@ def get_model(): # Load serialized, fitted quantities la.load_state_dict(torch.load('state_dict.bin')) -print(f'sigma={la.sigma_noise.item():.2f}', - f'prior precision={la.prior_precision.item():.2f}') +print( + f'sigma={la.sigma_noise.item():.2f}', + f'prior precision={la.prior_precision.item():.2f}', +) x = X_test.flatten().cpu().numpy() @@ -71,25 +79,40 @@ def get_model(): f_mu = f_mu.squeeze().detach().cpu().numpy() f_sigma = f_var.squeeze().detach().sqrt().cpu().numpy() -pred_std = np.sqrt(f_sigma**2 + la.sigma_noise.item()**2) +pred_std = np.sqrt(f_sigma**2 + la.sigma_noise.item() ** 2) -plot_regression(X_train, y_train, x, f_mu, pred_std, - file_name='regression_example', plot=True) +plot_regression( + X_train, y_train, x, f_mu, pred_std, file_name='regression_example', plot=True +) # alternatively, optimize parameters and hyperparameters of the prior jointly model = get_model() la, model, margliks, losses = marglik_training( - model=model, train_loader=train_loader, likelihood='regression', - hessian_structure='full', backend=BackPackGGN, n_epochs=n_epochs, - optimizer_kwargs={'lr': 1e-2}, prior_structure='scalar' + model=model, + train_loader=train_loader, + likelihood='regression', + hessian_structure='full', + backend=BackPackGGN, + n_epochs=n_epochs, + optimizer_kwargs={'lr': 1e-2}, + prior_structure='scalar', ) -print(f'sigma={la.sigma_noise.item():.2f}', - f'prior precision={la.prior_precision.numpy()}') +print( + f'sigma={la.sigma_noise.item():.2f}', + f'prior precision={la.prior_precision.numpy()}', +) f_mu, f_var = la(X_test) f_mu = f_mu.squeeze().detach().cpu().numpy() f_sigma = f_var.squeeze().sqrt().cpu().numpy() -pred_std = np.sqrt(f_sigma**2 + la.sigma_noise.item()**2) -plot_regression(X_train, y_train, x, f_mu, pred_std, - file_name='regression_example_online', plot=False) +pred_std = np.sqrt(f_sigma**2 + la.sigma_noise.item() ** 2) +plot_regression( + X_train, + y_train, + x, + f_mu, + pred_std, + file_name='regression_example_online', + plot=False, +) diff --git a/laplace/__init__.py b/laplace/__init__.py index f7149e24..0d338f91 100644 --- a/laplace/__init__.py +++ b/laplace/__init__.py @@ -4,20 +4,37 @@ .. include:: ../examples/regression_example.md .. include:: ../examples/calibration_example.md """ + REGRESSION = 'regression' CLASSIFICATION = 'classification' -from laplace.baselaplace import BaseLaplace, ParametricLaplace, FullLaplace, KronLaplace, DiagLaplace, LowRankLaplace +from laplace.baselaplace import ( + BaseLaplace, + ParametricLaplace, + FullLaplace, + KronLaplace, + DiagLaplace, + LowRankLaplace, +) from laplace.lllaplace import LLLaplace, FullLLLaplace, KronLLLaplace, DiagLLLaplace from laplace.subnetlaplace import SubnetLaplace, FullSubnetLaplace, DiagSubnetLaplace from laplace.laplace import Laplace from laplace.marglik_training import marglik_training -__all__ = ['Laplace', # direct access to all Laplace classes via unified interface - 'BaseLaplace', 'ParametricLaplace', # base-class and its (first-level) subclasses - 'FullLaplace', 'KronLaplace', 'DiagLaplace', 'LowRankLaplace', # all-weights - 'LLLaplace', # base-class last-layer - 'FullLLLaplace', 'KronLLLaplace', 'DiagLLLaplace', # last-layer - 'SubnetLaplace', # base-class subnetwork - 'FullSubnetLaplace', 'DiagSubnetLaplace', # subnetwork - 'marglik_training'] # methods +__all__ = [ + 'Laplace', # direct access to all Laplace classes via unified interface + 'BaseLaplace', + 'ParametricLaplace', # base-class and its (first-level) subclasses + 'FullLaplace', + 'KronLaplace', + 'DiagLaplace', + 'LowRankLaplace', # all-weights + 'LLLaplace', # base-class last-layer + 'FullLLLaplace', + 'KronLLLaplace', + 'DiagLLLaplace', # last-layer + 'SubnetLaplace', # base-class subnetwork + 'FullSubnetLaplace', + 'DiagSubnetLaplace', # subnetwork + 'marglik_training', +] # methods diff --git a/laplace/lllaplace.py b/laplace/lllaplace.py index c08e0bb5..54842c13 100644 --- a/laplace/lllaplace.py +++ b/laplace/lllaplace.py @@ -279,7 +279,7 @@ def __init__( backend=None, last_layer_name=None, damping=False, - **backend_kwargs + **backend_kwargs, ): self.damping = damping super().__init__( diff --git a/laplace/subnetlaplace.py b/laplace/subnetlaplace.py index d4e61bd7..f92a571a 100644 --- a/laplace/subnetlaplace.py +++ b/laplace/subnetlaplace.py @@ -66,16 +66,37 @@ class SubnetLaplace(ParametricLaplace): arguments passed to the backend on initialization, for example to set the number of MC samples for stochastic approximations. """ - def __init__(self, model, likelihood, subnetwork_indices, sigma_noise=1., prior_precision=1., - prior_mean=0., temperature=1., backend=None, backend_kwargs=None, asdl_fisher_kwargs=None): + + def __init__( + self, + model, + likelihood, + subnetwork_indices, + sigma_noise=1.0, + prior_precision=1.0, + prior_mean=0.0, + temperature=1.0, + backend=None, + backend_kwargs=None, + asdl_fisher_kwargs=None, + ): if asdl_fisher_kwargs is not None: raise ValueError('Subnetwork Laplace does not support asdl_fisher_kwargs.') self.H = None - super().__init__(model, likelihood, sigma_noise=sigma_noise, - prior_precision=prior_precision, prior_mean=prior_mean, - temperature=temperature, backend=backend, backend_kwargs=backend_kwargs) + super().__init__( + model, + likelihood, + sigma_noise=sigma_noise, + prior_precision=prior_precision, + prior_mean=prior_mean, + temperature=temperature, + backend=backend, + backend_kwargs=backend_kwargs, + ) if backend is not None: - if not isinstance(backend, GGNInterface) and not isinstance(backend, EFInterface): + if not isinstance(backend, GGNInterface) and not isinstance( + backend, EFInterface + ): raise ValueError('SubnetLaplace can only be used with GGN and EF.') # check validity of subnetwork indices and pass them to backend self._check_subnetwork_indices(subnetwork_indices) @@ -85,17 +106,28 @@ def __init__(self, model, likelihood, subnetwork_indices, sigma_noise=1., prior_ def _check_subnetwork_indices(self, subnetwork_indices): """Check that subnetwork indices are valid indices of the vectorized model parameters - (i.e. `torch.nn.utils.parameters_to_vector(model.parameters())`). + (i.e. `torch.nn.utils.parameters_to_vector(model.parameters())`). """ if subnetwork_indices is None: raise ValueError('Subnetwork indices cannot be None.') - elif not ((isinstance(subnetwork_indices, torch.LongTensor) or - isinstance(subnetwork_indices, torch.cuda.LongTensor)) and - subnetwork_indices.numel() > 0 and len(subnetwork_indices.shape) == 1): - raise ValueError('Subnetwork indices must be non-empty 1-dimensional torch.LongTensor.') - elif not (len(subnetwork_indices[subnetwork_indices < 0]) == 0 and - len(subnetwork_indices[subnetwork_indices >= self.n_params]) == 0): - raise ValueError(f'Subnetwork indices must lie between 0 and n_params={self.n_params}.') + elif not ( + ( + isinstance(subnetwork_indices, torch.LongTensor) + or isinstance(subnetwork_indices, torch.cuda.LongTensor) + ) + and subnetwork_indices.numel() > 0 + and len(subnetwork_indices.shape) == 1 + ): + raise ValueError( + 'Subnetwork indices must be non-empty 1-dimensional torch.LongTensor.' + ) + elif not ( + len(subnetwork_indices[subnetwork_indices < 0]) == 0 + and len(subnetwork_indices[subnetwork_indices >= self.n_params]) == 0 + ): + raise ValueError( + f'Subnetwork indices must lie between 0 and n_params={self.n_params}.' + ) elif not (len(subnetwork_indices.unique()) == len(subnetwork_indices)): raise ValueError('Subnetwork indices must not contain duplicate entries.') @@ -109,7 +141,9 @@ def prior_precision_diag(self): prior_precision_diag : torch.Tensor """ if len(self.prior_precision) == 1: # scalar - return self.prior_precision * torch.ones(self.n_params_subnet, device=self._device) + return self.prior_precision * torch.ones( + self.n_params_subnet, device=self._device + ) elif len(self.prior_precision) == self.n_params_subnet: # diagonal return self.prior_precision @@ -123,7 +157,7 @@ def mean_subnet(self): @property def scatter(self): - delta = (self.mean_subnet - self.prior_mean) + delta = self.mean_subnet - self.prior_mean return (delta * self.prior_precision_diag) @ delta def assemble_full_samples(self, subnet_samples): @@ -139,11 +173,14 @@ class FullSubnetLaplace(SubnetLaplace, FullLaplace): Mathematically, we have \\(P \\in \\mathbb{R}^{P \\times P}\\). See `FullLaplace`, `SubnetLaplace`, and `BaseLaplace` for the full interface. """ + # key to map to correct subclass of BaseLaplace, (subset of weights, Hessian structure) _key = ('subnetwork', 'full') def _init_H(self): - self.H = torch.zeros(self.n_params_subnet, self.n_params_subnet, device=self._device) + self.H = torch.zeros( + self.n_params_subnet, self.n_params_subnet, device=self._device + ) def sample(self, n_samples=100, generator=None): # sample only subnetwork parameters and set all other parameters to their MAP estimates @@ -158,6 +195,7 @@ class DiagSubnetLaplace(SubnetLaplace, DiagLaplace): Mathematically, we have \\(P \\approx \\textrm{diag}(P)\\). See `DiagLaplace`, `SubnetLaplace`, and `BaseLaplace` for the full interface. """ + # key to map to correct subclass of BaseLaplace, (subset of weights, Hessian structure) _key = ('subnetwork', 'diag') @@ -175,7 +213,9 @@ def _check_jacobians(self, Js): def sample(self, n_samples=100, generator=None): # sample only subnetwork parameters and set all other parameters to their MAP estimates - samples = torch.randn(n_samples, self.n_params_subnet, device=self._device, generator=generator) + samples = torch.randn( + n_samples, self.n_params_subnet, device=self._device, generator=generator + ) samples = samples * self.posterior_scale.reshape(1, self.n_params_subnet) subnet_samples = self.mean_subnet.reshape(1, self.n_params_subnet) + samples return self.assemble_full_samples(subnet_samples) diff --git a/laplace/utils/__init__.py b/laplace/utils/__init__.py index 9e979119..fbe0d8ed 100644 --- a/laplace/utils/__init__.py +++ b/laplace/utils/__init__.py @@ -1,17 +1,59 @@ -from laplace.utils.utils import get_nll, validate, parameters_per_layer, invsqrt_precision, _is_batchnorm, _is_valid_scalar, kron, diagonal_add_scalar, symeig, block_diag, expand_prior_precision, normal_samples, fix_prior_prec_structure +from laplace.utils.utils import ( + get_nll, + validate, + parameters_per_layer, + invsqrt_precision, + _is_batchnorm, + _is_valid_scalar, + kron, + diagonal_add_scalar, + symeig, + block_diag, + expand_prior_precision, + normal_samples, + fix_prior_prec_structure, +) from laplace.utils.feature_extractor import FeatureExtractor from laplace.utils.matrix import Kron, KronDecomposed from laplace.utils.swag import fit_diagonal_swag_var -from laplace.utils.subnetmask import SubnetMask, RandomSubnetMask, LargestMagnitudeSubnetMask, LargestVarianceDiagLaplaceSubnetMask, LargestVarianceSWAGSubnetMask, ParamNameSubnetMask, ModuleNameSubnetMask, LastLayerSubnetMask +from laplace.utils.subnetmask import ( + SubnetMask, + RandomSubnetMask, + LargestMagnitudeSubnetMask, + LargestVarianceDiagLaplaceSubnetMask, + LargestVarianceSWAGSubnetMask, + ParamNameSubnetMask, + ModuleNameSubnetMask, + LastLayerSubnetMask, +) from laplace.utils.metrics import RunningNLLMetric -__all__ = ['get_nll', 'validate', 'parameters_per_layer', 'invsqrt_precision', 'kron', - 'diagonal_add_scalar', 'symeig', 'block_diag', 'normal_samples', - '_is_batchnorm', '_is_valid_scalar', - 'expand_prior_precision', 'fix_prior_prec_structure', - 'FeatureExtractor', - 'Kron', 'KronDecomposed', - 'fit_diagonal_swag_var', - 'SubnetMask', 'RandomSubnetMask', 'LargestMagnitudeSubnetMask', 'LargestVarianceDiagLaplaceSubnetMask', - 'LargestVarianceSWAGSubnetMask', 'ParamNameSubnetMask', 'ModuleNameSubnetMask', 'LastLayerSubnetMask', 'RunningNLLMetric'] +__all__ = [ + 'get_nll', + 'validate', + 'parameters_per_layer', + 'invsqrt_precision', + 'kron', + 'diagonal_add_scalar', + 'symeig', + 'block_diag', + 'normal_samples', + '_is_batchnorm', + '_is_valid_scalar', + 'expand_prior_precision', + 'fix_prior_prec_structure', + 'FeatureExtractor', + 'Kron', + 'KronDecomposed', + 'fit_diagonal_swag_var', + 'SubnetMask', + 'RandomSubnetMask', + 'LargestMagnitudeSubnetMask', + 'LargestVarianceDiagLaplaceSubnetMask', + 'LargestVarianceSWAGSubnetMask', + 'ParamNameSubnetMask', + 'ModuleNameSubnetMask', + 'LastLayerSubnetMask', + 'RunningNLLMetric', +] diff --git a/laplace/utils/feature_extractor.py b/laplace/utils/feature_extractor.py index 95a176fc..e93c9118 100644 --- a/laplace/utils/feature_extractor.py +++ b/laplace/utils/feature_extractor.py @@ -24,9 +24,13 @@ class FeatureExtractor(nn.Module): if the name of the last layer is already known, otherwise it will be determined automatically. """ + def __init__( - self, model: nn.Module, last_layer_name: Optional[str] = None, - enable_backprop: bool = False) -> None: + self, + model: nn.Module, + last_layer_name: Optional[str] = None, + enable_backprop: bool = False, + ) -> None: super().__init__() self.model = model self._features = dict() @@ -54,7 +58,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: out = self.model(x) return out - def forward_with_features(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + def forward_with_features( + self, x: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: """Forward pass which returns the output of the penultimate layer along with the output of the last layer. If the last layer is not known yet, it will be determined when this function is called for the first time. @@ -90,9 +96,10 @@ def _get_hook(self, name: str) -> Callable: def hook(_, input, __): # only accepts one input (expects linear layer) self._features[name] = input[0] - + if not self.enable_backprop: self._features[name] = self._features[name].detach() + return hook def find_last_layer(self, x: torch.Tensor) -> torch.Tensor: @@ -112,6 +119,7 @@ def find_last_layer(self, x: torch.Tensor) -> torch.Tensor: raise ValueError('Last layer is already known.') act_out = dict() + def get_act_hook(name): def act_hook(_, input, __): # only accepts one input (expects linear layer) @@ -121,6 +129,7 @@ def act_hook(_, input, __): act_out[name] = None # remove hook handles[name].remove() + return act_hook # set hooks for all modules diff --git a/laplace/utils/metrics.py b/laplace/utils/metrics.py index 6a29d3e7..d327243f 100644 --- a/laplace/utils/metrics.py +++ b/laplace/utils/metrics.py @@ -12,10 +12,13 @@ class RunningNLLMetric(Metric): ignore_index: int, default = -100 which class label to ignore when computing the NLL loss """ + def __init__(self, ignore_index=-100): super().__init__() - self.add_state('nll_sum', default=torch.tensor(0.), dist_reduce_fx='sum') - self.add_state('n_valid_labels', default=torch.tensor(0.), dist_reduce_fx='sum') + self.add_state('nll_sum', default=torch.tensor(0.0), dist_reduce_fx='sum') + self.add_state( + 'n_valid_labels', default=torch.tensor(0.0), dist_reduce_fx='sum' + ) self.ignore_index = ignore_index def update(self, probs: torch.Tensor, targets: torch.Tensor) -> None: diff --git a/laplace/utils/subnetmask.py b/laplace/utils/subnetmask.py index e03f4891..db211cca 100644 --- a/laplace/utils/subnetmask.py +++ b/laplace/utils/subnetmask.py @@ -7,9 +7,16 @@ from laplace.utils import FeatureExtractor, fit_diagonal_swag_var -__all__ = ['SubnetMask', 'RandomSubnetMask', 'LargestMagnitudeSubnetMask', - 'LargestVarianceDiagLaplaceSubnetMask', 'LargestVarianceSWAGSubnetMask', - 'ParamNameSubnetMask', 'ModuleNameSubnetMask', 'LastLayerSubnetMask'] +__all__ = [ + 'SubnetMask', + 'RandomSubnetMask', + 'LargestMagnitudeSubnetMask', + 'LargestVarianceDiagLaplaceSubnetMask', + 'LargestVarianceSWAGSubnetMask', + 'ParamNameSubnetMask', + 'ModuleNameSubnetMask', + 'LastLayerSubnetMask', +] class SubnetMask: @@ -19,6 +26,7 @@ class SubnetMask: ---------- model : torch.nn.Module """ + def __init__(self, model): self.model = model self.parameter_vector = parameters_to_vector(self.model.parameters()).detach() @@ -65,22 +73,38 @@ def convert_subnet_mask_to_indices(self, subnet_mask): """ if not isinstance(subnet_mask, torch.Tensor): raise ValueError('Subnetwork mask needs to be torch.Tensor!') - elif subnet_mask.dtype not in [torch.int64, torch.int32, torch.int16, torch.int8, - torch.uint8, torch.bool] or len(subnet_mask.shape) != 1: + elif ( + subnet_mask.dtype + not in [ + torch.int64, + torch.int32, + torch.int16, + torch.int8, + torch.uint8, + torch.bool, + ] + or len(subnet_mask.shape) != 1 + ): + raise ValueError( + 'Subnetwork mask needs to be 1-dimensional integral or boolean tensor!' + ) + elif ( + len(subnet_mask) != self._n_params + or len(subnet_mask[subnet_mask == 0]) + len(subnet_mask[subnet_mask == 1]) + != self._n_params + ): raise ValueError( - 'Subnetwork mask needs to be 1-dimensional integral or boolean tensor!') - elif (len(subnet_mask) != self._n_params or len(subnet_mask[subnet_mask == 0]) - + len(subnet_mask[subnet_mask == 1]) != self._n_params): - raise ValueError('Subnetwork mask needs to be a binary vector of' - 'size (n_params) where 1s locate the subnetwork' - 'parameters within the vectorized model parameters' - '(i.e. `torch.nn.utils.parameters_to_vector(model.parameters())`)!') + 'Subnetwork mask needs to be a binary vector of' + 'size (n_params) where 1s locate the subnetwork' + 'parameters within the vectorized model parameters' + '(i.e. `torch.nn.utils.parameters_to_vector(model.parameters())`)!' + ) subnet_mask_indices = subnet_mask.nonzero(as_tuple=True)[0] return subnet_mask_indices def select(self, train_loader=None): - """ Select the subnetwork mask. + """Select the subnetwork mask. Parameters ---------- @@ -103,7 +127,7 @@ def select(self, train_loader=None): return self._indices def get_subnet_mask(self, train_loader): - """ Get the subnetwork mask. + """Get the subnetwork mask. Parameters ---------- @@ -131,15 +155,18 @@ class ScoreBasedSubnetMask(SubnetMask): n_params_subnet : int number of parameters in the subnetwork (i.e. number of top-scoring parameters to select) """ + def __init__(self, model, n_params_subnet): super().__init__(model) if n_params_subnet is None: raise ValueError( - 'Need to pass number of subnetwork parameters when using subnetwork Laplace.') + 'Need to pass number of subnetwork parameters when using subnetwork Laplace.' + ) if n_params_subnet > self._n_params: raise ValueError( - f'Subnetwork ({n_params_subnet}) cannot be larger than model ({self._n_params}).') + f'Subnetwork ({n_params_subnet}) cannot be larger than model ({self._n_params}).' + ) self._n_params_subnet = n_params_subnet self._param_scores = None @@ -148,16 +175,20 @@ def compute_param_scores(self, train_loader): def _check_param_scores(self): if self._param_scores.shape != self.parameter_vector.shape: - raise ValueError('Parameter scores need to be of same shape as parameter vector.') + raise ValueError( + 'Parameter scores need to be of same shape as parameter vector.' + ) def get_subnet_mask(self, train_loader): - """ Get the subnetwork mask by (descendingly) ranking parameters based on their scores.""" + """Get the subnetwork mask by (descendingly) ranking parameters based on their scores.""" if self._param_scores is None: self._param_scores = self.compute_param_scores(train_loader) self._check_param_scores() - idx = torch.argsort(self._param_scores, descending=True)[:self._n_params_subnet] + idx = torch.argsort(self._param_scores, descending=True)[ + : self._n_params_subnet + ] idx = idx.sort()[0] subnet_mask = torch.zeros_like(self.parameter_vector).bool() subnet_mask[idx] = 1 @@ -166,12 +197,14 @@ def get_subnet_mask(self, train_loader): class RandomSubnetMask(ScoreBasedSubnetMask): """Subnetwork mask of parameters sampled uniformly at random.""" + def compute_param_scores(self, train_loader): return torch.rand_like(self.parameter_vector) class LargestMagnitudeSubnetMask(ScoreBasedSubnetMask): - """Subnetwork mask identifying the parameters with the largest magnitude. """ + """Subnetwork mask identifying the parameters with the largest magnitude.""" + def compute_param_scores(self, train_loader): return self.parameter_vector.abs() @@ -188,6 +221,7 @@ class LargestVarianceDiagLaplaceSubnetMask(ScoreBasedSubnetMask): diag_laplace_model : `laplace.baselaplace.DiagLaplace` diagonal Laplace model to use for variance estimation """ + def __init__(self, model, n_params_subnet, diag_laplace_model): super().__init__(model, n_params_subnet) self.diag_laplace_model = diag_laplace_model @@ -218,8 +252,16 @@ class LargestVarianceSWAGSubnetMask(ScoreBasedSubnetMask): swag_lr : float learning rate for SWAG snapshot collection """ - def __init__(self, model, n_params_subnet, likelihood='classification', - swag_n_snapshots=40, swag_snapshot_freq=1, swag_lr=0.01): + + def __init__( + self, + model, + n_params_subnet, + likelihood='classification', + swag_n_snapshots=40, + swag_snapshot_freq=1, + swag_lr=0.01, + ): super().__init__(model, n_params_subnet) self.likelihood = likelihood self.swag_n_snapshots = swag_n_snapshots @@ -234,10 +276,14 @@ def compute_param_scores(self, train_loader): criterion = CrossEntropyLoss(reduction='mean') elif self.likelihood == 'regression': criterion = MSELoss(reduction='mean') - param_variances = fit_diagonal_swag_var(self.model, train_loader, criterion, - n_snapshots_total=self.swag_n_snapshots, - snapshot_freq=self.swag_snapshot_freq, - lr=self.swag_lr) + param_variances = fit_diagonal_swag_var( + self.model, + train_loader, + criterion, + n_snapshots_total=self.swag_n_snapshots, + snapshot_freq=self.swag_snapshot_freq, + lr=self.swag_lr, + ) return param_variances @@ -251,6 +297,7 @@ class ParamNameSubnetMask(SubnetMask): list of names of the parameters (as in `model.named_parameters()`) that define the subnetwork """ + def __init__(self, model, parameter_names): super().__init__(model) self._parameter_names = parameter_names @@ -268,7 +315,7 @@ def _check_param_names(self): raise ValueError(f'Parameters {param_names} do not exist in model.') def get_subnet_mask(self, train_loader): - """ Get the subnetwork mask identifying the specified parameters.""" + """Get the subnetwork mask identifying the specified parameters.""" self._check_param_names() @@ -293,6 +340,7 @@ class ModuleNameSubnetMask(SubnetMask): list of names of the modules (as in `model.named_modules()`) that define the subnetwork; the modules cannot have children, i.e. need to be leaf modules """ + def __init__(self, model, module_names): super().__init__(model) self._module_names = module_names @@ -306,7 +354,9 @@ def _check_module_names(self): for name, module in self.model.named_modules(): if name in module_names: if len(list(module.children())) > 0: - raise ValueError(f'Module "{name}" has children, which is not supported.') + raise ValueError( + f'Module "{name}" has children, which is not supported.' + ) elif len(list(module.parameters())) == 0: raise ValueError(f'Module "{name}" does not have any parameters.') else: @@ -315,7 +365,7 @@ def _check_module_names(self): raise ValueError(f'Modules {module_names} do not exist in model.') def get_subnet_mask(self, train_loader): - """ Get the subnetwork mask identifying the specified modules.""" + """Get the subnetwork mask identifying the specified modules.""" self._check_module_names() @@ -327,7 +377,9 @@ def get_subnet_mask(self, train_loader): mask_method = torch.ones_like else: mask_method = torch.zeros_like - subnet_mask_list.append(mask_method(parameters_to_vector(module.parameters()))) + subnet_mask_list.append( + mask_method(parameters_to_vector(module.parameters())) + ) subnet_mask = torch.cat(subnet_mask_list).bool() return subnet_mask @@ -341,13 +393,16 @@ class LastLayerSubnetMask(ModuleNameSubnetMask): last_layer_name: str, default=None name of the model's last layer, if None it will be determined automatically """ + def __init__(self, model, last_layer_name=None): super().__init__(model, None) - self._feature_extractor = FeatureExtractor(self.model, last_layer_name=last_layer_name) + self._feature_extractor = FeatureExtractor( + self.model, last_layer_name=last_layer_name + ) self._n_params_subnet = None def get_subnet_mask(self, train_loader): - """ Get the subnetwork mask identifying the last layer.""" + """Get the subnetwork mask identifying the last layer.""" if train_loader is None: raise ValueError('Need to pass train loader for subnet selection.') diff --git a/laplace/utils/swag.py b/laplace/utils/swag.py index a6aba701..ade02a6e 100644 --- a/laplace/utils/swag.py +++ b/laplace/utils/swag.py @@ -11,20 +11,29 @@ def _param_vector(model): return parameters_to_vector(model.parameters()).detach() -def fit_diagonal_swag_var(model, train_loader, criterion, n_snapshots_total=40, snapshot_freq=1, - lr=0.01, momentum=0.9, weight_decay=3e-4, min_var=1e-30): +def fit_diagonal_swag_var( + model, + train_loader, + criterion, + n_snapshots_total=40, + snapshot_freq=1, + lr=0.01, + momentum=0.9, + weight_decay=3e-4, + min_var=1e-30, +): """ Fit diagonal SWAG [1], which estimates marginal variances of model parameters by computing the first and second moment of SGD iterates with a large learning rate. - + Implementation partly adapted from: - https://github.com/wjmaddox/swa_gaussian/blob/master/swag/posteriors/swag.py - https://github.com/wjmaddox/swa_gaussian/blob/master/experiments/train/run_swag.py References ---------- - [1] Maddox, W., Garipov, T., Izmailov, P., Vetrov, D., Wilson, AG. - [*A Simple Baseline for Bayesian Uncertainty in Deep Learning*](https://arxiv.org/abs/1902.02476). + [1] Maddox, W., Garipov, T., Izmailov, P., Vetrov, D., Wilson, AG. + [*A Simple Baseline for Bayesian Uncertainty in Deep Learning*](https://arxiv.org/abs/1902.02476). NeurIPS 2019. Parameters @@ -65,7 +74,8 @@ def fit_diagonal_swag_var(model, train_loader, criterion, n_snapshots_total=40, # run SGD to collect model snapshots optimizer = torch.optim.SGD( - _model.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay) + _model.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay + ) n_epochs = snapshot_freq * n_snapshots_total for epoch in range(n_epochs): for inputs, targets in train_loader: @@ -83,5 +93,5 @@ def fit_diagonal_swag_var(model, train_loader, criterion, n_snapshots_total=40, n_snapshots += 1 # compute marginal parameter variances, Var[P] = E[P^2] - E[P]^2 - param_variances = torch.clamp(sq_mean - mean ** 2, min_var) + param_variances = torch.clamp(sq_mean - mean**2, min_var) return param_variances diff --git a/laplace/utils/utils.py b/laplace/utils/utils.py index 315784fb..5fe4d749 100644 --- a/laplace/utils/utils.py +++ b/laplace/utils/utils.py @@ -11,8 +11,17 @@ import math -__all__ = ['get_nll', 'validate', 'parameters_per_layer', 'invsqrt_precision', 'kron', - 'diagonal_add_scalar', 'symeig', 'block_diag', 'expand_prior_precision'] +__all__ = [ + 'get_nll', + 'validate', + 'parameters_per_layer', + 'invsqrt_precision', + 'kron', + 'diagonal_add_scalar', + 'symeig', + 'block_diag', + 'expand_prior_precision', +] def get_nll(out_dist, targets): @@ -20,7 +29,15 @@ def get_nll(out_dist, targets): @torch.no_grad() -def validate(laplace, val_loader, loss, pred_type='glm', link_approx='probit', n_samples=100, loss_with_var=False) -> float: +def validate( + laplace, + val_loader, + loss, + pred_type='glm', + link_approx='probit', + n_samples=100, + loss_with_var=False, +) -> float: laplace.model.eval() assert callable(loss) or isinstance(loss, Metric) is_offline = not isinstance(loss, Metric) @@ -38,9 +55,8 @@ def validate(laplace, val_loader, loss, pred_type='glm', link_approx='probit', n X = X.to(laplace._device) y = y.to(laplace._device) out = laplace( - X, pred_type=pred_type, - link_approx=link_approx, - n_samples=n_samples) + X, pred_type=pred_type, link_approx=link_approx, n_samples=n_samples + ) if type(out) == tuple: if is_offline: @@ -97,9 +113,11 @@ def invsqrt_precision(M): def _is_batchnorm(module): - if isinstance(module, BatchNorm1d) or \ - isinstance(module, BatchNorm2d) or \ - isinstance(module, BatchNorm3d): + if ( + isinstance(module, BatchNorm1d) + or isinstance(module, BatchNorm2d) + or isinstance(module, BatchNorm3d) + ): return True return False @@ -134,9 +152,9 @@ def kron(t1, t2): tiled_t2 = t2.repeat(t1_height, t1_width) expanded_t1 = ( t1.unsqueeze(2) - .unsqueeze(3) - .repeat(1, t2_height, t2_width, 1) - .view(out_height, out_width) + .unsqueeze(3) + .repeat(1, t2_height, t2_width, 1) + .view(out_height, out_width) ) return expanded_t1 * tiled_t2 @@ -185,7 +203,7 @@ def symeig(M): M = M + torch.eye(M.shape[0], device=M.device) try: L, W = torch.linalg.eigh(M, UPLO='U') - L -= 1. + L -= 1.0 except RuntimeError: stats = f'diag: {M.diagonal()}, max: {M.abs().max()}, ' stats = stats + f'min: {M.abs().min()}, mean: {M.abs().mean()}' @@ -214,7 +232,7 @@ def block_diag(blocks): p_cur = 0 for block in blocks: p_block = block.shape[0] - M[p_cur:p_cur+p_block, p_cur:p_cur+p_block] = block + M[p_cur : p_cur + p_block, p_cur : p_cur + p_block] = block p_cur += p_block return M @@ -243,11 +261,17 @@ def expand_prior_precision(prior_prec, model): elif len(prior_prec) == P: # full diagonal return prior_prec.to(device) else: - return torch.cat([delta * torch.ones_like(m).flatten() for delta, m - in zip(prior_prec, trainable_params)]) + return torch.cat( + [ + delta * torch.ones_like(m).flatten() + for delta, m in zip(prior_prec, trainable_params) + ] + ) -def fix_prior_prec_structure(prior_prec_init, prior_structure, n_layers, n_params, device): +def fix_prior_prec_structure( + prior_prec_init, prior_structure, n_layers, n_params, device +): if prior_structure == 'scalar': prior_prec_init = torch.full((1,), prior_prec_init, device=device) elif prior_structure == 'layerwise': @@ -275,8 +299,12 @@ def normal_samples(mean, var, n_samples, generator=None): """ assert mean.ndim == 2, 'Invalid input shape of mean, should be 2-dimensional.' _, output_dim = mean.shape - randn_samples = torch.randn((output_dim, n_samples), device=mean.device, - dtype=mean.dtype, generator=generator) + randn_samples = torch.randn( + (output_dim, n_samples), + device=mean.device, + dtype=mean.dtype, + generator=generator, + ) if mean.shape == var.shape: # diagonal covariance @@ -285,7 +313,9 @@ def normal_samples(mean, var, n_samples, generator=None): elif mean.shape == var.shape[:2] and var.shape[-1] == mean.shape[1]: # full covariance scale = torch.linalg.cholesky(var) - scaled_samples = torch.matmul(scale, randn_samples.unsqueeze(0)) # expand batch dim + scaled_samples = torch.matmul( + scale, randn_samples.unsqueeze(0) + ) # expand batch dim return (mean.unsqueeze(-1) + scaled_samples).permute((2, 0, 1)) else: raise ValueError('Invalid input shapes.') diff --git a/setup.py b/setup.py index 1abbd068..4356bfc5 100644 --- a/setup.py +++ b/setup.py @@ -1,4 +1,4 @@ import setuptools -if __name__ == "__main__": +if __name__ == '__main__': setuptools.setup() diff --git a/tests/test_curv_backends_backpack.py b/tests/test_curv_backends_backpack.py index cd254a10..fb1d7e70 100644 --- a/tests/test_curv_backends_backpack.py +++ b/tests/test_curv_backends_backpack.py @@ -226,7 +226,7 @@ def test_kron_normalization_reg(reg_Xy, model): loss_true = 7 * loss X = torch.repeat_interleave(xi, 7, 0) y = torch.repeat_interleave(yi, 7, 0) - loss_test, kron_test = backend.kron(X, y, N=7) + loss_test, kron_test = backend.kron(X, y, N=7) assert torch.allclose(kron_true.diag(), kron_test.diag()) assert torch.allclose(loss_true, loss_test) @@ -240,6 +240,6 @@ def test_kron_normalization_class(class_Xy, model): loss_true = 7 * loss X = torch.repeat_interleave(xi, 7, 0) y = torch.repeat_interleave(yi, 7, 0) - loss_test, kron_test = backend.kron(X, y, N=7) + loss_test, kron_test = backend.kron(X, y, N=7) assert torch.allclose(kron_true.diag(), kron_test.diag()) assert torch.allclose(loss_true, loss_test) diff --git a/tests/test_feature_extractor.py b/tests/test_feature_extractor.py index d3b95ad5..4765955e 100644 --- a/tests/test_feature_extractor.py +++ b/tests/test_feature_extractor.py @@ -11,42 +11,31 @@ def __init__(self, num_classes): self.conv1 = nn.Sequential( # Input shape (3, 64, 64) nn.Conv2d( - in_channels=3, - out_channels=6, - kernel_size=5, - stride=1, - padding=2 + in_channels=3, out_channels=6, kernel_size=5, stride=1, padding=2 ), # Output shape (6, 60, 60) nn.ReLU(), # Output shape (6, 30, 30) - nn.MaxPool2d(kernel_size=2) + nn.MaxPool2d(kernel_size=2), ) self.fc = nn.Sequential( - nn.Linear(in_features=16 * 16 * 16, - out_features=300), + nn.Linear(in_features=16 * 16 * 16, out_features=300), nn.ReLU(), - nn.Linear(in_features=300, - out_features=84), + nn.Linear(in_features=300, out_features=84), nn.ReLU(), - nn.Linear(in_features=84, - out_features=num_classes) + nn.Linear(in_features=84, out_features=num_classes), ) self.conv2 = nn.Sequential( # Input shape (6, 30, 30) nn.Conv2d( - in_channels=6, - out_channels=16, - kernel_size=5, - stride=1, - padding=2 + in_channels=6, out_channels=16, kernel_size=5, stride=1, padding=2 ), # Output shape (16, 26, 26) nn.ReLU(), # Output shape (16, 13, 13) - nn.MaxPool2d(kernel_size=2) + nn.MaxPool2d(kernel_size=2), ) def forward(self, x): @@ -94,9 +83,9 @@ def get_model(model_name): nn.Conv2d(3, 6, 3, 1, 1), nn.ReLU(), nn.Flatten(), - nn.Linear(6*64*64, 10), + nn.Linear(6 * 64 * 64, 10), nn.ReLU(), - nn.Linear(10, 10) + nn.Linear(10, 10), ) else: raise ValueError(f'{model_name} is not supported.') @@ -107,10 +96,23 @@ def get_model(model_name): def test_feature_extractor(): # all torchvision classifcation models but 'squeezenet' (no linear last layer) # + model where modules are initilaized in wrong order + nn.Sequential model - model_names = ['resnet18', 'alexnet', 'vgg16', 'densenet', 'inception', - 'googlenet', 'shufflenet', 'mobilenet_v2', 'mobilenet_v3_large', - 'mobilenet_v3_small', 'resnext50_32x4d', 'wide_resnet50_2', - 'mnasnet', 'switchedCNN', 'sequential'] + model_names = [ + 'resnet18', + 'alexnet', + 'vgg16', + 'densenet', + 'inception', + 'googlenet', + 'shufflenet', + 'mobilenet_v2', + 'mobilenet_v3_large', + 'mobilenet_v3_small', + 'resnext50_32x4d', + 'wide_resnet50_2', + 'mnasnet', + 'switchedCNN', + 'sequential', + ] # to test the last_layer_name argument # last_layer_names = ['fc', 'classifier.6', 'classifier.6', 'classifier', 'fc', diff --git a/tests/test_laplace.py b/tests/test_laplace.py index c40be165..23aa1465 100644 --- a/tests/test_laplace.py +++ b/tests/test_laplace.py @@ -10,10 +10,22 @@ torch.manual_seed(240) torch.set_default_tensor_type(torch.DoubleTensor) -flavors = [FullLaplace, KronLaplace, DiagLaplace, - FullLLLaplace, KronLLLaplace, DiagLLLaplace] -all_keys = [('all', 'full'), ('all', 'kron'), ('all', 'diag'), - ('last_layer', 'full'), ('last_layer', 'kron'), ('last_layer', 'diag')] +flavors = [ + FullLaplace, + KronLaplace, + DiagLaplace, + FullLLLaplace, + KronLLLaplace, + DiagLLLaplace, +] +all_keys = [ + ('all', 'full'), + ('all', 'kron'), + ('all', 'diag'), + ('last_layer', 'full'), + ('last_layer', 'kron'), + ('last_layer', 'diag'), +] @pytest.fixture @@ -42,8 +54,15 @@ def test_opt_keywords(key, model, likelihood='classification'): # test if optional keywords are correctly passed on w, s = key prior_mean = torch.zeros_like(parameters_to_vector(model.parameters())) - lap = Laplace(model, likelihood, subset_of_weights=w, hessian_structure=s, - prior_precision=0.01, prior_mean=prior_mean, temperature=10.) + lap = Laplace( + model, + likelihood, + subset_of_weights=w, + hessian_structure=s, + prior_precision=0.01, + prior_mean=prior_mean, + temperature=10.0, + ) assert torch.allclose(lap.prior_mean, prior_mean) assert lap.prior_precision == 0.01 - assert lap.temperature == 10. + assert lap.temperature == 10.0 diff --git a/tests/test_lllaplace.py b/tests/test_lllaplace.py index 97131a55..0da88dbf 100644 --- a/tests/test_lllaplace.py +++ b/tests/test_lllaplace.py @@ -12,11 +12,13 @@ from laplace.utils import FeatureExtractor from tests.utils import jacobians_naive + @pytest.fixture(autouse=True) def run_around_tests(): torch.set_default_dtype(torch.float64) yield + flavors = [FullLLLaplace, KronLLLaplace, DiagLLLaplace] @@ -103,34 +105,43 @@ def test_laplace_invalid_likelihood(laplace, model): def test_laplace_init_noise(laplace, model): # float sigma_noise = 1.2 - lap = laplace(model, likelihood='regression', sigma_noise=sigma_noise, - last_layer_name='1') + lap = laplace( + model, likelihood='regression', sigma_noise=sigma_noise, last_layer_name='1' + ) # torch.tensor 0-dim sigma_noise = torch.tensor(1.2) - lap = laplace(model, likelihood='regression', sigma_noise=sigma_noise, - last_layer_name='1') + lap = laplace( + model, likelihood='regression', sigma_noise=sigma_noise, last_layer_name='1' + ) # torch.tensor 1-dim sigma_noise = torch.tensor(1.2).reshape(-1) - lap = laplace(model, likelihood='regression', sigma_noise=sigma_noise, - last_layer_name='1') + lap = laplace( + model, likelihood='regression', sigma_noise=sigma_noise, last_layer_name='1' + ) # for classification should fail sigma_noise = 1.2 with pytest.raises(ValueError): - lap = laplace(model, likelihood='classification', - sigma_noise=sigma_noise, last_layer_name='1') + lap = laplace( + model, + likelihood='classification', + sigma_noise=sigma_noise, + last_layer_name='1', + ) # other than that should fail # higher dim sigma_noise = torch.tensor(1.2).reshape(1, 1) with pytest.raises(ValueError): - lap = laplace(model, likelihood='regression', sigma_noise=sigma_noise, - last_layer_name='1') + lap = laplace( + model, likelihood='regression', sigma_noise=sigma_noise, last_layer_name='1' + ) # other datatype, only reals supported sigma_noise = '1.2' with pytest.raises(ValueError): - lap = laplace(model, likelihood='regression', sigma_noise=sigma_noise, - last_layer_name='1') + lap = laplace( + model, likelihood='regression', sigma_noise=sigma_noise, last_layer_name='1' + ) @pytest.mark.parametrize('laplace', flavors) @@ -141,63 +152,107 @@ def test_laplace_init_precision(laplace, model): setattr(model, 'n_params', len(parameters_to_vector(model_params))) # float precision = 10.6 - lap = laplace(model, likelihood='regression', prior_precision=precision, - last_layer_name='1') + lap = laplace( + model, likelihood='regression', prior_precision=precision, last_layer_name='1' + ) # torch.tensor 0-dim precision = torch.tensor(10.6) - lap = laplace(model, likelihood='regression', prior_precision=precision, - last_layer_name='1') + lap = laplace( + model, likelihood='regression', prior_precision=precision, last_layer_name='1' + ) # torch.tensor 1-dim precision = torch.tensor(10.7).reshape(-1) - lap = laplace(model, likelihood='regression', prior_precision=precision, - last_layer_name='1') + lap = laplace( + model, likelihood='regression', prior_precision=precision, last_layer_name='1' + ) # torch.tensor 1-dim param-shape if laplace == KronLLLaplace: # kron only supports per layer with pytest.raises(ValueError): precision = torch.tensor(10.7).reshape(-1).repeat(model.n_params) - lap = laplace(model, likelihood='regression', prior_precision=precision, - last_layer_name='1') + lap = laplace( + model, + likelihood='regression', + prior_precision=precision, + last_layer_name='1', + ) else: precision = torch.tensor(10.7).reshape(-1).repeat(model.n_params) - lap = laplace(model, likelihood='regression', prior_precision=precision, - last_layer_name='1') + lap = laplace( + model, + likelihood='regression', + prior_precision=precision, + last_layer_name='1', + ) # torch.tensor 1-dim layer-shape precision = torch.tensor(10.7).reshape(-1).repeat(model.n_layers) - lap = laplace(model, likelihood='regression', prior_precision=precision, - last_layer_name='1') + lap = laplace( + model, likelihood='regression', prior_precision=precision, last_layer_name='1' + ) # other than that should fail # higher dim precision = torch.tensor(10.6).reshape(1, 1) with pytest.raises(ValueError): - lap = laplace(model, likelihood='regression', prior_precision=precision, - last_layer_name='1') + lap = laplace( + model, + likelihood='regression', + prior_precision=precision, + last_layer_name='1', + ) # unmatched dim precision = torch.tensor(10.6).reshape(-1).repeat(17) with pytest.raises(ValueError): - lap = laplace(model, likelihood='regression', prior_precision=precision, - last_layer_name='1') + lap = laplace( + model, + likelihood='regression', + prior_precision=precision, + last_layer_name='1', + ) # other datatype, only reals supported precision = '1.5' with pytest.raises(ValueError): - lap = laplace(model, likelihood='regression', prior_precision=precision, - last_layer_name='1') + lap = laplace( + model, + likelihood='regression', + prior_precision=precision, + last_layer_name='1', + ) @pytest.mark.parametrize('laplace', flavors) def test_laplace_init_prior_mean_and_scatter(laplace, model, class_loader): - lap_scalar_mean = laplace(model, 'classification', last_layer_name='1', - prior_precision=1e-2, prior_mean=1.) - assert torch.allclose(lap_scalar_mean.prior_mean, torch.tensor([1.])) - lap_tensor_mean = laplace(model, 'classification', last_layer_name='1', - prior_precision=1e-2, prior_mean=torch.ones(1)) - assert torch.allclose(lap_tensor_mean.prior_mean, torch.tensor([1.])) - lap_tensor_scalar_mean = laplace(model, 'classification', last_layer_name='1', - prior_precision=1e-2, prior_mean=torch.ones(1)[0]) - assert torch.allclose(lap_tensor_scalar_mean.prior_mean, torch.tensor(1.)) - lap_tensor_full_mean = laplace(model, 'classification', last_layer_name='1', - prior_precision=1e-2, prior_mean=torch.ones(20*2+2)) - assert torch.allclose(lap_tensor_full_mean.prior_mean, torch.ones(20*2+2)) + lap_scalar_mean = laplace( + model, + 'classification', + last_layer_name='1', + prior_precision=1e-2, + prior_mean=1.0, + ) + assert torch.allclose(lap_scalar_mean.prior_mean, torch.tensor([1.0])) + lap_tensor_mean = laplace( + model, + 'classification', + last_layer_name='1', + prior_precision=1e-2, + prior_mean=torch.ones(1), + ) + assert torch.allclose(lap_tensor_mean.prior_mean, torch.tensor([1.0])) + lap_tensor_scalar_mean = laplace( + model, + 'classification', + last_layer_name='1', + prior_precision=1e-2, + prior_mean=torch.ones(1)[0], + ) + assert torch.allclose(lap_tensor_scalar_mean.prior_mean, torch.tensor(1.0)) + lap_tensor_full_mean = laplace( + model, + 'classification', + last_layer_name='1', + prior_precision=1e-2, + prior_mean=torch.ones(20 * 2 + 2), + ) + assert torch.allclose(lap_tensor_full_mean.prior_mean, torch.ones(20 * 2 + 2)) lap_scalar_mean.fit(class_loader) lap_tensor_mean.fit(class_loader) @@ -214,36 +269,54 @@ def test_laplace_init_prior_mean_and_scatter(laplace, model, class_loader): # too many dims with pytest.raises(ValueError): - prior_mean = torch.ones(20*2+2).unsqueeze(-1) - laplace(model, 'classification', last_layer_name='1', - prior_precision=1e-2, prior_mean=prior_mean) + prior_mean = torch.ones(20 * 2 + 2).unsqueeze(-1) + laplace( + model, + 'classification', + last_layer_name='1', + prior_precision=1e-2, + prior_mean=prior_mean, + ) # unmatched dim with pytest.raises(ValueError): - prior_mean = torch.ones(20*2-3) - laplace(model, 'classification', last_layer_name='1', - prior_precision=1e-2, prior_mean=prior_mean) + prior_mean = torch.ones(20 * 2 - 3) + laplace( + model, + 'classification', + last_layer_name='1', + prior_precision=1e-2, + prior_mean=prior_mean, + ) # invalid argument type with pytest.raises(ValueError): - laplace(model, 'classification', last_layer_name='1', - prior_precision=1e-2, prior_mean='72') + laplace( + model, + 'classification', + last_layer_name='1', + prior_precision=1e-2, + prior_mean='72', + ) @pytest.mark.parametrize('laplace', flavors) def test_laplace_init_temperature(laplace, model): # valid float T = 1.1 - lap = laplace(model, likelihood='classification', temperature=T, - last_layer_name='1') + lap = laplace( + model, likelihood='classification', temperature=T, last_layer_name='1' + ) assert lap.temperature == T -@pytest.mark.parametrize('laplace,lh', product(flavors, ['classification', 'regression'])) +@pytest.mark.parametrize( + 'laplace,lh', product(flavors, ['classification', 'regression']) +) def test_laplace_functionality(laplace, lh, model, reg_loader, class_loader): if lh == 'classification': loader = class_loader - sigma_noise = 1. + sigma_noise = 1.0 else: loader = reg_loader sigma_noise = 0.3 @@ -281,13 +354,13 @@ def test_laplace_functionality(laplace, lh, model, reg_loader, class_loader): assert torch.allclose(theta, lap.mean) prior_prec = torch.diag(lap.prior_precision_diag) assert prior_prec.shape == torch.Size([len(theta), len(theta)]) - lml = lml - 1/2 * theta @ prior_prec @ theta + lml = lml - 1 / 2 * theta @ prior_prec @ theta Sigma_0 = torch.inverse(prior_prec) if laplace == DiagLLLaplace: log_det_post_prec = lap.posterior_precision.log().sum() else: log_det_post_prec = lap.posterior_precision.logdet() - lml = lml + 1/2 * (prior_prec.logdet() - log_det_post_prec) + lml = lml + 1 / 2 * (prior_prec.logdet() - log_det_post_prec) assert torch.allclose(lml, lap.log_marginal_likelihood()) # test sampling @@ -356,21 +429,31 @@ def test_classification_predictive(laplace, model, class_loader): # GLM predictive f_pred = lap(X, pred_type='glm', link_approx='mc', n_samples=100) assert f_pred.shape == f.shape - assert torch.allclose(f_pred.sum(), torch.tensor(len(f_pred), dtype=torch.double)) # sum up to 1 + assert torch.allclose( + f_pred.sum(), torch.tensor(len(f_pred), dtype=torch.double) + ) # sum up to 1 f_pred = lap(X, pred_type='glm', link_approx='probit') assert f_pred.shape == f.shape - assert torch.allclose(f_pred.sum(), torch.tensor(len(f_pred), dtype=torch.double)) # sum up to 1 + assert torch.allclose( + f_pred.sum(), torch.tensor(len(f_pred), dtype=torch.double) + ) # sum up to 1 f_pred = lap(X, pred_type='glm', link_approx='bridge') assert f_pred.shape == f.shape - assert torch.allclose(f_pred.sum(), torch.tensor(len(f_pred), dtype=torch.double)) # sum up to 1 + assert torch.allclose( + f_pred.sum(), torch.tensor(len(f_pred), dtype=torch.double) + ) # sum up to 1 f_pred = lap(X, pred_type='glm', link_approx='bridge_norm') assert f_pred.shape == f.shape - assert torch.allclose(f_pred.sum(), torch.tensor(len(f_pred), dtype=torch.double)) # sum up to 1 + assert torch.allclose( + f_pred.sum(), torch.tensor(len(f_pred), dtype=torch.double) + ) # sum up to 1 # NN predictive f_pred = lap(X, pred_type='nn', link_approx='mc', n_samples=100) assert f_pred.shape == f.shape - assert torch.allclose(f_pred.sum(), torch.tensor(len(f_pred), dtype=torch.double)) # sum up to 1 + assert torch.allclose( + f_pred.sum(), torch.tensor(len(f_pred), dtype=torch.double) + ) # sum up to 1 @pytest.mark.parametrize('laplace', flavors) diff --git a/tests/test_matrix.py b/tests/test_matrix.py index 757130c7..b035c3a3 100644 --- a/tests/test_matrix.py +++ b/tests/test_matrix.py @@ -32,7 +32,7 @@ def small_model(): def test_init_from_model(model): kron = Kron.init_from_model(model, 'cpu') - expected_sizes = [[20*20, 3*3], [20*20], [2*2, 20*20], [2*2]] + expected_sizes = [[20 * 20, 3 * 3], [20 * 20], [2 * 2, 20 * 20], [2 * 2]] for facs, exp_facs in zip(kron.kfacs, expected_sizes): for fi, exp_fi in zip(facs, exp_facs): assert torch.all(fi == 0) @@ -41,7 +41,7 @@ def test_init_from_model(model): def test_init_from_iterable(model): kron = Kron.init_from_model(model.parameters(), 'cpu') - expected_sizes = [[20*20, 3*3], [20*20], [2*2, 20*20], [2*2]] + expected_sizes = [[20 * 20, 3 * 3], [20 * 20], [2 * 2, 20 * 20], [2 * 2]] for facs, exp_facs in zip(kron.kfacs, expected_sizes): for fi, exp_fi in zip(facs, exp_facs): assert torch.all(fi == 0) @@ -81,7 +81,9 @@ def test_decompose(): kfacs = [[get_psd_matrix(i) for i in sizes] for sizes in expected_sizes] kron = Kron(kfacs) kron_decomp = kron.decompose() - for facs, Qs, ls in zip(kron.kfacs, kron_decomp.eigenvectors, kron_decomp.eigenvalues): + for facs, Qs, ls in zip( + kron.kfacs, kron_decomp.eigenvectors, kron_decomp.eigenvalues + ): if len(facs) == 1: H, Q, l = facs[0], Qs[0], ls[0] reconstructed = Q @ torch.diag(l) @ Q.T @@ -100,7 +102,9 @@ def test_decompose(): diag_kfacs = [[get_diag_psd_matrix(i) for i in sizes] for sizes in expected_sizes] kron = Kron(diag_kfacs) kron_decomp = kron.decompose() - for facs, Qs, ls in zip(kron.kfacs, kron_decomp.eigenvectors, kron_decomp.eigenvalues): + for facs, Qs, ls in zip( + kron.kfacs, kron_decomp.eigenvectors, kron_decomp.eigenvalues + ): if len(facs) == 1: H, Q, l = facs[0], Qs[0], ls[0] reconstructed = (Q @ torch.diag(l) @ Q.T).diag() @@ -154,29 +158,29 @@ def test_bmm_dense(small_model): # test J @ Kron_decomp @ Jt (square form) JS = kron_decomp.bmm(Js, exponent=1) JS_true = Js @ S - JSJ_true = torch.bmm(JS_true, Js.transpose(1,2)) - JSJ = torch.bmm(JS, Js.transpose(1,2)) + JSJ_true = torch.bmm(JS_true, Js.transpose(1, 2)) + JSJ = torch.bmm(JS, Js.transpose(1, 2)) assert torch.allclose(JSJ, JSJ_true) assert torch.allclose(JS, JS_true) # test J @ Kron @ Jt (square form) JS_nodecomp = kron.bmm(Js) - JSJ_nodecomp = torch.bmm(JS_nodecomp, Js.transpose(1,2)) + JSJ_nodecomp = torch.bmm(JS_nodecomp, Js.transpose(1, 2)) assert torch.allclose(JSJ_nodecomp, JSJ) assert torch.allclose(JS_nodecomp, JS) # test J @ S_inv @ J (funcitonal variance) JSJ = kron_decomp.inv_square_form(Js) S_inv = S.inverse() - JSJ_true = torch.bmm(Js @ S_inv, Js.transpose(1,2)) + JSJ_true = torch.bmm(Js @ S_inv, Js.transpose(1, 2)) assert torch.allclose(JSJ, JSJ_true) # test J @ S^-1/2 (sampling) - JS = kron_decomp.bmm(Js, exponent=-1/2) - JSJ = torch.bmm(JS, Js.transpose(1,2)) + JS = kron_decomp.bmm(Js, exponent=-1 / 2) + JSJ = torch.bmm(JS, Js.transpose(1, 2)) l, Q = torch.linalg.eigh(S_inv, UPLO='U') JS_true = Js @ Q @ torch.diag(torch.sqrt(l)) @ Q.T - JSJ_true = torch.bmm(JS_true, Js.transpose(1,2)) + JSJ_true = torch.bmm(JS_true, Js.transpose(1, 2)) assert torch.allclose(JS, JS_true) assert torch.allclose(JSJ, JSJ_true) @@ -222,29 +226,29 @@ def test_bmm_diag(small_model): # test J @ Kron_decomp @ Jt (square form) JS = kron_decomp.bmm(Js, exponent=1) JS_true = Js @ S - JSJ_true = torch.bmm(JS_true, Js.transpose(1,2)) - JSJ = torch.bmm(JS, Js.transpose(1,2)) + JSJ_true = torch.bmm(JS_true, Js.transpose(1, 2)) + JSJ = torch.bmm(JS, Js.transpose(1, 2)) assert torch.allclose(JSJ, JSJ_true) assert torch.allclose(JS, JS_true) # test J @ Kron @ Jt (square form) JS_nodecomp = kron.bmm(Js) - JSJ_nodecomp = torch.bmm(JS_nodecomp, Js.transpose(1,2)) + JSJ_nodecomp = torch.bmm(JS_nodecomp, Js.transpose(1, 2)) assert torch.allclose(JSJ_nodecomp, JSJ) assert torch.allclose(JS_nodecomp, JS) # test J @ S_inv @ J (funcitonal variance) JSJ = kron_decomp.inv_square_form(Js) S_inv = S.inverse() - JSJ_true = torch.bmm(Js @ S_inv, Js.transpose(1,2)) + JSJ_true = torch.bmm(Js @ S_inv, Js.transpose(1, 2)) assert torch.allclose(JSJ, JSJ_true) # test J @ S^-1/2 (sampling) - JS = kron_decomp.bmm(Js, exponent=-1/2) - JSJ = torch.bmm(JS, Js.transpose(1,2)) + JS = kron_decomp.bmm(Js, exponent=-1 / 2) + JSJ = torch.bmm(JS, Js.transpose(1, 2)) l, Q = torch.linalg.eigh(S_inv, UPLO='U') JS_true = Js @ Q @ torch.diag(torch.sqrt(l)) @ Q.T - JSJ_true = torch.bmm(JS_true, Js.transpose(1,2)) + JSJ_true = torch.bmm(JS_true, Js.transpose(1, 2)) assert torch.allclose(JS, JS_true) assert torch.allclose(JSJ, JSJ_true) @@ -271,7 +275,9 @@ def test_matrix_consistent(): kron = Kron(kfacs) kron_decomp = kron.decompose() assert torch.allclose(kron.to_matrix(), kron_decomp.to_matrix(exponent=1)) - assert torch.allclose(kron.to_matrix().inverse(), kron_decomp.to_matrix(exponent=-1)) + assert torch.allclose( + kron.to_matrix().inverse(), kron_decomp.to_matrix(exponent=-1) + ) M_true = kron.to_matrix() M_true.diagonal().add_(3.4) kron_decomp += torch.tensor(3.4) @@ -281,7 +287,9 @@ def test_matrix_consistent(): kron = Kron(diag_kfacs) kron_decomp = kron.decompose() assert torch.allclose(kron.to_matrix(), kron_decomp.to_matrix(exponent=1)) - assert torch.allclose(kron.to_matrix().inverse(), kron_decomp.to_matrix(exponent=-1)) + assert torch.allclose( + kron.to_matrix().inverse(), kron_decomp.to_matrix(exponent=-1) + ) M_true = kron.to_matrix() M_true.diagonal().add_(3.4) kron_decomp += torch.tensor(3.4) diff --git a/tests/test_metrics.py b/tests/test_metrics.py index 2edb72e0..faf3122a 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -18,7 +18,9 @@ def test_running_nll_metric(): all_probs, all_targets = torch.cat(all_probs, 0), torch.cat(all_targets, 0) nll_running = metric.compute().item() - nll_offline = F.nll_loss(all_probs.log().flatten(end_dim=-2), all_targets.flatten()).item() + nll_offline = F.nll_loss( + all_probs.log().flatten(end_dim=-2), all_targets.flatten() + ).item() assert math.isclose(nll_running, nll_offline, rel_tol=1e-7) @@ -40,7 +42,11 @@ def test_running_nll_metric_ignore_idx(): all_probs, all_targets = torch.cat(all_probs, 0), torch.cat(all_targets, 0) nll_running = metric.compute().item() - nll_offline = F.nll_loss(all_probs.log().flatten(end_dim=-2), all_targets.flatten(), ignore_index=ignore_idx).item() + nll_offline = F.nll_loss( + all_probs.log().flatten(end_dim=-2), + all_targets.flatten(), + ignore_index=ignore_idx, + ).item() print(nll_running, nll_offline) diff --git a/tests/test_serialization.py b/tests/test_serialization.py index 612a433c..2a6e022b 100644 --- a/tests/test_serialization.py +++ b/tests/test_serialization.py @@ -23,11 +23,20 @@ ) - torch.manual_seed(240) torch.set_default_tensor_type(torch.DoubleTensor) -lrlaplace_param = pytest.param(LowRankLaplace, marks=pytest.mark.xfail(reason='Unimplemented in the new ASDL')) -flavors = [FullLaplace, KronLaplace, DiagLaplace, lrlaplace_param, FullLLLaplace, KronLLLaplace, DiagLLLaplace] +lrlaplace_param = pytest.param( + LowRankLaplace, marks=pytest.mark.xfail(reason='Unimplemented in the new ASDL') +) +flavors = [ + FullLaplace, + KronLaplace, + DiagLaplace, + lrlaplace_param, + FullLLLaplace, + KronLLLaplace, + DiagLLLaplace, +] flavors_no_llla = [FullLaplace, KronLaplace, DiagLaplace, lrlaplace_param] flavors_llla = [FullLLLaplace, KronLLLaplace, DiagLLLaplace] flavors_subnet = [DiagSubnetLaplace, FullSubnetLaplace] @@ -57,10 +66,7 @@ def model2(): @pytest.fixture def model3(): model = torch.nn.Sequential( - OrderedDict([ - ('fc1', nn.Linear(3, 20)), - ('clf', nn.Linear(20, 2)) - ]) + OrderedDict([('fc1', nn.Linear(3, 20)), ('clf', nn.Linear(20, 2))]) ) setattr(model, 'output_size', 2) model_params = list(model.parameters()) @@ -180,20 +186,26 @@ def test_serialize_fail_different_hess_structures(model, reg_loader): la.sigma_noise = 1231 torch.save(la.state_dict(), 'state_dict.bin') - la2 = Laplace(model, 'regression', subset_of_weights='all', hessian_structure='diag') + la2 = Laplace( + model, 'regression', subset_of_weights='all', hessian_structure='diag' + ) with pytest.raises(ValueError): la2.load_state_dict(torch.load('state_dict.bin')) def test_serialize_fail_different_subset_of_weights(model, reg_loader): - la = Laplace(model, 'regression', subset_of_weights='last_layer', hessian_structure='diag') + la = Laplace( + model, 'regression', subset_of_weights='last_layer', hessian_structure='diag' + ) la.fit(reg_loader) la.optimize_prior_precision() la.sigma_noise = 1231 torch.save(la.state_dict(), 'state_dict.bin') - la2 = Laplace(model, 'regression', subset_of_weights='all', hessian_structure='diag') + la2 = Laplace( + model, 'regression', subset_of_weights='all', hessian_structure='diag' + ) with pytest.raises(ValueError): la2.load_state_dict(torch.load('state_dict.bin')) @@ -214,7 +226,9 @@ def test_serialize_fail_different_liks(laplace, model, reg_loader): @pytest.mark.parametrize('laplace', flavors_llla) -def test_serialize_fail_llla_different_last_layer_name(laplace, model, model3, reg_loader): +def test_serialize_fail_llla_different_last_layer_name( + laplace, model, model3, reg_loader +): print([n for n, _ in model.named_parameters()]) la = laplace(model, 'regression', last_layer_name='1') la.fit(reg_loader) @@ -244,15 +258,11 @@ def test_map_location( # AttributeError: Can't pickle local object 'FeatureExtractor._get_hook..hook' if issubclass(laplace, LLLaplace): if find_spec('dill') is None: - pytest.skip( - reason='dill package not found but needed for this test' - ) + pytest.skip(reason='dill package not found but needed for this test') else: import dill - torch_save = lambda obj, fn: torch.save( - obj, fn, pickle_module=dill - ) + torch_save = lambda obj, fn: torch.save(obj, fn, pickle_module=dill) else: # Use default pickle_module=pickle, but no need to import pickle here # just to set pickle_module=pickle. @@ -287,12 +297,8 @@ def test_map_location( 'prior_precision', 'sigma_noise', ]: - assert ( - getattr(la, name).device.type == device.type - ), f'la.{name} failed' - assert ( - getattr(la2, name).device.type == map_location - ), f'la2.{name} failed' + assert getattr(la, name).device.type == device.type, f'la.{name} failed' + assert getattr(la2, name).device.type == map_location, f'la2.{name} failed' # Test tensor attrs. for name, obj in vars(la).items(): @@ -306,9 +312,7 @@ def test_map_location( continue if isinstance(obj, torch.Tensor): assert obj.device.type == device.type, f'la.{name} failed' - assert ( - getattr(la2, name).device.type == map_location - ), f'la2.{name} failed' + assert getattr(la2, name).device.type == map_location, f'la2.{name} failed' assert la.sigma_noise == la2.sigma_noise X, _ = next(iter(reg_loader)) diff --git a/tests/test_subnetlaplace.py b/tests/test_subnetlaplace.py index 92c888b3..cfb1ed2f 100644 --- a/tests/test_subnetlaplace.py +++ b/tests/test_subnetlaplace.py @@ -9,15 +9,26 @@ from laplace import Laplace, SubnetLaplace, FullSubnetLaplace, DiagSubnetLaplace from laplace.baselaplace import DiagLaplace -from laplace.utils import (SubnetMask, RandomSubnetMask, LargestMagnitudeSubnetMask, - LargestVarianceDiagLaplaceSubnetMask, LargestVarianceSWAGSubnetMask, - ParamNameSubnetMask, ModuleNameSubnetMask, LastLayerSubnetMask) +from laplace.utils import ( + SubnetMask, + RandomSubnetMask, + LargestMagnitudeSubnetMask, + LargestVarianceDiagLaplaceSubnetMask, + LargestVarianceSWAGSubnetMask, + ParamNameSubnetMask, + ModuleNameSubnetMask, + LastLayerSubnetMask, +) torch.manual_seed(240) torch.set_default_tensor_type(torch.DoubleTensor) -score_based_subnet_masks = [RandomSubnetMask, LargestMagnitudeSubnetMask, - LargestVarianceDiagLaplaceSubnetMask, LargestVarianceSWAGSubnetMask] +score_based_subnet_masks = [ + RandomSubnetMask, + LargestMagnitudeSubnetMask, + LargestVarianceDiagLaplaceSubnetMask, + LargestVarianceSWAGSubnetMask, +] layer_subnet_masks = [ParamNameSubnetMask, ModuleNameSubnetMask, LastLayerSubnetMask] all_subnet_masks = score_based_subnet_masks + layer_subnet_masks likelihoods = ['classification', 'regression'] @@ -62,33 +73,59 @@ def test_subnet_laplace_init(model, likelihood): # subnet Laplace with full Hessian should work hessian_structure = 'full' - lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', - subnetwork_indices=subnetmask.indices, hessian_structure=hessian_structure) + lap = Laplace( + model, + likelihood=likelihood, + subset_of_weights='subnetwork', + subnetwork_indices=subnetmask.indices, + hessian_structure=hessian_structure, + ) assert isinstance(lap, FullSubnetLaplace) # subnet Laplace with diagonal Hessian should work hessian_structure = 'diag' - lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', - subnetwork_indices=subnetmask.indices, hessian_structure=hessian_structure) + lap = Laplace( + model, + likelihood=likelihood, + subset_of_weights='subnetwork', + subnetwork_indices=subnetmask.indices, + hessian_structure=hessian_structure, + ) assert isinstance(lap, DiagSubnetLaplace) # subnet Laplace without specifying subnetwork indices should raise an error with pytest.raises(TypeError): - lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', - hessian_structure=hessian_structure) + lap = Laplace( + model, + likelihood=likelihood, + subset_of_weights='subnetwork', + hessian_structure=hessian_structure, + ) # subnet Laplace with kron or lowrank Hessians should raise errors hessian_structure = 'kron' with pytest.raises(ValueError): - lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', - subnetwork_indices=subnetmask.indices, hessian_structure=hessian_structure) + lap = Laplace( + model, + likelihood=likelihood, + subset_of_weights='subnetwork', + subnetwork_indices=subnetmask.indices, + hessian_structure=hessian_structure, + ) hessian_structure = 'lowrank' with pytest.raises(ValueError): - lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', - subnetwork_indices=subnetmask.indices, hessian_structure=hessian_structure) - - -@pytest.mark.parametrize('likelihood,hessian_structure', product(likelihoods, hessian_structures)) + lap = Laplace( + model, + likelihood=likelihood, + subset_of_weights='subnetwork', + subnetwork_indices=subnetmask.indices, + hessian_structure=hessian_structure, + ) + + +@pytest.mark.parametrize( + 'likelihood,hessian_structure', product(likelihoods, hessian_structures) +) def test_subnet_laplace_large_init(large_model, likelihood, hessian_structure): # use random subnet mask for this test subnetwork_mask = RandomSubnetMask @@ -97,8 +134,13 @@ def test_subnet_laplace_large_init(large_model, likelihood, hessian_structure): subnetmask = subnetwork_mask(**subnetmask_kwargs) subnetmask.select() - lap = Laplace(large_model, likelihood=likelihood, subset_of_weights='subnetwork', - subnetwork_indices=subnetmask.indices, hessian_structure=hessian_structure) + lap = Laplace( + large_model, + likelihood=likelihood, + subset_of_weights='subnetwork', + subnetwork_indices=subnetmask.indices, + hessian_structure=hessian_structure, + ) assert lap.n_params_subnet == n_param_subnet if hessian_structure == 'full': assert lap.H.shape == (lap.n_params_subnet, lap.n_params_subnet) @@ -109,104 +151,188 @@ def test_subnet_laplace_large_init(large_model, likelihood, hessian_structure): assert torch.allclose(H, lap.H) -@pytest.mark.parametrize('likelihood,hessian_structure', product(likelihoods, hessian_structures)) -def test_custom_subnetwork_indices(model, likelihood, class_loader, reg_loader, hessian_structure): +@pytest.mark.parametrize( + 'likelihood,hessian_structure', product(likelihoods, hessian_structures) +) +def test_custom_subnetwork_indices( + model, likelihood, class_loader, reg_loader, hessian_structure +): loader = class_loader if likelihood == 'classification' else reg_loader # subnetwork indices that are None should raise an error subnetwork_indices = None with pytest.raises(ValueError): - lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', - subnetwork_indices=subnetwork_indices, hessian_structure=hessian_structure) + lap = Laplace( + model, + likelihood=likelihood, + subset_of_weights='subnetwork', + subnetwork_indices=subnetwork_indices, + hessian_structure=hessian_structure, + ) # subnetwork indices that are not PyTorch tensors should raise an error subnetwork_indices = [0, 5, 11, 42] with pytest.raises(ValueError): - lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', - subnetwork_indices=subnetwork_indices, hessian_structure=hessian_structure) + lap = Laplace( + model, + likelihood=likelihood, + subset_of_weights='subnetwork', + subnetwork_indices=subnetwork_indices, + hessian_structure=hessian_structure, + ) # subnetwork indices that are empty tensors should raise an error subnetwork_indices = torch.LongTensor([]) with pytest.raises(ValueError): - lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', - subnetwork_indices=subnetwork_indices, hessian_structure=hessian_structure) + lap = Laplace( + model, + likelihood=likelihood, + subset_of_weights='subnetwork', + subnetwork_indices=subnetwork_indices, + hessian_structure=hessian_structure, + ) # subnetwork indices that are scalar tensors should raise an error subnetwork_indices = torch.LongTensor(11) with pytest.raises(ValueError): - lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', - subnetwork_indices=subnetwork_indices, hessian_structure=hessian_structure) + lap = Laplace( + model, + likelihood=likelihood, + subset_of_weights='subnetwork', + subnetwork_indices=subnetwork_indices, + hessian_structure=hessian_structure, + ) # subnetwork indices that are not 1D PyTorch tensors should raise an error subnetwork_indices = torch.LongTensor([[0, 5], [11, 42]]) with pytest.raises(ValueError): - lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', - subnetwork_indices=subnetwork_indices, hessian_structure=hessian_structure) + lap = Laplace( + model, + likelihood=likelihood, + subset_of_weights='subnetwork', + subnetwork_indices=subnetwork_indices, + hessian_structure=hessian_structure, + ) # subnetwork indices that are double tensors should raise an error subnetwork_indices = torch.DoubleTensor([0.0, 5.0, 11.0, 42.0]) with pytest.raises(ValueError): - lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', - subnetwork_indices=subnetwork_indices, hessian_structure=hessian_structure) + lap = Laplace( + model, + likelihood=likelihood, + subset_of_weights='subnetwork', + subnetwork_indices=subnetwork_indices, + hessian_structure=hessian_structure, + ) # subnetwork indices that are float tensors should raise an error subnetwork_indices = torch.FloatTensor([0.0, 5.0, 11.0, 42.0]) with pytest.raises(ValueError): - lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', - subnetwork_indices=subnetwork_indices, hessian_structure=hessian_structure) + lap = Laplace( + model, + likelihood=likelihood, + subset_of_weights='subnetwork', + subnetwork_indices=subnetwork_indices, + hessian_structure=hessian_structure, + ) # subnetwork indices that are half tensors should raise an error subnetwork_indices = torch.HalfTensor([0.0, 5.0, 11.0, 42.0]) with pytest.raises(ValueError): - lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', - subnetwork_indices=subnetwork_indices, hessian_structure=hessian_structure) + lap = Laplace( + model, + likelihood=likelihood, + subset_of_weights='subnetwork', + subnetwork_indices=subnetwork_indices, + hessian_structure=hessian_structure, + ) # subnetwork indices that are int tensors should raise an error subnetwork_indices = torch.IntTensor([0, 5, 11, 42]) with pytest.raises(ValueError): - lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', - subnetwork_indices=subnetwork_indices, hessian_structure=hessian_structure) + lap = Laplace( + model, + likelihood=likelihood, + subset_of_weights='subnetwork', + subnetwork_indices=subnetwork_indices, + hessian_structure=hessian_structure, + ) # subnetwork indices that are short tensors should raise an error subnetwork_indices = torch.ShortTensor([0, 5, 11, 42]) with pytest.raises(ValueError): - lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', - subnetwork_indices=subnetwork_indices, hessian_structure=hessian_structure) + lap = Laplace( + model, + likelihood=likelihood, + subset_of_weights='subnetwork', + subnetwork_indices=subnetwork_indices, + hessian_structure=hessian_structure, + ) # subnetwork indices that are char tensors should raise an error subnetwork_indices = torch.CharTensor([0, 5, 11, 42]) with pytest.raises(ValueError): - lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', - subnetwork_indices=subnetwork_indices, hessian_structure=hessian_structure) + lap = Laplace( + model, + likelihood=likelihood, + subset_of_weights='subnetwork', + subnetwork_indices=subnetwork_indices, + hessian_structure=hessian_structure, + ) # subnetwork indices that are bool tensors should raise an error subnetwork_indices = torch.BoolTensor([0, 5, 11, 42]) with pytest.raises(ValueError): - lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', - subnetwork_indices=subnetwork_indices, hessian_structure=hessian_structure) + lap = Laplace( + model, + likelihood=likelihood, + subset_of_weights='subnetwork', + subnetwork_indices=subnetwork_indices, + hessian_structure=hessian_structure, + ) # subnetwork indices that contain elements smaller than zero should raise an error subnetwork_indices = torch.LongTensor([0, -1, -11]) with pytest.raises(ValueError): - lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', - subnetwork_indices=subnetwork_indices, hessian_structure=hessian_structure) + lap = Laplace( + model, + likelihood=likelihood, + subset_of_weights='subnetwork', + subnetwork_indices=subnetwork_indices, + hessian_structure=hessian_structure, + ) # subnetwork indices that contain elements larger than n_params should raise an error subnetwork_indices = torch.LongTensor([model.n_params + 1, model.n_params + 42]) with pytest.raises(ValueError): - lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', - subnetwork_indices=subnetwork_indices, hessian_structure=hessian_structure) + lap = Laplace( + model, + likelihood=likelihood, + subset_of_weights='subnetwork', + subnetwork_indices=subnetwork_indices, + hessian_structure=hessian_structure, + ) # subnetwork indices that contain duplicate entries should raise an error subnetwork_indices = torch.LongTensor([0, 0, 5, 11, 11, 42]) with pytest.raises(ValueError): - lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', - subnetwork_indices=subnetwork_indices, hessian_structure=hessian_structure) + lap = Laplace( + model, + likelihood=likelihood, + subset_of_weights='subnetwork', + subnetwork_indices=subnetwork_indices, + hessian_structure=hessian_structure, + ) # Non-empty, 1-dimensional torch.LongTensor with valid entries should work subnetwork_indices = torch.LongTensor([0, 5, 11, 42]) - lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', - subnetwork_indices=subnetwork_indices, hessian_structure=hessian_structure) + lap = Laplace( + model, + likelihood=likelihood, + subset_of_weights='subnetwork', + subnetwork_indices=subnetwork_indices, + hessian_structure=hessian_structure, + ) lap.fit(loader) assert isinstance(lap, SubnetLaplace) assert lap.n_params_subnet == 4 @@ -217,9 +343,13 @@ def test_custom_subnetwork_indices(model, likelihood, class_loader, reg_loader, assert lap.backend.subnetwork_indices.equal(subnetwork_indices) -@pytest.mark.parametrize('subnetwork_mask,likelihood,hessian_structure', - product(score_based_subnet_masks, likelihoods, hessian_structures)) -def test_score_based_subnet_masks(model, likelihood, subnetwork_mask, class_loader, reg_loader, hessian_structure): +@pytest.mark.parametrize( + 'subnetwork_mask,likelihood,hessian_structure', + product(score_based_subnet_masks, likelihoods, hessian_structures), +) +def test_score_based_subnet_masks( + model, likelihood, subnetwork_mask, class_loader, reg_loader, hessian_structure +): loader = class_loader if likelihood == 'classification' else reg_loader model_params = parameters_to_vector(model.parameters()) @@ -266,8 +396,13 @@ def test_score_based_subnet_masks(model, likelihood, subnetwork_mask, class_load subnetmask.select(loader) # define valid subnet Laplace model - lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', - subnetwork_indices=subnetmask.indices, hessian_structure=hessian_structure) + lap = Laplace( + model, + likelihood=likelihood, + subset_of_weights='subnetwork', + subnetwork_indices=subnetmask.indices, + hessian_structure=hessian_structure, + ) assert isinstance(lap, SubnetLaplace) # fit Laplace model @@ -287,15 +422,23 @@ def test_score_based_subnet_masks(model, likelihood, subnetwork_mask, class_load assert lap.prior_precision_diag.shape == (n_params_subnet,) -@pytest.mark.parametrize('subnetwork_mask,likelihood,hessian_structure', - product(layer_subnet_masks, likelihoods, hessian_structures)) -def test_layer_subnet_masks(model, likelihood, subnetwork_mask, class_loader, reg_loader, hessian_structure): +@pytest.mark.parametrize( + 'subnetwork_mask,likelihood,hessian_structure', + product(layer_subnet_masks, likelihoods, hessian_structures), +) +def test_layer_subnet_masks( + model, likelihood, subnetwork_mask, class_loader, reg_loader, hessian_structure +): loader = class_loader if likelihood == 'classification' else reg_loader subnetmask_kwargs = dict(model=model) # fit last-layer Laplace model - lllap = Laplace(model, likelihood=likelihood, subset_of_weights='last_layer', - hessian_structure=hessian_structure) + lllap = Laplace( + model, + likelihood=likelihood, + subset_of_weights='last_layer', + hessian_structure=hessian_structure, + ) lllap.fit(loader) # should raise error if we pass number of subnet parameters @@ -329,8 +472,13 @@ def test_layer_subnet_masks(model, likelihood, subnetwork_mask, class_loader, re subnetmask_kwargs.update(parameter_names=['1.weight', '1.bias']) subnetmask = subnetwork_mask(**subnetmask_kwargs) subnetmask.select(loader) - lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', - subnetwork_indices=subnetmask.indices, hessian_structure=hessian_structure) + lap = Laplace( + model, + likelihood=likelihood, + subset_of_weights='subnetwork', + subnetwork_indices=subnetmask.indices, + hessian_structure=hessian_structure, + ) lap.fit(loader) assert torch.allclose(lllap.H, lap.H, rtol=1e-3) @@ -345,8 +493,13 @@ def test_layer_subnet_masks(model, likelihood, subnetwork_mask, class_loader, re # select subnet mask and fit Laplace model subnetmask.select(loader) - lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', - subnetwork_indices=subnetmask.indices, hessian_structure=hessian_structure) + lap = Laplace( + model, + likelihood=likelihood, + subset_of_weights='subnetwork', + subnetwork_indices=subnetmask.indices, + hessian_structure=hessian_structure, + ) lap.fit(loader) assert isinstance(lap, SubnetLaplace) @@ -374,8 +527,13 @@ def test_layer_subnet_masks(model, likelihood, subnetwork_mask, class_loader, re subnetmask_kwargs.update(module_names=['1']) subnetmask = subnetwork_mask(**subnetmask_kwargs) subnetmask.select(loader) - lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', - subnetwork_indices=subnetmask.indices, hessian_structure=hessian_structure) + lap = Laplace( + model, + likelihood=likelihood, + subset_of_weights='subnetwork', + subnetwork_indices=subnetmask.indices, + hessian_structure=hessian_structure, + ) lap.fit(loader) assert torch.allclose(lllap.H, lap.H, rtol=1e-3) @@ -390,8 +548,13 @@ def test_layer_subnet_masks(model, likelihood, subnetwork_mask, class_loader, re # select subnet mask and fit Laplace model subnetmask.select(loader) - lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', - subnetwork_indices=subnetmask.indices, hessian_structure=hessian_structure) + lap = Laplace( + model, + likelihood=likelihood, + subset_of_weights='subnetwork', + subnetwork_indices=subnetmask.indices, + hessian_structure=hessian_structure, + ) lap.fit(loader) assert isinstance(lap, SubnetLaplace) @@ -412,8 +575,13 @@ def test_layer_subnet_masks(model, likelihood, subnetwork_mask, class_loader, re # select subnet mask and fit Laplace model subnetmask.select(loader) - lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', - subnetwork_indices=subnetmask.indices, hessian_structure=hessian_structure) + lap = Laplace( + model, + likelihood=likelihood, + subset_of_weights='subnetwork', + subnetwork_indices=subnetmask.indices, + hessian_structure=hessian_structure, + ) lap.fit(loader) assert isinstance(lap, SubnetLaplace) @@ -431,8 +599,13 @@ def test_layer_subnet_masks(model, likelihood, subnetwork_mask, class_loader, re # select subnet mask and fit Laplace model subnetmask.select(loader) - lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', - subnetwork_indices=subnetmask.indices, hessian_structure=hessian_structure) + lap = Laplace( + model, + likelihood=likelihood, + subset_of_weights='subnetwork', + subnetwork_indices=subnetmask.indices, + hessian_structure=hessian_structure, + ) lap.fit(loader) assert isinstance(lap, SubnetLaplace) @@ -452,8 +625,12 @@ def test_layer_subnet_masks(model, likelihood, subnetwork_mask, class_loader, re assert lap.prior_precision_diag.shape == (n_params_subnet,) -@pytest.mark.parametrize('likelihood,hessian_structure', product(likelihoods, hessian_structures)) -def test_full_subnet_mask(model, likelihood, class_loader, reg_loader, hessian_structure): +@pytest.mark.parametrize( + 'likelihood,hessian_structure', product(likelihoods, hessian_structures) +) +def test_full_subnet_mask( + model, likelihood, class_loader, reg_loader, hessian_structure +): loader = class_loader if likelihood == 'classification' else reg_loader # define full model 'subnet' mask class (i.e. where all parameters are part of the subnet) @@ -465,8 +642,13 @@ def get_subnet_mask(self, train_loader): subnetwork_mask = FullSubnetMask subnetmask = subnetwork_mask(model=model) subnetmask.select(loader) - lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', - subnetwork_indices=subnetmask.indices, hessian_structure=hessian_structure) + lap = Laplace( + model, + likelihood=likelihood, + subset_of_weights='subnetwork', + subnetwork_indices=subnetmask.indices, + hessian_structure=hessian_structure, + ) lap.fit(loader) assert isinstance(lap, SubnetLaplace) @@ -476,13 +658,19 @@ def get_subnet_mask(self, train_loader): assert lap.n_params_subnet == model.n_params # check that the Hessian is identical to that of an all-weights Laplace model - full_lap = Laplace(model, likelihood=likelihood, subset_of_weights='all', - hessian_structure=hessian_structure) + full_lap = Laplace( + model, + likelihood=likelihood, + subset_of_weights='all', + hessian_structure=hessian_structure, + ) full_lap.fit(loader) assert torch.allclose(full_lap.H, lap.H, rtol=1e-3) -@pytest.mark.parametrize('subnetwork_mask,hessian_structure', product(all_subnet_masks, hessian_structures)) +@pytest.mark.parametrize( + 'subnetwork_mask,hessian_structure', product(all_subnet_masks, hessian_structures) +) def test_regression_predictive(model, reg_loader, subnetwork_mask, hessian_structure): subnetmask_kwargs = dict(model=model) if subnetwork_mask in score_based_subnet_masks: @@ -499,8 +687,13 @@ def test_regression_predictive(model, reg_loader, subnetwork_mask, hessian_struc subnetmask = subnetwork_mask(**subnetmask_kwargs) subnetmask.select(reg_loader) - lap = Laplace(model, likelihood='regression', subset_of_weights='subnetwork', - subnetwork_indices=subnetmask.indices, hessian_structure=hessian_structure) + lap = Laplace( + model, + likelihood='regression', + subset_of_weights='subnetwork', + subnetwork_indices=subnetmask.indices, + hessian_structure=hessian_structure, + ) assert isinstance(lap, SubnetLaplace) lap.fit(reg_loader) @@ -524,8 +717,12 @@ def test_regression_predictive(model, reg_loader, subnetwork_mask, hessian_struc assert len(f_mu) == len(X) -@pytest.mark.parametrize('subnetwork_mask,hessian_structure', product(all_subnet_masks, hessian_structures)) -def test_classification_predictive(model, class_loader, subnetwork_mask, hessian_structure): +@pytest.mark.parametrize( + 'subnetwork_mask,hessian_structure', product(all_subnet_masks, hessian_structures) +) +def test_classification_predictive( + model, class_loader, subnetwork_mask, hessian_structure +): subnetmask_kwargs = dict(model=model) if subnetwork_mask in score_based_subnet_masks: subnetmask_kwargs.update(n_params_subnet=32) @@ -541,8 +738,13 @@ def test_classification_predictive(model, class_loader, subnetwork_mask, hessian subnetmask = subnetwork_mask(**subnetmask_kwargs) subnetmask.select(class_loader) - lap = Laplace(model, likelihood='classification', subset_of_weights='subnetwork', - subnetwork_indices=subnetmask.indices, hessian_structure=hessian_structure) + lap = Laplace( + model, + likelihood='classification', + subset_of_weights='subnetwork', + subnetwork_indices=subnetmask.indices, + hessian_structure=hessian_structure, + ) assert isinstance(lap, SubnetLaplace) lap.fit(class_loader) @@ -556,26 +758,40 @@ def test_classification_predictive(model, class_loader, subnetwork_mask, hessian # GLM predictive f_pred = lap(X, pred_type='glm', link_approx='mc', n_samples=100) assert f_pred.shape == f.shape - assert torch.allclose(f_pred.sum(), torch.tensor(len(f_pred), dtype=torch.double)) # sum up to 1 + assert torch.allclose( + f_pred.sum(), torch.tensor(len(f_pred), dtype=torch.double) + ) # sum up to 1 f_pred = lap(X, pred_type='glm', link_approx='probit') assert f_pred.shape == f.shape - assert torch.allclose(f_pred.sum(), torch.tensor(len(f_pred), dtype=torch.double)) # sum up to 1 + assert torch.allclose( + f_pred.sum(), torch.tensor(len(f_pred), dtype=torch.double) + ) # sum up to 1 f_pred = lap(X, pred_type='glm', link_approx='bridge') assert f_pred.shape == f.shape - assert torch.allclose(f_pred.sum(), torch.tensor(len(f_pred), dtype=torch.double)) # sum up to 1 + assert torch.allclose( + f_pred.sum(), torch.tensor(len(f_pred), dtype=torch.double) + ) # sum up to 1 f_pred = lap(X, pred_type='glm', link_approx='bridge_norm') assert f_pred.shape == f.shape - assert torch.allclose(f_pred.sum(), torch.tensor(len(f_pred), dtype=torch.double)) # sum up to 1 + assert torch.allclose( + f_pred.sum(), torch.tensor(len(f_pred), dtype=torch.double) + ) # sum up to 1 # NN predictive f_pred = lap(X, pred_type='nn', link_approx='mc', n_samples=100) assert f_pred.shape == f.shape - assert torch.allclose(f_pred.sum(), torch.tensor(len(f_pred), dtype=torch.double)) # sum up to 1 - - -@pytest.mark.parametrize('subnetwork_mask,likelihood,hessian_structure', - product(all_subnet_masks, likelihoods, hessian_structures)) -def test_subnet_marginal_likelihood(model, subnetwork_mask, likelihood, hessian_structure, class_loader, reg_loader): + assert torch.allclose( + f_pred.sum(), torch.tensor(len(f_pred), dtype=torch.double) + ) # sum up to 1 + + +@pytest.mark.parametrize( + 'subnetwork_mask,likelihood,hessian_structure', + product(all_subnet_masks, likelihoods, hessian_structures), +) +def test_subnet_marginal_likelihood( + model, subnetwork_mask, likelihood, hessian_structure, class_loader, reg_loader +): subnetmask_kwargs = dict(model=model) if subnetwork_mask in score_based_subnet_masks: subnetmask_kwargs.update(n_params_subnet=32) @@ -592,8 +808,13 @@ def test_subnet_marginal_likelihood(model, subnetwork_mask, likelihood, hessian_ subnetmask = subnetwork_mask(**subnetmask_kwargs) loader = class_loader if likelihood == 'classification' else reg_loader subnetmask.select(loader) - lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', - subnetwork_indices=subnetmask.indices, hessian_structure=hessian_structure) + lap = Laplace( + model, + likelihood=likelihood, + subset_of_weights='subnetwork', + subnetwork_indices=subnetmask.indices, + hessian_structure=hessian_structure, + ) assert isinstance(lap, SubnetLaplace) lap.fit(loader) diff --git a/tests/test_subset_params.py b/tests/test_subset_params.py index 01053273..24e4a78a 100644 --- a/tests/test_subset_params.py +++ b/tests/test_subset_params.py @@ -52,14 +52,18 @@ def reg_loader(): return DataLoader(TensorDataset(X, y), batch_size=3) -@pytest.mark.parametrize('laplace,lh', product(flavors, ['classification', 'regression'])) +@pytest.mark.parametrize( + 'laplace,lh', product(flavors, ['classification', 'regression']) +) def test_incompatible_backend(laplace, lh, model): lap = laplace(model, lh, backend=AsdlEF) lap = laplace(model, lh, backend=AsdlGGN) lap = laplace(model, lh, backend=AsdlHessian) -@pytest.mark.parametrize('laplace,lh', product(flavors, ['classification', 'regression'])) +@pytest.mark.parametrize( + 'laplace,lh', product(flavors, ['classification', 'regression']) +) def test_incompatible_backend(laplace, lh, model): with pytest.raises(ValueError): lap = laplace(model, lh, backend=BackPackGGN) @@ -116,5 +120,6 @@ def test_marglik(laplace, model, class_loader): def test_marglik(laplace, model, class_loader): lap = laplace(model, 'classification') lap.fit(class_loader) - lap.optimize_prior_precision(method='gridsearch', val_loader=class_loader, - pred_type='nn', link_approx='mc') + lap.optimize_prior_precision( + method='gridsearch', val_loader=class_loader, pred_type='nn', link_approx='mc' + ) diff --git a/tests/test_utils.py b/tests/test_utils.py index d6f9165c..29c142bc 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,7 +1,15 @@ import torch from torch.utils.data import TensorDataset, DataLoader from laplace import Laplace -from laplace.utils import invsqrt_precision, diagonal_add_scalar, symeig, normal_samples, validate, get_nll, RunningNLLMetric +from laplace.utils import ( + invsqrt_precision, + diagonal_add_scalar, + symeig, + normal_samples, + validate, + get_nll, + RunningNLLMetric, +) import math @@ -71,7 +79,9 @@ def test_validate(): y = torch.randint(3, size=(50,)) dataloader = DataLoader(TensorDataset(X, y), batch_size=10) - model = torch.nn.Sequential(torch.nn.Linear(10, 20), torch.nn.ReLU(), torch.nn.Linear(20, 3)) + model = torch.nn.Sequential( + torch.nn.Linear(10, 20), torch.nn.ReLU(), torch.nn.Linear(20, 3) + ) la = Laplace(model, 'classification', 'all') la.fit(dataloader) @@ -83,9 +93,13 @@ def test_validate(): assert res > 0 res = validate( - la, dataloader, RunningNLLMetric(), pred_type='nn', link_approx='mc', n_samples=10 + la, + dataloader, + RunningNLLMetric(), + pred_type='nn', + link_approx='mc', + n_samples=10, ) assert res != math.nan assert isinstance(res, float) assert res > 0 - From 1eea27cb7759602d2e484a0efa3efaa88f8e6a48 Mon Sep 17 00:00:00 2001 From: Agustinus Kristiadi Date: Sat, 27 Apr 2024 13:16:51 -0400 Subject: [PATCH 2/9] Typehints for `BaseLaplace` and `ParametricLaplace` --- laplace/baselaplace.py | 506 ++++++++++++++++++++++++++--------------- 1 file changed, 321 insertions(+), 185 deletions(-) diff --git a/laplace/baselaplace.py b/laplace/baselaplace.py index 3bb26a54..37ba1eda 100644 --- a/laplace/baselaplace.py +++ b/laplace/baselaplace.py @@ -1,10 +1,18 @@ +from __future__ import annotations + +from enum import Enum from math import sqrt, pi, log +from typing import Callable, List, Tuple, Type import numpy as np import torch +from torch import nn from torch.nn.utils import parameters_to_vector, vector_to_parameters +from torch.utils.data import DataLoader, TensorDataset +import torchmetrics import tqdm from collections.abc import MutableMapping from laplace.curvature.asdfghjkl import AsdfghjklHessian +from laplace.curvature.curvature import CurvatureInterface from laplace.curvature.curvlinops import CurvlinopsEF import warnings from torchmetrics import MeanSquaredError @@ -17,7 +25,7 @@ fix_prior_prec_structure, RunningNLLMetric, ) -from laplace.curvature import AsdlHessian, CurvlinopsGGN +from laplace.curvature import CurvlinopsGGN __all__ = [ @@ -27,16 +35,60 @@ 'KronLaplace', 'DiagLaplace', 'LowRankLaplace', + 'Likelihood', + 'PredType', + 'LinkApprox', + 'TuningMethod', + 'PriorStructure', ] +class Likelihood(str, Enum): + """Choices of likelihoods supported by Laplace""" + + REGRESSION = 'regression' + CLASSIFICATION = 'classification' + REWARD_MODELING = 'reward_modeling' + + +class PredType(str, Enum): + """Choices of predictive types. To obtain the p(f(x) | x, D)""" + + GLM = 'glm' + NN = 'nn' + + +class LinkApprox(str, Enum): + """Choices of the inverse link function p(f(x) | x, D) -> p(y | f(x), D)""" + + MC = 'mc' + PROBIT = 'probit' + BRIDGE = 'bridge' + BRIDGE_NORM = 'bridge_norm' + + +class TuningMethod(str, Enum): + """Choices of prior precision tuning methods""" + + MARGLIK = 'marglik' + GRIDSEARCH = 'gridsearch' + + +class PriorStructure(str, Enum): + """Choices of prior precision structures""" + + SCALAR = 'scalar' + DIAG = 'diag' + LAYERWISE = 'layerwise' + + class BaseLaplace: """Baseclass for all Laplace approximations in this library. Parameters ---------- model : torch.nn.Module - likelihood : {'classification', 'regression', 'reward_modeling'} + likelihood : Likelihood or str in {'classification', 'regression', 'reward_modeling'} determines the log likelihood Hessian approximation. In the case of 'reward_modeling', it fits Laplace in using the classification likelihood, then do prediction as in regression likelihood. The model needs to be defined accordingly: @@ -67,48 +119,48 @@ class BaseLaplace: def __init__( self, - model, - likelihood, - sigma_noise=1.0, - prior_precision=1.0, - prior_mean=0.0, - temperature=1.0, - enable_backprop=False, - backend=None, - backend_kwargs=None, - asdl_fisher_kwargs=None, - ): - if likelihood not in ['classification', 'regression', 'reward_modeling']: + model: nn.Module, + likelihood: Likelihood | str, + sigma_noise: float | torch.Tensor = 1.0, + prior_precision: float | torch.Tensor = 1.0, + prior_mean: float | torch.Tensor = 0.0, + temperature: float = 1.0, + enable_backprop: bool = False, + backend: Type[CurvatureInterface] | None = None, + backend_kwargs: dict | None = None, + asdl_fisher_kwargs: dict | None = None, + ) -> None: + if likelihood not in [lik.value for lik in Likelihood]: raise ValueError(f'Invalid likelihood type {likelihood}') - self.model = model + self.model: nn.Module = model # Only do Laplace on params that require grad - self.params = [] - self.is_subset_params = False + self.params: List[torch.Tensor] = [] + self.is_subset_params: bool = False for p in model.parameters(): if p.requires_grad: self.params.append(p) else: self.is_subset_params = True - self.n_params = sum(p.numel() for p in self.params) - self.n_layers = len(self.params) - self.prior_precision = prior_precision - self.prior_mean = prior_mean + self.n_params: int = sum(p.numel() for p in self.params) + self.n_layers: int = len(self.params) + self.prior_precision: float | torch.Tensor = prior_precision + self.prior_mean: float | torch.Tensor = prior_mean if sigma_noise != 1 and likelihood != 'regression': raise ValueError('Sigma noise != 1 only available for regression.') - self.reward_modeling = likelihood == 'reward_modeling' + self.reward_modeling: bool = likelihood == 'reward_modeling' if self.reward_modeling: # For fitting only. After it's done, self.likelihood = 'regression', see self.fit() self.likelihood = 'classification' else: self.likelihood = likelihood - self.sigma_noise = sigma_noise - self.temperature = temperature - self.enable_backprop = enable_backprop + self.sigma_noise: float | torch.Tensor = sigma_noise + self.temperature: float = temperature + self.enable_backprop: bool = enable_backprop if backend is None: backend = CurvlinopsGGN @@ -118,41 +170,49 @@ def __init__( 'If some grad are switched off, the BackPACK backend is not supported.' ) - self._backend = None - self._backend_cls = backend - self._backend_kwargs = dict() if backend_kwargs is None else backend_kwargs - self._asdl_fisher_kwargs = ( + self._backend: CurvatureInterface | None = None + self._backend_cls: Type[CurvatureInterface] = backend + self._backend_kwargs: dict = ( + dict() if backend_kwargs is None else backend_kwargs + ) + self._asdl_fisher_kwargs: dict = ( dict() if asdl_fisher_kwargs is None else asdl_fisher_kwargs ) # log likelihood = g(loss) - self.loss = 0.0 - self.n_outputs = None - self.n_data = 0 + self.loss: float = 0.0 + self.n_outputs: int = 0 + self.n_data: int = 0 @property - def _device(self): + def _device(self) -> torch.device: return next(self.model.parameters()).device @property - def backend(self): + def backend(self) -> CurvatureInterface: if self._backend is None: self._backend = self._backend_cls( self.model, self.likelihood, **self._backend_kwargs ) return self._backend - def _curv_closure(self, X, y, N): + def _curv_closure( + self, X: torch.Tensor | MutableMapping, y: torch.Tensor, N: int + ) -> Tuple[float, torch.Tensor]: raise NotImplementedError - def fit(self, train_loader): + def fit(self, train_loader: DataLoader) -> None: raise NotImplementedError - def log_marginal_likelihood(self, prior_precision=None, sigma_noise=None): + def log_marginal_likelihood( + self, + prior_precision: torch.Tensor | None = None, + sigma_noise: torch.Tensor | None = None, + ) -> torch.Tensor: raise NotImplementedError @property - def log_likelihood(self): + def log_likelihood(self) -> torch.Tensor: """Compute log likelihood on the training data after `.fit()` has been called. The log likelihood is computed on-demand based on the loss and, for example, the observation noise which makes it differentiable in the latter for @@ -168,20 +228,32 @@ def log_likelihood(self): c = ( self.n_data * self.n_outputs - * torch.log(self.sigma_noise * sqrt(2 * pi)) + * torch.log(torch.tensor(self.sigma_noise) * sqrt(2 * pi)) ) return factor * self.loss - c else: # for classification Xent == log Cat return factor * self.loss - def __call__(self, x, pred_type, link_approx, n_samples): + def __call__( + self, + x: torch.Tensor | MutableMapping, + pred_type: PredType | str, + link_approx: LinkApprox | str, + n_samples: int, + ) -> torch.Tensor | Tuple[torch.Tensor, torch.Tensor]: raise NotImplementedError - def predictive(self, x, pred_type, link_approx, n_samples): + def predictive( + self, + x: torch.Tensor, + pred_type: PredType | str, + link_approx: LinkApprox | str, + n_samples: int, + ) -> torch.Tensor | Tuple[torch.Tensor, torch.Tensor]: return self(x, pred_type, link_approx, n_samples) - def _check_jacobians(self, Js): + def _check_jacobians(self, Js: torch.Tensor) -> None: if not isinstance(Js, torch.Tensor): raise ValueError('Jacobians have to be torch.Tensor.') if not Js.device == self._device: @@ -191,7 +263,7 @@ def _check_jacobians(self, Js): raise ValueError('Invalid Jacobians shape for Laplace posterior approx.') @property - def prior_precision_diag(self): + def prior_precision_diag(self) -> torch.Tensor: """Obtain the diagonal prior precision \\(p_0\\) constructed from either a scalar, layer-wise, or diagonal prior precision. @@ -199,35 +271,38 @@ def prior_precision_diag(self): ------- prior_precision_diag : torch.Tensor """ - if len(self.prior_precision) == 1: # scalar - return self.prior_precision * torch.ones(self.n_params, device=self._device) - - elif len(self.prior_precision) == self.n_params: # diagonal - return self.prior_precision + prior_prec: torch.Tensor = ( + self.prior_precision + if isinstance(self.prior_precision, torch.Tensor) + else torch.tensor(self.prior_precision) + ) - elif len(self.prior_precision) == self.n_layers: # per layer + if prior_prec.ndim == 0 or len(prior_prec) == 1: # scalar + return self.prior_precision * torch.ones(self.n_params, device=self._device) + elif len(prior_prec) == self.n_params: # diagonal + return prior_prec + elif len(prior_prec) == self.n_layers: # per layer n_params_per_layer = [p.numel() for p in self.params] return torch.cat( [ prior * torch.ones(n_params, device=self._device) - for prior, n_params in zip(self.prior_precision, n_params_per_layer) + for prior, n_params in zip(prior_prec, n_params_per_layer) ] ) - else: raise ValueError( 'Mismatch of prior and model. Diagonal, scalar, or per-layer prior.' ) @property - def prior_mean(self): + def prior_mean(self) -> torch.Tensor: return self._prior_mean @prior_mean.setter - def prior_mean(self, prior_mean): + def prior_mean(self, prior_mean: float | torch.Tensor) -> None: if np.isscalar(prior_mean) and np.isreal(prior_mean): self._prior_mean = torch.tensor(prior_mean, device=self._device) - elif torch.is_tensor(prior_mean): + elif isinstance(prior_mean, torch.Tensor): if prior_mean.ndim == 0: self._prior_mean = prior_mean.reshape(-1).to(self._device) elif prior_mean.ndim == 1: @@ -240,15 +315,16 @@ def prior_mean(self, prior_mean): raise ValueError('Invalid argument type of prior mean.') @property - def prior_precision(self): + def prior_precision(self) -> torch.Tensor: return self._prior_precision @prior_precision.setter - def prior_precision(self, prior_precision): + def prior_precision(self, prior_precision: float | torch.Tensor): self._posterior_scale = None + if np.isscalar(prior_precision) and np.isreal(prior_precision): self._prior_precision = torch.tensor([prior_precision], device=self._device) - elif torch.is_tensor(prior_precision): + elif isinstance(prior_precision, torch.Tensor): if prior_precision.ndim == 0: # make dimensional self._prior_precision = prior_precision.reshape(-1).to(self._device) @@ -269,33 +345,33 @@ def prior_precision(self, prior_precision): def optimize_prior_precision_base( self, - pred_type, - method='marglik', - n_steps=100, - lr=1e-1, - init_prior_prec=1.0, - prior_structure='scalar', - val_loader=None, - loss=None, - log_prior_prec_min=-4, - log_prior_prec_max=4, - grid_size=100, - link_approx='probit', - n_samples=100, - verbose=False, - cv_loss_with_var=False, - progress_bar=False, + pred_type: PredType | str, + method: TuningMethod | str = TuningMethod.MARGLIK, + n_steps: int = 100, + lr: float = 1e-1, + init_prior_prec: float | torch.Tensor = 1.0, + prior_structure: PriorStructure | str = PriorStructure.SCALAR, + val_loader: DataLoader | None = None, + loss: torchmetrics.Metric | Callable | None = None, + log_prior_prec_min: float = -4, + log_prior_prec_max: float = 4, + grid_size: int = 100, + link_approx: LinkApprox | str = LinkApprox.PROBIT, + n_samples: int = 100, + verbose: bool = False, + cv_loss_with_var: bool = False, + progress_bar: bool = False, ): """Optimize the prior precision post-hoc using the `method` specified by the user. Parameters ---------- - pred_type : {'glm', 'nn', 'gp'}, default='glm' + pred_type : PredType or str in {'glm', 'nn'} type of posterior predictive, linearized GLM predictive or neural - network sampling predictive or Gaussian Process (GP) inference. - The GLM predictive is consistent with the curvature approximations used here. - method : {'marglik', 'gridsearch'}, default='marglik' + network sampling predictiv. The GLM predictive is consistent with the + curvature approximations used here. + method : TuningMethod or str in {'marglik', 'gridsearch'}, default=PredType.MARGLIK specifies how the prior precision should be optimized. n_steps : int, default=100 the number of gradient descent steps to take. @@ -303,7 +379,7 @@ def optimize_prior_precision_base( the learning rate to use for gradient descent. init_prior_prec : float or tensor, default=1.0 initial prior precision before the first optimization step. - prior_structure : {'scalar', 'layerwise', 'diag'}, default='scalar' + prior_structure : PriorStructure or str in {'scalar', 'layerwise', 'diag'}, default=PriorStructure.SCALAR if init_prior_prec is scalar, the prior precision is optimized with this structure. otherwise, the structure of init_prior_prec is maintained. val_loader : torch.data.utils.DataLoader, default=None @@ -322,7 +398,7 @@ def optimize_prior_precision_base( upper bound of gridsearch interval. grid_size : int, default=100 number of values to consider inside the gridsearch interval. - link_approx : {'mc', 'probit', 'bridge'}, default='probit' + link_approx : LinkApprox or str in {'mc', 'probit', 'bridge'}, default=LinkApprox.PROBIT how to approximate the classification link function for the `'glm'`. For `pred_type='nn'`, only `'mc'` is possible. n_samples : int, default=100 @@ -334,9 +410,17 @@ def optimize_prior_precision_base( whether to show a progress bar; updated at every batch-Hessian computation. Useful for very large model and large amount of data, esp. when `subset_of_weights='all'`. """ - if method == 'marglik': - self.prior_precision = init_prior_prec - if len(self.prior_precision) == 1 and prior_structure != 'scalar': + if method == TuningMethod.MARGLIK: + self.prior_precision = ( + init_prior_prec + if isinstance(init_prior_prec, torch.Tensor) + else torch.tensor(init_prior_prec) + ) + + if ( + len(self.prior_precision) == 1 + and prior_structure != PriorStructure.SCALAR + ): self.prior_precision = fix_prior_prec_structure( self.prior_precision.item(), prior_structure, @@ -344,6 +428,7 @@ def optimize_prior_precision_base( self.n_params, self._device, ) + log_prior_prec = self.prior_precision.log() log_prior_prec.requires_grad = True optimizer = torch.optim.Adam([log_prior_prec], lr=lr) @@ -362,8 +447,9 @@ def optimize_prior_precision_base( ) neg_log_marglik.backward() optimizer.step() + self.prior_precision = log_prior_prec.detach().exp() - elif method == 'gridsearch': + elif method == TuningMethod.GRIDSEARCH: if val_loader is None: raise ValueError('gridsearch requires a validation set DataLoader') @@ -388,25 +474,27 @@ def optimize_prior_precision_base( ) else: raise ValueError('For now only marglik and gridsearch is implemented.') + if verbose: print(f'Optimized prior precision is {self.prior_precision}.') def _gridsearch( self, - loss, - interval, - val_loader, - pred_type, - link_approx='probit', - n_samples=100, - loss_with_var=False, - progress_bar=False, + loss: torchmetrics.Metric | Callable, + interval: torch.Tensor, + val_loader: DataLoader, + pred_type: PredType | str, + link_approx: LinkApprox | str = LinkApprox.PROBIT, + n_samples: int = 100, + loss_with_var: bool = False, + progress_bar: bool = False, ): - assert callable(loss) or isinstance(loss, tm.Metric) + assert callable(loss) or isinstance(loss, torchmetrics.Metric) results = list() prior_precs = list() - pbar = tqdm.tqdm(interval) if progress_bar else interval + pbar = tqdm.tqdm(interval, disable=not progress_bar) + for prior_prec in pbar: self.prior_precision = prior_prec try: @@ -429,18 +517,20 @@ def _gridsearch( results.append(result) prior_precs.append(prior_prec) + return prior_precs[np.argmin(results)] @property - def sigma_noise(self): + def sigma_noise(self) -> torch.Tensor: return self._sigma_noise @sigma_noise.setter - def sigma_noise(self, sigma_noise): + def sigma_noise(self, sigma_noise: float | torch.Tensor) -> None: self._posterior_scale = None + if np.isscalar(sigma_noise) and np.isreal(sigma_noise): self._sigma_noise = torch.tensor(sigma_noise, device=self._device) - elif torch.is_tensor(sigma_noise): + elif isinstance(sigma_noise, torch.Tensor): if sigma_noise.ndim == 0: self._sigma_noise = sigma_noise.to(self._device) elif sigma_noise.ndim == 1: @@ -455,7 +545,7 @@ def sigma_noise(self, sigma_noise): ) @property - def _H_factor(self): + def _H_factor(self) -> torch.Tensor: sigma2 = self.sigma_noise.square() return 1 / sigma2 / self.temperature @@ -486,16 +576,16 @@ class ParametricLaplace(BaseLaplace): def __init__( self, - model, - likelihood, - sigma_noise=1.0, - prior_precision=1.0, - prior_mean=0.0, - temperature=1.0, - enable_backprop=False, - backend=None, - backend_kwargs=None, - asdl_fisher_kwargs=None, + model: nn.Module, + likelihood: Likelihood | str, + sigma_noise: float | torch.Tensor = 1.0, + prior_precision: float | torch.Tensor = 1.0, + prior_mean: float | torch.Tensor = 0.0, + temperature: float = 1.0, + enable_backprop: bool = False, + backend: Type[CurvatureInterface] | None = None, + backend_kwargs: dict | None = None, + asdl_fisher_kwargs: dict | None = None, ): super().__init__( model, @@ -512,16 +602,21 @@ def __init__( if not hasattr(self, 'H'): self._init_H() # posterior mean/mode - self.mean = self.prior_mean + self.mean: float | torch.Tensor = self.prior_mean - def _init_H(self): + def _init_H(self) -> None: raise NotImplementedError - def _check_H_init(self): + def _check_H_init(self) -> None: if self.H is None: raise AttributeError('Laplace not fitted. Run fit() first.') - def fit(self, train_loader, override=True, progress_bar=False): + def fit( + self, + train_loader: DataLoader, + override: bool = True, + progress_bar: bool = False, + ) -> None: """Fit the local Laplace approximation at the parameters of the model. Parameters @@ -538,16 +633,19 @@ def fit(self, train_loader, override=True, progress_bar=False): """ if override: self._init_H() - self.loss = 0 - self.n_data = 0 + self.loss: float | torch.Tensor = 0 + self.n_data: int = 0 self.model.eval() - self.mean = parameters_to_vector(self.params) + self.mean: torch.Tensor = parameters_to_vector(self.params) if not self.enable_backprop: self.mean = self.mean.detach() - data = next(iter(train_loader)) + data: Tuple[torch.Tensor, torch.Tensor] | MutableMapping = next( + iter(train_loader) + ) + with torch.no_grad(): if isinstance(data, MutableMapping): # To support Huggingface dataset if isinstance(self, DiagLaplace) and self._backend_cls == CurvlinopsEF: @@ -569,11 +667,9 @@ def fit(self, train_loader, override=True, progress_bar=False): setattr(self.model, 'output_size', self.n_outputs) N = len(train_loader.dataset) - if progress_bar: - pbar = tqdm.tqdm(train_loader) - pbar.set_description('[Computing Hessian]') - else: - pbar = train_loader + + pbar = tqdm.tqdm(train_loader, disable=not progress_bar) + pbar.set_description('[Computing Hessian]') for data in pbar: if isinstance(data, MutableMapping): # To support Huggingface dataset @@ -589,21 +685,20 @@ def fit(self, train_loader, override=True, progress_bar=False): self.n_data += N @property - def scatter(self): + def scatter(self) -> torch.Tensor: """Computes the _scatter_, a term of the log marginal likelihood that corresponds to L-2 regularization: `scatter` = \\((\\theta_{MAP} - \\mu_0)^{T} P_0 (\\theta_{MAP} - \\mu_0) \\). Returns ------- - [type] - [description] + scatter: torch.Tensor """ delta = self.mean - self.prior_mean return (delta * self.prior_precision_diag) @ delta @property - def log_det_prior_precision(self): + def log_det_prior_precision(self) -> torch.Tensor: """Compute log determinant of the prior precision \\(\\log \\det P_0\\) @@ -614,7 +709,7 @@ def log_det_prior_precision(self): return self.prior_precision_diag.log().sum() @property - def log_det_posterior_precision(self): + def log_det_posterior_precision(self) -> torch.Tensor: """Compute log determinant of the posterior precision \\(\\log \\det P\\) which depends on the subclasses structure used for the Hessian approximation. @@ -626,7 +721,7 @@ def log_det_posterior_precision(self): raise NotImplementedError @property - def log_det_ratio(self): + def log_det_ratio(self) -> torch.Tensor: """Compute the log determinant ratio, a part of the log marginal likelihood. \\[ \\log \\frac{\\det P}{\\det P_0} = \\log \\det P - \\log \\det P_0 @@ -638,7 +733,7 @@ def log_det_ratio(self): """ return self.log_det_posterior_precision - self.log_det_prior_precision - def square_norm(self, value): + def square_norm(self, value) -> torch.Tensor: """Compute the square norm under post. Precision with `value-self.mean` as 𝛥: \\[ \\Delta^\top P \\Delta @@ -649,11 +744,12 @@ def square_norm(self, value): """ raise NotImplementedError - def log_prob(self, value, normalized=True): + def log_prob(self, value: torch.Tensor, normalized: bool = True) -> torch.Tensor: """Compute the log probability under the (current) Laplace approximation. Parameters ---------- + value: torch.Tensor normalized : bool, default=True whether to return log of a properly normalized Gaussian or just the terms that depend on `value`. @@ -670,7 +766,11 @@ def log_prob(self, value, normalized=True): log_prob -= self.square_norm(value) / 2 return log_prob - def log_marginal_likelihood(self, prior_precision=None, sigma_noise=None): + def log_marginal_likelihood( + self, + prior_precision: torch.Tensor | None = None, + sigma_noise: torch.Tensor | None = None, + ) -> torch.Tensor: """Compute the Laplace approximation to the log marginal likelihood subject to specific Hessian approximations that subclasses implement. Requires that the Laplace approximation has been fit before. @@ -683,7 +783,7 @@ def log_marginal_likelihood(self, prior_precision=None, sigma_noise=None): ---------- prior_precision : torch.Tensor, optional prior precision if should be changed from current `prior_precision` value - sigma_noise : [type], optional + sigma_noise : torch.Tensor, optional observation noise standard deviation if should be changed Returns @@ -704,15 +804,15 @@ def log_marginal_likelihood(self, prior_precision=None, sigma_noise=None): def __call__( self, - x, - pred_type='glm', - joint=False, - link_approx='probit', - n_samples=100, - diagonal_output=False, - generator=None, + x: torch.Tensor | MutableMapping, + pred_type: PredType | str = PredType.GLM, + joint: bool = False, + link_approx: LinkApprox | str = LinkApprox.PROBIT, + n_samples: int = 100, + diagonal_output: bool = False, + generator: torch.Generator | None = None, **model_kwargs, - ): + ) -> torch.Tensor | Tuple[torch.Tensor, torch.Tensor]: """Compute the posterior predictive on input data `x`. Parameters @@ -759,13 +859,13 @@ def __call__( For `likelihood='regression'` and `joint=True`, a tuple of torch.Tensor is returned with the mean and the predictive covariance. """ - if pred_type not in ['glm', 'nn']: + if pred_type not in [pred for pred in PredType]: raise ValueError('Only glm and nn supported as prediction types.') - if link_approx not in ['mc', 'probit', 'bridge', 'bridge_norm']: + if link_approx not in [la for la in LinkApprox]: raise ValueError(f'Unsupported link approximation {link_approx}.') - if pred_type == 'nn' and link_approx != 'mc': + if pred_type == PredType.NN and link_approx != LinkApprox.MC: raise ValueError( 'Only mc link approximation is supported for nn prediction type.' ) @@ -773,31 +873,31 @@ def __call__( if generator is not None: if ( not isinstance(generator, torch.Generator) - or generator.device != x.device + or generator.device != self._device ): raise ValueError('Invalid random generator (check type and device).') # For reward modeling, replace the likelihood to regression and override model state - if self.reward_modeling and self.likelihood == 'classification': + if self.reward_modeling and self.likelihood == Likelihood.CLASSIFICATION: self.likelihood = 'regression' - self.model.output_size = 1 + setattr(self.model, 'output_size', 1) - if pred_type == 'glm': + if pred_type == PredType.GLM: f_mu, f_var = self._glm_predictive_distribution( x, joint=joint and self.likelihood == 'regression' ) - # regression - if self.likelihood == 'regression': + + if self.likelihood == Likelihood.REGRESSION: return f_mu, f_var - # classification - if link_approx == 'mc': + + if link_approx == LinkApprox.MC: return self.predictive_samples( x, pred_type='glm', n_samples=n_samples, diagonal_output=diagonal_output, ).mean(dim=0) - elif link_approx == 'probit': + elif link_approx == LinkApprox.PROBIT: kappa = 1 / torch.sqrt(1.0 + np.pi / 8 * f_var.diagonal(dim1=1, dim2=2)) return torch.softmax(kappa * f_mu, dim=-1) elif 'bridge' in link_approx: @@ -810,36 +910,48 @@ def __call__( f_var -= torch.einsum( 'bi,bj->bij', f_var.sum(-1), f_var.sum(-2) ) / f_var.sum(dim=(1, 2)).reshape(-1, 1, 1) + # Laplace Bridge _, K = f_mu.size(0), f_mu.size(-1) f_var_diag = torch.diagonal(f_var, dim1=1, dim2=2) + # optional: variance correction - if link_approx == 'bridge_norm': + if link_approx == LinkApprox.BRIDGE_NORM: f_var_diag_mean = f_var_diag.mean(dim=1) f_var_diag_mean /= torch.as_tensor( [K / 2], device=self._device ).sqrt() f_mu /= f_var_diag_mean.sqrt().unsqueeze(-1) f_var_diag /= f_var_diag_mean.unsqueeze(-1) + sum_exp = torch.exp(-f_mu).sum(dim=1).unsqueeze(-1) alpha = (1 - 2 / K + f_mu.exp() / K**2 * sum_exp) / f_var_diag return torch.nan_to_num(alpha / alpha.sum(dim=1).unsqueeze(-1), nan=1.0) + else: + raise ValueError( + 'Prediction path invalid. Check the likelihood, pred_type, link_approx combination!' + ) else: - if self.likelihood == 'regression': + if self.likelihood == Likelihood.REGRESSION: samples = self._nn_predictive_samples(x, n_samples, **model_kwargs) return samples.mean(dim=0), samples.var(dim=0) else: # classification; the average is computed online return self._nn_predictive_classification(x, n_samples, **model_kwargs) def predictive_samples( - self, x, pred_type='glm', n_samples=100, diagonal_output=False, generator=None - ): + self, + x: torch.Tensor | MutableMapping, + pred_type: PredType | str = PredType.GLM, + n_samples: int = 100, + diagonal_output: bool = False, + generator: torch.Generator | None = None, + ) -> torch.Tensor: """Sample from the posterior predictive on input data `x`. Can be used, for example, for Thompson sampling. Parameters ---------- - x : torch.Tensor + x : torch.Tensor or MutableMapping input data `(batch_size, input_shape)` pred_type : {'glm', 'nn'}, default='glm' @@ -862,18 +974,21 @@ def predictive_samples( samples : torch.Tensor samples `(n_samples, batch_size, output_shape)` """ - if pred_type not in ['glm', 'nn']: + if pred_type not in PredType.__members__.values(): raise ValueError('Only glm and nn supported as prediction types.') - if pred_type == 'glm': + if pred_type == PredType.GLM: f_mu, f_var = self._glm_predictive_distribution(x) assert f_var.shape == torch.Size( [f_mu.shape[0], f_mu.shape[1], f_mu.shape[1]] ) + if diagonal_output: f_var = torch.diagonal(f_var, dim1=1, dim2=2) + f_samples = normal_samples(f_mu, f_var, n_samples, generator) - if self.likelihood == 'regression': + + if self.likelihood == Likelihood.REGRESSION: return f_samples else: return torch.softmax(f_samples, dim=-1) @@ -882,7 +997,9 @@ def predictive_samples( return self._nn_predictive_samples(x, n_samples, generator) @torch.enable_grad() - def _glm_predictive_distribution(self, X, joint=False): + def _glm_predictive_distribution( + self, X: torch.Tensor | MutableMapping, joint: bool = False + ) -> Tuple[torch.Tensor, torch.Tensor]: if 'backpack' in self._backend_cls.__name__.lower(): # BackPACK supports backprop through Jacobians, but it interferes with functorch Js, f_mu = self.backend.jacobians(X, enable_backprop=self.enable_backprop) @@ -904,7 +1021,13 @@ def _glm_predictive_distribution(self, X, joint=False): else (f_mu, f_var) ) - def _nn_predictive_samples(self, X, n_samples=100, generator=None, **model_kwargs): + def _nn_predictive_samples( + self, + X: torch.Tensor | MutableMapping, + n_samples: int = 100, + generator: torch.Generator | None = None, + **model_kwargs, + ) -> torch.Tensor: fs = list() for sample in self.sample(n_samples, generator): vector_to_parameters(sample, self.params) @@ -912,24 +1035,31 @@ def _nn_predictive_samples(self, X, n_samples=100, generator=None, **model_kwarg X.to(self._device) if isinstance(X, torch.Tensor) else X, **model_kwargs ) fs.append(logits.detach() if not self.enable_backprop else logits) + vector_to_parameters(self.mean, self.params) fs = torch.stack(fs) - if self.likelihood == 'classification': + + if self.likelihood == Likelihood.CLASSIFICATION: fs = torch.softmax(fs, dim=-1) + return fs - def _nn_predictive_classification(self, X, n_samples=100, **model_kwargs): - py = 0 + def _nn_predictive_classification( + self, X: torch.Tensor | MutableMapping, n_samples: int = 100, **model_kwargs + ) -> torch.Tensor: + py = torch.Tensor(0.0) for sample in self.sample(n_samples): vector_to_parameters(sample, self.params) logits = self.model( X.to(self._device) if isinstance(X, torch.Tensor) else X, **model_kwargs ).detach() py += torch.softmax(logits, dim=-1) / n_samples + vector_to_parameters(self.mean, self.params) + return py - def functional_variance(self, Jacs): + def functional_variance(self, Jacs: torch.Tensor) -> torch.Tensor: """Compute functional variance for the `'glm'` predictive: `f_var[i] = Jacs[i] @ P.inv() @ Jacs[i].T`, which is a output x output predictive covariance matrix. @@ -951,7 +1081,7 @@ def functional_variance(self, Jacs): """ raise NotImplementedError - def functional_covariance(self, Jacs): + def functional_covariance(self, Jacs: torch.Tensor) -> torch.Tensor: """Compute functional covariance for the `'glm'` predictive: `f_cov = Jacs @ P.inv() @ Jacs.T`, which is a batch*output x batch*output predictive covariance matrix. @@ -972,7 +1102,9 @@ def functional_covariance(self, Jacs): """ raise NotImplementedError - def sample(self, n_samples=100, generator=None): + def sample( + self, n_samples: int = 100, generator: torch.Generator | None = None + ) -> torch.Tensor: """Sample from the Laplace posterior approximation, i.e., \\( \\theta \\sim \\mathcal{N}(\\theta_{MAP}, P^{-1})\\). @@ -983,29 +1115,33 @@ def sample(self, n_samples=100, generator=None): generator : torch.Generator, optional random number generator to control the samples + + Returns + ------- + samples: torch.Tensor """ raise NotImplementedError def optimize_prior_precision( self, - method='marglik', - pred_type='glm', - n_steps=100, - lr=1e-1, - init_prior_prec=1.0, - prior_structure='scalar', - val_loader=None, - loss=None, - log_prior_prec_min=-4, - log_prior_prec_max=4, - grid_size=100, - link_approx='probit', - n_samples=100, - verbose=False, - cv_loss_with_var=False, - progress_bar=False, - ): - assert pred_type in ['glm', 'nn'] + method: TuningMethod | str = TuningMethod.MARGLIK, + pred_type: PredType | str = PredType.GLM, + n_steps: int = 100, + lr: float = 1e-1, + init_prior_prec: float | torch.Tensor = 1.0, + prior_structure: PriorStructure | str = PriorStructure.SCALAR, + val_loader: DataLoader | None = None, + loss: torchmetrics.Metric | Callable | None = None, + log_prior_prec_min: float = -4, + log_prior_prec_max: float = 4, + grid_size: int = 100, + link_approx: LinkApprox | str = LinkApprox.PROBIT, + n_samples: int = 100, + verbose: bool = False, + cv_loss_with_var: bool = False, + progress_bar: bool = False, + ) -> None: + assert pred_type in PredType.__members__.values() self.optimize_prior_precision_base( pred_type, method, @@ -1026,7 +1162,7 @@ def optimize_prior_precision( ) @property - def posterior_precision(self): + def posterior_precision(self) -> torch.Tensor: """Compute or return the posterior precision \\(P\\). Returns From b71b07a5028e31a9b45e229c58be823a861b9516 Mon Sep 17 00:00:00 2001 From: Agustinus Kristiadi Date: Sat, 27 Apr 2024 14:44:08 -0400 Subject: [PATCH 3/9] Finish typehinting `baselaplace.py` --- laplace/baselaplace.py | 249 ++++++++++++++++++++++++----------------- 1 file changed, 145 insertions(+), 104 deletions(-) diff --git a/laplace/baselaplace.py b/laplace/baselaplace.py index 37ba1eda..acdaab05 100644 --- a/laplace/baselaplace.py +++ b/laplace/baselaplace.py @@ -7,7 +7,7 @@ import torch from torch import nn from torch.nn.utils import parameters_to_vector, vector_to_parameters -from torch.utils.data import DataLoader, TensorDataset +from torch.utils.data import DataLoader import torchmetrics import tqdm from collections.abc import MutableMapping @@ -26,6 +26,7 @@ RunningNLLMetric, ) from laplace.curvature import CurvlinopsGGN +from laplace.utils.matrix import KronDecomposed __all__ = [ @@ -198,7 +199,7 @@ def backend(self) -> CurvatureInterface: def _curv_closure( self, X: torch.Tensor | MutableMapping, y: torch.Tensor, N: int - ) -> Tuple[float, torch.Tensor]: + ) -> Tuple[torch.Tensor, torch.Tensor]: raise NotImplementedError def fit(self, train_loader: DataLoader) -> None: @@ -1189,7 +1190,7 @@ def state_dict(self) -> dict: } return state_dict - def load_state_dict(self, state_dict: dict): + def load_state_dict(self, state_dict: dict) -> None: # Dealbreaker errors if self.__class__.__name__ != state_dict['cls_name']: raise ValueError( @@ -1247,15 +1248,15 @@ class FullLaplace(ParametricLaplace): def __init__( self, - model, - likelihood, - sigma_noise=1.0, - prior_precision=1.0, - prior_mean=0.0, - temperature=1.0, - enable_backprop=False, - backend=None, - backend_kwargs=None, + model: nn.Module, + likelihood: Likelihood | str, + sigma_noise: float | torch.Tensor = 1.0, + prior_precision: float | torch.Tensor = 1.0, + prior_mean: float | torch.Tensor = 0.0, + temperature: float = 1.0, + enable_backprop: bool = False, + backend: Type[CurvatureInterface] | None = None, + backend_kwargs: dict | None = None, ): super().__init__( model, @@ -1268,23 +1269,32 @@ def __init__( backend, backend_kwargs, ) - self._posterior_scale = None + self._posterior_scale: torch.Tensor | None = None - def _init_H(self): - self.H = torch.zeros(self.n_params, self.n_params, device=self._device) + def _init_H(self) -> None: + self.H: torch.Tensor = torch.zeros( + self.n_params, self.n_params, device=self._device + ) - def _curv_closure(self, X, y, N): + def _curv_closure( + self, X: torch.Tensor | MutableMapping, y: torch.Tensor, N: int + ) -> Tuple[torch.Tensor, torch.Tensor]: return self.backend.full(X, y, N=N) - def fit(self, train_loader, override=True, progress_bar=False): + def fit( + self, + train_loader: DataLoader, + override: bool = True, + progress_bar: bool = False, + ) -> None: self._posterior_scale = None - return super().fit(train_loader, override=override, progress_bar=progress_bar) + super().fit(train_loader, override=override, progress_bar=progress_bar) - def _compute_scale(self): + def _compute_scale(self) -> None: self._posterior_scale = invsqrt_precision(self.posterior_precision) @property - def posterior_scale(self): + def posterior_scale(self) -> torch.Tensor: """Posterior scale (square root of the covariance), i.e., \\(P^{-\\frac{1}{2}}\\). @@ -1298,7 +1308,7 @@ def posterior_scale(self): return self._posterior_scale @property - def posterior_covariance(self): + def posterior_covariance(self) -> torch.Tensor: """Posterior covariance, i.e., \\(P^{-1}\\). Returns @@ -1310,7 +1320,7 @@ def posterior_covariance(self): return scale @ scale.T @property - def posterior_precision(self): + def posterior_precision(self) -> torch.Tensor: """Posterior precision \\(P\\). Returns @@ -1322,22 +1332,24 @@ def posterior_precision(self): return self._H_factor * self.H + torch.diag(self.prior_precision_diag) @property - def log_det_posterior_precision(self): + def log_det_posterior_precision(self) -> torch.Tensor: return self.posterior_precision.logdet() - def square_norm(self, value): + def square_norm(self, value: torch.Tensor) -> torch.Tensor: delta = value - self.mean return delta @ self.posterior_precision @ delta - def functional_variance(self, Js): - return torch.einsum('ncp,pq,nkq->nck', Js, self.posterior_covariance, Js) + def functional_variance(self, Jacs: torch.Tensor) -> torch.Tensor: + return torch.einsum('ncp,pq,nkq->nck', Jacs, self.posterior_covariance, Jacs) - def functional_covariance(self, Js): - n_batch, n_outs, n_params = Js.shape - Js = Js.reshape(n_batch * n_outs, n_params) - return torch.einsum('np,pq,mq->nm', Js, self.posterior_covariance, Js) + def functional_covariance(self, Jacs: torch.Tensor) -> torch.Tensor: + n_batch, n_outs, n_params = Jacs.shape + Jacs = Jacs.reshape(n_batch * n_outs, n_params) + return torch.einsum('np,pq,mq->nm', Jacs, self.posterior_covariance, Jacs) - def sample(self, n_samples=100, generator=None): + def sample( + self, n_samples: int = 100, generator: torch.Generator | None = None + ) -> torch.Tensor: samples = torch.randn( n_samples, self.n_params, device=self._device, generator=generator ) @@ -1364,20 +1376,20 @@ class KronLaplace(ParametricLaplace): def __init__( self, - model, - likelihood, - sigma_noise=1.0, - prior_precision=1.0, - prior_mean=0.0, - temperature=1.0, - enable_backprop=False, - backend=None, - damping=False, - backend_kwargs=None, - asdl_fisher_kwargs=None, + model: nn.Module, + likelihood: Likelihood | str, + sigma_noise: float | torch.Tensor = 1.0, + prior_precision: float | torch.Tensor = 1.0, + prior_mean: float | torch.Tensor = 0.0, + temperature: float = 1.0, + enable_backprop: bool = False, + backend: Type[CurvatureInterface] | None = None, + damping: bool = False, + backend_kwargs: dict | None = None, + asdl_fisher_kwargs: dict | None = None, ): - self.damping = damping - self.H_facs = None + self.damping: bool = damping + self.H_facs: Kron | None = None super().__init__( model, likelihood, @@ -1391,26 +1403,35 @@ def __init__( asdl_fisher_kwargs, ) - def _init_H(self): - self.H = Kron.init_from_model(self.params, self._device) + def _init_H(self) -> None: + self.H: Kron | KronDecomposed | None = Kron.init_from_model( + self.params, self._device + ) - def _curv_closure(self, X, y, N): + def _curv_closure( + self, X: torch.Tensor | MutableMapping, y: torch.Tensor, N: int + ) -> Tuple[torch.Tensor, torch.Tensor]: return self.backend.kron(X, y, N=N, **self._asdl_fisher_kwargs) @staticmethod - def _rescale_factors(kron, factor): + def _rescale_factors(kron: Kron, factor: float) -> Kron: for F in kron.kfacs: if len(F) == 2: F[1] *= factor return kron - def fit(self, train_loader, override=True, progress_bar=False): + def fit( + self, + train_loader: DataLoader, + override: bool = True, + progress_bar: bool = False, + ) -> None: if override: self.H_facs = None if self.H_facs is not None: - n_data_old = self.n_data - n_data_new = len(train_loader.dataset) + n_data_old: int = self.n_data + n_data_new: int = len(train_loader.dataset) self._init_H() # re-init H non-decomposed # discount previous Kronecker factors to sum up properly together with new ones self.H_facs = self._rescale_factors( @@ -1427,11 +1448,12 @@ def fit(self, train_loader, override=True, progress_bar=False): self.H, n_data_new / (n_data_new + n_data_old) ) self.H_facs += self.H + # Decompose to self.H for all required quantities but keep H_facs for further inference self.H = self.H_facs.decompose(damping=self.damping) @property - def posterior_precision(self): + def posterior_precision(self) -> KronDecomposed: """Kronecker factored Posterior precision \\(P\\). Returns @@ -1442,29 +1464,31 @@ def posterior_precision(self): return self.H * self._H_factor + self.prior_precision @property - def log_det_posterior_precision(self): + def log_det_posterior_precision(self) -> torch.Tensor: if type(self.H) is Kron: # Fall back to diag prior return self.prior_precision_diag.log().sum() return self.posterior_precision.logdet() - def square_norm(self, value): + def square_norm(self, value: torch.Tensor) -> torch.Tensor: delta = value - self.mean if type(self.H) is Kron: # fall back to prior return (delta * self.prior_precision_diag) @ delta return delta @ self.posterior_precision.bmm(delta, exponent=1) - def functional_variance(self, Js): - return self.posterior_precision.inv_square_form(Js) + def functional_variance(self, Jacs: torch.Tensor) -> torch.Tensor: + return self.posterior_precision.inv_square_form(Jacs) - def functional_covariance(self, Js): - self._check_jacobians(Js) - n_batch, n_outs, n_params = Js.shape - Js = Js.reshape(n_batch * n_outs, n_params).unsqueeze(0) - cov = self.posterior_precision.inv_square_form(Js).squeeze(0) + def functional_covariance(self, Jacs: torch.Tensor) -> torch.Tensor: + self._check_jacobians(Jacs) + n_batch, n_outs, n_params = Jacs.shape + Jacs = Jacs.reshape(n_batch * n_outs, n_params).unsqueeze(0) + cov = self.posterior_precision.inv_square_form(Jacs).squeeze(0) assert cov.shape == (n_batch * n_outs, n_batch * n_outs) return cov - def sample(self, n_samples=100, generator=None): + def sample( + self, n_samples: int = 100, generator: torch.Generator | None = None + ) -> torch.Tensor: samples = torch.randn( n_samples, self.n_params, device=self._device, generator=generator ) @@ -1474,7 +1498,7 @@ def sample(self, n_samples=100, generator=None): ) @BaseLaplace.prior_precision.setter - def prior_precision(self, prior_precision): + def prior_precision(self, prior_precision: torch.Tensor) -> None: # Extend setter from Laplace to restrict prior precision structure. super(KronLaplace, type(self)).prior_precision.fset(self, prior_precision) if len(self.prior_precision) not in [1, self.n_layers]: @@ -1482,12 +1506,14 @@ def prior_precision(self, prior_precision): def state_dict(self) -> dict: state_dict = super().state_dict() + assert isinstance(self.H_facs, Kron) state_dict['H'] = self.H_facs.kfacs return state_dict def load_state_dict(self, state_dict: dict): super().load_state_dict(state_dict) self._init_H() + assert isinstance(self.H, Kron) self.H_facs = self.H self.H_facs.kfacs = state_dict['H'] self.H = self.H_facs.decompose(damping=self.damping) @@ -1511,15 +1537,14 @@ class LowRankLaplace(ParametricLaplace): def __init__( self, - model, - likelihood, - sigma_noise=1, - prior_precision=1, - prior_mean=0, - temperature=1, - enable_backprop=False, - backend=AsdfghjklHessian, - backend_kwargs=None, + model: nn.Module, + likelihood: Likelihood | str, + sigma_noise: float | torch.Tensor = 1, + prior_precision: float | torch.Tensor = 1, + prior_mean: float | torch.Tensor = 0, + temperature: float = 1, + enable_backprop: bool = False, + backend_kwargs: dict | None = None, ): super().__init__( model, @@ -1529,24 +1554,30 @@ def __init__( prior_mean=prior_mean, temperature=temperature, enable_backprop=enable_backprop, - backend=backend, + backend=AsdfghjklHessian, backend_kwargs=backend_kwargs, ) + self.backend: AsdfghjklHessian def _init_H(self): - self.H = None + self.H: Tuple[torch.Tensor, torch.Tensor] | None = None @property - def V(self): - (U, l), prior_prec_diag = self.posterior_precision + def V(self) -> torch.Tensor: + (U, eigvals), prior_prec_diag = self.posterior_precision return U / prior_prec_diag.reshape(-1, 1) @property - def Kinv(self): - (U, l), _ = self.posterior_precision - return torch.inverse(torch.diag(1 / l) + U.T @ self.V) + def Kinv(self) -> torch.Tensor: + (U, eigvals), _ = self.posterior_precision + return torch.inverse(torch.diag(1 / eigvals) + U.T @ self.V) - def fit(self, train_loader, override=True): + def fit( + self, + train_loader: DataLoader, + override: bool = True, + progress_bar: bool = False, + ) -> None: # override fit since output of eighessian not additive across batch if not override: # LowRankLA cannot be updated since eigenvalue representation not additive @@ -1574,7 +1605,9 @@ def fit(self, train_loader, override=True): self.n_data = len(train_loader.dataset) @property - def posterior_precision(self): + def posterior_precision( + self, + ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]: """Return correctly scaled posterior precision that would be constructed as H[0] @ diag(H[1]) @ H[0].T + self.prior_precision_diag. @@ -1588,13 +1621,13 @@ def posterior_precision(self): self._check_H_init() return (self.H[0], self._H_factor * self.H[1]), self.prior_precision_diag - def functional_variance(self, Jacs): + def functional_variance(self, Jacs: torch.Tensor) -> torch.Tensor: prior_var = torch.einsum('ncp,nkp->nck', Jacs / self.prior_precision_diag, Jacs) Jacs_V = torch.einsum('ncp,pl->ncl', Jacs, self.V) info_gain = torch.einsum('ncl,nkl->nck', Jacs_V @ self.Kinv, Jacs_V) return prior_var - info_gain - def functional_covariance(self, Jacs): + def functional_covariance(self, Jacs: torch.Tensor) -> torch.Tensor: n_batch, n_outs, n_params = Jacs.shape Jacs = Jacs.reshape(n_batch * n_outs, n_params) prior_cov = torch.einsum('np,mp->nm', Jacs / self.prior_precision_diag, Jacs) @@ -1604,7 +1637,9 @@ def functional_covariance(self, Jacs): assert cov.shape == (n_batch * n_outs, n_batch * n_outs) return cov - def sample(self, n_samples, generator=None): + def sample( + self, n_samples: int = 100, generator: torch.Generator | None = None + ) -> torch.Tensor: samples = torch.randn(self.n_params, n_samples, generator=generator) d = self.prior_precision_diag Vs = self.V * d.sqrt().reshape(-1, 1) @@ -1621,9 +1656,11 @@ def sample(self, n_samples, generator=None): return self.mean + (prior_sample - gain_sample).T @property - def log_det_posterior_precision(self): - (U, l), prior_prec_diag = self.posterior_precision - return l.log().sum() + prior_prec_diag.log().sum() - torch.logdet(self.Kinv) + def log_det_posterior_precision(self) -> torch.Tensor: + (U, eigvals), prior_prec_diag = self.posterior_precision + return ( + eigvals.log().sum() + prior_prec_diag.log().sum() - torch.logdet(self.Kinv) + ) class DiagLaplace(ParametricLaplace): @@ -1636,14 +1673,16 @@ class DiagLaplace(ParametricLaplace): # key to map to correct subclass of BaseLaplace, (subset of weights, Hessian structure) _key = ('all', 'diag') - def _init_H(self): - self.H = torch.zeros(self.n_params, device=self._device) + def _init_H(self) -> None: + self.H: torch.Tensor = torch.zeros(self.n_params, device=self._device) - def _curv_closure(self, X, y, N): + def _curv_closure( + self, X: torch.Tensor | MutableMapping, y: torch.Tensor, N: int + ) -> Tuple[torch.Tensor, torch.Tensor]: return self.backend.diag(X, y, N=N, **self._asdl_fisher_kwargs) @property - def posterior_precision(self): + def posterior_precision(self) -> torch.Tensor: """Diagonal posterior precision \\(p\\). Returns @@ -1655,7 +1694,7 @@ def posterior_precision(self): return self._H_factor * self.H + self.prior_precision_diag @property - def posterior_scale(self): + def posterior_scale(self) -> torch.Tensor: """Diagonal posterior scale \\(\\sqrt{p^{-1}}\\). Returns @@ -1666,7 +1705,7 @@ def posterior_scale(self): return 1 / self.posterior_precision.sqrt() @property - def posterior_variance(self): + def posterior_variance(self) -> torch.Tensor: """Diagonal posterior variance \\(p^{-1}\\). Returns @@ -1677,25 +1716,27 @@ def posterior_variance(self): return 1 / self.posterior_precision @property - def log_det_posterior_precision(self): + def log_det_posterior_precision(self) -> torch.Tensor: return self.posterior_precision.log().sum() - def square_norm(self, value): + def square_norm(self, value: torch.Tensor) -> torch.Tensor: delta = value - self.mean return delta @ (delta * self.posterior_precision) - def functional_variance(self, Js: torch.Tensor) -> torch.Tensor: - self._check_jacobians(Js) - return torch.einsum('ncp,p,nkp->nck', Js, self.posterior_variance, Js) + def functional_variance(self, Jacs: torch.Tensor) -> torch.Tensor: + self._check_jacobians(Jacs) + return torch.einsum('ncp,p,nkp->nck', Jacs, self.posterior_variance, Jacs) - def functional_covariance(self, Js): - self._check_jacobians(Js) - n_batch, n_outs, n_params = Js.shape - Js = Js.reshape(n_batch * n_outs, n_params) - cov = torch.einsum('np,p,mp->nm', Js, self.posterior_variance, Js) + def functional_covariance(self, Jacs: torch.Tensor) -> torch.Tensor: + self._check_jacobians(Jacs) + n_batch, n_outs, n_params = Jacs.shape + Jacs = Jacs.reshape(n_batch * n_outs, n_params) + cov = torch.einsum('np,p,mp->nm', Jacs, self.posterior_variance, Jacs) return cov - def sample(self, n_samples=100, generator=None): + def sample( + self, n_samples: int = 100, generator: torch.Generator | None = None + ) -> torch.Tensor: samples = torch.randn( n_samples, self.n_params, device=self._device, generator=generator ) From dace2a358796cc7b8f4d3428c9c6da59f9540809 Mon Sep 17 00:00:00 2001 From: Agustinus Kristiadi Date: Sat, 27 Apr 2024 16:44:34 -0400 Subject: [PATCH 4/9] Typehinting `laplace.py` and `marglik_training.py` --- laplace/__init__.py | 20 ++++++++-- laplace/baselaplace.py | 26 +++++++------ laplace/laplace.py | 27 +++++++------ laplace/marglik_training.py | 76 ++++++++++++++++++++----------------- tests/test_baselaplace.py | 3 +- tests/test_serialization.py | 2 +- tests/test_subset_params.py | 4 +- 7 files changed, 91 insertions(+), 67 deletions(-) diff --git a/laplace/__init__.py b/laplace/__init__.py index 0d338f91..d8393df0 100644 --- a/laplace/__init__.py +++ b/laplace/__init__.py @@ -5,9 +5,6 @@ .. include:: ../examples/calibration_example.md """ -REGRESSION = 'regression' -CLASSIFICATION = 'classification' - from laplace.baselaplace import ( BaseLaplace, ParametricLaplace, @@ -15,6 +12,13 @@ KronLaplace, DiagLaplace, LowRankLaplace, + SubsetOfWeights, + HessianStructure, + Likelihood, + PredType, + LinkApprox, + TuningMethod, + PriorStructure, ) from laplace.lllaplace import LLLaplace, FullLLLaplace, KronLLLaplace, DiagLLLaplace from laplace.subnetlaplace import SubnetLaplace, FullSubnetLaplace, DiagSubnetLaplace @@ -37,4 +41,12 @@ 'FullSubnetLaplace', 'DiagSubnetLaplace', # subnetwork 'marglik_training', -] # methods + # Enums + 'SubsetOfWeights', + 'HessianStructure', + 'Likelihood', + 'PredType', + 'LinkApprox', + 'TuningMethod', + 'PriorStructure', +] diff --git a/laplace/baselaplace.py b/laplace/baselaplace.py index acdaab05..18b9c8fe 100644 --- a/laplace/baselaplace.py +++ b/laplace/baselaplace.py @@ -41,27 +41,35 @@ 'LinkApprox', 'TuningMethod', 'PriorStructure', + 'SubsetOfWeights', ] -class Likelihood(str, Enum): - """Choices of likelihoods supported by Laplace""" +class SubsetOfWeights(str, Enum): + ALL = 'all' + LAST_LAYER = 'last_layer' + SUBNETWORK = 'subnetwork' + + +class HessianStructure(str, Enum): + FULL = 'full' + KRON = 'kron' + DIAG = 'diag' + LOWRANK = 'lowrank' + +class Likelihood(str, Enum): REGRESSION = 'regression' CLASSIFICATION = 'classification' REWARD_MODELING = 'reward_modeling' class PredType(str, Enum): - """Choices of predictive types. To obtain the p(f(x) | x, D)""" - GLM = 'glm' NN = 'nn' class LinkApprox(str, Enum): - """Choices of the inverse link function p(f(x) | x, D) -> p(y | f(x), D)""" - MC = 'mc' PROBIT = 'probit' BRIDGE = 'bridge' @@ -69,15 +77,11 @@ class LinkApprox(str, Enum): class TuningMethod(str, Enum): - """Choices of prior precision tuning methods""" - MARGLIK = 'marglik' GRIDSEARCH = 'gridsearch' class PriorStructure(str, Enum): - """Choices of prior precision structures""" - SCALAR = 'scalar' DIAG = 'diag' LAYERWISE = 'layerwise' @@ -1048,7 +1052,7 @@ def _nn_predictive_samples( def _nn_predictive_classification( self, X: torch.Tensor | MutableMapping, n_samples: int = 100, **model_kwargs ) -> torch.Tensor: - py = torch.Tensor(0.0) + py = 0.0 for sample in self.sample(n_samples): vector_to_parameters(sample, self.params) logits = self.model( diff --git a/laplace/laplace.py b/laplace/laplace.py index f8c45778..9469c725 100644 --- a/laplace/laplace.py +++ b/laplace/laplace.py @@ -1,24 +1,29 @@ -from laplace.baselaplace import ParametricLaplace -from laplace import * +from laplace.baselaplace import ( + SubsetOfWeights, + HessianStructure, + Likelihood, + ParametricLaplace, +) +import torch def Laplace( - model, - likelihood, - subset_of_weights='last_layer', - hessian_structure='kron', + model: torch.nn.Module, + likelihood: Likelihood | str, + subset_of_weights: SubsetOfWeights | str = SubsetOfWeights.LAST_LAYER, + hessian_structure: HessianStructure | str = HessianStructure.KRON, *args, **kwargs, -): +) -> ParametricLaplace: """Simplified Laplace access using strings instead of different classes. Parameters ---------- model : torch.nn.Module - likelihood : {'classification', 'regression'} - subset_of_weights : {'last_layer', 'subnetwork', 'all'}, default='last_layer' + likelihood : Likelihood or str in {'classification', 'regression'} + subset_of_weights : SubsetofWeights or {'last_layer', 'subnetwork', 'all'}, default=SubsetOfWeights.LAST_LAYER subset of weights to consider for inference - hessian_structure : {'diag', 'kron', 'full', 'lowrank'}, default='kron' + hessian_structure : HessianStructure or str in {'diag', 'kron', 'full', 'lowrank'}, default=HessianStructure.KRON structure of the Hessian approximation Returns @@ -40,7 +45,7 @@ def Laplace( return laplace_class(model, likelihood, *args, **kwargs) -def _all_subclasses(cls): +def _all_subclasses(cls) -> set: return set(cls.__subclasses__()).union( [s for c in cls.__subclasses__() for s in _all_subclasses(c)] ) diff --git a/laplace/marglik_training.py b/laplace/marglik_training.py index b57ac520..5a7b3d4c 100644 --- a/laplace/marglik_training.py +++ b/laplace/marglik_training.py @@ -1,41 +1,47 @@ from copy import deepcopy import numpy as np import torch -from torch.optim import Adam +from torch.optim import Adam, Optimizer +from torch.optim.lr_scheduler import LRScheduler from torch.nn import CrossEntropyLoss, MSELoss from torch.nn.utils import parameters_to_vector +from torch.utils.data import DataLoader import warnings import logging from collections import UserDict import tqdm -from laplace import Laplace +from laplace import Laplace, Likelihood, HessianStructure, PriorStructure +from laplace.baselaplace import SubsetOfWeights from laplace.curvature import AsdlGGN +from laplace.curvature.curvature import CurvatureInterface from laplace.utils import expand_prior_precision, fix_prior_prec_structure +from typing import Type + def marglik_training( - model, - train_loader, - likelihood='classification', - hessian_structure='kron', - backend=AsdlGGN, - optimizer_cls=Adam, - optimizer_kwargs=None, - scheduler_cls=None, - scheduler_kwargs=None, - n_epochs=300, - lr_hyp=1e-1, - prior_structure='layerwise', - n_epochs_burnin=0, - n_hypersteps=10, - marglik_frequency=1, - prior_prec_init=1.0, - sigma_noise_init=1.0, - temperature=1.0, - fix_sigma_noise=False, - progress_bar=False, - enable_backprop=False, + model: torch.nn.Module, + train_loader: DataLoader, + likelihood: Likelihood | str = Likelihood.CLASSIFICATION, + hessian_structure: HessianStructure | str = HessianStructure.KRON, + backend: Type[CurvatureInterface] = AsdlGGN, + optimizer_cls: Type[Optimizer] = Adam, + optimizer_kwargs: dict | None = None, + scheduler_cls: Type[LRScheduler] | None = None, + scheduler_kwargs: dict | None = None, + n_epochs: int = 300, + lr_hyp: float = 1e-1, + prior_structure: PriorStructure | str = PriorStructure.LAYERWISE, + n_epochs_burnin: int = 0, + n_hypersteps: int = 10, + marglik_frequency: int = 1, + prior_prec_init: float = 1.0, + sigma_noise_init: float = 1.0, + temperature: float = 1.0, + fix_sigma_noise: bool = False, + progress_bar: bool = False, + enable_backprop: bool = False, ): """Marginal-likelihood based training (Algorithm 1 in [1]). Optimize model parameters and hyperparameters jointly. @@ -48,7 +54,7 @@ def marglik_training( The settings of standard training can be controlled by passing `train_loader`, `optimizer_cls`, `optimizer_kwargs`, `scheduler_cls`, `scheduler_kwargs`, and `n_epochs`. The `model` should return logits, i.e., no softmax should be applied. - With `likelihood='classification'` or `'regression'`, one can choose between + With `likelihood=Likelihood.CLASSIFICATION` or `Likelihood.REGRESSION`, one can choose between categorical likelihood (CrossEntropyLoss) and Gaussian likelihood (MSELoss). As in [1], we optimize prior precision and, for regression, observation noise @@ -73,8 +79,8 @@ def marglik_training( torch neural network model (needs to comply with Backend choice) train_loader : DataLoader pytorch dataloader that implements `len(train_loader.dataset)` to obtain number of data points - likelihood : str, default='classification' - 'classification' or 'regression' + likelihood : str, default=Likelihood.CLASSIFICATION + Likelihood.CLASSIFICATION or Likelihood.REGRESSION hessian_structure : {'diag', 'kron', 'full'}, default='kron' structure of the Hessian approximation backend : Backend, default=AsdlGGN @@ -127,7 +133,7 @@ def marglik_training( losses : list list of losses (log joints) obtained during training (to monitor convergence) """ - if 'weight_decay' in optimizer_kwargs: + if optimizer_kwargs is not None and 'weight_decay' in optimizer_kwargs: warnings.warn('Weight decay is handled and optimized. Will be set to 0.') optimizer_kwargs['weight_decay'] = 0.0 @@ -149,10 +155,10 @@ def marglik_training( hyperparameters.append(log_prior_prec) # set up loss (and observation noise hyperparam) - if likelihood == 'classification': + if likelihood == Likelihood.CLASSIFICATION: criterion = CrossEntropyLoss(reduction='mean') sigma_noise = 1.0 - elif likelihood == 'regression': + elif likelihood == Likelihood.REGRESSION: criterion = MSELoss(reduction='mean') log_sigma_noise_init = np.log(sigma_noise_init) log_sigma_noise = log_sigma_noise_init * torch.ones(1, device=device) @@ -201,7 +207,7 @@ def marglik_training( X, y = data X, y = X.to(device, non_blocking=True), y.to(device, non_blocking=True) optimizer.zero_grad() - if likelihood == 'regression': + if likelihood == Likelihood.REGRESSION: sigma_noise = ( torch.exp(log_sigma_noise).detach() if not fix_sigma_noise @@ -220,7 +226,7 @@ def marglik_training( loss.backward() optimizer.step() epoch_loss += loss.cpu().item() * len(y) - if likelihood == 'regression': + if likelihood == Likelihood.REGRESSION: epoch_perf += (f.detach() - y).square().sum() else: epoch_perf += torch.sum(torch.argmax(f.detach(), dim=-1) == y).item() @@ -241,7 +247,7 @@ def marglik_training( # optimizer hyperparameters by differentiating marglik # 1. fit laplace approximation - if likelihood == 'classification': + if likelihood == Likelihood.CLASSIFICATION: sigma_noise = 1 else: sigma_noise = ( @@ -263,7 +269,7 @@ def marglik_training( # 2. differentiate wrt. hyperparameters for n_hypersteps for _ in range(n_hypersteps): hyper_optimizer.zero_grad() - if likelihood == 'classification' or fix_sigma_noise: + if likelihood == Likelihood.CLASSIFICATION or fix_sigma_noise: sigma_noise = None else: sigma_noise = torch.exp(log_sigma_noise) @@ -277,7 +283,7 @@ def marglik_training( if margliks[-1] < best_marglik: best_model_dict = deepcopy(model.state_dict()) best_precision = deepcopy(prior_prec.detach()) - if likelihood == 'classification': + if likelihood == Likelihood.CLASSIFICATION: best_sigma = 1 else: best_sigma = ( @@ -309,7 +315,7 @@ def marglik_training( prior_precision=prior_prec, temperature=temperature, backend=backend, - subset_of_weights='all', + subset_of_weights=SubsetOfWeights.ALL, enable_backprop=enable_backprop, ) lap.fit(train_loader) diff --git a/tests/test_baselaplace.py b/tests/test_baselaplace.py index 862b3d67..0d9e37ac 100644 --- a/tests/test_baselaplace.py +++ b/tests/test_baselaplace.py @@ -10,11 +10,10 @@ from torch.nn.utils import parameters_to_vector from torch.utils.data import DataLoader, TensorDataset from torch.distributions import Normal, Categorical -from laplace.curvature.asdfghjkl import AsdfghjklGGN, AsdfghjklEF from laplace.curvature.curvlinops import CurvlinopsEF, CurvlinopsGGN from torchvision.models import wide_resnet50_2 -from laplace.laplace import FullLaplace, KronLaplace, DiagLaplace, LowRankLaplace +from laplace import FullLaplace, KronLaplace, DiagLaplace, LowRankLaplace from laplace.utils import KronDecomposed from laplace.curvature import AsdlGGN, BackPackGGN, AsdlEF from tests.utils import ListDataset, dict_data_collator, jacobians_naive diff --git a/tests/test_serialization.py b/tests/test_serialization.py index 2a6e022b..1ad007d7 100644 --- a/tests/test_serialization.py +++ b/tests/test_serialization.py @@ -8,7 +8,7 @@ from collections import OrderedDict from laplace import Laplace -from laplace.laplace import ( +from laplace import ( FullLaplace, KronLaplace, DiagLaplace, diff --git a/tests/test_subset_params.py b/tests/test_subset_params.py index 24e4a78a..92f3190f 100644 --- a/tests/test_subset_params.py +++ b/tests/test_subset_params.py @@ -11,10 +11,8 @@ from torch.distributions import Normal, Categorical from torchvision.models import wide_resnet50_2 -from laplace.laplace import FullLaplace, KronLaplace, DiagLaplace, LowRankLaplace -from laplace.utils import KronDecomposed +from laplace import FullLaplace, KronLaplace, DiagLaplace from laplace.curvature import AsdlGGN, AsdlHessian, AsdlEF, BackPackEF, BackPackGGN -from tests.utils import jacobians_naive torch.manual_seed(240) From 8f2503c04dc96c99cc7bcbf2681e5045621f5e40 Mon Sep 17 00:00:00 2001 From: Agustinus Kristiadi Date: Sat, 27 Apr 2024 18:41:36 -0400 Subject: [PATCH 5/9] Typehinting `SubnetLaplace` & `LLLaplace`. Finer-grained typehint on `MutableMapping` --- laplace/baselaplace.py | 45 ++++++---- laplace/lllaplace.py | 160 +++++++++++++++++++++++------------- laplace/marglik_training.py | 12 ++- laplace/subnetlaplace.py | 63 ++++++++------ 4 files changed, 182 insertions(+), 98 deletions(-) diff --git a/laplace/baselaplace.py b/laplace/baselaplace.py index 18b9c8fe..6608d34a 100644 --- a/laplace/baselaplace.py +++ b/laplace/baselaplace.py @@ -2,7 +2,7 @@ from enum import Enum from math import sqrt, pi, log -from typing import Callable, List, Tuple, Type +from typing import Callable, List, Tuple, Type, Any import numpy as np import torch from torch import nn @@ -202,7 +202,10 @@ def backend(self) -> CurvatureInterface: return self._backend def _curv_closure( - self, X: torch.Tensor | MutableMapping, y: torch.Tensor, N: int + self, + X: torch.Tensor | MutableMapping[str, torch.Tensor | Any], + y: torch.Tensor, + N: int, ) -> Tuple[torch.Tensor, torch.Tensor]: raise NotImplementedError @@ -242,7 +245,7 @@ def log_likelihood(self) -> torch.Tensor: def __call__( self, - x: torch.Tensor | MutableMapping, + x: torch.Tensor | MutableMapping[str, torch.Tensor | Any], pred_type: PredType | str, link_approx: LinkApprox | str, n_samples: int, @@ -647,9 +650,9 @@ def fit( if not self.enable_backprop: self.mean = self.mean.detach() - data: Tuple[torch.Tensor, torch.Tensor] | MutableMapping = next( - iter(train_loader) - ) + data: ( + Tuple[torch.Tensor, torch.Tensor] | MutableMapping[str, torch.Tensor | Any] + ) = next(iter(train_loader)) with torch.no_grad(): if isinstance(data, MutableMapping): # To support Huggingface dataset @@ -809,7 +812,7 @@ def log_marginal_likelihood( def __call__( self, - x: torch.Tensor | MutableMapping, + x: torch.Tensor | MutableMapping[str, torch.Tensor | Any], pred_type: PredType | str = PredType.GLM, joint: bool = False, link_approx: LinkApprox | str = LinkApprox.PROBIT, @@ -945,7 +948,7 @@ def __call__( def predictive_samples( self, - x: torch.Tensor | MutableMapping, + x: torch.Tensor | MutableMapping[str, torch.Tensor | Any], pred_type: PredType | str = PredType.GLM, n_samples: int = 100, diagonal_output: bool = False, @@ -1003,7 +1006,9 @@ def predictive_samples( @torch.enable_grad() def _glm_predictive_distribution( - self, X: torch.Tensor | MutableMapping, joint: bool = False + self, + X: torch.Tensor | MutableMapping[str, torch.Tensor | Any], + joint: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor]: if 'backpack' in self._backend_cls.__name__.lower(): # BackPACK supports backprop through Jacobians, but it interferes with functorch @@ -1028,7 +1033,7 @@ def _glm_predictive_distribution( def _nn_predictive_samples( self, - X: torch.Tensor | MutableMapping, + X: torch.Tensor | MutableMapping[str, torch.Tensor | Any], n_samples: int = 100, generator: torch.Generator | None = None, **model_kwargs, @@ -1050,7 +1055,10 @@ def _nn_predictive_samples( return fs def _nn_predictive_classification( - self, X: torch.Tensor | MutableMapping, n_samples: int = 100, **model_kwargs + self, + X: torch.Tensor | MutableMapping[str, torch.Tensor | Any], + n_samples: int = 100, + **model_kwargs, ) -> torch.Tensor: py = 0.0 for sample in self.sample(n_samples): @@ -1281,7 +1289,10 @@ def _init_H(self) -> None: ) def _curv_closure( - self, X: torch.Tensor | MutableMapping, y: torch.Tensor, N: int + self, + X: torch.Tensor | MutableMapping[str, torch.Tensor | Any], + y: torch.Tensor, + N: int, ) -> Tuple[torch.Tensor, torch.Tensor]: return self.backend.full(X, y, N=N) @@ -1413,7 +1424,10 @@ def _init_H(self) -> None: ) def _curv_closure( - self, X: torch.Tensor | MutableMapping, y: torch.Tensor, N: int + self, + X: torch.Tensor | MutableMapping[str, torch.Tensor | Any], + y: torch.Tensor, + N: int, ) -> Tuple[torch.Tensor, torch.Tensor]: return self.backend.kron(X, y, N=N, **self._asdl_fisher_kwargs) @@ -1681,7 +1695,10 @@ def _init_H(self) -> None: self.H: torch.Tensor = torch.zeros(self.n_params, device=self._device) def _curv_closure( - self, X: torch.Tensor | MutableMapping, y: torch.Tensor, N: int + self, + X: torch.Tensor | MutableMapping[str, torch.Tensor | Any], + y: torch.Tensor, + N: int, ) -> Tuple[torch.Tensor, torch.Tensor]: return self.backend.diag(X, y, N=N, **self._asdl_fisher_kwargs) diff --git a/laplace/lllaplace.py b/laplace/lllaplace.py index 54842c13..059e0e32 100644 --- a/laplace/lllaplace.py +++ b/laplace/lllaplace.py @@ -1,12 +1,21 @@ from copy import deepcopy import torch +from torch import nn from torch.nn.utils import parameters_to_vector, vector_to_parameters - -from laplace.baselaplace import ParametricLaplace, FullLaplace, KronLaplace, DiagLaplace +from torch.utils.data import DataLoader + +from laplace.baselaplace import ( + Likelihood, + ParametricLaplace, + FullLaplace, + KronLaplace, + DiagLaplace, +) +from laplace.curvature.curvature import CurvatureInterface from laplace.utils import FeatureExtractor, Kron from collections.abc import MutableMapping -from typing import Union +from typing import Tuple, Type, Any __all__ = ['LLLaplace', 'FullLLLaplace', 'KronLLLaplace', 'DiagLLLaplace'] @@ -37,7 +46,7 @@ class LLLaplace(ParametricLaplace): Parameters ---------- model : torch.nn.Module or `laplace.utils.feature_extractor.FeatureExtractor` - likelihood : {'classification', 'regression'} + likelihood : Likelihood or {'classification', 'regression'} determines the log likelihood Hessian approximation sigma_noise : torch.Tensor or float, default=1 observation noise for the regression setting; must be 1 for classification @@ -63,20 +72,21 @@ class LLLaplace(ParametricLaplace): def __init__( self, - model, - likelihood, - sigma_noise=1.0, - prior_precision=1.0, - prior_mean=0.0, - temperature=1.0, - enable_backprop=False, - backend=None, - last_layer_name=None, - backend_kwargs=None, - asdl_fisher_kwargs=None, + model: nn.Module, + likelihood: Likelihood | str, + sigma_noise: float | torch.Tensor = 1.0, + prior_precision: float | torch.Tensor = 1.0, + prior_mean: float | torch.Tensor = 0.0, + temperature: float = 1.0, + enable_backprop: bool = False, + backend: Type[CurvatureInterface] | None = None, + last_layer_name: str | None = None, + backend_kwargs: dict | None = None, + asdl_fisher_kwargs: dict | None = None, ): if asdl_fisher_kwargs is not None: raise ValueError('Last-layer Laplace does not support asdl_fisher_kwargs.') + self.H = None super().__init__( model, @@ -94,26 +104,33 @@ def __init__( last_layer_name=last_layer_name, enable_backprop=enable_backprop, ) + if self.model.last_layer is None: - self.mean = None - self.n_params = None - self.n_layers = None + self.mean: torch.Tensor | None = None + self.n_params: int | None = None + self.n_layers: int | None = None # ignore checks of prior mean setter temporarily, check on .fit() - self._prior_precision = prior_precision - self._prior_mean = prior_mean + self._prior_precision: float | torch.Tensor = prior_precision + self._prior_mean: float | torch.Tensor = prior_mean else: - self.n_params = len( + self.n_params: int = len( parameters_to_vector(self.model.last_layer.parameters()) ) - self.n_layers = len(list(self.model.last_layer.parameters())) - self.prior_precision = prior_precision - self.prior_mean = prior_mean - self.mean = self.prior_mean + self.n_layers: int | None = len(list(self.model.last_layer.parameters())) + self.prior_precision: float | torch.Tensor = prior_precision + self.prior_mean: float | torch.Tensor = prior_mean + self.mean: float | torch.Tensor = self.prior_mean self._init_H() + self._backend_kwargs['last_layer'] = True - self._last_layer_name = last_layer_name + self._last_layer_name: str | None = last_layer_name - def fit(self, train_loader, override=True): + def fit( + self, + train_loader: DataLoader, + override: bool = True, + progress_bar: bool = False, + ) -> None: """Fit the local Laplace approximation at the parameters of the model. Parameters @@ -124,6 +141,7 @@ def fit(self, train_loader, override=True): override : bool, default=True whether to initialize H, loss, and n_data again; setting to False is useful for online learning settings to accumulate a sequential posterior approximation. + progress_bar: bool, default=False """ if not override: raise ValueError( @@ -133,23 +151,31 @@ def fit(self, train_loader, override=True): self.model.eval() if self.model.last_layer is None: - self.data = next(iter(train_loader)) + self.data: Tuple[torch.Tensor, torch.Tensor] | MutableMapping = next( + iter(train_loader) + ) self._find_last_layer(self.data) - params = parameters_to_vector(self.model.last_layer.parameters()).detach() - self.n_params = len(params) - self.n_layers = len(list(self.model.last_layer.parameters())) + params: torch.Tensor = parameters_to_vector( + self.model.last_layer.parameters() + ).detach() + self.n_params: int = len(params) + self.n_layers: int = len(list(self.model.last_layer.parameters())) # here, check the already set prior precision again - self.prior_precision = self._prior_precision - self.prior_mean = self._prior_mean + self.prior_precision: float | torch.Tensor = self._prior_precision + self.prior_mean: float | torch.Tensor = self._prior_mean self._init_H() super().fit(train_loader, override=override) - self.mean = parameters_to_vector(self.model.last_layer.parameters()) + self.mean: torch.Tensor = parameters_to_vector( + self.model.last_layer.parameters() + ) if not self.enable_backprop: self.mean = self.mean.detach() - def _glm_predictive_distribution(self, X, joint=False): + def _glm_predictive_distribution( + self, X: torch.Tensor | MutableMapping, joint: bool = False + ) -> Tuple[torch.Tensor, torch.Tensor]: Js, f_mu = self.backend.last_layer_jacobians(X) if joint: @@ -164,32 +190,49 @@ def _glm_predictive_distribution(self, X, joint=False): else (f_mu, f_var) ) - def _nn_predictive_samples(self, X, n_samples=100, generator=None, **model_kwargs): + def _nn_predictive_samples( + self, + X: torch.Tensor | MutableMapping[str, torch.Tensor | Any], + n_samples: int = 100, + generator: torch.Generator | None = None, + **model_kwargs, + ) -> torch.Tensor: fs = list() + for sample in self.sample(n_samples, generator): vector_to_parameters(sample, self.model.last_layer.parameters()) f = self.model(X.to(self._device), **model_kwargs) fs.append(f.detach() if not self.enable_backprop else f) + vector_to_parameters(self.mean, self.model.last_layer.parameters()) fs = torch.stack(fs) + if self.likelihood == 'classification': fs = torch.softmax(fs, dim=-1) + return fs def _nn_predictive_classification( - self, X, n_samples=100, generator=None, **model_kwargs - ): - py = 0 + self, + X: torch.Tensor | MutableMapping, + n_samples: int = 100, + generator: torch.Generator | None = None, + **model_kwargs, + ) -> torch.Tensor: + py = 0.0 + for sample in self.sample(n_samples, generator): vector_to_parameters(sample, self.model.last_layer.parameters()) # TODO: Implement with a single forward pass until last layer. logits = self.model(X.to(self._device), **model_kwargs).detach() py += torch.softmax(logits, dim=-1) / n_samples + vector_to_parameters(self.mean, self.model.last_layer.parameters()) + return py @property - def prior_precision_diag(self): + def prior_precision_diag(self) -> torch.Tensor: """Obtain the diagonal prior precision \\(p_0\\) constructed from either a scalar or diagonal prior precision. @@ -197,12 +240,12 @@ def prior_precision_diag(self): ------- prior_precision_diag : torch.Tensor """ - if len(self.prior_precision) == 1: # scalar + if ( + isinstance(self.prior_precision, float) or len(self.prior_precision) == 1 + ): # scalar return self.prior_precision * torch.ones_like(self.mean) - elif len(self.prior_precision) == self.n_params: # diagonal return self.prior_precision - else: raise ValueError('Mismatch of prior and model. Diagonal or scalar prior.') @@ -212,7 +255,7 @@ def state_dict(self) -> dict: state_dict['_last_layer_name'] = self._last_layer_name return state_dict - def load_state_dict(self, state_dict: dict): + def load_state_dict(self, state_dict: dict) -> None: if self._last_layer_name != state_dict['_last_layer_name']: raise ValueError('Different `last_layer_name` detected!') @@ -227,12 +270,15 @@ def load_state_dict(self, state_dict: dict): self.n_layers = len(list(self.model.last_layer.parameters())) @torch.no_grad() - def _find_last_layer(self, data: Union[torch.Tensor, MutableMapping]) -> None: + def _find_last_layer( + self, data: torch.Tensor | MutableMapping[str, torch.Tensor | Any] + ) -> None: # To support Huggingface dataset if isinstance(data, MutableMapping): self.model.find_last_layer(data) else: X = data[0] + try: self.model.find_last_layer(X[:1].to(self._device)) except (TypeError, AttributeError): @@ -269,17 +315,18 @@ class KronLLLaplace(LLLaplace, KronLaplace): def __init__( self, - model, - likelihood, - sigma_noise=1.0, - prior_precision=1.0, - prior_mean=0.0, - temperature=1.0, - enable_backprop=False, - backend=None, - last_layer_name=None, - damping=False, - **backend_kwargs, + model: nn.Module, + likelihood: Likelihood | str, + sigma_noise: float | torch.Tensor = 1.0, + prior_precision: float | torch.Tensor = 1.0, + prior_mean: float | torch.Tensor = 0.0, + temperature: float = 1.0, + enable_backprop: bool = False, + backend: Type[CurvatureInterface] | None = None, + last_layer_name: str | None = None, + damping: bool = False, + backend_kwargs: dict | None = None, + asdl_fisher_kwargs: dict | None = None, ): self.damping = damping super().__init__( @@ -293,6 +340,7 @@ def __init__( backend, last_layer_name, backend_kwargs, + asdl_fisher_kwargs, ) def _init_H(self): diff --git a/laplace/marglik_training.py b/laplace/marglik_training.py index 5a7b3d4c..ac8778c1 100644 --- a/laplace/marglik_training.py +++ b/laplace/marglik_training.py @@ -171,6 +171,7 @@ def marglik_training( optimizer = optimizer_cls(model.parameters(), **optimizer_kwargs) # set up learning rate scheduler + scheduler = None if scheduler_cls is not None: if scheduler_kwargs is None: scheduler_kwargs = dict() @@ -194,6 +195,7 @@ def marglik_training( desc='[Training]', colour='blue', ) + for epoch in pbar: epoch_loss = 0 epoch_perf = 0 @@ -206,7 +208,9 @@ def marglik_training( else: X, y = data X, y = X.to(device, non_blocking=True), y.to(device, non_blocking=True) + optimizer.zero_grad() + if likelihood == Likelihood.REGRESSION: sigma_noise = ( torch.exp(log_sigma_noise).detach() @@ -216,21 +220,25 @@ def marglik_training( crit_factor = temperature / (2 * sigma_noise**2) else: crit_factor = temperature + prior_prec = torch.exp(log_prior_prec).detach() theta = parameters_to_vector( [p for p in model.parameters() if p.requires_grad] ) delta = expand_prior_precision(prior_prec, model) + f = model(X) loss = criterion(f, y) + (0.5 * (delta * theta) @ theta) / N / crit_factor loss.backward() optimizer.step() epoch_loss += loss.cpu().item() * len(y) + if likelihood == Likelihood.REGRESSION: epoch_perf += (f.detach() - y).square().sum() else: epoch_perf += torch.sum(torch.argmax(f.detach(), dim=-1) == y).item() - if scheduler_cls is not None: + + if scheduler is not None: scheduler.step() losses.append(epoch_loss / N) @@ -303,10 +311,12 @@ def marglik_training( ) logging.info('MARGLIK: finished training. Recover best model and fit Laplace.') + if best_model_dict is not None: model.load_state_dict(best_model_dict) sigma_noise = best_sigma prior_prec = best_precision + lap = Laplace( model, likelihood, diff --git a/laplace/subnetlaplace.py b/laplace/subnetlaplace.py index f92a571a..f6e80cf5 100644 --- a/laplace/subnetlaplace.py +++ b/laplace/subnetlaplace.py @@ -1,8 +1,11 @@ +from typing import Type import torch +from torch import nn from torch.distributions import MultivariateNormal -from laplace.baselaplace import ParametricLaplace, FullLaplace, DiagLaplace +from laplace.baselaplace import Likelihood, ParametricLaplace, FullLaplace, DiagLaplace from laplace.curvature import GGNInterface, EFInterface +from laplace.curvature.curvature import CurvatureInterface __all__ = ['SubnetLaplace', 'FullSubnetLaplace', 'DiagSubnetLaplace'] @@ -69,19 +72,20 @@ class SubnetLaplace(ParametricLaplace): def __init__( self, - model, - likelihood, - subnetwork_indices, - sigma_noise=1.0, - prior_precision=1.0, - prior_mean=0.0, - temperature=1.0, - backend=None, - backend_kwargs=None, - asdl_fisher_kwargs=None, - ): + model: nn.Module, + likelihood: Likelihood | str, + subnetwork_indices: torch.LongTensor, + sigma_noise: float | torch.Tensor = 1.0, + prior_precision: float | torch.Tensor = 1.0, + prior_mean: float | torch.Tensor = 0.0, + temperature: float = 1.0, + backend: Type[CurvatureInterface] | None = None, + backend_kwargs: dict | None = None, + asdl_fisher_kwargs: dict | None = None, + ) -> None: if asdl_fisher_kwargs is not None: raise ValueError('Subnetwork Laplace does not support asdl_fisher_kwargs.') + self.H = None super().__init__( model, @@ -93,28 +97,29 @@ def __init__( backend=backend, backend_kwargs=backend_kwargs, ) + if backend is not None: if not isinstance(backend, GGNInterface) and not isinstance( backend, EFInterface ): raise ValueError('SubnetLaplace can only be used with GGN and EF.') + # check validity of subnetwork indices and pass them to backend self._check_subnetwork_indices(subnetwork_indices) self.backend.subnetwork_indices = subnetwork_indices self.n_params_subnet = len(subnetwork_indices) self._init_H() - def _check_subnetwork_indices(self, subnetwork_indices): + def _check_subnetwork_indices( + self, subnetwork_indices: torch.LongTensor | None + ) -> None: """Check that subnetwork indices are valid indices of the vectorized model parameters (i.e. `torch.nn.utils.parameters_to_vector(model.parameters())`). """ if subnetwork_indices is None: raise ValueError('Subnetwork indices cannot be None.') elif not ( - ( - isinstance(subnetwork_indices, torch.LongTensor) - or isinstance(subnetwork_indices, torch.cuda.LongTensor) - ) + isinstance(subnetwork_indices, torch.LongTensor) and subnetwork_indices.numel() > 0 and len(subnetwork_indices.shape) == 1 ): @@ -132,7 +137,7 @@ def _check_subnetwork_indices(self, subnetwork_indices): raise ValueError('Subnetwork indices must not contain duplicate entries.') @property - def prior_precision_diag(self): + def prior_precision_diag(self) -> torch.Tensor: """Obtain the diagonal prior precision \\(p_0\\) constructed from either a scalar or diagonal prior precision. @@ -140,27 +145,26 @@ def prior_precision_diag(self): ------- prior_precision_diag : torch.Tensor """ - if len(self.prior_precision) == 1: # scalar + # scalar + if isinstance(self.prior_precision, float) or len(self.prior_precision) == 1: return self.prior_precision * torch.ones( self.n_params_subnet, device=self._device ) - elif len(self.prior_precision) == self.n_params_subnet: # diagonal return self.prior_precision - else: raise ValueError('Mismatch of prior and model. Diagonal or scalar prior.') @property - def mean_subnet(self): + def mean_subnet(self) -> torch.Tensor: return self.mean[self.backend.subnetwork_indices] @property - def scatter(self): + def scatter(self) -> torch.Tensor: delta = self.mean_subnet - self.prior_mean return (delta * self.prior_precision_diag) @ delta - def assemble_full_samples(self, subnet_samples): + def assemble_full_samples(self, subnet_samples) -> torch.Tensor: full_samples = self.mean.repeat(subnet_samples.shape[0], 1) full_samples[:, self.backend.subnetwork_indices] = subnet_samples return full_samples @@ -182,7 +186,9 @@ def _init_H(self): self.n_params_subnet, self.n_params_subnet, device=self._device ) - def sample(self, n_samples=100, generator=None): + def sample( + self, n_samples: int = 100, generator: torch.Generator | None = None + ) -> torch.Tensor: # sample only subnetwork parameters and set all other parameters to their MAP estimates dist = MultivariateNormal(loc=self.mean_subnet, scale_tril=self.posterior_scale) subnet_samples = dist.sample((n_samples,)) @@ -202,16 +208,19 @@ class DiagSubnetLaplace(SubnetLaplace, DiagLaplace): def _init_H(self): self.H = torch.zeros(self.n_params_subnet, device=self._device) - def _check_jacobians(self, Js): + def _check_jacobians(self, Js: torch.Tensor) -> None: if not isinstance(Js, torch.Tensor): raise ValueError('Jacobians have to be torch.Tensor.') if not Js.device == self._device: raise ValueError('Jacobians need to be on the same device as Laplace.') + m, k, p = Js.size() if p != self.n_params_subnet: raise ValueError('Invalid Jacobians shape for Laplace posterior approx.') - def sample(self, n_samples=100, generator=None): + def sample( + self, n_samples: int = 100, generator: torch.Generator | None = None + ) -> torch.Tensor: # sample only subnetwork parameters and set all other parameters to their MAP estimates samples = torch.randn( n_samples, self.n_params_subnet, device=self._device, generator=generator From 4ffab984320c2d2f8e0cfdbd344d202f36fbcaaa Mon Sep 17 00:00:00 2001 From: Agustinus Kristiadi Date: Sat, 27 Apr 2024 18:57:03 -0400 Subject: [PATCH 6/9] [WIP] Typehinting `CurvatureInterface` --- laplace/curvature/curvature.py | 49 +++++++++++++++++++++++----------- 1 file changed, 34 insertions(+), 15 deletions(-) diff --git a/laplace/curvature/curvature.py b/laplace/curvature/curvature.py index 0dde26d3..77868cc7 100644 --- a/laplace/curvature/curvature.py +++ b/laplace/curvature/curvature.py @@ -1,6 +1,9 @@ +from typing import List, Tuple import torch +from torch import nn from torch.nn import MSELoss, CrossEntropyLoss -from torch.nn.utils import parameters_to_vector, vector_to_parameters + +from laplace.baselaplace import Likelihood class CurvatureInterface: @@ -17,7 +20,7 @@ class CurvatureInterface: likelihood : {'classification', 'regression'} last_layer : bool, default=False only consider curvature of last layer - subnetwork_indices : torch.Tensor, default=None + subnetwork_indices : torch.LongTensor, default=None indices of the vectorized model parameters that define the subnetwork to apply the Laplace approximation over @@ -29,29 +32,43 @@ class CurvatureInterface: For example, \\(\\frac{1}{2}\\) to get to \\(\\mathcal{N}(f, 1)\\) from MSELoss. """ - def __init__(self, model, likelihood, last_layer=False, subnetwork_indices=None): - assert likelihood in ['regression', 'classification'] - self.likelihood = likelihood - self.model = model - self.last_layer = last_layer - self.subnetwork_indices = subnetwork_indices + def __init__( + self, + model: nn.Module, + likelihood: Likelihood | str, + last_layer: bool = False, + subnetwork_indices: torch.LongTensor | None = None, + ): + assert likelihood in [Likelihood.REGRESSION, Likelihood.CLASSIFICATION] + self.likelihood: Likelihood | str = likelihood + self.model: nn.Module = model + self.last_layer: bool = last_layer + self.subnetwork_indices: torch.LongTensor | None = subnetwork_indices + if likelihood == 'regression': self.lossfunc = MSELoss(reduction='sum') self.factor = 0.5 else: self.lossfunc = CrossEntropyLoss(reduction='sum') self.factor = 1.0 - self.params = [p for p in self._model.parameters() if p.requires_grad] - self.params_dict = { + + self.params: List[torch.Tensor] = [ + p for p in self._model.parameters() if p.requires_grad + ] + self.params_dict: dict[str, torch.Tensor] = { k: v for k, v in self._model.named_parameters() if v.requires_grad } - self.buffers_dict = {k: v for k, v in self.model.named_buffers()} + self.buffers_dict: dict[str, torch.Tensor] = { + k: v for k, v in self.model.named_buffers() + } @property - def _model(self): + def _model(self) -> nn.Module: return self.model.last_layer if self.last_layer else self.model - def jacobians(self, x, enable_backprop=False): + def jacobians( + self, x: torch.Tensor, enable_backprop: bool = False + ) -> Tuple[torch.Tensor, torch.Tensor]: """Compute Jacobians \\(\\nabla_{\\theta} f(x;\\theta)\\) at current parameter \\(\\theta\\), via torch.func. @@ -90,7 +107,9 @@ def model_fn_params_only(params_dict, buffers_dict): return (Js, f) if enable_backprop else (Js.detach(), f.detach()) - def functorch_jacobians(self, x, enable_backprop=False): + def functorch_jacobians( + self, x: torch.Tensor, enable_backprop: bool = False + ) -> Tuple[torch.Tensor, torch.Tensor]: """Compute Jacobians \\(\\nabla_\\theta f(x;\\theta)\\) at current parameter \\(\\theta\\). Parameters @@ -130,7 +149,7 @@ def model_fn_params_only(params_dict): return (Js, f) if enable_backprop else (Js.detach(), f.detach()) - def last_layer_jacobians(self, x, enable_backprop=False): + def last_layer_jacobians(self, x: torch.Tensor, enable_backprop=False): """Compute Jacobians \\(\\nabla_{\\theta_\\textrm{last}} f(x;\\theta_\\textrm{last})\\) only at current last-layer parameter \\(\\theta_{\\textrm{last}}\\). From 662116ac016d402d8bf68e70ce1b44e73fdddbc3 Mon Sep 17 00:00:00 2001 From: Agustinus Kristiadi Date: Sun, 28 Apr 2024 13:36:53 -0400 Subject: [PATCH 7/9] Typehinting curvatures --- laplace/__init__.py | 11 ++- laplace/baselaplace.py | 170 +++++++++++++------------------- laplace/curvature/asdfghjkl.py | 104 ++++++++++++------- laplace/curvature/asdl.py | 124 ++++++++++++++++------- laplace/curvature/backpack.py | 83 +++++++++++----- laplace/curvature/curvature.py | 116 +++++++++++++++------- laplace/curvature/curvlinops.py | 87 +++++++++------- laplace/laplace.py | 4 +- laplace/lllaplace.py | 24 ++--- laplace/marglik_training.py | 12 ++- laplace/subnetlaplace.py | 2 +- laplace/utils/__init__.py | 16 +++ laplace/utils/enums.py | 43 ++++++++ 13 files changed, 513 insertions(+), 283 deletions(-) create mode 100644 laplace/utils/enums.py diff --git a/laplace/__init__.py b/laplace/__init__.py index d8393df0..0459edd9 100644 --- a/laplace/__init__.py +++ b/laplace/__init__.py @@ -12,6 +12,13 @@ KronLaplace, DiagLaplace, LowRankLaplace, +) +from laplace.lllaplace import LLLaplace, FullLLLaplace, KronLLLaplace, DiagLLLaplace +from laplace.subnetlaplace import SubnetLaplace, FullSubnetLaplace, DiagSubnetLaplace +from laplace.laplace import Laplace +from laplace.marglik_training import marglik_training + +from laplace.utils.enums import ( SubsetOfWeights, HessianStructure, Likelihood, @@ -20,10 +27,6 @@ TuningMethod, PriorStructure, ) -from laplace.lllaplace import LLLaplace, FullLLLaplace, KronLLLaplace, DiagLLLaplace -from laplace.subnetlaplace import SubnetLaplace, FullSubnetLaplace, DiagSubnetLaplace -from laplace.laplace import Laplace -from laplace.marglik_training import marglik_training __all__ = [ 'Laplace', # direct access to all Laplace classes via unified interface diff --git a/laplace/baselaplace.py b/laplace/baselaplace.py index 6608d34a..a8f8fd37 100644 --- a/laplace/baselaplace.py +++ b/laplace/baselaplace.py @@ -2,7 +2,7 @@ from enum import Enum from math import sqrt, pi, log -from typing import Callable, List, Tuple, Type, Any +from typing import Callable, Any import numpy as np import torch from torch import nn @@ -11,9 +11,7 @@ import torchmetrics import tqdm from collections.abc import MutableMapping -from laplace.curvature.asdfghjkl import AsdfghjklHessian -from laplace.curvature.curvature import CurvatureInterface -from laplace.curvature.curvlinops import CurvlinopsEF + import warnings from torchmetrics import MeanSquaredError @@ -25,9 +23,17 @@ fix_prior_prec_structure, RunningNLLMetric, ) -from laplace.curvature import CurvlinopsGGN +from laplace.curvature.curvature import CurvatureInterface +from laplace.curvature.asdfghjkl import AsdfghjklHessian +from laplace.curvature.curvlinops import CurvlinopsGGN, CurvlinopsEF from laplace.utils.matrix import KronDecomposed - +from laplace.utils.enums import ( + Likelihood, + PredType, + LinkApprox, + TuningMethod, + PriorStructure, +) __all__ = [ 'BaseLaplace', @@ -36,57 +42,9 @@ 'KronLaplace', 'DiagLaplace', 'LowRankLaplace', - 'Likelihood', - 'PredType', - 'LinkApprox', - 'TuningMethod', - 'PriorStructure', - 'SubsetOfWeights', ] -class SubsetOfWeights(str, Enum): - ALL = 'all' - LAST_LAYER = 'last_layer' - SUBNETWORK = 'subnetwork' - - -class HessianStructure(str, Enum): - FULL = 'full' - KRON = 'kron' - DIAG = 'diag' - LOWRANK = 'lowrank' - - -class Likelihood(str, Enum): - REGRESSION = 'regression' - CLASSIFICATION = 'classification' - REWARD_MODELING = 'reward_modeling' - - -class PredType(str, Enum): - GLM = 'glm' - NN = 'nn' - - -class LinkApprox(str, Enum): - MC = 'mc' - PROBIT = 'probit' - BRIDGE = 'bridge' - BRIDGE_NORM = 'bridge_norm' - - -class TuningMethod(str, Enum): - MARGLIK = 'marglik' - GRIDSEARCH = 'gridsearch' - - -class PriorStructure(str, Enum): - SCALAR = 'scalar' - DIAG = 'diag' - LAYERWISE = 'layerwise' - - class BaseLaplace: """Baseclass for all Laplace approximations in this library. @@ -131,9 +89,9 @@ def __init__( prior_mean: float | torch.Tensor = 0.0, temperature: float = 1.0, enable_backprop: bool = False, - backend: Type[CurvatureInterface] | None = None, - backend_kwargs: dict | None = None, - asdl_fisher_kwargs: dict | None = None, + backend: type[CurvatureInterface] | None = None, + backend_kwargs: dict[str, Any] | None = None, + asdl_fisher_kwargs: dict[str, Any] | None = None, ) -> None: if likelihood not in [lik.value for lik in Likelihood]: raise ValueError(f'Invalid likelihood type {likelihood}') @@ -141,7 +99,7 @@ def __init__( self.model: nn.Module = model # Only do Laplace on params that require grad - self.params: List[torch.Tensor] = [] + self.params: list[torch.Tensor] = [] self.is_subset_params: bool = False for p in model.parameters(): if p.requires_grad: @@ -176,11 +134,11 @@ def __init__( ) self._backend: CurvatureInterface | None = None - self._backend_cls: Type[CurvatureInterface] = backend - self._backend_kwargs: dict = ( + self._backend_cls: type[CurvatureInterface] = backend + self._backend_kwargs: dict[str, Any] = ( dict() if backend_kwargs is None else backend_kwargs ) - self._asdl_fisher_kwargs: dict = ( + self._asdl_fisher_kwargs: dict[str, Any] = ( dict() if asdl_fisher_kwargs is None else asdl_fisher_kwargs ) @@ -189,6 +147,12 @@ def __init__( self.n_outputs: int = 0 self.n_data: int = 0 + # Declare attributes + self._prior_mean: torch.Tensor + self._prior_precision: torch.Tensor + self._sigma_noise: torch.Tensor + self._posterior_scale: torch.Tensor | None + @property def _device(self) -> torch.device: return next(self.model.parameters()).device @@ -206,7 +170,7 @@ def _curv_closure( X: torch.Tensor | MutableMapping[str, torch.Tensor | Any], y: torch.Tensor, N: int, - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor]: raise NotImplementedError def fit(self, train_loader: DataLoader) -> None: @@ -249,7 +213,7 @@ def __call__( pred_type: PredType | str, link_approx: LinkApprox | str, n_samples: int, - ) -> torch.Tensor | Tuple[torch.Tensor, torch.Tensor]: + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: raise NotImplementedError def predictive( @@ -258,7 +222,7 @@ def predictive( pred_type: PredType | str, link_approx: LinkApprox | str, n_samples: int, - ) -> torch.Tensor | Tuple[torch.Tensor, torch.Tensor]: + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: return self(x, pred_type, link_approx, n_samples) def _check_jacobians(self, Js: torch.Tensor) -> None: @@ -351,7 +315,7 @@ def prior_precision(self, prior_precision: float | torch.Tensor): 'Prior precision either scalar or torch.Tensor up to 1-dim.' ) - def optimize_prior_precision_base( + def optimize_prior_precision( self, pred_type: PredType | str, method: TuningMethod | str = TuningMethod.MARGLIK, @@ -360,7 +324,9 @@ def optimize_prior_precision_base( init_prior_prec: float | torch.Tensor = 1.0, prior_structure: PriorStructure | str = PriorStructure.SCALAR, val_loader: DataLoader | None = None, - loss: torchmetrics.Metric | Callable | None = None, + loss: torchmetrics.Metric + | Callable[[torch.Tensor], torch.Tensor | float] + | None = None, log_prior_prec_min: float = -4, log_prior_prec_max: float = 4, grid_size: int = 100, @@ -488,7 +454,7 @@ def optimize_prior_precision_base( def _gridsearch( self, - loss: torchmetrics.Metric | Callable, + loss: torchmetrics.Metric | Callable[[torch.Tensor], torch.Tensor | float], interval: torch.Tensor, val_loader: DataLoader, pred_type: PredType | str, @@ -496,11 +462,11 @@ def _gridsearch( n_samples: int = 100, loss_with_var: bool = False, progress_bar: bool = False, - ): + ) -> torch.Tensor: assert callable(loss) or isinstance(loss, torchmetrics.Metric) - results = list() - prior_precs = list() + results: list[float] = list() + prior_precs: list[torch.Tensor] = list() pbar = tqdm.tqdm(interval, disable=not progress_bar) for prior_prec in pbar: @@ -591,9 +557,9 @@ def __init__( prior_mean: float | torch.Tensor = 0.0, temperature: float = 1.0, enable_backprop: bool = False, - backend: Type[CurvatureInterface] | None = None, - backend_kwargs: dict | None = None, - asdl_fisher_kwargs: dict | None = None, + backend: type[CurvatureInterface] | None = None, + backend_kwargs: dict[str, Any] | None = None, + asdl_fisher_kwargs: dict[str, Any] | None = None, ): super().__init__( model, @@ -651,7 +617,7 @@ def fit( self.mean = self.mean.detach() data: ( - Tuple[torch.Tensor, torch.Tensor] | MutableMapping[str, torch.Tensor | Any] + tuple[torch.Tensor, torch.Tensor] | MutableMapping[str, torch.Tensor | Any] ) = next(iter(train_loader)) with torch.no_grad(): @@ -686,7 +652,7 @@ def fit( X, y = data X, y = X.to(self._device), y.to(self._device) self.model.zero_grad() - loss_batch, H_batch = self._curv_closure(X, y, N) + loss_batch, H_batch = self._curv_closure(X, y, N=N) self.loss += loss_batch self.H += H_batch @@ -804,7 +770,7 @@ def log_marginal_likelihood( # update sigma_noise (useful when iterating on marglik) if sigma_noise is not None: - if self.likelihood != 'regression': + if self.likelihood != Likelihood.REGRESSION: raise ValueError('Can only change sigma_noise for regression.') self.sigma_noise = sigma_noise @@ -819,8 +785,8 @@ def __call__( n_samples: int = 100, diagonal_output: bool = False, generator: torch.Generator | None = None, - **model_kwargs, - ) -> torch.Tensor | Tuple[torch.Tensor, torch.Tensor]: + **model_kwargs: dict[str, Any], + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: """Compute the posterior predictive on input data `x`. Parameters @@ -859,7 +825,7 @@ def __call__( Returns ------- - predictive: torch.Tensor or Tuple[torch.Tensor] + predictive: torch.Tensor or tuple[torch.Tensor] For `likelihood='classification'`, a torch.Tensor is returned with a distribution over classes (similar to a Softmax). For `likelihood='regression'`, a tuple of torch.Tensor is returned @@ -887,7 +853,7 @@ def __call__( # For reward modeling, replace the likelihood to regression and override model state if self.reward_modeling and self.likelihood == Likelihood.CLASSIFICATION: - self.likelihood = 'regression' + self.likelihood = Likelihood.REGRESSION setattr(self.model, 'output_size', 1) if pred_type == PredType.GLM: @@ -1009,7 +975,7 @@ def _glm_predictive_distribution( self, X: torch.Tensor | MutableMapping[str, torch.Tensor | Any], joint: bool = False, - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor]: if 'backpack' in self._backend_cls.__name__.lower(): # BackPACK supports backprop through Jacobians, but it interferes with functorch Js, f_mu = self.backend.jacobians(X, enable_backprop=self.enable_backprop) @@ -1036,7 +1002,7 @@ def _nn_predictive_samples( X: torch.Tensor | MutableMapping[str, torch.Tensor | Any], n_samples: int = 100, generator: torch.Generator | None = None, - **model_kwargs, + **model_kwargs: dict[str, Any], ) -> torch.Tensor: fs = list() for sample in self.sample(n_samples, generator): @@ -1058,7 +1024,7 @@ def _nn_predictive_classification( self, X: torch.Tensor | MutableMapping[str, torch.Tensor | Any], n_samples: int = 100, - **model_kwargs, + **model_kwargs: dict[str, Any], ) -> torch.Tensor: py = 0.0 for sample in self.sample(n_samples): @@ -1144,7 +1110,9 @@ def optimize_prior_precision( init_prior_prec: float | torch.Tensor = 1.0, prior_structure: PriorStructure | str = PriorStructure.SCALAR, val_loader: DataLoader | None = None, - loss: torchmetrics.Metric | Callable | None = None, + loss: torchmetrics.Metric + | Callable[[torch.Tensor], torch.Tensor | float] + | None = None, log_prior_prec_min: float = -4, log_prior_prec_max: float = 4, grid_size: int = 100, @@ -1155,7 +1123,7 @@ def optimize_prior_precision( progress_bar: bool = False, ) -> None: assert pred_type in PredType.__members__.values() - self.optimize_prior_precision_base( + super().optimize_prior_precision( pred_type, method, n_steps, @@ -1184,7 +1152,7 @@ def posterior_precision(self) -> torch.Tensor: """ raise NotImplementedError - def state_dict(self) -> dict: + def state_dict(self) -> dict[str, Any]: self._check_H_init() state_dict = { 'mean': self.mean, @@ -1202,7 +1170,7 @@ def state_dict(self) -> dict: } return state_dict - def load_state_dict(self, state_dict: dict) -> None: + def load_state_dict(self, state_dict: dict[str, Any]) -> None: # Dealbreaker errors if self.__class__.__name__ != state_dict['cls_name']: raise ValueError( @@ -1267,8 +1235,8 @@ def __init__( prior_mean: float | torch.Tensor = 0.0, temperature: float = 1.0, enable_backprop: bool = False, - backend: Type[CurvatureInterface] | None = None, - backend_kwargs: dict | None = None, + backend: type[CurvatureInterface] | None = None, + backend_kwargs: dict[str, Any] | None = None, ): super().__init__( model, @@ -1293,7 +1261,7 @@ def _curv_closure( X: torch.Tensor | MutableMapping[str, torch.Tensor | Any], y: torch.Tensor, N: int, - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor]: return self.backend.full(X, y, N=N) def fit( @@ -1398,10 +1366,10 @@ def __init__( prior_mean: float | torch.Tensor = 0.0, temperature: float = 1.0, enable_backprop: bool = False, - backend: Type[CurvatureInterface] | None = None, + backend: type[CurvatureInterface] | None = None, damping: bool = False, - backend_kwargs: dict | None = None, - asdl_fisher_kwargs: dict | None = None, + backend_kwargs: dict[str, Any] | None = None, + asdl_fisher_kwargs: dict[str, Any] | None = None, ): self.damping: bool = damping self.H_facs: Kron | None = None @@ -1428,7 +1396,7 @@ def _curv_closure( X: torch.Tensor | MutableMapping[str, torch.Tensor | Any], y: torch.Tensor, N: int, - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor]: return self.backend.kron(X, y, N=N, **self._asdl_fisher_kwargs) @staticmethod @@ -1522,13 +1490,13 @@ def prior_precision(self, prior_precision: torch.Tensor) -> None: if len(self.prior_precision) not in [1, self.n_layers]: raise ValueError('Prior precision for Kron either scalar or per-layer.') - def state_dict(self) -> dict: + def state_dict(self) -> dict[str, Any]: state_dict = super().state_dict() assert isinstance(self.H_facs, Kron) state_dict['H'] = self.H_facs.kfacs return state_dict - def load_state_dict(self, state_dict: dict): + def load_state_dict(self, state_dict: dict[str, Any]): super().load_state_dict(state_dict) self._init_H() assert isinstance(self.H, Kron) @@ -1562,7 +1530,7 @@ def __init__( prior_mean: float | torch.Tensor = 0, temperature: float = 1, enable_backprop: bool = False, - backend_kwargs: dict | None = None, + backend_kwargs: dict[str, Any] | None = None, ): super().__init__( model, @@ -1578,7 +1546,7 @@ def __init__( self.backend: AsdfghjklHessian def _init_H(self): - self.H: Tuple[torch.Tensor, torch.Tensor] | None = None + self.H: tuple[torch.Tensor, torch.Tensor] | None = None @property def V(self) -> torch.Tensor: @@ -1625,7 +1593,7 @@ def fit( @property def posterior_precision( self, - ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]: + ) -> tuple[tuple[torch.Tensor, torch.Tensor], torch.Tensor]: """Return correctly scaled posterior precision that would be constructed as H[0] @ diag(H[1]) @ H[0].T + self.prior_precision_diag. @@ -1675,7 +1643,7 @@ def sample( @property def log_det_posterior_precision(self) -> torch.Tensor: - (U, eigvals), prior_prec_diag = self.posterior_precision + (_, eigvals), prior_prec_diag = self.posterior_precision return ( eigvals.log().sum() + prior_prec_diag.log().sum() - torch.logdet(self.Kinv) ) @@ -1699,7 +1667,7 @@ def _curv_closure( X: torch.Tensor | MutableMapping[str, torch.Tensor | Any], y: torch.Tensor, N: int, - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor]: return self.backend.diag(X, y, N=N, **self._asdl_fisher_kwargs) @property diff --git a/laplace/curvature/asdfghjkl.py b/laplace/curvature/asdfghjkl.py index 7b0d9a4b..b1d43d78 100644 --- a/laplace/curvature/asdfghjkl.py +++ b/laplace/curvature/asdfghjkl.py @@ -1,15 +1,19 @@ +from collections.abc import MutableMapping +from typing import Any import warnings import numpy as np import torch +from torch import nn from asdfghjkl import FISHER_EXACT, FISHER_MC, COV from asdfghjkl import SHAPE_KRON, SHAPE_DIAG, SHAPE_FULL from asdfghjkl import fisher_for_cross_entropy from asdfghjkl.hessian import hessian_eigenvalues, hessian_for_loss from asdfghjkl.gradient import batch_gradient +from torch.utils.data import DataLoader from laplace.curvature import CurvatureInterface, GGNInterface, EFInterface -from laplace.utils import Kron, _is_batchnorm +from laplace.utils import Kron, _is_batchnorm, Likelihood EPS = 1e-6 @@ -17,7 +21,11 @@ class AsdfghjklInterface(CurvatureInterface): """Interface for asdfghjkl backend.""" - def jacobians(self, x, enable_backprop=False): + def jacobians( + self, + x: torch.Tensor | MutableMapping[str, torch.Tensor | Any], + enable_backprop: bool = False, + ) -> tuple[torch.Tensor, torch.Tensor]: """Compute Jacobians \\(\\nabla_\\theta f(x;\\theta)\\) at current parameter \\(\\theta\\) using asdfghjkl's gradient per output dimension. @@ -49,7 +57,11 @@ def loss_fn(outputs, targets): Js = torch.stack(Js, dim=1) return Js, f - def gradients(self, x, y): + def gradients( + self, + x: torch.Tensor | MutableMapping[str, torch.Tensor | Any], + y: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: """Compute gradients \\(\\nabla_\\theta \\ell(f(x;\\theta, y)\\) at current parameter \\(\\theta\\) using asdfghjkl's backend. @@ -73,10 +85,10 @@ def gradients(self, x, y): return Gs, loss @property - def _ggn_type(self): + def _ggn_type(self) -> str: raise NotImplementedError - def _get_kron_factors(self, curv, M): + def _get_kron_factors(self, curv, M: int) -> Kron: kfacs = list() for module in curv._model.modules(): if _is_batchnorm(module): @@ -101,36 +113,47 @@ def _get_kron_factors(self, curv, M): return Kron(kfacs) @staticmethod - def _rescale_kron_factors(kron, N): + def _rescale_kron_factors(kron: Kron, N: int) -> Kron: for F in kron.kfacs: if len(F) == 2: F[1] *= 1 / N return kron - def diag(self, X, y, **kwargs): + def diag( + self, + x: torch.Tensor | MutableMapping[str, torch.Tensor | Any], + y: torch.Tensor, + **kwargs: dict[str, Any], + ) -> tuple[torch.Tensor, torch.Tensor]: with torch.no_grad(): if self.last_layer: - f, X = self.model.forward_with_features(X) + f, x = self.model.forward_with_features(x) else: - f = self.model(X) + f = self.model(x) loss = self.lossfunc(f, y) curv = fisher_for_cross_entropy( - self._model, self._ggn_type, SHAPE_DIAG, inputs=X, targets=y + self._model, self._ggn_type, SHAPE_DIAG, inputs=x, targets=y ) diag_ggn = curv.matrices_to_vector(None) if self.subnetwork_indices is not None: diag_ggn = diag_ggn[self.subnetwork_indices] return self.factor * loss, self.factor * diag_ggn - def kron(self, X, y, N, **kwargs): + def kron( + self, + x: torch.Tensor | MutableMapping[str, torch.Tensor | Any], + y: torch.Tensor, + N: int, + **kwargs: dict[str, Any], + ) -> tuple[torch.Tensor, Kron]: with torch.no_grad(): if self.last_layer: - f, X = self.model.forward_with_features(X) + f, x = self.model.forward_with_features(x) else: - f = self.model(X) + f = self.model(x) loss = self.lossfunc(f, y) curv = fisher_for_cross_entropy( - self._model, self._ggn_type, SHAPE_KRON, inputs=X, targets=y + self._model, self._ggn_type, SHAPE_KRON, inputs=x, targets=y ) M = len(y) kron = self._get_kron_factors(curv, M) @@ -139,21 +162,34 @@ def kron(self, X, y, N, **kwargs): class AsdfghjklHessian(AsdfghjklInterface): - def __init__(self, model, likelihood, last_layer=False, low_rank=10): + def __init__( + self, + model: nn.Module, + likelihood: Likelihood | str, + last_layer: bool = False, + low_rank: int = 10, + ) -> None: super().__init__(model, likelihood, last_layer) - self.low_rank = low_rank + self.low_rank: int = low_rank @property - def _ggn_type(self): + def _ggn_type(self) -> str: raise NotImplementedError() - def full(self, x, y, **kwargs): + def full( + self, + x: torch.Tensor | MutableMapping[str, torch.Tensor | Any], + y: torch.Tensor, + **kwargs: dict[str, Any], + ) -> tuple[torch.Tensor, torch.Tensor]: hessian_for_loss(self.model, self.lossfunc, SHAPE_FULL, x, y) H = self._model.hessian.data loss = self.lossfunc(self.model(x), y).detach() return self.factor * loss, self.factor * H - def eig_lowrank(self, data_loader): + def eig_lowrank( + self, data_loader: DataLoader + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: # compute truncated eigendecomposition of the Hessian, only keep eigvals > EPS eigvals, eigvecs = hessian_eigenvalues( self.model, @@ -183,43 +219,45 @@ class AsdfghjklGGN(AsdfghjklInterface, GGNInterface): def __init__( self, - model, - likelihood, - last_layer=False, - subnetwork_indices=None, - stochastic=False, - ): - if likelihood != 'classification': + model: nn.Module, + likelihood: Likelihood | str, + last_layer: bool = False, + subnetwork_indices: torch.LongTensor | None = None, + stochastic: bool = False, + ) -> None: + if likelihood != Likelihood.CLASSIFICATION: raise ValueError('This backend only supports classification currently.') super().__init__(model, likelihood, last_layer, subnetwork_indices) - self.stochastic = stochastic + self.stochastic: bool = stochastic @property - def _ggn_type(self): + def _ggn_type(self) -> str: return FISHER_MC if self.stochastic else FISHER_EXACT class AsdfghjklEF(AsdfghjklInterface, EFInterface): """Implementation of the `EFInterface` using asdfghjkl.""" - def __init__(self, model, likelihood, last_layer=False): - if likelihood != 'classification': + def __init__( + self, model: nn.Module, likelihood: Likelihood | None, last_layer: bool = False + ) -> None: + if likelihood != Likelihood.CLASSIFICATION: raise ValueError('This backend only supports classification currently.') super().__init__(model, likelihood, last_layer) @property - def _ggn_type(self): + def _ggn_type(self) -> str: return COV -def _flatten_after_batch(tensor: torch.Tensor): +def _flatten_after_batch(tensor: torch.Tensor) -> torch.Tensor: if tensor.ndim == 1: return tensor.unsqueeze(-1) else: return tensor.flatten(start_dim=1) -def _get_batch_grad(model): +def _get_batch_grad(model: nn.Module) -> torch.Tensor: batch_grads = list() for module in model.modules(): if hasattr(module, 'op_results'): diff --git a/laplace/curvature/asdl.py b/laplace/curvature/asdl.py index f0a88397..8ec4c081 100644 --- a/laplace/curvature/asdl.py +++ b/laplace/curvature/asdl.py @@ -1,7 +1,8 @@ +from typing import Any import warnings -import numpy as np import torch +from torch import nn from asdl.matrices import ( FISHER_EXACT, @@ -17,9 +18,9 @@ from asdl.gradient import batch_gradient from laplace.curvature import CurvatureInterface, GGNInterface, EFInterface -from laplace.utils import Kron, _is_batchnorm +from laplace.utils import Kron, _is_batchnorm, Likelihood -from collections import UserDict +from collections.abc import MutableMapping EPS = 1e-6 @@ -27,23 +28,34 @@ class AsdlInterface(CurvatureInterface): """Interface for asdfghjkl backend.""" - def __init__(self, model, likelihood, last_layer=False, subnetwork_indices=None): + def __init__( + self, + model: nn.Module, + likelihood: Likelihood | str, + last_layer: bool = False, + subnetwork_indices: torch.LongTensor | None = None, + ): super().__init__(model, likelihood, last_layer, subnetwork_indices) @property - def loss_type(self): - return LOSS_MSE if self.likelihood == 'regression' else LOSS_CROSS_ENTROPY + def loss_type(self) -> str: + return ( + LOSS_MSE if self.likelihood == Likelihood.REGRESSION else LOSS_CROSS_ENTROPY + ) - def jacobians(self, x, enable_backprop=False): + def jacobians( + self, + x: torch.Tensor | MutableMapping[str, torch.Tensor | Any], + enable_backprop: bool = False, + ) -> tuple[torch.Tensor, torch.Tensor]: """Compute Jacobians \\(\\nabla_\\theta f(x;\\theta)\\) at current parameter \\(\\theta\\) using asdfghjkl's gradient per output dimension. Parameters ---------- - x : torch.Tensor or UserDict + x : torch.Tensor or MutableMapping input data `(batch, input_shape)` on compatible device with model if torch.Tensor. - If UserDict, then at least contains key ['input_ids'] or ['input_ids_0', 'input_ids_1']. - The latter is specific for reward modeling. + Must contain the said tensor if dict-like. enable_backprop : bool, default = False whether to enable backprop through the Js and f w.r.t. x @@ -78,7 +90,9 @@ def closure(): Js = torch.stack(Js, dim=1) return Js, f - def gradients(self, x, y): + def gradients( + self, x: torch.Tensor | MutableMapping[str, torch.Tensor | Any], y: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: """Compute gradients \\(\\nabla_\\theta \\ell(f(x;\\theta, y)\\) at current parameter \\(\\theta\\) using asdfghjkl's backend. @@ -109,10 +123,10 @@ def closure(): return Gs, loss @property - def _ggn_type(self): + def _ggn_type(self) -> str: raise NotImplementedError - def _get_kron_factors(self, M): + def _get_kron_factors(self, M: int) -> Kron: kfacs = list() for module in self.model.modules(): if _is_batchnorm(module): @@ -138,16 +152,24 @@ def _get_kron_factors(self, M): return Kron(kfacs) @staticmethod - def _rescale_kron_factors(kron, N): + def _rescale_kron_factors(kron: Kron, N: int) -> Kron: for F in kron.kfacs: if len(F) == 2: F[1] *= 1 / N return kron - def diag(self, X, y, N=None, **kwargs): - del N + def diag( + self, + x: torch.Tensor | MutableMapping[str, torch.Tensor | Any], + y: torch.Tensor, + **kwargs: dict[str, Any], + ) -> tuple[torch.Tensor, torch.Tensor]: + if 'N' in kwargs: + del kwargs['N'] + if self.last_layer: - _, X = self.model.forward_with_features(X) + _, x = self.model.forward_with_features(x) + cfg = FisherConfig( fisher_type=self._ggn_type, loss_type=self.loss_type, @@ -158,13 +180,13 @@ def diag(self, X, y, N=None, **kwargs): fisher_maker = get_fisher_maker(self.model, cfg) y = y if self.loss_type == LOSS_MSE else y.view(-1) if 'emp' in self._ggn_type: - dummy = fisher_maker.setup_model_call(self._model, X) + dummy = fisher_maker.setup_model_call(self._model, x) dummy = ( dummy if self.loss_type == LOSS_MSE else dummy.view(-1, dummy.size(-1)) ) fisher_maker.setup_loss_call(self.lossfunc, dummy, y) else: - fisher_maker.setup_model_call(self._model, X) + fisher_maker.setup_model_call(self._model, x) f, _ = fisher_maker.forward_and_backward() # Assumes that the last dimension of f is of size outputs. f = f if self.loss_type == LOSS_MSE else f.view(-1, f.size(-1)) @@ -184,9 +206,15 @@ def diag(self, X, y, N=None, **kwargs): curv_factor = 1.0 # ASDL uses proper 1/2 * MSELoss return self.factor * loss, curv_factor * diag_ggn - def kron(self, X, y, N, **kwargs): + def kron( + self, + x: torch.Tensor | MutableMapping[str, torch.Tensor | Any], + y: torch.Tensor, + N: int, + **kwargs: dict[str, Any], + ) -> tuple[torch.Tensor, Kron]: if self.last_layer: - _, X = self.model.forward_with_features(X) + _, x = self.model.forward_with_features(x) cfg = FisherConfig( fisher_type=self._ggn_type, loss_type=self.loss_type, @@ -197,13 +225,13 @@ def kron(self, X, y, N, **kwargs): fisher_maker = get_fisher_maker(self.model, cfg) y = y if self.loss_type == LOSS_MSE else y.view(-1) if 'emp' in self._ggn_type: - dummy = fisher_maker.setup_model_call(self._model, X) + dummy = fisher_maker.setup_model_call(self._model, x) dummy = ( dummy if self.loss_type == LOSS_MSE else dummy.view(-1, dummy.size(-1)) ) fisher_maker.setup_loss_call(self.lossfunc, dummy, y) else: - fisher_maker.setup_model_call(self._model, X) + fisher_maker.setup_model_call(self._model, x) f, _ = fisher_maker.forward_and_backward() # Assumes that the last dimension of f is of size outputs. f = f if self.loss_type == LOSS_MSE else f.view(-1, f.size(-1)) @@ -218,14 +246,16 @@ def kron(self, X, y, N, **kwargs): return self.factor * loss, curv_factor * kron @staticmethod - def _get_batch_size(x): + def _get_batch_size( + x: torch.Tensor | MutableMapping[str, torch.Tensor | Any], + ) -> int | None: """ ASDL assumes that all leading dimensions are the batch size by default (batch_size = None). Here, we want to specify that only the first dimension is the actual batch size. This is the case for LLMs. """ # If x is UserDict, then it has weight-sharing dimension (from Huggingface datasets) - if isinstance(x, UserDict) or isinstance(x, dict): + if isinstance(x, MutableMapping): try: return x['input_ids'].shape[0] except KeyError: @@ -236,29 +266,45 @@ def _get_batch_size(x): class AsdlHessian(AsdlInterface): - def __init__(self, model, likelihood, last_layer=False, low_rank=10): + def __init__( + self, + model: nn.Module, + likelihood: Likelihood | str, + last_layer: bool = False, + low_rank: int = 10, + ) -> None: super().__init__(model, likelihood, last_layer) - self.low_rank = low_rank + self.low_rank: int = low_rank @property - def _ggn_type(self): + def _ggn_type(self) -> str: raise NotImplementedError() - def full(self, x, y, **kwargs): + def full( + self, + x: torch.Tensor | MutableMapping[str, torch.Tensor | Any], + y: torch.Tensor, + **kwargs: dict[str, Any], + ) -> tuple[torch.Tensor, torch.Tensor]: if self.last_layer: _, x = self.model.forward_with_features(x) + cfg = HessianConfig(hessian_shapes=[SHAPE_FULL]) hess_maker = HessianMaker(self.model, cfg) + dummy = hess_maker.setup_model_call(self._model, x) dummy = dummy if self.loss_type == LOSS_MSE else dummy.view(-1, dummy.size(-1)) y = y if self.loss_type == LOSS_MSE else y.view(-1) + hess_maker.setup_loss_call(self.lossfunc, dummy, y) hess_maker.forward_and_backward() + H = self._model.hessian.data f = self.model(x).detach() # Assumes that the last dimension of f is of size outputs. f = f if self.loss_type == LOSS_MSE else f.view(-1, f.size(-1)) loss = self.lossfunc(f, y) + return self.factor * loss, self.factor * H @@ -267,26 +313,28 @@ class AsdlGGN(AsdlInterface, GGNInterface): def __init__( self, - model, - likelihood, - last_layer=False, - subnetwork_indices=None, - stochastic=False, + model: nn.Module, + likelihood: Likelihood | str, + last_layer: bool = False, + subnetwork_indices: torch.LongTensor | None = None, + stochastic: bool = False, ): super().__init__(model, likelihood, last_layer, subnetwork_indices) - self.stochastic = stochastic + self.stochastic: bool = stochastic @property - def _ggn_type(self): + def _ggn_type(self) -> str: return FISHER_MC if self.stochastic else FISHER_EXACT class AsdlEF(AsdlInterface, EFInterface): """Implementation of the `EFInterface` using asdfghjkl.""" - def __init__(self, model, likelihood, last_layer=False): + def __init__( + self, model: nn.Module, likelihood: Likelihood | str, last_layer: bool = False + ): super().__init__(model, likelihood, last_layer) @property - def _ggn_type(self): + def _ggn_type(self) -> str: return FISHER_EMP diff --git a/laplace/curvature/backpack.py b/laplace/curvature/backpack.py index ee919417..7173f253 100644 --- a/laplace/curvature/backpack.py +++ b/laplace/curvature/backpack.py @@ -1,5 +1,7 @@ -from typing import Tuple +from collections.abc import MutableMapping +from typing import Any import torch +from torch import nn from backpack import backpack, extend, memory_cleanup from backpack.extensions import ( @@ -13,18 +15,28 @@ from backpack.context import CTX from laplace.curvature import CurvatureInterface, GGNInterface, EFInterface -from laplace.utils import Kron +from laplace.utils import Kron, Likelihood class BackPackInterface(CurvatureInterface): """Interface for Backpack backend.""" - def __init__(self, model, likelihood, last_layer=False, subnetwork_indices=None): + def __init__( + self, + model: nn.Module, + likelihood: Likelihood | str, + last_layer: bool = False, + subnetwork_indices: torch.LongTensor | None = None, + ) -> None: super().__init__(model, likelihood, last_layer, subnetwork_indices) extend(self._model) extend(self.lossfunc) - def jacobians(self, x, enable_backprop=False): + def jacobians( + self, + x: torch.Tensor | MutableMapping[str, torch.Tensor | Any], + enable_backprop: bool = False, + ) -> tuple[torch.Tensor, torch.Tensor]: """Compute Jacobians \\(\\nabla_{\\theta} f(x;\\theta)\\) at current parameter \\(\\theta\\) using backpack's BatchGrad per output dimension. Note that BackPACK doesn't play well with torch.func, so this method has to be overridden. @@ -43,6 +55,9 @@ def jacobians(self, x, enable_backprop=False): f : torch.Tensor output function `(batch, outputs)` """ + if isinstance(x, MutableMapping): + raise ValueError('BackPACK backend does not support dict-like inputs!') + model = extend(self.model) to_stack = [] for i in range(model.output_size): @@ -76,7 +91,9 @@ def jacobians(self, x, enable_backprop=False): else: return Jk.unsqueeze(-1).transpose(1, 2), f - def gradients(self, x, y): + def gradients( + self, x: torch.Tensor | MutableMapping[str, torch.Tensor | Any], y: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: """Compute gradients \\(\\nabla_\\theta \\ell(f(x;\\theta, y)\\) at current parameter \\(\\theta\\) using Backpack's BatchGrad. Note that BackPACK doesn't play well with torch.func, so this method has to be overridden. @@ -111,16 +128,16 @@ class BackPackGGN(BackPackInterface, GGNInterface): def __init__( self, - model, - likelihood, - last_layer=False, - subnetwork_indices=None, - stochastic=False, + model: nn.Module, + likelihood: Likelihood | str, + last_layer: bool = False, + subnetwork_indices: torch.LongTensor | None = None, + stochastic: bool = False, ): super().__init__(model, likelihood, last_layer, subnetwork_indices) - self.stochastic = stochastic + self.stochastic: bool = stochastic - def _get_diag_ggn(self): + def _get_diag_ggn(self) -> torch.Tensor: if self.stochastic: return torch.cat( [p.diag_ggn_mc.data.flatten() for p in self._model.parameters()] @@ -130,14 +147,14 @@ def _get_diag_ggn(self): [p.diag_ggn_exact.data.flatten() for p in self._model.parameters()] ) - def _get_kron_factors(self): + def _get_kron_factors(self) -> Kron: if self.stochastic: return Kron([p.kfac for p in self._model.parameters()]) else: return Kron([p.kflr for p in self._model.parameters()]) @staticmethod - def _rescale_kron_factors(kron, M, N): + def _rescale_kron_factors(kron: Kron, M: int, N: int) -> Kron: # Renormalize Kronecker factor to sum up correctly over N data points with batches of M # for M=N (full-batch) just M/N=1 for F in kron.kfacs: @@ -145,9 +162,14 @@ def _rescale_kron_factors(kron, M, N): F[1] *= M / N return kron - def diag(self, X, y, **kwargs): + def diag( + self, + x: torch.Tensor | MutableMapping[str, torch.Tensor | Any], + y: torch.Tensor, + **kwargs: dict[str, Any], + ) -> tuple[torch.Tensor, torch.Tensor]: context = DiagGGNMC if self.stochastic else DiagGGNExact - f = self.model(X) + f = self.model(x) # Assumes that the last dimension of f is of size outputs. f = f if self.likelihood == 'regression' else f.view(-1, f.size(-1)) y = y if self.likelihood == 'regression' else y.view(-1) @@ -160,9 +182,15 @@ def diag(self, X, y, **kwargs): return self.factor * loss.detach(), self.factor * dggn - def kron(self, X, y, N, **kwargs) -> Tuple[torch.Tensor, Kron]: + def kron( + self, + x: torch.Tensor | MutableMapping[str, torch.Tensor | Any], + y: torch.Tensor, + N: int, + **kwargs: dict[str, Any], + ) -> tuple[torch.Tensor, Kron]: context = KFAC if self.stochastic else KFLR - f = self.model(X) + f = self.model(x) # Assumes that the last dimension of f is of size outputs. f = f if self.likelihood == 'regression' else f.view(-1, f.size(-1)) y = y if self.likelihood == 'regression' else y.view(-1) @@ -178,8 +206,13 @@ def kron(self, X, y, N, **kwargs) -> Tuple[torch.Tensor, Kron]: class BackPackEF(BackPackInterface, EFInterface): """Implementation of `EFInterface` using Backpack.""" - def diag(self, X, y, **kwargs): - f = self.model(X) + def diag( + self, + x: torch.Tensor | MutableMapping[str, torch.Tensor | Any], + y: torch.Tensor, + **kwargs: dict[str, Any], + ) -> tuple[torch.Tensor, torch.Tensor]: + f = self.model(x) # Assumes that the last dimension of f is of size outputs. f = f if self.likelihood == 'regression' else f.view(-1, f.size(-1)) y = y if self.likelihood == 'regression' else y.view(-1) @@ -194,11 +227,17 @@ def diag(self, X, y, **kwargs): return self.factor * loss.detach(), self.factor * diag_EF - def kron(self, X, y, **kwargs): + def kron( + self, + x: torch.Tensor | MutableMapping[str, torch.Tensor | Any], + y: torch.Tensor, + N: int, + **kwargs: dict[str, Any], + ) -> tuple[torch.Tensor, Kron]: raise NotImplementedError('Unavailable through Backpack.') -def _cleanup(module): +def _cleanup(module: nn.Module) -> None: for child in module.children(): _cleanup(child) diff --git a/laplace/curvature/curvature.py b/laplace/curvature/curvature.py index 77868cc7..106f22a3 100644 --- a/laplace/curvature/curvature.py +++ b/laplace/curvature/curvature.py @@ -1,9 +1,9 @@ -from typing import List, Tuple +from typing import Callable, MutableMapping, Any import torch from torch import nn from torch.nn import MSELoss, CrossEntropyLoss -from laplace.baselaplace import Likelihood +from laplace.utils import Kron, Likelihood class CurvatureInterface: @@ -46,16 +46,20 @@ def __init__( self.subnetwork_indices: torch.LongTensor | None = subnetwork_indices if likelihood == 'regression': - self.lossfunc = MSELoss(reduction='sum') - self.factor = 0.5 + self.lossfunc: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = ( + MSELoss(reduction='sum') + ) + self.factor: float = 0.5 else: - self.lossfunc = CrossEntropyLoss(reduction='sum') - self.factor = 1.0 + self.lossfunc: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = ( + CrossEntropyLoss(reduction='sum') + ) + self.factor: float = 1.0 - self.params: List[torch.Tensor] = [ + self.params: list[nn.Parameter] = [ p for p in self._model.parameters() if p.requires_grad ] - self.params_dict: dict[str, torch.Tensor] = { + self.params_dict: dict[str, nn.Parameter] = { k: v for k, v in self._model.named_parameters() if v.requires_grad } self.buffers_dict: dict[str, torch.Tensor] = { @@ -67,8 +71,10 @@ def _model(self) -> nn.Module: return self.model.last_layer if self.last_layer else self.model def jacobians( - self, x: torch.Tensor, enable_backprop: bool = False - ) -> Tuple[torch.Tensor, torch.Tensor]: + self, + x: torch.Tensor | MutableMapping[str, torch.Tensor | Any], + enable_backprop: bool = False, + ) -> tuple[torch.Tensor, torch.Tensor]: """Compute Jacobians \\(\\nabla_{\\theta} f(x;\\theta)\\) at current parameter \\(\\theta\\), via torch.func. @@ -109,7 +115,7 @@ def model_fn_params_only(params_dict, buffers_dict): def functorch_jacobians( self, x: torch.Tensor, enable_backprop: bool = False - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor]: """Compute Jacobians \\(\\nabla_\\theta f(x;\\theta)\\) at current parameter \\(\\theta\\). Parameters @@ -149,7 +155,11 @@ def model_fn_params_only(params_dict): return (Js, f) if enable_backprop else (Js.detach(), f.detach()) - def last_layer_jacobians(self, x: torch.Tensor, enable_backprop=False): + def last_layer_jacobians( + self, + x: torch.Tensor | MutableMapping[str, torch.Tensor | Any], + enable_backprop: bool = False, + ) -> tuple[torch.Tensor, torch.Tensor]: """Compute Jacobians \\(\\nabla_{\\theta_\\textrm{last}} f(x;\\theta_\\textrm{last})\\) only at current last-layer parameter \\(\\theta_{\\textrm{last}}\\). @@ -182,7 +192,9 @@ def last_layer_jacobians(self, x: torch.Tensor, enable_backprop=False): return Js, f - def gradients(self, x, y): + def gradients( + self, x: torch.Tensor | MutableMapping[str, torch.Tensor | Any], y: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: """Compute batch gradients \\(\\nabla_\\theta \\ell(f(x;\\theta, y)\\) at current parameter \\(\\theta\\). @@ -223,7 +235,12 @@ def loss_single(x, y, params_dict, buffers_dict): return Gs, loss - def full(self, x, y, **kwargs): + def full( + self, + x: torch.Tensor | MutableMapping[str, torch.Tensor | Any], + y: torch.Tensor, + **kwargs: dict[str, Any], + ): """Compute a dense curvature (approximation) in the form of a \\(P \\times P\\) matrix \\(H\\) with respect to parameters \\(\\theta \\in \\mathbb{R}^P\\). @@ -242,7 +259,13 @@ def full(self, x, y, **kwargs): """ raise NotImplementedError - def kron(self, x, y, **kwargs): + def kron( + self, + x: torch.Tensor | MutableMapping[str, torch.Tensor | Any], + y: torch.Tensor, + N: int, + **kwargs: dict[str, Any], + ) -> tuple[torch.Tensor, Kron]: """Compute a Kronecker factored curvature approximation (such as KFAC). The approximation to \\(H\\) takes the form of two Kronecker factors \\(Q, H\\), i.e., \\(H \\approx Q \\otimes H\\) for each Module in the neural network permitting @@ -256,6 +279,8 @@ def kron(self, x, y, **kwargs): input data `(batch, input_shape)` y : torch.Tensor labels `(batch, label_shape)` + N : int + total number of data points Returns ------- @@ -265,7 +290,12 @@ def kron(self, x, y, **kwargs): """ raise NotImplementedError - def diag(self, x, y, **kwargs): + def diag( + self, + x: torch.Tensor | MutableMapping[str, torch.Tensor | Any], + y: torch.Tensor, + **kwargs: dict[str, Any], + ): """Compute a diagonal Hessian approximation to \\(H\\) and is represented as a vector of the dimensionality of parameters \\(\\theta\\). @@ -308,18 +338,18 @@ class GGNInterface(CurvatureInterface): def __init__( self, - model, - likelihood, - last_layer=False, - subnetwork_indices=None, - stochastic=False, - num_samples=1, - ): - self.stochastic = stochastic - self.num_samples = num_samples + model: nn.Module, + likelihood: Likelihood | str, + last_layer: bool = False, + subnetwork_indices: torch.LongTensor | None = None, + stochastic: bool = False, + num_samples: int = 1, + ) -> None: + self.stochastic: bool = stochastic + self.num_samples: int = num_samples super().__init__(model, likelihood, last_layer, subnetwork_indices) - def _get_mc_functional_fisher(self, f): + def _get_mc_functional_fisher(self, f: torch.Tensor) -> torch.Tensor: """Approximate the Fisher's middle matrix (expected outer product of the functional gradient) using MC integral with `self.num_samples` many samples. """ @@ -343,7 +373,7 @@ def _get_mc_functional_fisher(self, f): return F - def _get_functional_hessian(self, f): + def _get_functional_hessian(self, f: torch.Tensor) -> torch.Tensor | None: if self.likelihood == 'regression': return None else: @@ -352,7 +382,12 @@ def _get_functional_hessian(self, f): G = torch.diag_embed(ps) - torch.einsum('mk,mc->mck', ps, ps) return G - def full(self, x, y, **kwargs): + def full( + self, + x: torch.Tensor | MutableMapping[str, torch.Tensor | Any], + y: torch.Tensor, + **kwargs: dict[str, Any], + ) -> tuple[torch.Tensor, torch.Tensor]: """Compute the full GGN \\(P \\times P\\) matrix as Hessian approximation \\(H_{ggn}\\) with respect to parameters \\(\\theta \\in \\mathbb{R}^P\\). For last-layer, reduced to \\(\\theta_{last}\\) @@ -385,8 +420,13 @@ def full(self, x, y, **kwargs): return loss.detach(), H.detach() - def diag(self, X, y, **kwargs): - Js, f = self.last_layer_jacobians(X) if self.last_layer else self.jacobians(X) + def diag( + self, + x: torch.Tensor | MutableMapping[str, torch.Tensor | Any], + y: torch.Tensor, + **kwargs: dict[str, Any], + ) -> tuple[torch.Tensor, torch.Tensor]: + Js, f = self.last_layer_jacobians(x) if self.last_layer else self.jacobians(x) loss = self.factor * self.lossfunc(f, y) H_lik = ( @@ -426,7 +466,12 @@ class EFInterface(CurvatureInterface): For example, \\(\\frac{1}{2}\\) to get to \\(\\mathcal{N}(f, 1)\\) from MSELoss. """ - def full(self, x, y, **kwargs): + def full( + self, + x: torch.Tensor | MutableMapping[str, torch.Tensor | Any], + y: torch.Tensor, + **kwargs: dict[str, Any], + ) -> tuple[torch.Tensor, torch.Tensor]: """Compute the full EF \\(P \\times P\\) matrix as Hessian approximation \\(H_{ef}\\) with respect to parameters \\(\\theta \\in \\mathbb{R}^P\\). For last-layer, reduced to \\(\\theta_{last}\\) @@ -448,8 +493,13 @@ def full(self, x, y, **kwargs): H_ef = torch.einsum('bp,bq->pq', Gs, Gs) return self.factor * loss.detach(), self.factor * H_ef - def diag(self, X, y, **kwargs): + def diag( + self, + x: torch.Tensor | MutableMapping[str, torch.Tensor | Any], + y: torch.Tensor, + **kwargs: dict[str, Any], + ) -> tuple[torch.Tensor, torch.Tensor]: # Gs is (batchsize, n_params) - Gs, loss = self.gradients(X, y) + Gs, loss = self.gradients(x, y) diag_ef = torch.einsum('bp,bp->p', Gs, Gs) return self.factor * loss.detach(), self.factor * diag_ef diff --git a/laplace/curvature/curvlinops.py b/laplace/curvature/curvlinops.py index 20b1e49e..b72bc2ec 100644 --- a/laplace/curvature/curvlinops.py +++ b/laplace/curvature/curvlinops.py @@ -1,6 +1,9 @@ +from typing import Any import torch -import numpy as np +from torch import nn +from collections.abc import MutableMapping +from curvlinops._base import _LinearOperator from curvlinops import ( HessianLinearOperator, GGNLinearOperator, @@ -10,35 +13,39 @@ ) from laplace.curvature import CurvatureInterface, GGNInterface, EFInterface -from laplace.utils import Kron - -from collections import UserDict +from laplace.utils import Kron, Likelihood class CurvlinopsInterface(CurvatureInterface): """Interface for Curvlinops backend. """ - def __init__(self, model, likelihood, last_layer=False, subnetwork_indices=None): + def __init__( + self, + model: nn.Module, + likelihood: Likelihood | str, + last_layer: bool = False, + subnetwork_indices: torch.LongTensor | None = None, + ) -> None: super().__init__(model, likelihood, last_layer, subnetwork_indices) @property - def _kron_fisher_type(self): + def _kron_fisher_type(self) -> str: raise NotImplementedError @property - def _linop_context(self): + def _linop_context(self) -> type[_LinearOperator]: raise NotImplementedError @staticmethod - def _rescale_kron_factors(kron, M, N): - # Renormalize Kronecker factor to sum up correctly over N data points with batches of M - # for M=N (full-batch) just M/N=1 + def _rescale_kron_factors(kron: Kron, M: int, N: int) -> Kron: + # Renormalize Kronecker factor to sum up correctly over N data points with + # batches of M. For M=N (full-batch) just M/N=1 for F in kron.kfacs: if len(F) == 2: F[1] *= M / N return kron - def _get_kron_factors(self, linop): + def _get_kron_factors(self, linop: KFACLinearOperator) -> Kron: kfacs = list() for name, module in self.model.named_modules(): if name not in linop._mapping.keys(): @@ -60,14 +67,21 @@ def _get_kron_factors(self, linop): raise ValueError(f'Whats happening with {module}?') return Kron(kfacs) - def kron(self, X, y, N, **kwargs): - if isinstance(X, (dict, UserDict)): - kwargs['batch_size_fn'] = lambda x: x['input_ids'].shape[0] + def kron( + self, + x: torch.Tensor | MutableMapping[str, torch.Tensor | Any], + y: torch.Tensor, + N: int, + **kwargs: dict[str, Any], + ) -> tuple[torch.Tensor, Kron]: + if isinstance(x, MutableMapping): + kwargs['batch_size_fn'] = lambda _x: _x['input_ids'].shape[0] + linop = KFACLinearOperator( self.model, self.lossfunc, self.params, - [(X, y)], + [(x, y)], fisher_type=self._kron_fisher_type, loss_average=None, # Since self.lossfunc is sum separate_weight_and_bias=True, @@ -83,24 +97,29 @@ def kron(self, X, y, N, **kwargs): kron = self._rescale_kron_factors(kron, len(y), N) kron *= self.factor - loss = self.lossfunc(self.model(X), y) + loss = self.lossfunc(self.model(x), y) return self.factor * loss.detach(), kron - def full(self, X, y, **kwargs): + def full( + self, + x: torch.Tensor | MutableMapping[str, torch.Tensor | Any], + y: torch.Tensor, + **kwargs: dict[str, Any], + ) -> tuple[torch.Tensor, torch.Tensor]: # Fallback to torch.func backend for SubnetLaplace if self.subnetwork_indices is not None: - return super().full(X, y, **kwargs) + return super().full(x, y, **kwargs) curvlinops_kwargs = {k: v for k, v in kwargs.items() if k != 'N'} - if isinstance(X, (dict, UserDict)): - curvlinops_kwargs['batch_size_fn'] = lambda x: x['input_ids'].shape[0] + if isinstance(x, MutableMapping): + curvlinops_kwargs['batch_size_fn'] = lambda _x: _x['input_ids'].shape[0] linop = self._linop_context( self.model, self.lossfunc, self.params, - [(X, y)], + [(x, y)], check_deterministic=False, **curvlinops_kwargs, ) @@ -109,7 +128,7 @@ def full(self, X, y, **kwargs): device=next(self.model.parameters()).device, ) - f = self.model(X) + f = self.model(x) loss = self.lossfunc(f, y) return self.factor * loss.detach(), self.factor * H @@ -120,21 +139,21 @@ class CurvlinopsGGN(CurvlinopsInterface, GGNInterface): def __init__( self, - model, - likelihood, - last_layer=False, - subnetwork_indices=None, - stochastic=False, - ): + model: nn.Module, + likelihood: Likelihood | str, + last_layer: bool = False, + subnetwork_indices: torch.LongTensor | None = None, + stochastic: bool = False, + ) -> None: super().__init__(model, likelihood, last_layer, subnetwork_indices) - self.stochastic = stochastic + self.stochastic: bool = stochastic @property - def _kron_fisher_type(self): + def _kron_fisher_type(self) -> str: return 'mc' if self.stochastic else 'type-2' @property - def _linop_context(self): + def _linop_context(self) -> type[_LinearOperator]: return FisherMCLinearOperator if self.stochastic else GGNLinearOperator @@ -142,11 +161,11 @@ class CurvlinopsEF(CurvlinopsInterface, EFInterface): """Implementation of `EFInterface` using Curvlinops.""" @property - def _kron_fisher_type(self): + def _kron_fisher_type(self) -> str: return 'empirical' @property - def _linop_context(self): + def _linop_context(self) -> type[_LinearOperator]: return EFLinearOperator @@ -154,5 +173,5 @@ class CurvlinopsHessian(CurvlinopsInterface): """Implementation of the full Hessian using Curvlinops.""" @property - def _linop_context(self): + def _linop_context(self) -> type[_LinearOperator]: return HessianLinearOperator diff --git a/laplace/laplace.py b/laplace/laplace.py index 9469c725..1dc8da9f 100644 --- a/laplace/laplace.py +++ b/laplace/laplace.py @@ -1,9 +1,9 @@ -from laplace.baselaplace import ( +from laplace.utils.enums import ( SubsetOfWeights, HessianStructure, Likelihood, - ParametricLaplace, ) +from laplace.baselaplace import ParametricLaplace import torch diff --git a/laplace/lllaplace.py b/laplace/lllaplace.py index 059e0e32..be064d8f 100644 --- a/laplace/lllaplace.py +++ b/laplace/lllaplace.py @@ -15,7 +15,7 @@ from laplace.utils import FeatureExtractor, Kron from collections.abc import MutableMapping -from typing import Tuple, Type, Any +from typing import Any __all__ = ['LLLaplace', 'FullLLLaplace', 'KronLLLaplace', 'DiagLLLaplace'] @@ -79,10 +79,10 @@ def __init__( prior_mean: float | torch.Tensor = 0.0, temperature: float = 1.0, enable_backprop: bool = False, - backend: Type[CurvatureInterface] | None = None, + backend: type[CurvatureInterface] | None = None, last_layer_name: str | None = None, - backend_kwargs: dict | None = None, - asdl_fisher_kwargs: dict | None = None, + backend_kwargs: dict[str, Any] | None = None, + asdl_fisher_kwargs: dict[str, Any] | None = None, ): if asdl_fisher_kwargs is not None: raise ValueError('Last-layer Laplace does not support asdl_fisher_kwargs.') @@ -151,7 +151,7 @@ def fit( self.model.eval() if self.model.last_layer is None: - self.data: Tuple[torch.Tensor, torch.Tensor] | MutableMapping = next( + self.data: tuple[torch.Tensor, torch.Tensor] | MutableMapping = next( iter(train_loader) ) self._find_last_layer(self.data) @@ -175,7 +175,7 @@ def fit( def _glm_predictive_distribution( self, X: torch.Tensor | MutableMapping, joint: bool = False - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor]: Js, f_mu = self.backend.last_layer_jacobians(X) if joint: @@ -249,13 +249,13 @@ def prior_precision_diag(self) -> torch.Tensor: else: raise ValueError('Mismatch of prior and model. Diagonal or scalar prior.') - def state_dict(self) -> dict: + def state_dict(self) -> dict[str, Any]: state_dict = super().state_dict() state_dict['data'] = getattr(self, 'data', None) # None if not present state_dict['_last_layer_name'] = self._last_layer_name return state_dict - def load_state_dict(self, state_dict: dict) -> None: + def load_state_dict(self, state_dict: dict[str, Any]) -> None: if self._last_layer_name != state_dict['_last_layer_name']: raise ValueError('Different `last_layer_name` detected!') @@ -322,11 +322,11 @@ def __init__( prior_mean: float | torch.Tensor = 0.0, temperature: float = 1.0, enable_backprop: bool = False, - backend: Type[CurvatureInterface] | None = None, + backend: type[CurvatureInterface] | None = None, last_layer_name: str | None = None, damping: bool = False, - backend_kwargs: dict | None = None, - asdl_fisher_kwargs: dict | None = None, + backend_kwargs: dict[str, Any] | None = None, + asdl_fisher_kwargs: dict[str, Any] | None = None, ): self.damping = damping super().__init__( @@ -343,7 +343,7 @@ def __init__( asdl_fisher_kwargs, ) - def _init_H(self): + def _init_H(self) -> None: self.H = Kron.init_from_model(self.model.last_layer, self._device) diff --git a/laplace/marglik_training.py b/laplace/marglik_training.py index ac8778c1..651b4430 100644 --- a/laplace/marglik_training.py +++ b/laplace/marglik_training.py @@ -11,11 +11,17 @@ from collections import UserDict import tqdm -from laplace import Laplace, Likelihood, HessianStructure, PriorStructure -from laplace.baselaplace import SubsetOfWeights +from laplace import Laplace from laplace.curvature import AsdlGGN from laplace.curvature.curvature import CurvatureInterface -from laplace.utils import expand_prior_precision, fix_prior_prec_structure +from laplace.utils import ( + expand_prior_precision, + fix_prior_prec_structure, + Likelihood, + SubsetOfWeights, + HessianStructure, + PriorStructure, +) from typing import Type diff --git a/laplace/subnetlaplace.py b/laplace/subnetlaplace.py index f6e80cf5..c351dfa4 100644 --- a/laplace/subnetlaplace.py +++ b/laplace/subnetlaplace.py @@ -181,7 +181,7 @@ class FullSubnetLaplace(SubnetLaplace, FullLaplace): # key to map to correct subclass of BaseLaplace, (subset of weights, Hessian structure) _key = ('subnetwork', 'full') - def _init_H(self): + def _init_H(self) -> None: self.H = torch.zeros( self.n_params_subnet, self.n_params_subnet, device=self._device ) diff --git a/laplace/utils/__init__.py b/laplace/utils/__init__.py index fbe0d8ed..aace6733 100644 --- a/laplace/utils/__init__.py +++ b/laplace/utils/__init__.py @@ -27,6 +27,15 @@ LastLayerSubnetMask, ) from laplace.utils.metrics import RunningNLLMetric +from laplace.utils.enums import ( + SubsetOfWeights, + HessianStructure, + Likelihood, + PredType, + LinkApprox, + TuningMethod, + PriorStructure, +) __all__ = [ @@ -56,4 +65,11 @@ 'ModuleNameSubnetMask', 'LastLayerSubnetMask', 'RunningNLLMetric', + 'SubsetOfWeights', + 'HessianStructure', + 'Likelihood', + 'PredType', + 'LinkApprox', + 'TuningMethod', + 'PriorStructure', ] diff --git a/laplace/utils/enums.py b/laplace/utils/enums.py new file mode 100644 index 00000000..630c8af9 --- /dev/null +++ b/laplace/utils/enums.py @@ -0,0 +1,43 @@ +from enum import Enum + + +class SubsetOfWeights(str, Enum): + ALL = 'all' + LAST_LAYER = 'last_layer' + SUBNETWORK = 'subnetwork' + + +class HessianStructure(str, Enum): + FULL = 'full' + KRON = 'kron' + DIAG = 'diag' + LOWRANK = 'lowrank' + + +class Likelihood(str, Enum): + REGRESSION = 'regression' + CLASSIFICATION = 'classification' + REWARD_MODELING = 'reward_modeling' + + +class PredType(str, Enum): + GLM = 'glm' + NN = 'nn' + + +class LinkApprox(str, Enum): + MC = 'mc' + PROBIT = 'probit' + BRIDGE = 'bridge' + BRIDGE_NORM = 'bridge_norm' + + +class TuningMethod(str, Enum): + MARGLIK = 'marglik' + GRIDSEARCH = 'gridsearch' + + +class PriorStructure(str, Enum): + SCALAR = 'scalar' + DIAG = 'diag' + LAYERWISE = 'layerwise' From 746f0a824970a61f493d07381f8232ed300ac0d4 Mon Sep 17 00:00:00 2001 From: Agustinus Kristiadi Date: Sun, 28 Apr 2024 15:28:15 -0400 Subject: [PATCH 8/9] Typehinting utils --- laplace/baselaplace.py | 14 ++-- laplace/utils/feature_extractor.py | 26 +++--- laplace/utils/matrix.py | 76 +++++++++-------- laplace/utils/metrics.py | 6 +- laplace/utils/subnetmask.py | 127 +++++++++++++++++------------ laplace/utils/swag.py | 42 +++++----- laplace/utils/utils.py | 89 +++++++++++++------- 7 files changed, 224 insertions(+), 156 deletions(-) diff --git a/laplace/baselaplace.py b/laplace/baselaplace.py index a8f8fd37..baad8e49 100644 --- a/laplace/baselaplace.py +++ b/laplace/baselaplace.py @@ -1,6 +1,5 @@ from __future__ import annotations -from enum import Enum from math import sqrt, pi, log from typing import Callable, Any import numpy as np @@ -15,18 +14,17 @@ import warnings from torchmetrics import MeanSquaredError -from laplace.utils import ( +from laplace.curvature.curvature import CurvatureInterface +from laplace.curvature.asdfghjkl import AsdfghjklHessian +from laplace.curvature.curvlinops import CurvlinopsGGN, CurvlinopsEF +from laplace.utils.matrix import Kron, KronDecomposed +from laplace.utils.metrics import RunningNLLMetric +from laplace.utils.utils import ( invsqrt_precision, validate, - Kron, normal_samples, fix_prior_prec_structure, - RunningNLLMetric, ) -from laplace.curvature.curvature import CurvatureInterface -from laplace.curvature.asdfghjkl import AsdfghjklHessian -from laplace.curvature.curvlinops import CurvlinopsGGN, CurvlinopsEF -from laplace.utils.matrix import KronDecomposed from laplace.utils.enums import ( Likelihood, PredType, diff --git a/laplace/utils/feature_extractor.py b/laplace/utils/feature_extractor.py index e93c9118..cbf3734a 100644 --- a/laplace/utils/feature_extractor.py +++ b/laplace/utils/feature_extractor.py @@ -1,6 +1,7 @@ +from collections.abc import MutableMapping import torch import torch.nn as nn -from typing import Tuple, Callable, Optional +from typing import Any, Tuple, Callable, Optional __all__ = ['FeatureExtractor'] @@ -32,22 +33,25 @@ def __init__( enable_backprop: bool = False, ) -> None: super().__init__() - self.model = model - self._features = dict() - self.enable_backprop = enable_backprop + self.model: nn.Module = model + self._features: dict[str, torch.Tensor] = dict() + self.enable_backprop: bool = enable_backprop + self.last_layer: nn.Module | None if last_layer_name is None: self.last_layer = None else: self.set_last_layer(last_layer_name) - def forward(self, x: torch.Tensor) -> torch.Tensor: + def forward( + self, x: torch.Tensor | MutableMapping[str, torch.Tensor | Any] + ) -> torch.Tensor: """Forward pass. If the last layer is not known yet, it will be determined when this function is called for the first time. Parameters ---------- - x : torch.Tensor + x : torch.Tensor or a dict-like object containing the input tensors one batch of data to use as input for the forward pass """ if self.last_layer is None: @@ -59,7 +63,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return out def forward_with_features( - self, x: torch.Tensor + self, x: torch.Tensor | MutableMapping[str, torch.Tensor | Any] ) -> Tuple[torch.Tensor, torch.Tensor]: """Forward pass which returns the output of the penultimate layer along with the output of the last layer. If the last layer is not known yet, @@ -67,7 +71,7 @@ def forward_with_features( Parameters ---------- - x : torch.Tensor + x : torch.Tensor or a dict-like object containing the input tensors one batch of data to use as input for the forward pass """ out = self.forward(x) @@ -102,7 +106,9 @@ def hook(_, input, __): return hook - def find_last_layer(self, x: torch.Tensor) -> torch.Tensor: + def find_last_layer( + self, x: torch.Tensor | MutableMapping[str, torch.Tensor | Any] + ) -> torch.Tensor: """Automatically determines the last layer of the model with one forward pass. It assumes that the last layer is the same for every forward pass and that it is an instance of `torch.nn.Linear`. @@ -112,7 +118,7 @@ def find_last_layer(self, x: torch.Tensor) -> torch.Tensor: Parameters ---------- - x : torch.Tensor + x : torch.Tensor or dict-like object containing the input tensors one batch of data to use as input for the forward pass """ if self.last_layer is not None: diff --git a/laplace/utils/matrix.py b/laplace/utils/matrix.py index e58d8058..3c0ab9d1 100644 --- a/laplace/utils/matrix.py +++ b/laplace/utils/matrix.py @@ -1,7 +1,9 @@ +from __future__ import annotations from math import pow import torch +from torch import nn import numpy as np -from typing import Union +from typing import Iterable import opt_einsum as oe from laplace.utils import _is_valid_scalar, symeig, kron, block_diag @@ -19,16 +21,18 @@ class Kron: Parameters ---------- - kfacs : list[Tuple] - each element in the list is a Tuple of two Kronecker factors Q, H + kfacs : list[Iterable[torch.Tensor] | torch.Tensor] + each element in the list is a tuple of two Kronecker factors Q, H or a single matrix approximating the Hessian (in case of bias, for example) """ - def __init__(self, kfacs): - self.kfacs = kfacs + def __init__(self, kfacs: list[tuple[torch.Tensor] | torch.Tensor]) -> None: + self.kfacs: list[tuple[torch.Tensor] | torch.Tensor] = kfacs @classmethod - def init_from_model(cls, model, device): + def init_from_model( + cls, model: nn.Module | Iterable[nn.Parameter], device: torch.device + ) -> Kron: """Initialize Kronecker factors based on a models architecture. Parameters @@ -53,7 +57,7 @@ def init_from_model(cls, model, device): elif 4 >= p.ndim >= 2: # fully connected or conv if p.ndim == 2: # fully connected P_in, P_out = p.size() - elif p.ndim > 2: + else: P_in, P_out = p.shape[0], np.prod(p.shape[1:]) kfacs.append( @@ -66,7 +70,7 @@ def init_from_model(cls, model, device): raise ValueError('Invalid parameter shape in network.') return cls(kfacs) - def __add__(self, other): + def __add__(self, other: Kron) -> Kron: """Add up Kronecker factors `self` and `other`. Parameters @@ -87,7 +91,7 @@ def __add__(self, other): return Kron(kfacs) - def __mul__(self, scalar: Union[float, torch.Tensor]): + def __mul__(self, scalar: float | torch.Tensor) -> Kron: """Multiply all Kronecker factors by scalar. The multiplication is distributed across the number of factors using `pow(scalar, 1 / len(F))`. `len(F)` is either `1` or `2`. @@ -107,10 +111,10 @@ def __mul__(self, scalar: Union[float, torch.Tensor]): kfacs = [[pow(scalar, 1 / len(F)) * Hi for Hi in F] for F in self.kfacs] return Kron(kfacs) - def __len__(self): + def __len__(self) -> int: return len(self.kfacs) - def decompose(self, damping=False): + def decompose(self, damping: bool = False) -> KronDecomposed: """Eigendecompose Kronecker factors and turn into `KronDecomposed`. Parameters ---------- @@ -127,14 +131,14 @@ def decompose(self, damping=False): for Hi in F: if Hi.ndim > 1: # Dense Kronecker factor. - l, Q = symeig(Hi) + eigval, Q = symeig(Hi) else: # Diagonal Kronecker factor. - l = Hi + eigval = Hi # This might be too memory intensive since len(Hi) can be large. Q = torch.eye(len(Hi), dtype=Hi.dtype, device=Hi.device) Qs.append(Q) - ls.append(l) + ls.append(eigval) eigvecs.append(Qs) eigvals.append(ls) return KronDecomposed(eigvecs, eigvals, damping=damping) @@ -291,22 +295,28 @@ class KronDecomposed: use dampen approximation mixing prior and Kron partially multiplicatively """ - def __init__(self, eigenvectors, eigenvalues, deltas=None, damping=False): - self.eigenvectors = eigenvectors - self.eigenvalues = eigenvalues - device = eigenvectors[0][0].device + def __init__( + self, + eigenvectors: list[tuple[torch.Tensor]], + eigenvalues: list[tuple[torch.Tensor]], + deltas: torch.Tensor | None = None, + damping: bool = False, + ): + self.eigenvectors: list[tuple[torch.Tensor]] = eigenvectors + self.eigenvalues: list[tuple[torch.Tensor]] = eigenvalues + device: torch.device = eigenvectors[0][0].device if deltas is None: - self.deltas = torch.zeros(len(self), device=device) + self.deltas: torch.Tensor = torch.zeros(len(self), device=device) else: self._check_deltas(deltas) - self.deltas = deltas - self.damping = damping + self.deltas: torch.Tensor = deltas + self.damping: bool = damping - def detach(self): + def detach(self) -> KronDecomposed: self.deltas = self.deltas.detach() return self - def _check_deltas(self, deltas: torch.Tensor): + def _check_deltas(self, deltas: torch.Tensor) -> None: if not isinstance(deltas, torch.Tensor): raise ValueError('Can only add torch.Tensor to KronDecomposed.') @@ -318,7 +328,7 @@ def _check_deltas(self, deltas: torch.Tensor): else: raise ValueError('Invalid shape of delta added to KronDecomposed.') - def __add__(self, deltas: torch.Tensor): + def __add__(self, deltas: torch.Tensor) -> KronDecomposed: """Add scalar per layer or only scalar to Kronecker factors. Parameters @@ -333,7 +343,7 @@ def __add__(self, deltas: torch.Tensor): self._check_deltas(deltas) return KronDecomposed(self.eigenvectors, self.eigenvalues, self.deltas + deltas) - def __mul__(self, scalar): + def __mul__(self, scalar: torch.Tensor | float) -> KronDecomposed: """Multiply by a scalar by changing the eigenvalues. Same as for the case of `Kron`. @@ -490,12 +500,12 @@ def diag(self, exponent: float = 1) -> torch.Tensor: l1, l2 = ls if self.damping: delta_sqrt = torch.sqrt(delta) - l = torch.pow( + eigval = torch.pow( torch.outer(l1 + delta_sqrt, l2 + delta_sqrt), exponent ) else: - l = torch.pow(torch.outer(l1, l2) + delta, exponent) - d = oe.contract('mp,nq,pq,mp,nq->mn', Q1, Q2, l, Q1, Q2).flatten() + eigval = torch.pow(torch.outer(l1, l2) + delta, exponent) + d = oe.contract('mp,nq,pq,mp,nq->mn', Q1, Q2, eigval, Q1, Q2).flatten() diags.append(d) return torch.cat(diags) @@ -516,20 +526,20 @@ def to_matrix(self, exponent: float = 1) -> torch.Tensor: blocks = list() for Qs, ls, delta in zip(self.eigenvectors, self.eigenvalues, self.deltas): if len(ls) == 1: - Q, l = Qs[0], ls[0] - blocks.append(Q @ torch.diag(torch.pow(l + delta, exponent)) @ Q.T) + Q, eigval = Qs[0], ls[0] + blocks.append(Q @ torch.diag(torch.pow(eigval + delta, exponent)) @ Q.T) else: Q1, Q2 = Qs l1, l2 = ls Q = kron(Q1, Q2) if self.damping: delta_sqrt = torch.sqrt(delta) - l = torch.pow( + eigval = torch.pow( torch.outer(l1 + delta_sqrt, l2 + delta_sqrt), exponent ) else: - l = torch.pow(torch.outer(l1, l2) + delta, exponent) - L = torch.diag(l.flatten()) + eigval = torch.pow(torch.outer(l1, l2) + delta, exponent) + L = torch.diag(eigval.flatten()) blocks.append(Q @ L @ Q.T) return block_diag(blocks) diff --git a/laplace/utils/metrics.py b/laplace/utils/metrics.py index d327243f..3a3cda27 100644 --- a/laplace/utils/metrics.py +++ b/laplace/utils/metrics.py @@ -13,13 +13,13 @@ class RunningNLLMetric(Metric): which class label to ignore when computing the NLL loss """ - def __init__(self, ignore_index=-100): + def __init__(self, ignore_index: int = -100) -> None: super().__init__() self.add_state('nll_sum', default=torch.tensor(0.0), dist_reduce_fx='sum') self.add_state( 'n_valid_labels', default=torch.tensor(0.0), dist_reduce_fx='sum' ) - self.ignore_index = ignore_index + self.ignore_index: int = ignore_index def update(self, probs: torch.Tensor, targets: torch.Tensor) -> None: """ @@ -39,5 +39,5 @@ def update(self, probs: torch.Tensor, targets: torch.Tensor) -> None: ) self.n_valid_labels += (targets != self.ignore_index).sum() - def compute(self): + def compute(self) -> torch.Tensor: return self.nll_sum / self.n_valid_labels diff --git a/laplace/utils/subnetmask.py b/laplace/utils/subnetmask.py index db211cca..ac8a5d89 100644 --- a/laplace/utils/subnetmask.py +++ b/laplace/utils/subnetmask.py @@ -1,10 +1,15 @@ +from __future__ import annotations from copy import deepcopy import torch +from torch import nn from torch.nn import CrossEntropyLoss, MSELoss from torch.nn.utils import parameters_to_vector +from torch.utils.data import DataLoader +import laplace.baselaplace from laplace.utils import FeatureExtractor, fit_diagonal_swag_var +from laplace.utils.enums import Likelihood __all__ = [ @@ -27,34 +32,38 @@ class SubnetMask: model : torch.nn.Module """ - def __init__(self, model): - self.model = model - self.parameter_vector = parameters_to_vector(self.model.parameters()).detach() - self._n_params = len(self.parameter_vector) - self._indices = None - self._n_params_subnet = None + def __init__(self, model: nn.Module) -> None: + self.model: nn.Module = model + self.parameter_vector: torch.Tensor = parameters_to_vector( + self.model.parameters() + ).detach() + self._n_params: int = len(self.parameter_vector) + self._indices: torch.LongTensor | None = None + self._n_params_subnet: int | None = None - def _check_select(self): + def _check_select(self) -> None: if self._indices is None: raise AttributeError('Subnetwork mask not selected. Run select() first.') @property - def _device(self): + def _device(self) -> torch.device: return next(self.model.parameters()).device @property - def indices(self): + def indices(self) -> torch.LongTensor: self._check_select() return self._indices @property - def n_params_subnet(self): + def n_params_subnet(self) -> int: if self._n_params_subnet is None: self._check_select() self._n_params_subnet = len(self._indices) return self._n_params_subnet - def convert_subnet_mask_to_indices(self, subnet_mask): + def convert_subnet_mask_to_indices( + self, subnet_mask: torch.Tensor + ) -> torch.LongTensor: """Converts a subnetwork mask into subnetwork indices. Parameters @@ -103,7 +112,7 @@ def convert_subnet_mask_to_indices(self, subnet_mask): subnet_mask_indices = subnet_mask.nonzero(as_tuple=True)[0] return subnet_mask_indices - def select(self, train_loader=None): + def select(self, train_loader: DataLoader | None = None) -> torch.LongTensor: """Select the subnetwork mask. Parameters @@ -126,7 +135,7 @@ def select(self, train_loader=None): self._indices = self.convert_subnet_mask_to_indices(subnet_mask) return self._indices - def get_subnet_mask(self, train_loader): + def get_subnet_mask(self, train_loader: DataLoader) -> torch.Tensor: """Get the subnetwork mask. Parameters @@ -156,7 +165,7 @@ class ScoreBasedSubnetMask(SubnetMask): number of parameters in the subnetwork (i.e. number of top-scoring parameters to select) """ - def __init__(self, model, n_params_subnet): + def __init__(self, model: nn.Module, n_params_subnet: int) -> None: super().__init__(model) if n_params_subnet is None: @@ -168,20 +177,20 @@ def __init__(self, model, n_params_subnet): f'Subnetwork ({n_params_subnet}) cannot be larger than model ({self._n_params}).' ) self._n_params_subnet = n_params_subnet - self._param_scores = None + self._param_scores: torch.Tensor | None = None - def compute_param_scores(self, train_loader): + def compute_param_scores(self, train_loader: DataLoader) -> torch.Tensor: raise NotImplementedError - def _check_param_scores(self): + def _check_param_scores(self) -> None: + assert self._param_scores is not None if self._param_scores.shape != self.parameter_vector.shape: raise ValueError( 'Parameter scores need to be of same shape as parameter vector.' ) - def get_subnet_mask(self, train_loader): + def get_subnet_mask(self, train_loader: DataLoader) -> torch.Tensor: """Get the subnetwork mask by (descendingly) ranking parameters based on their scores.""" - if self._param_scores is None: self._param_scores = self.compute_param_scores(train_loader) self._check_param_scores() @@ -198,14 +207,14 @@ def get_subnet_mask(self, train_loader): class RandomSubnetMask(ScoreBasedSubnetMask): """Subnetwork mask of parameters sampled uniformly at random.""" - def compute_param_scores(self, train_loader): + def compute_param_scores(self, train_loader: DataLoader) -> torch.Tensor: return torch.rand_like(self.parameter_vector) class LargestMagnitudeSubnetMask(ScoreBasedSubnetMask): """Subnetwork mask identifying the parameters with the largest magnitude.""" - def compute_param_scores(self, train_loader): + def compute_param_scores(self, train_loader: DataLoader) -> torch.Tensor: return self.parameter_vector.abs() @@ -222,11 +231,16 @@ class LargestVarianceDiagLaplaceSubnetMask(ScoreBasedSubnetMask): diagonal Laplace model to use for variance estimation """ - def __init__(self, model, n_params_subnet, diag_laplace_model): + def __init__( + self, + model: nn.Module, + n_params_subnet: int, + diag_laplace_model: laplace.baselaplace.DiagLaplace, + ): super().__init__(model, n_params_subnet) - self.diag_laplace_model = diag_laplace_model + self.diag_laplace_model: laplace.baselaplace.DiagLaplace = diag_laplace_model - def compute_param_scores(self, train_loader): + def compute_param_scores(self, train_loader: DataLoader) -> torch.Tensor: if train_loader is None: raise ValueError('Need to pass train loader for subnet selection.') @@ -255,27 +269,32 @@ class LargestVarianceSWAGSubnetMask(ScoreBasedSubnetMask): def __init__( self, - model, - n_params_subnet, - likelihood='classification', - swag_n_snapshots=40, - swag_snapshot_freq=1, - swag_lr=0.01, + model: nn.Module, + n_params_subnet: int, + likelihood: Likelihood | str = Likelihood.CLASSIFICATION, + swag_n_snapshots: int = 40, + swag_snapshot_freq: int = 1, + swag_lr: float = 0.01, ): + if likelihood not in [Likelihood.CLASSIFICATION, Likelihood.REGRESSION]: + raise ValueError('Only available for classification and regression!') + super().__init__(model, n_params_subnet) - self.likelihood = likelihood - self.swag_n_snapshots = swag_n_snapshots - self.swag_snapshot_freq = swag_snapshot_freq - self.swag_lr = swag_lr - def compute_param_scores(self, train_loader): + self.likelihood: Likelihood | str = likelihood + self.swag_n_snapshots: int = swag_n_snapshots + self.swag_snapshot_freq: int = swag_snapshot_freq + self.swag_lr: float = swag_lr + + def compute_param_scores(self, train_loader: DataLoader) -> torch.Tensor: if train_loader is None: raise ValueError('Need to pass train loader for subnet selection.') - if self.likelihood == 'classification': + if self.likelihood == Likelihood.CLASSIFICATION: criterion = CrossEntropyLoss(reduction='mean') - elif self.likelihood == 'regression': + else: criterion = MSELoss(reduction='mean') + param_variances = fit_diagonal_swag_var( self.model, train_loader, @@ -298,15 +317,15 @@ class ParamNameSubnetMask(SubnetMask): that define the subnetwork """ - def __init__(self, model, parameter_names): + def __init__(self, model: nn.Module, parameter_names: list[str]) -> None: super().__init__(model) - self._parameter_names = parameter_names - self._n_params_subnet = None + self._parameter_names: list[str] = parameter_names + self._n_params_subnet: int | None = None - def _check_param_names(self): + def _check_param_names(self) -> None: param_names = deepcopy(self._parameter_names) if len(param_names) == 0: - raise ValueError(f'Parameter name list cannot be empty.') + raise ValueError('Parameter name list cannot be empty.') for name, _ in self.model.named_parameters(): if name in param_names: @@ -314,7 +333,7 @@ def _check_param_names(self): if len(param_names) > 0: raise ValueError(f'Parameters {param_names} do not exist in model.') - def get_subnet_mask(self, train_loader): + def get_subnet_mask(self, train_loader: DataLoader) -> torch.Tensor: """Get the subnetwork mask identifying the specified parameters.""" self._check_param_names() @@ -341,15 +360,15 @@ class ModuleNameSubnetMask(SubnetMask): the modules cannot have children, i.e. need to be leaf modules """ - def __init__(self, model, module_names): + def __init__(self, model: nn.Module, module_names: list[str]): super().__init__(model) - self._module_names = module_names - self._n_params_subnet = None + self._module_names: list[str] = module_names + self._n_params_subnet: int | None = None - def _check_module_names(self): + def _check_module_names(self) -> None: module_names = deepcopy(self._module_names) if len(module_names) == 0: - raise ValueError(f'Module name list cannot be empty.') + raise ValueError('Module name list cannot be empty.') for name, module in self.model.named_modules(): if name in module_names: @@ -364,7 +383,7 @@ def _check_module_names(self): if len(module_names) > 0: raise ValueError(f'Modules {module_names} do not exist in model.') - def get_subnet_mask(self, train_loader): + def get_subnet_mask(self, train_loader: DataLoader) -> torch.Tensor: """Get the subnetwork mask identifying the specified modules.""" self._check_module_names() @@ -394,14 +413,14 @@ class LastLayerSubnetMask(ModuleNameSubnetMask): name of the model's last layer, if None it will be determined automatically """ - def __init__(self, model, last_layer_name=None): - super().__init__(model, None) - self._feature_extractor = FeatureExtractor( + def __init__(self, model: nn.Module, last_layer_name: str | None = None): + super().__init__(model, []) + self._feature_extractor: FeatureExtractor = FeatureExtractor( self.model, last_layer_name=last_layer_name ) - self._n_params_subnet = None + self._n_params_subnet: int | None = None - def get_subnet_mask(self, train_loader): + def get_subnet_mask(self, train_loader: DataLoader) -> torch.Tensor: """Get the subnetwork mask identifying the last layer.""" if train_loader is None: diff --git a/laplace/utils/swag.py b/laplace/utils/swag.py index ade02a6e..2581b938 100644 --- a/laplace/utils/swag.py +++ b/laplace/utils/swag.py @@ -1,27 +1,30 @@ from copy import deepcopy import torch +from torch import nn from torch.nn.utils import parameters_to_vector +from torch.optim import Optimizer +from torch.utils.data import DataLoader __all__ = ['fit_diagonal_swag_var'] -def _param_vector(model): +def _param_vector(model: nn.Module) -> torch.Tensor: return parameters_to_vector(model.parameters()).detach() def fit_diagonal_swag_var( - model, - train_loader, - criterion, - n_snapshots_total=40, - snapshot_freq=1, - lr=0.01, - momentum=0.9, - weight_decay=3e-4, - min_var=1e-30, -): + model: nn.Module, + train_loader: DataLoader, + criterion: nn.CrossEntropyLoss | nn.MSELoss, + n_snapshots_total: int = 40, + snapshot_freq: int = 1, + lr: float = 0.01, + momentum: float = 0.9, + weight_decay: float = 3e-4, + min_var: float = 1e-30, +) -> torch.Tensor: """ Fit diagonal SWAG [1], which estimates marginal variances of model parameters by computing the first and second moment of SGD iterates with a large learning rate. @@ -63,20 +66,21 @@ def fit_diagonal_swag_var( """ # create a copy of the model to avoid undesired changes to the original model parameters - _model = deepcopy(model) + _model: nn.Module = deepcopy(model) _model.train() - device = next(_model.parameters()).device + device: torch.device = next(_model.parameters()).device # initialize running estimates of first and second moment of model parameters - mean = torch.zeros_like(_param_vector(_model)) - sq_mean = torch.zeros_like(_param_vector(_model)) - n_snapshots = 0 + mean: torch.Tensor = torch.zeros_like(_param_vector(_model)) + sq_mean: torch.Tensor = torch.zeros_like(_param_vector(_model)) + n_snapshots: int = 0 # run SGD to collect model snapshots - optimizer = torch.optim.SGD( + optimizer: Optimizer = torch.optim.SGD( _model.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay ) - n_epochs = snapshot_freq * n_snapshots_total + n_epochs: int = snapshot_freq * n_snapshots_total + for epoch in range(n_epochs): for inputs, targets in train_loader: inputs, targets = inputs.to(device), targets.to(device) @@ -93,5 +97,5 @@ def fit_diagonal_swag_var( n_snapshots += 1 # compute marginal parameter variances, Var[P] = E[P^2] - E[P]^2 - param_variances = torch.clamp(sq_mean - mean**2, min_var) + param_variances: torch.Tensor = torch.clamp(sq_mean - mean**2, min_var) return param_variances diff --git a/laplace/utils/utils.py b/laplace/utils/utils.py index 5fe4d749..c852226e 100644 --- a/laplace/utils/utils.py +++ b/laplace/utils/utils.py @@ -1,15 +1,21 @@ +from __future__ import annotations import logging -from typing import Union +from typing import Callable, Union import numpy as np import torch +from torch import nn import torch.nn.functional as F from torch.nn.utils import parameters_to_vector from torch.nn import BatchNorm1d, BatchNorm2d, BatchNorm3d from torch.distributions.multivariate_normal import _precision_to_scale_tril +from torch.utils.data import DataLoader from torchmetrics import Metric from collections import UserDict -import math +import torchmetrics + +import laplace.baselaplace +from laplace.utils.enums import LinkApprox, PredType, PriorStructure __all__ = [ 'get_nll', @@ -24,19 +30,21 @@ ] -def get_nll(out_dist, targets): +def get_nll(out_dist: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: return F.nll_loss(torch.log(out_dist), targets) @torch.no_grad() def validate( - laplace, - val_loader, - loss, - pred_type='glm', - link_approx='probit', - n_samples=100, - loss_with_var=False, + laplace: laplace.baselaplace.BaseLaplace, + val_loader: DataLoader, + loss: torchmetrics.Metric + | Callable[[torch.Tensor, torch.Tensor], torch.Tensor] + | Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor], + pred_type: PredType | str = PredType.GLM, + link_approx: LinkApprox | str = LinkApprox.PROBIT, + n_samples: int = 100, + loss_with_var: int = False, ) -> float: laplace.model.eval() assert callable(loss) or isinstance(loss, Metric) @@ -84,7 +92,7 @@ def validate( return loss.compute().item() -def parameters_per_layer(model): +def parameters_per_layer(model: nn.Module) -> list[int]: """Get number of parameters per layer. Parameters @@ -98,7 +106,7 @@ def parameters_per_layer(model): return [np.prod(p.shape) for p in model.parameters()] -def invsqrt_precision(M): +def invsqrt_precision(M: torch.Tensor) -> torch.Tensor: """Compute ``M^{-0.5}`` as a tridiagonal matrix. Parameters @@ -112,7 +120,7 @@ def invsqrt_precision(M): return _precision_to_scale_tril(M) -def _is_batchnorm(module): +def _is_batchnorm(module: nn.Module) -> bool: if ( isinstance(module, BatchNorm1d) or isinstance(module, BatchNorm2d) @@ -122,7 +130,7 @@ def _is_batchnorm(module): return False -def _is_valid_scalar(scalar: Union[float, int, torch.Tensor]) -> bool: +def _is_valid_scalar(scalar: float | int | torch.Tensor) -> bool: if np.isscalar(scalar) and np.isreal(scalar): return True elif torch.is_tensor(scalar) and scalar.ndim <= 1: @@ -132,7 +140,7 @@ def _is_valid_scalar(scalar: Union[float, int, torch.Tensor]) -> bool: return False -def kron(t1, t2): +def kron(t1: torch.Tensor, t2: torch.Tensor) -> torch.Tensor: """Computes the Kronecker product between two tensors. Parameters @@ -160,7 +168,7 @@ def kron(t1, t2): return expanded_t1 * tiled_t2 -def diagonal_add_scalar(X, value): +def diagonal_add_scalar(X: torch.Tensor, value: torch.Tensor) -> torch.Tensor: """Add scalar value `value` to diagonal of `X`. Parameters @@ -172,15 +180,12 @@ def diagonal_add_scalar(X, value): ------- X_add_scalar : torch.Tensor """ - if not X.device == torch.device('cpu'): - indices = torch.cuda.LongTensor([[i, i] for i in range(X.shape[0])]) - else: - indices = torch.LongTensor([[i, i] for i in range(X.shape[0])]) + indices = torch.LongTensor([[i, i] for i in range(X.shape[0])], device=X.device) values = X.new_ones(X.shape[0]).mul(value) return X.index_put(tuple(indices.t()), values, accumulate=True) -def symeig(M): +def symeig(M: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: """Symetric eigendecomposition avoiding failure cases by adding and removing jitter to the diagonal. @@ -216,7 +221,7 @@ def symeig(M): return L, W -def block_diag(blocks): +def block_diag(blocks: list[torch.Tensor]) -> torch.Tensor: """Compose block-diagonal matrix of individual blocks. Parameters @@ -237,7 +242,7 @@ def block_diag(blocks): return M -def expand_prior_precision(prior_prec, model): +def expand_prior_precision(prior_prec: torch.Tensor, model: nn.Module) -> torch.Tensor: """Expand prior precision to match the shape of the model parameters. Parameters @@ -270,20 +275,46 @@ def expand_prior_precision(prior_prec, model): def fix_prior_prec_structure( - prior_prec_init, prior_structure, n_layers, n_params, device -): - if prior_structure == 'scalar': + prior_prec_init: torch.Tensor, + prior_structure: PriorStructure | str, + n_layers: int, + n_params: int, + device: torch.device, +) -> torch.Tensor: + """Create a tensor of prior precision with the correct shape, depending on the + choice of the prior structure type. + + Parameters + ---------- + prior_prec_init: torch.Tensor + the initial prior precision tensor (could be scalar) + prior_structure: PriorStructure | str + the choice of the prior structure type + n_layers: int + n_params: int + device: torch.device + + Returns + ------- + correct_prior_precision: torch.Tensor + """ + if prior_structure == PriorStructure.SCALAR: prior_prec_init = torch.full((1,), prior_prec_init, device=device) - elif prior_structure == 'layerwise': + elif prior_structure == PriorStructure.LAYERWISE: prior_prec_init = torch.full((n_layers,), prior_prec_init, device=device) - elif prior_structure == 'diag': + elif prior_structure == PriorStructure.DIAG: prior_prec_init = torch.full((n_params,), prior_prec_init, device=device) else: raise ValueError(f'Invalid prior structure {prior_structure}.') return prior_prec_init -def normal_samples(mean, var, n_samples, generator=None): +def normal_samples( + mean: torch.Tensor, + var: torch.Tensor, + n_samples: int, + generator: torch.Generator | None = None, +) -> torch.Tensor: """Produce samples from a batch of Normal distributions either parameterized by a diagonal or full covariance given by `var`. From 158cd2c09604204b1c9c3df37bdc9fcf30639f47 Mon Sep 17 00:00:00 2001 From: Agustinus Kristiadi Date: Thu, 13 Jun 2024 16:34:36 -0400 Subject: [PATCH 9/9] Formatting and linting --- examples/bayesopt_example.py | 2 +- laplace/baselaplace.py | 2 +- laplace/curvature/asdfghjkl.py | 1 + laplace/utils/enums.py | 42 +++++++++++++++++----------------- setup.py | 2 +- 5 files changed, 25 insertions(+), 24 deletions(-) diff --git a/examples/bayesopt_example.py b/examples/bayesopt_example.py index 43989a22..bb907430 100644 --- a/examples/bayesopt_example.py +++ b/examples/bayesopt_example.py @@ -24,7 +24,7 @@ from torch import nn, optim from torch.nn import functional as F -from laplace import BaseLaplace +from laplace import BaseLaplace, Laplace class LaplaceBNN(Model): diff --git a/laplace/baselaplace.py b/laplace/baselaplace.py index 11fdeaff..11279348 100644 --- a/laplace/baselaplace.py +++ b/laplace/baselaplace.py @@ -451,7 +451,7 @@ def optimize_prior_precision( if loss is None: loss = ( - tm.MeanSquaredError(num_outputs=self.n_outputs) + torchmetrics.MeanSquaredError(num_outputs=self.n_outputs) if self.likelihood == "regression" else RunningNLLMetric() ) diff --git a/laplace/curvature/asdfghjkl.py b/laplace/curvature/asdfghjkl.py index 2994ac1a..ad7dacf0 100644 --- a/laplace/curvature/asdfghjkl.py +++ b/laplace/curvature/asdfghjkl.py @@ -18,6 +18,7 @@ from asdfghjkl.gradient import batch_gradient from asdfghjkl.hessian import hessian_eigenvalues, hessian_for_loss from torch import nn +from torch.utils.data import DataLoader from laplace.curvature import CurvatureInterface, EFInterface, GGNInterface from laplace.utils import Kron, _is_batchnorm diff --git a/laplace/utils/enums.py b/laplace/utils/enums.py index 630c8af9..90dabf4d 100644 --- a/laplace/utils/enums.py +++ b/laplace/utils/enums.py @@ -2,42 +2,42 @@ class SubsetOfWeights(str, Enum): - ALL = 'all' - LAST_LAYER = 'last_layer' - SUBNETWORK = 'subnetwork' + ALL = "all" + LAST_LAYER = "last_layer" + SUBNETWORK = "subnetwork" class HessianStructure(str, Enum): - FULL = 'full' - KRON = 'kron' - DIAG = 'diag' - LOWRANK = 'lowrank' + FULL = "full" + KRON = "kron" + DIAG = "diag" + LOWRANK = "lowrank" class Likelihood(str, Enum): - REGRESSION = 'regression' - CLASSIFICATION = 'classification' - REWARD_MODELING = 'reward_modeling' + REGRESSION = "regression" + CLASSIFICATION = "classification" + REWARD_MODELING = "reward_modeling" class PredType(str, Enum): - GLM = 'glm' - NN = 'nn' + GLM = "glm" + NN = "nn" class LinkApprox(str, Enum): - MC = 'mc' - PROBIT = 'probit' - BRIDGE = 'bridge' - BRIDGE_NORM = 'bridge_norm' + MC = "mc" + PROBIT = "probit" + BRIDGE = "bridge" + BRIDGE_NORM = "bridge_norm" class TuningMethod(str, Enum): - MARGLIK = 'marglik' - GRIDSEARCH = 'gridsearch' + MARGLIK = "marglik" + GRIDSEARCH = "gridsearch" class PriorStructure(str, Enum): - SCALAR = 'scalar' - DIAG = 'diag' - LAYERWISE = 'layerwise' + SCALAR = "scalar" + DIAG = "diag" + LAYERWISE = "layerwise" diff --git a/setup.py b/setup.py index 4356bfc5..1abbd068 100644 --- a/setup.py +++ b/setup.py @@ -1,4 +1,4 @@ import setuptools -if __name__ == '__main__': +if __name__ == "__main__": setuptools.setup()