Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Trainers: support binary, multiclass, and multilabel tasks #2219

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

adamjstewart
Copy link
Collaborator

@adamjstewart adamjstewart commented Aug 12, 2024

Instead of having separate trainers for binary, multiclass, and multilabel, let's create a single trainer that can handle all 3.

This applies to both Classification and Semantic Segmentation but not to our other trainers.

Closes #2205 @robmarkcole
Closes #245 @calebrob6

@github-actions github-actions bot added testing Continuous integration testing trainers PyTorch Lightning trainers labels Aug 12, 2024
num_classes: int = 1000,
task: str = 'multiclass',
num_classes: int | None = None,
num_labels: int | None = None,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The defaults here match torchmetrics. If task='multiclass', only num_classes is used. If task='multilabel', only num_labels is used. If task='binary', both are ignored. Honestly, we could have a single num_classes if we want and simply use it for both.

@@ -266,147 +262,3 @@ def predict_step(
x = batch['image']
y_hat: Tensor = self(x).softmax(dim=-1)
return y_hat


class MultiLabelClassificationTask(ClassificationTask):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The right thing to do would be to deprecate this first and remove it in 0.7.0. Not sure how widely used it is. Deprecation is kind of annoying because you need to change all tests to acknowledge the warning message.

@adamjstewart adamjstewart added this to the 0.7.0 milestone Aug 27, 2024
@adamjstewart
Copy link
Collaborator Author

I think we first need to add a multilabel semantic segmentation dataset to properly test this.

@adamjstewart
Copy link
Collaborator Author

Alternatively, skip multilabel semantic segmentation and only support multilabel classification.

@adamjstewart adamjstewart added the backwards-incompatible Changes that are not backwards compatible label Oct 3, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backwards-incompatible Changes that are not backwards compatible testing Continuous integration testing trainers PyTorch Lightning trainers
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Multiclass Classification: assert num_classes >=2 Add BinarySemanticSegmentationTask to properly compute IoU
1 participant