-
Notifications
You must be signed in to change notification settings - Fork 72
/
baselaplace.py
1232 lines (1049 loc) · 49.1 KB
/
baselaplace.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
from math import sqrt, pi, log
import numpy as np
import torch
from torch.nn.utils import parameters_to_vector, vector_to_parameters
from torch.distributions import MultivariateNormal
import warnings
from laplace.utils import (parameters_per_layer, invsqrt_precision,
get_nll, validate, Kron, normal_samples,
fix_prior_prec_structure)
from laplace.curvature import AsdlHessian, CurvlinopsGGN
__all__ = ['BaseLaplace', 'ParametricLaplace',
'FullLaplace', 'KronLaplace', 'DiagLaplace', 'LowRankLaplace']
class BaseLaplace:
"""Baseclass for all Laplace approximations in this library.
Parameters
----------
model : torch.nn.Module
likelihood : {'classification', 'regression'}
determines the log likelihood Hessian approximation
sigma_noise : torch.Tensor or float, default=1
observation noise for the regression setting; must be 1 for classification
prior_precision : torch.Tensor or float, default=1
prior precision of a Gaussian prior (= weight decay);
can be scalar, per-layer, or diagonal in the most general case
prior_mean : torch.Tensor or float, default=0
prior mean of a Gaussian prior, useful for continual learning
temperature : float, default=1
temperature of the likelihood; lower temperature leads to more
concentrated posterior and vice versa.
enable_backprop: bool, default=False
whether to enable backprop to the input `x` through the Laplace predictive.
Useful for e.g. Bayesian optimization.
backend : subclasses of `laplace.curvature.CurvatureInterface`
backend for access to curvature/Hessian approximations. Defaults to CurvlinopsGGN if None.
backend_kwargs : dict, default=None
arguments passed to the backend on initialization, for example to
set the number of MC samples for stochastic approximations.
"""
def __init__(self, model, likelihood, sigma_noise=1., prior_precision=1.,
prior_mean=0., temperature=1., enable_backprop=False,
backend=None, backend_kwargs=None):
if likelihood not in ['classification', 'regression']:
raise ValueError(f'Invalid likelihood type {likelihood}')
self.model = model
self.n_params = len(parameters_to_vector(self.model.parameters()).detach())
self.n_layers = len(list(self.model.parameters()))
self.prior_precision = prior_precision
self.prior_mean = prior_mean
if sigma_noise != 1 and likelihood != 'regression':
raise ValueError('Sigma noise != 1 only available for regression.')
self.likelihood = likelihood
self.sigma_noise = sigma_noise
self.temperature = temperature
self.enable_backprop = enable_backprop
self._backend = None
self._backend_cls = backend if backend is not None else CurvlinopsGGN
self._backend_kwargs = dict() if backend_kwargs is None else backend_kwargs
# log likelihood = g(loss)
self.loss = 0.
self.n_outputs = None
self.n_data = 0
@property
def _device(self):
return next(self.model.parameters()).device
@property
def backend(self):
if self._backend is None:
self._backend = self._backend_cls(self.model, self.likelihood,
**self._backend_kwargs)
return self._backend
def _curv_closure(self, X, y, N):
raise NotImplementedError
def fit(self, train_loader):
raise NotImplementedError
def log_marginal_likelihood(self, prior_precision=None, sigma_noise=None):
raise NotImplementedError
@property
def log_likelihood(self):
"""Compute log likelihood on the training data after `.fit()` has been called.
The log likelihood is computed on-demand based on the loss and, for example,
the observation noise which makes it differentiable in the latter for
iterative updates.
Returns
-------
log_likelihood : torch.Tensor
"""
factor = - self._H_factor
if self.likelihood == 'regression':
# loss used is just MSE, need to add normalizer for gaussian likelihood
c = self.n_data * self.n_outputs * torch.log(self.sigma_noise * sqrt(2 * pi))
return factor * self.loss - c
else:
# for classification Xent == log Cat
return factor * self.loss
def __call__(self, x, pred_type, link_approx, n_samples):
raise NotImplementedError
def predictive(self, x, pred_type, link_approx, n_samples):
return self(x, pred_type, link_approx, n_samples)
def _check_jacobians(self, Js):
if not isinstance(Js, torch.Tensor):
raise ValueError('Jacobians have to be torch.Tensor.')
if not Js.device == self._device:
raise ValueError('Jacobians need to be on the same device as Laplace.')
m, k, p = Js.size()
if p != self.n_params:
raise ValueError('Invalid Jacobians shape for Laplace posterior approx.')
@property
def prior_precision_diag(self):
"""Obtain the diagonal prior precision \\(p_0\\) constructed from either
a scalar, layer-wise, or diagonal prior precision.
Returns
-------
prior_precision_diag : torch.Tensor
"""
if len(self.prior_precision) == 1: # scalar
return self.prior_precision * torch.ones(self.n_params, device=self._device)
elif len(self.prior_precision) == self.n_params: # diagonal
return self.prior_precision
elif len(self.prior_precision) == self.n_layers: # per layer
n_params_per_layer = parameters_per_layer(self.model)
return torch.cat([prior * torch.ones(n_params, device=self._device) for prior, n_params
in zip(self.prior_precision, n_params_per_layer)])
else:
raise ValueError('Mismatch of prior and model. Diagonal, scalar, or per-layer prior.')
@property
def prior_mean(self):
return self._prior_mean
@prior_mean.setter
def prior_mean(self, prior_mean):
if np.isscalar(prior_mean) and np.isreal(prior_mean):
self._prior_mean = torch.tensor(prior_mean, device=self._device)
elif torch.is_tensor(prior_mean):
if prior_mean.ndim == 0:
self._prior_mean = prior_mean.reshape(-1).to(self._device)
elif prior_mean.ndim == 1:
if not len(prior_mean) in [1, self.n_params]:
raise ValueError('Invalid length of prior mean.')
self._prior_mean = prior_mean
else:
raise ValueError('Prior mean has too many dimensions!')
else:
raise ValueError('Invalid argument type of prior mean.')
@property
def prior_precision(self):
return self._prior_precision
@prior_precision.setter
def prior_precision(self, prior_precision):
self._posterior_scale = None
if np.isscalar(prior_precision) and np.isreal(prior_precision):
self._prior_precision = torch.tensor([prior_precision], device=self._device)
elif torch.is_tensor(prior_precision):
if prior_precision.ndim == 0:
# make dimensional
self._prior_precision = prior_precision.reshape(-1).to(self._device)
elif prior_precision.ndim == 1:
if len(prior_precision) not in [1, self.n_layers, self.n_params]:
raise ValueError('Length of prior precision does not align with architecture.')
self._prior_precision = prior_precision.to(self._device)
else:
raise ValueError('Prior precision needs to be at most one-dimensional tensor.')
else:
raise ValueError('Prior precision either scalar or torch.Tensor up to 1-dim.')
def optimize_prior_precision_base(
self,
pred_type,
method='marglik',
n_steps=100,
lr=1e-1,
init_prior_prec=1.,
prior_structure='scalar',
val_loader=None,
loss=get_nll,
log_prior_prec_min=-4,
log_prior_prec_max=4,
grid_size=100,
link_approx='probit',
n_samples=100,
verbose=False,
cv_loss_with_var=False,
):
"""Optimize the prior precision post-hoc using the `method`
specified by the user.
Parameters
----------
pred_type : {'glm', 'nn', 'gp'}, default='glm'
type of posterior predictive, linearized GLM predictive or neural
network sampling predictive or Gaussian Process (GP) inference.
The GLM predictive is consistent with the curvature approximations used here.
method : {'marglik', 'gridsearch'}, default='marglik'
specifies how the prior precision should be optimized.
n_steps : int, default=100
the number of gradient descent steps to take.
lr : float, default=1e-1
the learning rate to use for gradient descent.
init_prior_prec : float or tensor, default=1.0
initial prior precision before the first optimization step.
prior_structure : {'scalar', 'layerwise', 'diag'}, default='scalar'
if init_prior_prec is scalar, the prior precision is optimized with this structure.
otherwise, the structure of init_prior_prec is maintained.
val_loader : torch.data.utils.DataLoader, default=None
DataLoader for the validation set; each iterate is a training batch (X, y).
loss : callable, default=get_nll
loss function to use for gridsearch.
cv_loss_with_var: bool, default=False
if true, `loss` takes three arguments `loss(output_mean, output_var, target)`,
otherwise, `loss` takes two arguments `loss(output_mean, target)`
log_prior_prec_min : float, default=-4
lower bound of gridsearch interval.
log_prior_prec_max : float, default=4
upper bound of gridsearch interval.
grid_size : int, default=100
number of values to consider inside the gridsearch interval.
link_approx : {'mc', 'probit', 'bridge'}, default='probit'
how to approximate the classification link function for the `'glm'`.
For `pred_type='nn'`, only `'mc'` is possible.
n_samples : int, default=100
number of samples for `link_approx='mc'`.
verbose : bool, default=False
if true, the optimized prior precision will be printed
(can be a large tensor if the prior has a diagonal covariance).
"""
if method == 'marglik':
self.prior_precision = init_prior_prec
if len(self.prior_precision) == 1 and prior_structure != 'scalar':
self.prior_precision = fix_prior_prec_structure(
self.prior_precision.item(), prior_structure,
self.n_layers, self.n_params, self._device
)
log_prior_prec = self.prior_precision.log()
log_prior_prec.requires_grad = True
optimizer = torch.optim.Adam([log_prior_prec], lr=lr)
for _ in range(n_steps):
optimizer.zero_grad()
prior_prec = log_prior_prec.exp()
neg_log_marglik = -self.log_marginal_likelihood(prior_precision=prior_prec)
neg_log_marglik.backward()
optimizer.step()
self.prior_precision = log_prior_prec.detach().exp()
elif method == 'gridsearch':
if val_loader is None:
raise ValueError('gridsearch requires a validation set DataLoader')
interval = torch.logspace(
log_prior_prec_min, log_prior_prec_max, grid_size
)
self.prior_precision = self._gridsearch(
loss, interval, val_loader, pred_type=pred_type,
link_approx=link_approx, n_samples=n_samples, loss_with_var=cv_loss_with_var
)
else:
raise ValueError('For now only marglik and gridsearch is implemented.')
if verbose:
print(f'Optimized prior precision is {self.prior_precision}.')
def _gridsearch(self, loss, interval, val_loader, pred_type,
link_approx='probit', n_samples=100, loss_with_var=False):
results = list()
prior_precs = list()
for prior_prec in interval:
self.prior_precision = prior_prec
try:
out_dist, targets = validate(
self, val_loader, pred_type=pred_type,
link_approx=link_approx, n_samples=n_samples
)
if self.likelihood == 'regression':
out_mean, out_var = out_dist
if loss_with_var:
result = loss(out_mean, out_var, targets).item()
else:
result = loss(out_mean, targets).item()
else:
result = loss(out_dist, targets).item()
except RuntimeError:
result = np.inf
results.append(result)
prior_precs.append(prior_prec)
return prior_precs[np.argmin(results)]
@property
def sigma_noise(self):
return self._sigma_noise
@sigma_noise.setter
def sigma_noise(self, sigma_noise):
self._posterior_scale = None
if np.isscalar(sigma_noise) and np.isreal(sigma_noise):
self._sigma_noise = torch.tensor(sigma_noise, device=self._device)
elif torch.is_tensor(sigma_noise):
if sigma_noise.ndim == 0:
self._sigma_noise = sigma_noise.to(self._device)
elif sigma_noise.ndim == 1:
if len(sigma_noise) > 1:
raise ValueError('Only homoscedastic output noise supported.')
self._sigma_noise = sigma_noise[0].to(self._device)
else:
raise ValueError('Sigma noise needs to be scalar or 1-dimensional.')
else:
raise ValueError('Invalid type: sigma noise needs to be torch.Tensor or scalar.')
@property
def _H_factor(self):
sigma2 = self.sigma_noise.square()
return 1 / sigma2 / self.temperature
class ParametricLaplace(BaseLaplace):
"""
Parametric Laplace class.
Subclasses need to specify how the Hessian approximation is initialized,
how to add up curvature over training data, how to sample from the
Laplace approximation, and how to compute the functional variance.
A Laplace approximation is represented by a MAP which is given by the
`model` parameter and a posterior precision or covariance specifying
a Gaussian distribution \\(\\mathcal{N}(\\theta_{MAP}, P^{-1})\\).
The goal of this class is to compute the posterior precision \\(P\\)
which sums as
\\[
P = \\sum_{n=1}^N \\nabla^2_\\theta \\log p(\\mathcal{D}_n \\mid \\theta)
\\vert_{\\theta_{MAP}} + \\nabla^2_\\theta \\log p(\\theta) \\vert_{\\theta_{MAP}}.
\\]
Every subclass implements different approximations to the log likelihood Hessians,
for example, a diagonal one. The prior is assumed to be Gaussian and therefore we have
a simple form for \\(\\nabla^2_\\theta \\log p(\\theta) \\vert_{\\theta_{MAP}} = P_0 \\).
In particular, we assume a scalar, layer-wise, or diagonal prior precision so that in
all cases \\(P_0 = \\textrm{diag}(p_0)\\) and the structure of \\(p_0\\) can be varied.
"""
def __init__(self, model, likelihood, sigma_noise=1., prior_precision=1.,
prior_mean=0., temperature=1., enable_backprop=False,
backend=None, backend_kwargs=None):
super().__init__(model, likelihood, sigma_noise, prior_precision,
prior_mean, temperature, enable_backprop, backend, backend_kwargs)
if not hasattr(self, 'H'):
self._init_H()
# posterior mean/mode
self.mean = self.prior_mean
def _init_H(self):
raise NotImplementedError
def _check_H_init(self):
if self.H is None:
raise AttributeError('Laplace not fitted. Run fit() first.')
def fit(self, train_loader, override=True):
"""Fit the local Laplace approximation at the parameters of the model.
Parameters
----------
train_loader : torch.data.utils.DataLoader
each iterate is a training batch (X, y);
`train_loader.dataset` needs to be set to access \\(N\\), size of the data set
override : bool, default=True
whether to initialize H, loss, and n_data again; setting to False is useful for
online learning settings to accumulate a sequential posterior approximation.
"""
if override:
self._init_H()
self.loss = 0
self.n_data = 0
self.model.eval()
self.mean = parameters_to_vector(self.model.parameters())
if not self.enable_backprop:
self.mean = self.mean.detach()
X, _ = next(iter(train_loader))
with torch.no_grad():
try:
out = self.model(X[:1].to(self._device))
except (TypeError, AttributeError):
out = self.model(X.to(self._device))
self.n_outputs = out.shape[-1]
setattr(self.model, 'output_size', self.n_outputs)
N = len(train_loader.dataset)
for X, y in train_loader:
self.model.zero_grad()
X, y = X.to(self._device), y.to(self._device)
loss_batch, H_batch = self._curv_closure(X, y, N)
self.loss += loss_batch
self.H += H_batch
self.n_data += N
@property
def scatter(self):
"""Computes the _scatter_, a term of the log marginal likelihood that
corresponds to L-2 regularization:
`scatter` = \\((\\theta_{MAP} - \\mu_0)^{T} P_0 (\\theta_{MAP} - \\mu_0) \\).
Returns
-------
[type]
[description]
"""
delta = (self.mean - self.prior_mean)
return (delta * self.prior_precision_diag) @ delta
@property
def log_det_prior_precision(self):
"""Compute log determinant of the prior precision
\\(\\log \\det P_0\\)
Returns
-------
log_det : torch.Tensor
"""
return self.prior_precision_diag.log().sum()
@property
def log_det_posterior_precision(self):
"""Compute log determinant of the posterior precision
\\(\\log \\det P\\) which depends on the subclasses structure
used for the Hessian approximation.
Returns
-------
log_det : torch.Tensor
"""
raise NotImplementedError
@property
def log_det_ratio(self):
"""Compute the log determinant ratio, a part of the log marginal likelihood.
\\[
\\log \\frac{\\det P}{\\det P_0} = \\log \\det P - \\log \\det P_0
\\]
Returns
-------
log_det_ratio : torch.Tensor
"""
return self.log_det_posterior_precision - self.log_det_prior_precision
def square_norm(self, value):
"""Compute the square norm under post. Precision with `value-self.mean` as 𝛥:
\\[
\\Delta^\top P \\Delta
\\]
Returns
-------
square_form
"""
raise NotImplementedError
def log_prob(self, value, normalized=True):
"""Compute the log probability under the (current) Laplace approximation.
Parameters
----------
normalized : bool, default=True
whether to return log of a properly normalized Gaussian or just the
terms that depend on `value`.
Returns
-------
log_prob : torch.Tensor
"""
if not normalized:
return - self.square_norm(value) / 2
log_prob = - self.n_params / 2 * log(2 * pi) + self.log_det_posterior_precision / 2
log_prob -= self.square_norm(value) / 2
return log_prob
def log_marginal_likelihood(self, prior_precision=None, sigma_noise=None):
"""Compute the Laplace approximation to the log marginal likelihood subject
to specific Hessian approximations that subclasses implement.
Requires that the Laplace approximation has been fit before.
The resulting torch.Tensor is differentiable in `prior_precision` and
`sigma_noise` if these have gradients enabled.
By passing `prior_precision` or `sigma_noise`, the current value is
overwritten. This is useful for iterating on the log marginal likelihood.
Parameters
----------
prior_precision : torch.Tensor, optional
prior precision if should be changed from current `prior_precision` value
sigma_noise : [type], optional
observation noise standard deviation if should be changed
Returns
-------
log_marglik : torch.Tensor
"""
# update prior precision (useful when iterating on marglik)
if prior_precision is not None:
self.prior_precision = prior_precision
# update sigma_noise (useful when iterating on marglik)
if sigma_noise is not None:
if self.likelihood != 'regression':
raise ValueError('Can only change sigma_noise for regression.')
self.sigma_noise = sigma_noise
return self.log_likelihood - 0.5 * (self.log_det_ratio + self.scatter)
def __call__(self, x, pred_type='glm', joint=False, link_approx='probit',
n_samples=100, diagonal_output=False, generator=None):
"""Compute the posterior predictive on input data `x`.
Parameters
----------
x : torch.Tensor
`(batch_size, input_shape)`
pred_type : {'glm', 'nn'}, default='glm'
type of posterior predictive, linearized GLM predictive or neural
network sampling predictive. The GLM predictive is consistent with
the curvature approximations used here.
link_approx : {'mc', 'probit', 'bridge', 'bridge_norm'}
how to approximate the classification link function for the `'glm'`.
For `pred_type='nn'`, only 'mc' is possible.
joint : bool
Whether to output a joint predictive distribution in regression with
`pred_type='glm'`. If set to `True`, the predictive distribution
has the same form as GP posterior, i.e. N([f(x1), ...,f(xm)], Cov[f(x1), ..., f(xm)]).
If `False`, then only outputs the marginal predictive distribution.
Only available for regression and GLM predictive.
n_samples : int
number of samples for `link_approx='mc'`.
diagonal_output : bool
whether to use a diagonalized posterior predictive on the outputs.
Only works for `pred_type='glm'` and `link_approx='mc'`.
generator : torch.Generator, optional
random number generator to control the samples (if sampling used).
Returns
-------
predictive: torch.Tensor or Tuple[torch.Tensor]
For `likelihood='classification'`, a torch.Tensor is returned with
a distribution over classes (similar to a Softmax).
For `likelihood='regression'`, a tuple of torch.Tensor is returned
with the mean and the predictive variance.
For `likelihood='regression'` and `joint=True`, a tuple of torch.Tensor
is returned with the mean and the predictive covariance.
"""
if pred_type not in ['glm', 'nn']:
raise ValueError('Only glm and nn supported as prediction types.')
if link_approx not in ['mc', 'probit', 'bridge', 'bridge_norm']:
raise ValueError(f'Unsupported link approximation {link_approx}.')
if pred_type == 'nn' and link_approx != 'mc':
raise ValueError('Only mc link approximation is supported for nn prediction type.')
if generator is not None:
if not isinstance(generator, torch.Generator) or generator.device != x.device:
raise ValueError('Invalid random generator (check type and device).')
if pred_type == 'glm':
f_mu, f_var = self._glm_predictive_distribution(
x, joint=joint and self.likelihood == 'regression'
)
# regression
if self.likelihood == 'regression':
return f_mu, f_var
# classification
if link_approx == 'mc':
return self.predictive_samples(x, pred_type='glm', n_samples=n_samples,
diagonal_output=diagonal_output).mean(dim=0)
elif link_approx == 'probit':
kappa = 1 / torch.sqrt(1. + np.pi / 8 * f_var.diagonal(dim1=1, dim2=2))
return torch.softmax(kappa * f_mu, dim=-1)
elif 'bridge' in link_approx:
# zero mean correction
f_mu -= (f_var.sum(-1) * f_mu.sum(-1).reshape(-1, 1) /
f_var.sum(dim=(1, 2)).reshape(-1, 1))
f_var -= (torch.einsum('bi,bj->bij', f_var.sum(-1), f_var.sum(-2)) /
f_var.sum(dim=(1, 2)).reshape(-1, 1, 1))
# Laplace Bridge
_, K = f_mu.size(0), f_mu.size(-1)
f_var_diag = torch.diagonal(f_var, dim1=1, dim2=2)
# optional: variance correction
if link_approx == 'bridge_norm':
f_var_diag_mean = f_var_diag.mean(dim=1)
f_var_diag_mean /= torch.as_tensor([K/2], device=self._device).sqrt()
f_mu /= f_var_diag_mean.sqrt().unsqueeze(-1)
f_var_diag /= f_var_diag_mean.unsqueeze(-1)
sum_exp = torch.exp(-f_mu).sum(dim=1).unsqueeze(-1)
alpha = (1 - 2/K + f_mu.exp() / K**2 * sum_exp) / f_var_diag
return torch.nan_to_num(alpha / alpha.sum(dim=1).unsqueeze(-1), nan=1.0)
else:
samples = self._nn_predictive_samples(x, n_samples)
if self.likelihood == 'regression':
return samples.mean(dim=0), samples.var(dim=0)
return samples.mean(dim=0)
def predictive_samples(self, x, pred_type='glm', n_samples=100,
diagonal_output=False, generator=None):
"""Sample from the posterior predictive on input data `x`.
Can be used, for example, for Thompson sampling.
Parameters
----------
x : torch.Tensor
input data `(batch_size, input_shape)`
pred_type : {'glm', 'nn'}, default='glm'
type of posterior predictive, linearized GLM predictive or neural
network sampling predictive. The GLM predictive is consistent with
the curvature approximations used here.
n_samples : int
number of samples
diagonal_output : bool
whether to use a diagonalized glm posterior predictive on the outputs.
Only applies when `pred_type='glm'`.
generator : torch.Generator, optional
random number generator to control the samples (if sampling used)
Returns
-------
samples : torch.Tensor
samples `(n_samples, batch_size, output_shape)`
"""
if pred_type not in ['glm', 'nn']:
raise ValueError('Only glm and nn supported as prediction types.')
if pred_type == 'glm':
f_mu, f_var = self._glm_predictive_distribution(x)
assert f_var.shape == torch.Size([f_mu.shape[0], f_mu.shape[1], f_mu.shape[1]])
if diagonal_output:
f_var = torch.diagonal(f_var, dim1=1, dim2=2)
f_samples = normal_samples(f_mu, f_var, n_samples, generator)
if self.likelihood == 'regression':
return f_samples
return torch.softmax(f_samples, dim=-1)
else: # 'nn'
return self._nn_predictive_samples(x, n_samples)
@torch.enable_grad()
def _glm_predictive_distribution(self, X, joint=False):
Js, f_mu = self.backend.jacobians(X, enable_backprop=self.enable_backprop)
if joint:
f_mu = f_mu.flatten() # (batch*out)
f_var = self.functional_covariance(Js) # (batch*out, batch*out)
else:
f_var = self.functional_variance(Js)
return (f_mu.detach(), f_var.detach()) if not self.enable_backprop else (f_mu, f_var)
def _nn_predictive_samples(self, X, n_samples=100):
fs = list()
for sample in self.sample(n_samples):
vector_to_parameters(sample, self.model.parameters())
f = self.model(X.to(self._device))
fs.append(f.detach() if not self.enable_backprop else f)
vector_to_parameters(self.mean, self.model.parameters())
fs = torch.stack(fs)
if self.likelihood == 'classification':
fs = torch.softmax(fs, dim=-1)
return fs
def functional_variance(self, Jacs):
"""Compute functional variance for the `'glm'` predictive:
`f_var[i] = Jacs[i] @ P.inv() @ Jacs[i].T`, which is a output x output
predictive covariance matrix.
Mathematically, we have for a single Jacobian
\\(\\mathcal{J} = \\nabla_\\theta f(x;\\theta)\\vert_{\\theta_{MAP}}\\)
the output covariance matrix
\\( \\mathcal{J} P^{-1} \\mathcal{J}^T \\).
Parameters
----------
Jacs : torch.Tensor
Jacobians of model output wrt parameters
`(batch, outputs, parameters)`
Returns
-------
f_var : torch.Tensor
output covariance `(batch, outputs, outputs)`
"""
raise NotImplementedError
def functional_covariance(self, Jacs):
"""Compute functional covariance for the `'glm'` predictive:
`f_cov = Jacs @ P.inv() @ Jacs.T`, which is a batch*output x batch*output
predictive covariance matrix.
This emulates the GP posterior covariance N([f(x1), ...,f(xm)], Cov[f(x1), ..., f(xm)]).
Useful for joint predictions, such as in batched Bayesian optimization.
Parameters
----------
Jacs : torch.Tensor
Jacobians of model output wrt parameters
`(batch*outputs, parameters)`
Returns
-------
f_cov : torch.Tensor
output covariance `(batch*outputs, batch*outputs)`
"""
raise NotImplementedError
def sample(self, n_samples=100):
"""Sample from the Laplace posterior approximation, i.e.,
\\( \\theta \\sim \\mathcal{N}(\\theta_{MAP}, P^{-1})\\).
Parameters
----------
n_samples : int, default=100
number of samples
"""
raise NotImplementedError
def optimize_prior_precision(self, method='marglik', pred_type='glm', n_steps=100, lr=1e-1,
init_prior_prec=1., prior_structure='scalar', val_loader=None,
loss=get_nll, log_prior_prec_min=-4, log_prior_prec_max=4,
grid_size=100, link_approx='probit', n_samples=100, verbose=False,
cv_loss_with_var=False):
assert pred_type in ['glm', 'nn']
self.optimize_prior_precision_base(pred_type, method, n_steps, lr,
init_prior_prec, prior_structure,
val_loader, loss, log_prior_prec_min,
log_prior_prec_max, grid_size,
link_approx, n_samples, verbose,
cv_loss_with_var)
@property
def posterior_precision(self):
"""Compute or return the posterior precision \\(P\\).
Returns
-------
posterior_prec : torch.Tensor
"""
raise NotImplementedError
def state_dict(self) -> dict:
self._check_H_init()
state_dict = {
'mean': self.mean,
'H': self.H,
'loss': self.loss,
'prior_mean': self.prior_mean,
'prior_precision': self.prior_precision,
'sigma_noise': self.sigma_noise,
'n_data': self.n_data,
'n_outputs': self.n_outputs,
'likelihood': self.likelihood,
'temperature': self.temperature,
'enable_backprop': self.enable_backprop,
'cls_name': self.__class__.__name__
}
return state_dict
def load_state_dict(self, state_dict: dict):
# Dealbreaker errors
if self.__class__.__name__ != state_dict['cls_name']:
raise ValueError(
'Loading a wrong Laplace type. Make sure `subset_of_weights` and'
+ ' `hessian_structure` are correct!'
)
if self.n_params is not None and len(state_dict['mean']) != self.n_params:
raise ValueError(
'Attempting to load Laplace with different number of parameters than the model.'
+ ' Make sure that you use the same `subset_of_weights` value and the same `.requires_grad`'
+ ' switch on `model.parameters()`.'
)
if self.likelihood != state_dict['likelihood']:
raise ValueError('Different likelihoods detected!')
# Ignorable warnings
if self.prior_mean is None and state_dict['prior_mean'] is not None:
warnings.warn('Loading non-`None` prior mean into a `None` prior mean. You might get wrong results.')
if self.temperature != state_dict['temperature']:
warnings.warn('Different `temperature` parameters detected. Some calculation might be off!')
if self.enable_backprop != state_dict['enable_backprop']:
warnings.warn(
'Different `enable_backprop` values. You might encounter error when differentiating'
+ ' the predictive mean and variance.'
)
self.mean = state_dict['mean']
self.H = state_dict['H']
self.loss = state_dict['loss']
self.prior_mean = state_dict['prior_mean']
self.prior_precision = state_dict['prior_precision']
self.sigma_noise = state_dict['sigma_noise']
self.n_data = state_dict['n_data']
self.n_outputs = state_dict['n_outputs']
setattr(self.model, 'output_size', self.n_outputs)
self.likelihood = state_dict['likelihood']
self.temperature = state_dict['temperature']
self.enable_backprop = state_dict['enable_backprop']
class FullLaplace(ParametricLaplace):
"""Laplace approximation with full, i.e., dense, log likelihood Hessian approximation
and hence posterior precision. Based on the chosen `backend` parameter, the full
approximation can be, for example, a generalized Gauss-Newton matrix.
Mathematically, we have \\(P \\in \\mathbb{R}^{P \\times P}\\).
See `BaseLaplace` for the full interface.
"""
# key to map to correct subclass of BaseLaplace, (subset of weights, Hessian structure)
_key = ('all', 'full')
def __init__(self, model, likelihood, sigma_noise=1., prior_precision=1.,
prior_mean=0., temperature=1., enable_backprop=False, backend=None, backend_kwargs=None):
super().__init__(model, likelihood, sigma_noise, prior_precision,
prior_mean, temperature, enable_backprop, backend, backend_kwargs)
self._posterior_scale = None
def _init_H(self):
self.H = torch.zeros(self.n_params, self.n_params, device=self._device)
def _curv_closure(self, X, y, N):
return self.backend.full(X, y, N=N)
def fit(self, train_loader, override=True):
self._posterior_scale = None
return super().fit(train_loader, override=override)
def _compute_scale(self):
self._posterior_scale = invsqrt_precision(self.posterior_precision)
@property
def posterior_scale(self):
"""Posterior scale (square root of the covariance), i.e.,
\\(P^{-\\frac{1}{2}}\\).
Returns
-------
scale : torch.tensor
`(parameters, parameters)`
"""
if self._posterior_scale is None:
self._compute_scale()
return self._posterior_scale
@property
def posterior_covariance(self):
"""Posterior covariance, i.e., \\(P^{-1}\\).
Returns
-------
covariance : torch.tensor
`(parameters, parameters)`
"""
scale = self.posterior_scale
return scale @ scale.T
@property
def posterior_precision(self):
"""Posterior precision \\(P\\).
Returns
-------
precision : torch.tensor
`(parameters, parameters)`
"""
self._check_H_init()
return self._H_factor * self.H + torch.diag(self.prior_precision_diag)
@property
def log_det_posterior_precision(self):
return self.posterior_precision.logdet()
def square_norm(self, value):
delta = value - self.mean
return delta @ self.posterior_precision @ delta
def functional_variance(self, Js):
return torch.einsum('ncp,pq,nkq->nck', Js, self.posterior_covariance, Js)
def functional_covariance(self, Js):
n_batch, n_outs, n_params = Js.shape
Js = Js.reshape(n_batch*n_outs, n_params)
return torch.einsum('np,pq,mq->nm', Js, self.posterior_covariance, Js)
def sample(self, n_samples=100):
dist = MultivariateNormal(loc=self.mean, scale_tril=self.posterior_scale)
return dist.sample((n_samples,))
class KronLaplace(ParametricLaplace):
"""Laplace approximation with Kronecker factored log likelihood Hessian approximation
and hence posterior precision.
Mathematically, we have for each parameter group, e.g., torch.nn.Module,
that \\P\\approx Q \\otimes H\\.
See `BaseLaplace` for the full interface and see
`laplace.utils.matrix.Kron` and `laplace.utils.matrix.KronDecomposed` for the structure of
the Kronecker factors. `Kron` is used to aggregate factors by summing up and
`KronDecomposed` is used to add the prior, a Hessian factor (e.g. temperature),
and computing posterior covariances, marginal likelihood, etc.
Damping can be enabled by setting `damping=True`.
"""
# key to map to correct subclass of BaseLaplace, (subset of weights, Hessian structure)
_key = ('all', 'kron')
def __init__(self, model, likelihood, sigma_noise=1., prior_precision=1.,
prior_mean=0., temperature=1., enable_backprop=False, backend=None,
damping=False, **backend_kwargs):
self.damping = damping
self.H_facs = None
super().__init__(model, likelihood, sigma_noise, prior_precision,
prior_mean, temperature, enable_backprop, backend, **backend_kwargs)
def _init_H(self):
self.H = Kron.init_from_model(self.model, self._device)
def _curv_closure(self, X, y, N):
return self.backend.kron(X, y, N=N)
@staticmethod
def _rescale_factors(kron, factor):
for F in kron.kfacs:
if len(F) == 2:
F[1] *= factor
return kron
def fit(self, train_loader, override=True):
if override:
self.H_facs = None
if self.H_facs is not None:
n_data_old = self.n_data
n_data_new = len(train_loader.dataset)
self._init_H() # re-init H non-decomposed
# discount previous Kronecker factors to sum up properly together with new ones
self.H_facs = self._rescale_factors(self.H_facs, n_data_old / (n_data_old + n_data_new))
super().fit(train_loader, override=override)
if self.H_facs is None:
self.H_facs = self.H
else:
# discount new factors that were computed assuming N = n_data_new
self.H = self._rescale_factors(self.H, n_data_new / (n_data_new + n_data_old))
self.H_facs += self.H
# Decompose to self.H for all required quantities but keep H_facs for further inference
self.H = self.H_facs.decompose(damping=self.damping)
@property
def posterior_precision(self):
"""Kronecker factored Posterior precision \\(P\\).
Returns
-------
precision : `laplace.utils.matrix.KronDecomposed`
"""
self._check_H_init()
return self.H * self._H_factor + self.prior_precision
@property
def log_det_posterior_precision(self):
if type(self.H) is Kron: # Fall back to diag prior
return self.prior_precision_diag.log().sum()
return self.posterior_precision.logdet()
def square_norm(self, value):
delta = value - self.mean
if type(self.H) is Kron: # fall back to prior
return (delta * self.prior_precision_diag) @ delta