Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Help for Running Laplace on Image Segmentation Tasks #111

Open
SouLeo opened this issue Aug 20, 2022 · 4 comments
Open

Help for Running Laplace on Image Segmentation Tasks #111

SouLeo opened this issue Aug 20, 2022 · 4 comments
Milestone

Comments

@SouLeo
Copy link

SouLeo commented Aug 20, 2022

Hello,

I am using a U-Net augmentation (specifically: https://github.com/juntang-zhuang/LadderNet) to perform segmentation of hands. To be specific, I am classifying each pixel of an image to one of five classes (no hand, my right hand, my left hand, your right hand, your left hand.)

This requires my prob shape (in fisher.py: 446) to be [batch_size, img_h, img_w, n_classes] -> [8,32,32,5] (snippet below)

def __fisher_exact(loss_and_backward, model, probs):
    _, n_classes = probs.shape  

Because of this dimensionality, this line of code fails. I assume it's because it expects the probs.shape tuple to be (img_as_tensor, label_as_int) per the CIFAR example: https://github.com/AlexImmer/Laplace/blob/main/examples/calibration_example.py) where the CIFAR dataset object returns a a tuple of (#_examples, (img_as_tensor: [3,32,32], label_as_int: 0,1,2,3,etc.).

I can always reshape my training data to be a tuple of that format, but because I am classifying by pixel, my associated label would not be a single integer. It would have to be in the same shape as the image tensor [32,32].

So I'm asking this question to the community to see if anyone has attempted this kind of segmentation task using the laplace-torch package before I try to force a solution within the asdfghjkl/fisher.py file.

@SouLeo
Copy link
Author

SouLeo commented Aug 22, 2022

Update: I have gotten the following lines of code to run by modifying my LadderNet model to follow this architecture:

LadderNetv6(
  (initial_block): Initial_LadderBlock(
    (inconv): Conv2d(1, 10, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (down_module_list): ModuleList(
      (0): BasicBlock(
        (conv1): Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (relu): ReLU()
        (drop): Dropout2d(p=0.25, inplace=False)
      )
      (1): BasicBlock(
        (conv1): Conv2d(20, 20, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (relu): ReLU()
        (drop): Dropout2d(p=0.25, inplace=False)
      )
      (2): BasicBlock(
        (conv1): Conv2d(40, 40, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (relu): ReLU()
        (drop): Dropout2d(p=0.25, inplace=False)
      )
      (3): BasicBlock(
        (conv1): Conv2d(80, 80, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (relu): ReLU()
        (drop): Dropout2d(p=0.25, inplace=False)
      )
    )
    (down_conv_list): ModuleList(
      (0): Conv2d(10, 20, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (1): Conv2d(20, 40, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (2): Conv2d(40, 80, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (3): Conv2d(80, 160, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    )
    (bottom): BasicBlock(
      (conv1): Conv2d(160, 160, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (relu): ReLU()
      (drop): Dropout2d(p=0.25, inplace=False)
    )
    (up_conv_list): ModuleList(
      (0): ConvTranspose2d(160, 80, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))
      (1): ConvTranspose2d(80, 40, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))
      (2): ConvTranspose2d(40, 20, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))
      (3): ConvTranspose2d(20, 10, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))
    )
    (up_dense_list): ModuleList(
      (0): BasicBlock(
        (conv1): Conv2d(80, 80, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (relu): ReLU()
        (drop): Dropout2d(p=0.25, inplace=False)
      )
      (1): BasicBlock(
        (conv1): Conv2d(40, 40, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (relu): ReLU()
        (drop): Dropout2d(p=0.25, inplace=False)
      )
      (2): BasicBlock(
        (conv1): Conv2d(20, 20, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (relu): ReLU()
        (drop): Dropout2d(p=0.25, inplace=False)
      )
      (3): BasicBlock(
        (conv1): Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (relu): ReLU()
        (drop): Dropout2d(p=0.25, inplace=False)
      )
    )
  )
  (final_block): Final_LadderBlock(
    (block): LadderBlock(
      (inconv): BasicBlock(
        (conv1): Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (relu): ReLU()
        (drop): Dropout2d(p=0.25, inplace=False)
      )
      (down_module_list): ModuleList(
        (0): BasicBlock(
          (conv1): Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (relu): ReLU()
          (drop): Dropout2d(p=0.25, inplace=False)
        )
        (1): BasicBlock(
          (conv1): Conv2d(20, 20, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (relu): ReLU()
          (drop): Dropout2d(p=0.25, inplace=False)
        )
        (2): BasicBlock(
          (conv1): Conv2d(40, 40, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (relu): ReLU()
          (drop): Dropout2d(p=0.25, inplace=False)
        )
        (3): BasicBlock(
          (conv1): Conv2d(80, 80, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (relu): ReLU()
          (drop): Dropout2d(p=0.25, inplace=False)
        )
      )
      (down_conv_list): ModuleList(
        (0): Conv2d(10, 20, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
        (1): Conv2d(20, 40, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
        (2): Conv2d(40, 80, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
        (3): Conv2d(80, 160, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      )
      (bottom): BasicBlock(
        (conv1): Conv2d(160, 160, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (relu): ReLU()
        (drop): Dropout2d(p=0.25, inplace=False)
      )
      (up_conv_list): ModuleList(
        (0): ConvTranspose2d(160, 80, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))
        (1): ConvTranspose2d(80, 40, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))
        (2): ConvTranspose2d(40, 20, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))
        (3): ConvTranspose2d(20, 10, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))
      )
      (up_dense_list): ModuleList(
        (0): BasicBlock(
          (conv1): Conv2d(80, 80, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (relu): ReLU()
          (drop): Dropout2d(p=0.25, inplace=False)
        )
        (1): BasicBlock(
          (conv1): Conv2d(40, 40, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (relu): ReLU()
          (drop): Dropout2d(p=0.25, inplace=False)
        )
        (2): BasicBlock(
          (conv1): Conv2d(20, 20, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (relu): ReLU()
          (drop): Dropout2d(p=0.25, inplace=False)
        )
        (3): BasicBlock(
          (conv1): Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (relu): ReLU()
          (drop): Dropout2d(p=0.25, inplace=False)
        )
      )
    )
  )
  (final_fc): Final_Layer(
    (layer): Linear(in_features=10, out_features=5, bias=False)
  )
)
la = Laplace(net, 'classification', subset_of_weights='last_layer', hessian_structure='kron', backend=AsdlGGN)

la.fit(train_loader)

la.optimize_prior_precision(method='marglik')

However, despite successfully training my model, running predictions on my model, running .fit() for post-hoc laplace on my model, I cannot run the following code without error:

@torch.no_grad()

def predict(dataloader, model, laplace=False):
"""
this code was taken from the calibration_example.py
"""

    py = []

    for x, _ in dataloader:

        if laplace:

           py.append(model(x.cuda()))

        else:

            py.append(torch.softmax(model(x.cuda()), dim=-1))

    return torch.cat(py).cpu





probs_laplace = predict(test_loader, la, laplace=True)  # this line fails

The following is the trace when running the predict() code:

RuntimeError                              Traceback (most recent call last)
Input In [12], in <cell line: 65>()
     50 # # TODO: specify val_loader
     51 # # From API docs page
     52 # # post-hoc update:
   (...)
     61 
     62 # From GitHub CIFAR example:
     63 la.optimize_prior_precision(method='marglik') #, val_loader=test_loader_copy)
---> 65 probs_laplace = predict(test_loader_copy, la, laplace=True) # in future, replace w/test set: test_loader
     67 acc_laplace = (probs_laplace.argmax(-1) == targets).float().mean()
     69 # ece_laplace = ECE(bins=15).measure(probs_laplace.numpy(), targets.numpy())
     70 
     71 # nll_laplace = -dists.Categorical(probs_laplace).log_prob(targets).mean()
     72 
     73 
     74 # print(f'[Laplace] Acc.: {acc_laplace:.1%}; ECE: {ece_laplace:.1%}; NLL: {nll_laplace:.3}')')

File ~/.conda/envs/hrc-laddernet/lib/python3.10/site-packages/torch/autograd/grad_mode.py:27, in _DecoratorContextManager.__call__.<locals>.decorate_context(*args, **kwargs)
     24 @functools.wraps(func)
     25 def decorate_context(*args, **kwargs):
     26     with self.clone():
---> 27         return func(*args, **kwargs)

Input In [12], in predict(dataloader, model, laplace)
     12 print(x.shape)
     13 if laplace:
---> 14     py.append(model(x.cuda()))
     15 else:
     16     py.append(torch.softmax(model(x.cuda()), dim=-1))

File ~/.conda/envs/hrc-laddernet/lib/python3.10/site-packages/laplace/baselaplace.py:536, in ParametricLaplace.__call__(self, x, pred_type, link_approx, n_samples)
    533     raise ValueError(f'Unsupported link approximation {link_approx}.')
    535 if pred_type == 'glm':
--> 536     f_mu, f_var = self._glm_predictive_distribution(x)
    537     # regression
    538     if self.likelihood == 'regression':

File ~/.conda/envs/hrc-laddernet/lib/python3.10/site-packages/laplace/lllaplace.py:124, in LLLaplace._glm_predictive_distribution(self, X)
    122 print(Js.shape)
    123 print(f_mu.shape)
--> 124 f_var = self.functional_variance(Js)
    125 print('shape of f_var, which is variance(Js)')
    126 print(f_var.shape)

File ~/.conda/envs/hrc-laddernet/lib/python3.10/site-packages/laplace/baselaplace.py:841, in KronLaplace.functional_variance(self, Js)
    840 def functional_variance(self, Js):
--> 841     return self.posterior_precision.inv_square_form(Js)

File ~/.conda/envs/hrc-laddernet/lib/python3.10/site-packages/laplace/utils/matrix.py:411, in KronDecomposed.inv_square_form(self, W)
    409 print('from laplace/utils inv_square_form')
    410 print(W.shape)
--> 411 SW = self._bmm(W, exponent=-1)
    412 return torch.bmm(W, SW.transpose(1, 2))

File ~/.conda/envs/hrc-laddernet/lib/python3.10/site-packages/laplace/utils/matrix.py:404, in KronDecomposed._bmm(self, W, exponent)
    402 print('length of SW')
    403 print(len(SW))
--> 404 SW = torch.cat(SW, dim=1).reshape(B, K, P)
    405 return SW

RuntimeError: shape '[1024, 32, 320]' is invalid for input of size 1638400

I'm somewhat at a loss because I did not expect this model to fail if the .fit() function performed properly. Any help would be greatly appreciated.

@wiseodd
Copy link
Collaborator

wiseodd commented Aug 22, 2022

Hi @SouLeo, multi-output models are indeed still in our backlog. For now, I think this paper https://arxiv.org/abs/2206.15078 along with the code can be very useful for you https://github.com/FrederikWarburg/LaplaceAE.

@SouLeo
Copy link
Author

SouLeo commented Aug 22, 2022

Oh, I see. So this framework is not suited for multiclass labels for a single image?

I'll review the items you have linked. Thank you very much!

I am still somewhat confused that I was able to perform the following lines of code without error:

la = Laplace(net, 'classification', subset_of_weights='last_layer', hessian_structure='kron', backend=AsdlGGN)

la.fit(train_loader)

la.optimize_prior_precision(method='marglik')

but cannot run the model prediction. Do you have any thoughts on this?

@jerofad
Copy link

jerofad commented Apr 10, 2023

@SouLeo Were you able to use this library successfully for image segmentation?

@wiseodd wiseodd added this to the 0.3 milestone Jul 8, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants