Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Logging a class-wise metric with ClasswiseWrapper #2756

Closed
patrontheo opened this issue Sep 18, 2024 Discussed in #1462 · 5 comments
Closed

Logging a class-wise metric with ClasswiseWrapper #2756

patrontheo opened this issue Sep 18, 2024 Discussed in #1462 · 5 comments
Labels
bug / fix Something isn't working

Comments

@patrontheo
Copy link

patrontheo commented Sep 18, 2024

Discussed in #1462

Originally posted by blazejdolicki January 26, 2023
Hey everyone, I am trying to log the metric Dice(average="none") to Neptune using torchmetrics and pytorch lightning. Because this is a class-wise metric, it returns a tensor of shape (num_classes) instead of a scalar which makes it a bit complicated. I found this issue and PR which were supposed to solve the problem, but when I try to use them, I am still getting an error. I would appreciate any hints whether I am doing something wrong in my code or this is just not supported :)

My code looks something like this:

class MyModule(pl.LightningModule):
    def __init__(self):
        (...)
        
        class_labels = ["class_0", "class_1", "class_2"]
        metrics_dict = {"dice_macro": torchmetrics.Dice(average="macro"),
                        "dice_per_class": torchmetrics.ClasswiseWrapper(torchmetrics.Dice(average="none"),
                                                                        labels=class_labels)}
        self.val_metrics = torchmetrics.MetricCollection(metrics_dict)
        
    def validation_step(self, val_batch, batch_idx):
        x, y, w = val_batch
        outputs = self.forward(x)
        loss = self.loss_fn(outputs, y, w)
        self.log("val_loss", loss, on_epoch=True, on_step=False)
        self.val_metrics(outputs, y)
        self.log_dict(self.val_metrics, on_epoch=True, on_step=False)

The error:

Traceback (most recent call last):
    trainer.fit(model, train_loader, val_loader, ckpt_path=checkpoint_dir)
  File "C:\Users\BlazejDolicki\anaconda3\envs\data_science\lib\site-packages\pytorch_lightning\trainer\trainer.py", line 603, in fit
    call._call_and_handle_interrupt(
  File "C:\Users\BlazejDolicki\anaconda3\envs\data_science\lib\site-packages\pytorch_lightning\trainer\call.py", line 38, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
  File "C:\Users\BlazejDolicki\anaconda3\envs\data_science\lib\site-packages\pytorch_lightning\trainer\trainer.py", line 645, in _fit_impl
    self._run(model, ckpt_path=self.ckpt_path)
  File "C:\Users\BlazejDolicki\anaconda3\envs\data_science\lib\site-packages\pytorch_lightning\trainer\trainer.py", line 1098, in _run
    results = self._run_stage()
  File "C:\Users\BlazejDolicki\anaconda3\envs\data_science\lib\site-packages\pytorch_lightning\trainer\trainer.py", line 1177, in _run_stage
    self._run_train()
  File "C:\Users\BlazejDolicki\anaconda3\envs\data_science\lib\site-packages\pytorch_lightning\trainer\trainer.py", line 1190, in _run_train
    self._run_sanity_check()
  File "C:\Users\BlazejDolicki\anaconda3\envs\data_science\lib\site-packages\pytorch_lightning\trainer\trainer.py", line 1262, in _run_sanity_check
    val_loop.run()
  File "C:\Users\BlazejDolicki\anaconda3\envs\data_science\lib\site-packages\pytorch_lightning\loops\loop.py", line 206, in run
    output = self.on_run_end()
  File "C:\Users\BlazejDolicki\anaconda3\envs\data_science\lib\site-packages\pytorch_lightning\loops\dataloader\evaluation_loop.py", line 184, in on_run_end
    self._on_evaluation_epoch_end()
  File "C:\Users\BlazejDolicki\anaconda3\envs\data_science\lib\site-packages\pytorch_lightning\loops\dataloader\evaluation_loop.py", line 296, in _on_evaluation_epoch_end
    self.trainer._logger_connector.on_epoch_end()
  File "C:\Users\BlazejDolicki\anaconda3\envs\data_science\lib\site-packages\pytorch_lightning\trainer\connectors\logger_connector\logger_connector.py", line 179, in on_epoch_end
    metrics = self.metrics
  File "C:\Users\BlazejDolicki\anaconda3\envs\data_science\lib\site-packages\pytorch_lightning\trainer\connectors\logger_connector\logger_connector.py", line 224, in metrics
    return self.trainer._results.metrics(on_step)
  File "C:\Users\BlazejDolicki\anaconda3\envs\data_science\lib\site-packages\pytorch_lightning\trainer\connectors\logger_connector\result.py", line 584, in metrics
    value = apply_to_collection(result_metric, _ResultMetric, self._get_cache, on_step, include_none=False)
  File "C:\Users\BlazejDolicki\anaconda3\envs\data_science\lib\site-packages\lightning_utilities\core\apply_func.py", line 47, in apply_to_collection
    return function(data, *args, **kwargs)
  File "C:\Users\BlazejDolicki\anaconda3\envs\data_science\lib\site-packages\pytorch_lightning\trainer\connectors\logger_connector\result.py", line 550, in _get_cache
    raise ValueError(
ValueError: The `.compute()` return of the metric logged as 'val_dice_per_class' must be a tensor. Found {'dice_class_0': tensor(0., device='cuda:0'), 'dice_class_1': tensor(0., device='cuda:0'), 'dice_class_2': tensor(0., device='cuda:0')
```</div>
Copy link

Hi! thanks for your contribution!, great first issue!

@Borda Borda added the bug / fix Something isn't working label Sep 23, 2024
@SkafteNicki
Copy link
Member

Was fixed in this PR: #2720
Here is the relevant integration test:

def test_collection_classwise_lightning_integration(tmpdir):
"""Check the integration of ClasswiseWrapper, MetricCollection and LightningModule.
See issue: https://github.com/Lightning-AI/torchmetrics/issues/2683
"""
class TestModel(BoringModel):
def __init__(self) -> None:
super().__init__()
self.train_metrics = MetricCollection(
{
"macro_accuracy": MulticlassAccuracy(num_classes=5, average="macro"),
"classwise_accuracy": ClasswiseWrapper(MulticlassAccuracy(num_classes=5, average=None)),
},
prefix="train_",
)
self.val_metrics = MetricCollection(
{
"macro_accuracy": MulticlassAccuracy(num_classes=5, average="macro"),
"classwise_accuracy": ClasswiseWrapper(MulticlassAccuracy(num_classes=5, average=None)),
},
prefix="val_",
)
def training_step(self, batch, batch_idx):
loss = self(batch).sum()
preds = torch.randint(0, 5, (100,), device=batch.device)
target = torch.randint(0, 5, (100,), device=batch.device)
self.train_metrics.update(preds, target)
batch_values = self.train_metrics.compute()
self.log_dict(batch_values, on_step=True, on_epoch=False)
return {"loss": loss}
def validation_step(self, batch, batch_idx):
preds = torch.randint(0, 5, (100,), device=batch.device)
target = torch.randint(0, 5, (100,), device=batch.device)
self.val_metrics.update(preds, target)
def on_validation_epoch_end(self):
self.log_dict(self.val_metrics.compute(), on_step=False, on_epoch=True)
model = TestModel()
trainer = Trainer(
default_root_dir=tmpdir,
limit_train_batches=2,
limit_val_batches=2,
max_epochs=1,
log_every_n_steps=1,
)
trainer.fit(model)
logged = trainer.logged_metrics
# check that all metrics are logged
assert "train_macro_accuracy" in logged
assert "val_macro_accuracy" in logged
for i in range(5):
assert f"train_multiclassaccuracy_{i}" in logged
assert f"val_multiclassaccuracy_{i}" in logged

Please wait for the fix to be released in the next release of torchmetrics or alternatively you can install directly from master:

pip install https://github.com/Lightning-AI/torchmetrics/archive/master.zip

Closing issue.

@patrontheo
Copy link
Author

patrontheo commented Oct 9, 2024

@SkafteNicki Thank you !
I was hoping it could be integrated with the automatic logging way, but you made clear it was complicated.
I ended up going the manual way as in your integration test.
Quick question about your integration test here.
I have the same workflow for the validation_step and on_validation_epoch_end, but not exactly the same for the training_step.
You are doing:

self.train_metrics.update(preds, target)
batch_values = self.train_metrics.compute()
self.log_dict(batch_values, on_step=True, on_epoch=False)

I'm not too sure how the compute works here, is it computing the metric value on the batch data or on all the data since the beginning of the epoch ?
Also, shouldn't you reset the metric in on_train_epoch_end ?

On my side, I am doing:

def training_step(self, batch, batch_idx):
        logits, y, loss = self.step(batch)

        self.log("train/loss", loss, on_step=True, on_epoch=True)
        metric_dict = self.train_metrics(logits, y)
        metric_dict = {f"{key}_step": value for key, value in metric_dict.items()}
        self.log_dict(metric_dict)
        
def on_train_epoch_end(self):
        train_metrics = self.train_metrics.compute()
        train_metrics = {f"{key}_epoch": value for key, value in train_metrics.items()}
        self.log_dict(train_metrics)
        self.train_metrics.reset()

Am I doing things correctly here ? Should I specify on_epoch=True/False and on_step=True/False ? What would it change here ?

@SkafteNicki
Copy link
Member

Hi @patrontheo, you are completely right I made a mistake in the PR. The training_step should not be using the compute method but instead the forward function as you are using (will be fixes in PR #2775). So overall you are doing it correctly. There is no reason for you to specify the on_step, on_epoch in your self.log_dict because you are already correctly logging it.

@patrontheo
Copy link
Author

@SkafteNicki Thanks a lot for the clarification ! :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug / fix Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants