Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Discussion] Last-layer Laplace for img2img problem #118

Open
wiseodd opened this issue Nov 11, 2022 · 6 comments
Open

[Discussion] Last-layer Laplace for img2img problem #118

wiseodd opened this issue Nov 11, 2022 · 6 comments
Milestone

Comments

@wiseodd
Copy link
Collaborator

wiseodd commented Nov 11, 2022

@aleximmer, @runame, @edaxberger: As you know, I'm currently working on last-layer Laplace for img2img tasks, e.g. autoencoder, image segmentation. We can't use the current implementation in this library mainly due the fact that we hard-code the last-layer Jacobian to be the fully-connected Jacobian---see #111 for example. Note: GGN computation using BackPACK & ASDL doesn't seem to pose any problem (#111 for ASDL, below for BackPACK).

So, my current thinking is to simply generalizing the last_layer_jacobians in laplace/curvature/curvature.py using functorch, see the predict function below.
I also propose to only support diagonal LLLA since it's too costly otherwise.

Let me know your thoughts and if I missed anything. Feel free to try out the self-contained script below.

import torch
import torch.nn.functional as F
import torchvision as tv
import torchvision.transforms as transforms
from torch import nn, optim
from functorch import jacrev
from backpack import backpack, extend
from backpack.extensions import DiagGGNExact


DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

transform = transforms.Compose([transforms.ToTensor()])
train_batch_size = 32
test_batch_size = 10

# TODO Replace root path
trainset = tv.datasets.CIFAR10(
    root='~/Datasets', train=True, transform=transform, download=False
)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=train_batch_size, shuffle=True)

# TODO Replace root path
testset = tv.datasets.CIFAR10(
    root='~/Datasets', train=False, transform=transform, download=False
)
testloader = torch.utils.data.DataLoader(testset, batch_size=test_batch_size, shuffle=False)


class Model(nn.Module):
    def __init__(self):
        super().__init__()

        self.feature_extractor = nn.Sequential(
            nn.Conv2d(3, 100, kernel_size=5),
            nn.Sigmoid(),
        )
        self.last_layer = nn.Sequential(
            nn.ConvTranspose2d(100, 3, kernel_size=5, bias=False),
            nn.Flatten(1)
        )

    def forward(self, x):
        x = self.feature_extractor(x)
        return self.last_layer(x)


model = Model().to(DEVICE)

for p in model.feature_extractor.parameters():
    p.requires_grad = False

lastlayer = extend(model.last_layer)
lossfunc = extend(nn.MSELoss(reduction='sum'))

for x, _ in trainloader:
    x = x.to(DEVICE)

    # (n_data, n_channel*width*height)
    reconstruction = lastlayer(model.feature_extractor(x))

    loss = lossfunc(reconstruction, x.flatten(1))
    with backpack(DiagGGNExact()):
        loss.backward()

    # (n_ll_params,)
    GGN = model.last_layer[0].weight.diag_ggn_exact

# Covariance
prec0 = 1
Sigma = torch.linalg.inv(prec0 + GGN)


@torch.no_grad()
def predict(x):
    phi = model.feature_extractor(x)

    # MAP prediction
    mean_pred = model.last_layer(phi).reshape(x.shape)

    # Variance
    def f(feat, w):
        """ w is vectorized """
        SHAPE = (100, 3, 5, 5)
        return F.conv_transpose2d(feat, w.reshape(SHAPE))

    jac = jacrev(f, argnums=1)

    # (n_data, n_channel, width, height, n_params)
    J_pred = jac(phi, model.last_layer[0].weight.flatten())

    # (n_data, n_channel, width, height)
    var_pred = torch.einsum('nabci,i,nabci->nabc', J_pred, Sigma, J_pred)

    return mean_pred.cpu().numpy(), var_pred.cpu().numpy()


for x, _ in testloader:
    predict(x.to(DEVICE))
    break
@wiseodd
Copy link
Collaborator Author

wiseodd commented Nov 11, 2022

Beyond CIFAR images, though, the GGN computation will also be an issue. E.g. in ImageNet the output dim is 224*224*3=150528, much larger than 3024 of CIFAR. I talked Felix about this and one solution is to exploit the per-pixel nature of the loss and compute the minibatch-GGN in chunk in terms of output dimension, see example for MNIST below.

Thoughts?

from backpack import backpack, extend
from backpack.custom_module.slicing import Slicing
from backpack.extensions import DiagGGNExactlastlayer = extend(model.last_layer)
lossfunc = extend(nn.MSELoss(reduction='sum'))
​​
chunked_ggn = torch.zeros_like(model.last_layer[0].weight)
​
for x, _ in trainloader:
    x = x.to(DEVICE)
​
    # [N, 784]
    reconstruction = lastlayer(model.feature_extractor(x))
​​
    for i in range(28):
        slicing = (slice(None), slice(i * 28, (i + 1) * 28))
        slicing_module = extend(Slicing(slicing))
​
        sliced_reconstruction = slicing_module(reconstruction)
        sliced_loss = lossfunc(sliced_reconstruction, x.flatten(1)[slicing])
​
        with backpack(DiagGGNExact(), retain_graph=True):
            sliced_loss.backward(retain_graph=True)
            chunked_ggn += model.last_layer[0].weight.diag_ggn_exact

@wiseodd
Copy link
Collaborator Author

wiseodd commented Nov 23, 2022

For predictions/reconstructions, my proposal is to use https://github.com/f-dangel/unfoldNd. Using this, then conv_transpose2d is just a matrix multiplication under the original weights/filters, implying that we can easily obtain $p(f(x))$.

import unfoldNd

prec0 = 1

# Laplace cov
diag_Sigma = 1/(diag_GGN + prec0)
diag_Sigma = diag_Sigma.transpose(0, 1).flatten(1)

# diag_Sigma.shape should be (c_out, c_in*k*k
# )
assert len(diag_Sigma.shape) == 2 and diag_Sigma.shape == (1, 100*3*3)

# Following the last layer of the model
unfold_transpose = unfoldNd.UnfoldTransposeNd(
    kernel_size=3, dilation=1, padding=1, stride=3
)

@torch.no_grad()
def reconstruct(x):
    phi = model.feature_extractor(x)
    
    # MAP prediction
    mean_pred = model.last_layer(phi).reshape(x.shape)

    # Variance
    J_pred = unfold_transpose(phi)
    var_pred = torch.einsum('bij,ki,bij->bkj', J_pred, diag_Sigma, J_pred).reshape(mean_pred.shape)

    return mean_pred.cpu().numpy(), var_pred.cpu().numpy()

x_recons = []

for x, _ in testloader:
    x = x.cuda()
    x_recons.append(reconstruct(x))

@wiseodd
Copy link
Collaborator Author

wiseodd commented Nov 23, 2022

Full, self-contained prototype here: https://gist.github.com/wiseodd/b8d57fa029f876e00b336b7b3b5052bd

@JRopes
Copy link

JRopes commented Jan 30, 2024

Hello @wiseodd , have there been any updates on this topic over the last year? I am currently working on Laplace approximations for segmentation tasks and would be very interested. Thank you!

@wiseodd
Copy link
Collaborator Author

wiseodd commented Jan 30, 2024

Unfortunately, there's no update on this. Partly because the loss function usually used in image problems (BCELoss) is not supported by the Hessian backends, and partly because my research agenda is far away from computer vision/graphics.

In any case, I can point you to a good direction:

In any case, I hope the references above and the snippets in the previous posts are useful for you.

@wiseodd
Copy link
Collaborator Author

wiseodd commented Mar 12, 2024

This issue should be easier to solve once #145 is merged. Will work on this after the release of milestone 0.2.

@wiseodd wiseodd added this to the 0.3 milestone Jul 8, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants