-
Notifications
You must be signed in to change notification settings - Fork 72
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Discussion] Last-layer Laplace for img2img problem #118
Comments
Beyond CIFAR images, though, the GGN computation will also be an issue. E.g. in ImageNet the output dim is Thoughts? from backpack import backpack, extend
from backpack.custom_module.slicing import Slicing
from backpack.extensions import DiagGGNExact
lastlayer = extend(model.last_layer)
lossfunc = extend(nn.MSELoss(reduction='sum'))
chunked_ggn = torch.zeros_like(model.last_layer[0].weight)
for x, _ in trainloader:
x = x.to(DEVICE)
# [N, 784]
reconstruction = lastlayer(model.feature_extractor(x))
for i in range(28):
slicing = (slice(None), slice(i * 28, (i + 1) * 28))
slicing_module = extend(Slicing(slicing))
sliced_reconstruction = slicing_module(reconstruction)
sliced_loss = lossfunc(sliced_reconstruction, x.flatten(1)[slicing])
with backpack(DiagGGNExact(), retain_graph=True):
sliced_loss.backward(retain_graph=True)
chunked_ggn += model.last_layer[0].weight.diag_ggn_exact |
For predictions/reconstructions, my proposal is to use https://github.com/f-dangel/unfoldNd. Using this, then import unfoldNd
prec0 = 1
# Laplace cov
diag_Sigma = 1/(diag_GGN + prec0)
diag_Sigma = diag_Sigma.transpose(0, 1).flatten(1)
# diag_Sigma.shape should be (c_out, c_in*k*k
# )
assert len(diag_Sigma.shape) == 2 and diag_Sigma.shape == (1, 100*3*3)
# Following the last layer of the model
unfold_transpose = unfoldNd.UnfoldTransposeNd(
kernel_size=3, dilation=1, padding=1, stride=3
)
@torch.no_grad()
def reconstruct(x):
phi = model.feature_extractor(x)
# MAP prediction
mean_pred = model.last_layer(phi).reshape(x.shape)
# Variance
J_pred = unfold_transpose(phi)
var_pred = torch.einsum('bij,ki,bij->bkj', J_pred, diag_Sigma, J_pred).reshape(mean_pred.shape)
return mean_pred.cpu().numpy(), var_pred.cpu().numpy()
x_recons = []
for x, _ in testloader:
x = x.cuda()
x_recons.append(reconstruct(x)) |
Full, self-contained prototype here: https://gist.github.com/wiseodd/b8d57fa029f876e00b336b7b3b5052bd |
Hello @wiseodd , have there been any updates on this topic over the last year? I am currently working on Laplace approximations for segmentation tasks and would be very interested. Thank you! |
Unfortunately, there's no update on this. Partly because the loss function usually used in image problems (BCELoss) is not supported by the Hessian backends, and partly because my research agenda is far away from computer vision/graphics. In any case, I can point you to a good direction:
In any case, I hope the references above and the snippets in the previous posts are useful for you. |
This issue should be easier to solve once #145 is merged. Will work on this after the release of milestone 0.2. |
@aleximmer, @runame, @edaxberger: As you know, I'm currently working on last-layer Laplace for img2img tasks, e.g. autoencoder, image segmentation. We can't use the current implementation in this library mainly due the fact that we hard-code the last-layer Jacobian to be the fully-connected Jacobian---see #111 for example. Note: GGN computation using BackPACK & ASDL doesn't seem to pose any problem (#111 for ASDL, below for BackPACK).
So, my current thinking is to simply generalizing the
last_layer_jacobians
inlaplace/curvature/curvature.py
usingfunctorch
, see thepredict
function below.I also propose to only support diagonal LLLA since it's too costly otherwise.
Let me know your thoughts and if I missed anything. Feel free to try out the self-contained script below.
The text was updated successfully, but these errors were encountered: