Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Distinction between multi_class_dice_loss & avg_dice_loss #27

Open
gattia opened this issue Apr 25, 2022 · 3 comments
Open

Distinction between multi_class_dice_loss & avg_dice_loss #27

gattia opened this issue Apr 25, 2022 · 3 comments

Comments

@gattia
Copy link
Collaborator

gattia commented Apr 25, 2022

I was looking through the code and thinking I should probably use multi_class_dice_loss with softmax after a quick look but then it wasn't an option - so, I added it. But, then I wondered why it wasn't an option so I did more comparing of multi_class_dice_loss and avg_dice_loss. Below I summarize what I'm interpreting the main difference between them to be.

My general take is that I'd be inclined to use avg_dice_loss because it would be less likely to wash out a few bad results on segments with only a few pixels (e.g., if one image in a batch has thick healthy patellar cartilage and the other has barely any, multi_class would be skewed more towards the healthy, I think?). I'm curious on the original rationale/distinction between them and if I might be missing something.

My take is:
multi_class_dice_loss

  • Flattens the batch of images into shape [batch_size* product(image_dims), n_classes].
  • Use this to calculate a dice loss per class which effectively treats all images in the batch as one image.
    • Shape after dice loss = [n_classes,]
  • Average the dice loss from the multiple classes to get a single value

avg_dice_loss

  • Reshape the batch of images into shape [batch_size, product(image_dims), n_classes].
  • Calcualate dice loss per image & per class
    • Shape after dice loss = [batch_size, n_classes]
  • Average the dice losses
@ad12
Copy link
Owner

ad12 commented Apr 25, 2022

A bunch of these were legacy names that we preserved because a few different projects were using them. I'd stick to using medsegpy.loss.DiceLoss, which defaults to the same functionality as avg_dice_loss but will be more stable long term

@gattia
Copy link
Collaborator Author

gattia commented Apr 25, 2022

Perfect - for anyone else reading this, medsegpy.loss.DiceLoss is currently specified as: LOSS: ("avg_dice_no_reduce", "sigmoid") in the yaml here and here

More versions of DiceLoss (different reductions weightings, etc.) can be added in the code near those references.

I want to try working with softmax, so I will add a softmax option that works with DiceLoss and make a pull request eventually.

@ad12
Copy link
Owner

ad12 commented Apr 26, 2022

sounds good - i think long term we should have a way of configuring this dynamically. maybe in the config the user can specify the loss class they want to instantiate and with what parameters

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants