-
Notifications
You must be signed in to change notification settings - Fork 72
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add fast computation of functional_variance for DiagLLLaplace and KronLLLaplace #145
Conversation
laplace/lllaplace.py
Outdated
@@ -201,6 +208,40 @@ def __init__(self, model, likelihood, sigma_noise=1., prior_precision=1., | |||
def _init_H(self): | |||
self.H = Kron.init_from_model(self.model.last_layer, self._device) | |||
|
|||
def _functional_variance_fast(self, X): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@aleximmer here's the initial implementation for the KronLLLaplace. The test, comparing this f_var
to the f_var = la.posterior_precision.inv_square_form(Js)
fails tho...
Could you please check this? Feel free to propose a more elegant solution since you know more about the implementation of Kron.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the advantage you are trying to achieve with this? Is it faster because you do the damping formulation of the posterio update (eigenvalues + sqrt(delta))? I suppose that approximation makes the test fail potentially.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The idea is to use this identity of the matrix-Normal distribution: (See https://arxiv.org/pdf/2002.10118.pdf, Appendix B.1)
Then, it's much faster than the naive functional_variance
since we don't need to compute the Jacobian which is (batch_size, num_classes, num_params)
. We only need to multiply the inverse-Kronecker factors with the last layer features
Let me know your thoughts on the best way to achieve this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PS. the rationale for the sqrt damping thing: I follow the KFAC-Laplace https://openreview.net/pdf?id=Skdvd2xAZ. Let me know what you think.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I looked a bit more into this now. It's a very important change to have this in since it's significantly faster. The way it should probably be implemented is by using the damping=True/False
flag. The exact inversion with a prior would have to be done using an eigendecomposition, which we are doing right now, and can be avoided when we use the fast predictive by using the damping formulation instead. This also would avoid the recomputation of U and V from the eigendecomposition and we could add the method to matrix.py
. I could look into how to do this best. One thing I am wondering: do you know if it is possible to do this as well for the joint
posterior predictive?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, great. I'll leave it to you then to do this implicitly by using damping. The corresponding test case is in tests/test_lllaplace.py -k "test_functional_variance_fast[KronLLLaplace]"
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think joint predictive can also benefit from this (code for naive). However, more thoughts need to be put here. So let's do this in a separate PR.
PR for #138. Very useful for LLMs or diffusion models or any models with many outputs.
TODO: Implementation for
KronLLLaplace
. I'd like input from @aleximmer, who's the author ofmatrix.py
.