diff --git a/docs/en/tutorials/training_tricks.md b/docs/en/tutorials/training_tricks.md index 1c8fe06b94..6ff2c4249d 100644 --- a/docs/en/tutorials/training_tricks.md +++ b/docs/en/tutorials/training_tricks.md @@ -68,3 +68,23 @@ model = dict( In this way, `loss_weight` and `loss_name` will be weight and name in training log of corresponding loss, respectively. Note: If you want this loss item to be included into the backward graph, `loss_` must be the prefix of the name. + +## Ignore specified label index in loss calculation + +In default setting, `avg_non_ignore=False` which means each pixel counts for loss calculation although some of them belong to ignore-index labels. + +For loss calculation, we support ignore index of certain label by `avg_non_ignore` and `ignore_index`. In this way, the average loss would only be calculated in non-ignored labels which may achieve better performance, and here is the [reference](https://github.com/open-mmlab/mmsegmentation/pull/1409). Here is an example config of training `unet` on `Cityscapes` dataset: in loss calculation it would ignore label 0 which is background and loss average is only calculated on non-ignore labels: + +```python +_base_ = './fcn_unet_s5-d16_4x4_512x1024_160k_cityscapes.py' +model = dict( + decode_head=dict( + ignore_index=0, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0, avg_non_ignore=True), + auxiliary_head=dict( + ignore_index=0, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0, avg_non_ignore=True)), + )) +``` diff --git a/docs/zh_cn/tutorials/training_tricks.md b/docs/zh_cn/tutorials/training_tricks.md index be9112cabd..2efbdf177c 100644 --- a/docs/zh_cn/tutorials/training_tricks.md +++ b/docs/zh_cn/tutorials/training_tricks.md @@ -68,3 +68,28 @@ model = dict( 通过这种方式,确定训练过程中损失函数的权重 `loss_weight` 和在训练日志里的名字 `loss_name`。 注意: `loss_name` 的名字必须带有 `loss_` 前缀,这样它才能被包括在反传的图里。 + +## 在损失函数中忽略特定的 label 类别 + +默认设置 `avg_non_ignore=False`, 即每个像素都用来计算损失函数。尽管其中的一些像素属于需要被忽略的类别。 + +对于训练时损失函数的计算,我们目前支持使用 `avg_non_ignore` 和 `ignore_index` 来忽略 label 特定的类别。 这样损失函数将只在非忽略类别像素中求平均值,会获得更好的表现。这里是[相关 PR](https://github.com/open-mmlab/mmsegmentation/pull/1409)。以 `unet` 使用 `Cityscapes` 数据集训练为例, +在计算损失函数时,忽略 label 为0的背景,并且仅在不被忽略的像素上计算均值。配置文件写为: + +```python +_base_ = './fcn_unet_s5-d16_4x4_512x1024_160k_cityscapes.py' +model = dict( + decode_head=dict( + ignore_index=0, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0, avg_non_ignore=True), + auxiliary_head=dict( + ignore_index=0, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0, avg_non_ignore=True)), + )) +``` + +通过这种方式,确定训练过程中损失函数的权重 `loss_weight` 和在训练日志里的名字 `loss_name`。 + +注意: `loss_name` 的名字必须带有 `loss_` 前缀,这样它才能被包括在反传的图里。 diff --git a/mmseg/models/losses/cross_entropy_loss.py b/mmseg/models/losses/cross_entropy_loss.py index ee489a888f..7c2158f832 100644 --- a/mmseg/models/losses/cross_entropy_loss.py +++ b/mmseg/models/losses/cross_entropy_loss.py @@ -1,4 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. +import warnings + import torch import torch.nn as nn import torch.nn.functional as F @@ -13,8 +15,31 @@ def cross_entropy(pred, class_weight=None, reduction='mean', avg_factor=None, - ignore_index=-100): - """The wrapper function for :func:`F.cross_entropy`""" + ignore_index=-100, + avg_non_ignore=False): + """cross_entropy. The wrapper function for :func:`F.cross_entropy` + + Args: + pred (torch.Tensor): The prediction with shape (N, 1). + label (torch.Tensor): The learning label of the prediction. + weight (torch.Tensor, optional): Sample-wise loss weight. + Default: None. + class_weight (list[float], optional): The weight for each class. + Default: None. + reduction (str, optional): The method used to reduce the loss. + Options are 'none', 'mean' and 'sum'. Default: 'mean'. + avg_factor (int, optional): Average factor that is used to average + the loss. Default: None. + ignore_index (int): Specifies a target value that is ignored and + does not contribute to the input gradients. When + ``avg_non_ignore `` is ``True``, and the ``reduction`` is + ``''mean''``, the loss is averaged over non-ignored targets. + Defaults: -100. + avg_non_ignore (bool): The flag decides to whether the loss is + only averaged over non-ignored targets. Default: False. + `New in version 0.23.0.` + """ + # class_weight is a manual rescaling weight given to each class. # If given, has to be a Tensor of size C element-wise losses loss = F.cross_entropy( @@ -25,6 +50,11 @@ def cross_entropy(pred, ignore_index=ignore_index) # apply weights and do the reduction + # average loss over non-ignored elements + # pytorch's official cross_entropy average loss over non-ignored elements + # refer to https://github.com/pytorch/pytorch/blob/56b43f4fec1f76953f15a627694d4bba34588969/torch/nn/functional.py#L2660 # noqa + if (avg_factor is None) and avg_non_ignore and reduction == 'mean': + avg_factor = label.numel() - (label == ignore_index).sum().item() if weight is not None: weight = weight.float() loss = weight_reduce_loss( @@ -46,13 +76,14 @@ def _expand_onehot_labels(labels, label_weights, target_shape, ignore_index): bin_labels[inds[0], labels[valid_mask]] = 1 valid_mask = valid_mask.unsqueeze(1).expand(target_shape).float() + if label_weights is None: bin_label_weights = valid_mask else: bin_label_weights = label_weights.unsqueeze(1).expand(target_shape) bin_label_weights *= valid_mask - return bin_labels, bin_label_weights + return bin_labels, bin_label_weights, valid_mask def binary_cross_entropy(pred, @@ -61,19 +92,25 @@ def binary_cross_entropy(pred, reduction='mean', avg_factor=None, class_weight=None, - ignore_index=255): + ignore_index=-100, + avg_non_ignore=False, + **kwargs): """Calculate the binary CrossEntropy loss. Args: pred (torch.Tensor): The prediction with shape (N, 1). label (torch.Tensor): The learning label of the prediction. + Note: In bce loss, label < 0 is invalid. weight (torch.Tensor, optional): Sample-wise loss weight. reduction (str, optional): The method used to reduce the loss. Options are "none", "mean" and "sum". avg_factor (int, optional): Average factor that is used to average the loss. Defaults to None. class_weight (list[float], optional): The weight for each class. - ignore_index (int | None): The label index to be ignored. Default: 255 + ignore_index (int): The label index to be ignored. Default: -100. + avg_non_ignore (bool): The flag decides to whether the loss is + only averaged over non-ignored targets. Default: False. + `New in version 0.23.0.` Returns: torch.Tensor: The calculated loss @@ -83,12 +120,21 @@ def binary_cross_entropy(pred, pred.dim() == 4 and label.dim() == 3), \ 'Only pred shape [N, C], label shape [N] or pred shape [N, C, ' \ 'H, W], label shape [N, H, W] are supported' - label, weight = _expand_onehot_labels(label, weight, pred.shape, - ignore_index) + # `weight` returned from `_expand_onehot_labels` + # has been treated for valid (non-ignore) pixels + label, weight, valid_mask = _expand_onehot_labels( + label, weight, pred.shape, ignore_index) + else: + # should mask out the ignored elements + valid_mask = ((label >= 0) & (label != ignore_index)).float() + if weight is not None: + weight *= valid_mask + else: + weight = valid_mask + # average loss over non-ignored and valid elements + if reduction == 'mean' and avg_factor is None and avg_non_ignore: + avg_factor = valid_mask.sum().item() - # weighted element-wise losses - if weight is not None: - weight = weight.float() loss = F.binary_cross_entropy_with_logits( pred, label.float(), pos_weight=class_weight, reduction='none') # do the reduction for the weighted loss @@ -104,7 +150,8 @@ def mask_cross_entropy(pred, reduction='mean', avg_factor=None, class_weight=None, - ignore_index=None): + ignore_index=None, + **kwargs): """Calculate the CrossEntropy loss for masks. Args: @@ -153,6 +200,9 @@ class CrossEntropyLoss(nn.Module): loss_name (str, optional): Name of the loss item. If you want this loss item to be included into the backward graph, `loss_` must be the prefix of the name. Defaults to 'loss_ce'. + avg_non_ignore (bool): The flag decides to whether the loss is + only averaged over non-ignored targets. Default: False. + `New in version 0.23.0.` """ def __init__(self, @@ -161,7 +211,8 @@ def __init__(self, reduction='mean', class_weight=None, loss_weight=1.0, - loss_name='loss_ce'): + loss_name='loss_ce', + avg_non_ignore=False): super(CrossEntropyLoss, self).__init__() assert (use_sigmoid is False) or (use_mask is False) self.use_sigmoid = use_sigmoid @@ -169,6 +220,13 @@ def __init__(self, self.reduction = reduction self.loss_weight = loss_weight self.class_weight = get_class_weight(class_weight) + self.avg_non_ignore = avg_non_ignore + if not self.avg_non_ignore and self.reduction == 'mean': + warnings.warn( + 'Default ``avg_non_ignore`` is False, if you would like to ' + 'ignore the certain label and average loss over non-ignore ' + 'labels, which is the same with PyTorch official ' + 'cross_entropy, set ``avg_non_ignore=True``.') if self.use_sigmoid: self.cls_criterion = binary_cross_entropy @@ -178,12 +236,18 @@ def __init__(self, self.cls_criterion = cross_entropy self._loss_name = loss_name + def extra_repr(self): + """Extra repr.""" + s = f'avg_non_ignore={self.avg_non_ignore}' + return s + def forward(self, cls_score, label, weight=None, avg_factor=None, reduction_override=None, + ignore_index=-100, **kwargs): """Forward function.""" assert reduction_override in (None, 'none', 'mean', 'sum') @@ -193,6 +257,7 @@ def forward(self, class_weight = cls_score.new_tensor(self.class_weight) else: class_weight = None + # Note: for BCE loss, label < 0 is invalid. loss_cls = self.loss_weight * self.cls_criterion( cls_score, label, @@ -200,6 +265,8 @@ def forward(self, class_weight=class_weight, reduction=reduction, avg_factor=avg_factor, + avg_non_ignore=self.avg_non_ignore, + ignore_index=ignore_index, **kwargs) return loss_cls @@ -212,6 +279,7 @@ def loss_name(self): by simple sum operation. In addition, if you want this loss item to be included into the backward graph, `loss_` must be the prefix of the name. + Returns: str: The name of this loss item. """ diff --git a/mmseg/models/losses/utils.py b/mmseg/models/losses/utils.py index c37875fadb..621f57c746 100644 --- a/mmseg/models/losses/utils.py +++ b/mmseg/models/losses/utils.py @@ -3,6 +3,7 @@ import mmcv import numpy as np +import torch import torch.nn.functional as F @@ -69,7 +70,10 @@ def weight_reduce_loss(loss, weight=None, reduction='mean', avg_factor=None): else: # if reduction is mean, then average the loss by avg_factor if reduction == 'mean': - loss = loss.sum() / avg_factor + # Avoid causing ZeroDivisionError when avg_factor is 0.0, + # i.e., all labels of an image belong to ignore index. + eps = torch.finfo(torch.float32).eps + loss = loss.sum() / (avg_factor + eps) # if reduction is 'none', then do nothing, otherwise raise an error elif reduction != 'none': raise ValueError('avg_factor can not be used with reduction="sum"') diff --git a/tests/test_models/test_losses/test_ce_loss.py b/tests/test_models/test_losses/test_ce_loss.py index 2fe5c2eb49..6fd8d25a9c 100644 --- a/tests/test_models/test_losses/test_ce_loss.py +++ b/tests/test_models/test_losses/test_ce_loss.py @@ -2,8 +2,14 @@ import pytest import torch +from mmseg.models.losses.cross_entropy_loss import _expand_onehot_labels -def test_ce_loss(): + +@pytest.mark.parametrize('use_sigmoid', [True, False]) +@pytest.mark.parametrize('reduction', ('mean', 'sum', 'none')) +@pytest.mark.parametrize('avg_non_ignore', [True, False]) +@pytest.mark.parametrize('bce_input_same_dim', [True, False]) +def test_ce_loss(use_sigmoid, reduction, avg_non_ignore, bce_input_same_dim): from mmseg.models import build_loss # use_mask and use_sigmoid cannot be true at the same time @@ -15,19 +21,73 @@ def test_ce_loss(): loss_weight=1.0) build_loss(loss_cfg) - # test loss with class weights + # test loss with simple case for ce/bce + fake_pred = torch.Tensor([[100, -100]]) + fake_label = torch.Tensor([1]).long() loss_cls_cfg = dict( type='CrossEntropyLoss', - use_sigmoid=False, - class_weight=[0.8, 0.2], + use_sigmoid=use_sigmoid, loss_weight=1.0, + avg_non_ignore=avg_non_ignore, loss_name='loss_ce') loss_cls = build_loss(loss_cls_cfg) - fake_pred = torch.Tensor([[100, -100]]) - fake_label = torch.Tensor([1]).long() - assert torch.allclose(loss_cls(fake_pred, fake_label), torch.tensor(40.)) + if use_sigmoid: + assert torch.allclose( + loss_cls(fake_pred, fake_label), torch.tensor(100.)) + else: + assert torch.allclose( + loss_cls(fake_pred, fake_label), torch.tensor(200.)) + + # test loss with complicated case for ce/bce + # when avg_non_ignore is False, `avg_factor` would not be calculated + fake_pred = torch.full(size=(2, 21, 8, 8), fill_value=0.5) + fake_label = torch.ones(2, 8, 8).long() + fake_label[:, 0, 0] = 255 + fake_weight = None + # extra test bce loss when pred.shape == label.shape + if use_sigmoid and bce_input_same_dim: + fake_pred = torch.randn(2, 10).float() + fake_label = torch.rand(2, 10).float() + fake_weight = torch.rand(2, 10) # set weight in forward function + fake_label[0, [1, 2, 5, 7]] = 255 # set ignore_index + fake_label[1, [0, 5, 8, 9]] = 255 + loss_cls = build_loss(loss_cls_cfg) + loss = loss_cls( + fake_pred, fake_label, weight=fake_weight, ignore_index=255) + if use_sigmoid: + if fake_pred.dim() != fake_label.dim(): + fake_label, weight, valid_mask = _expand_onehot_labels( + labels=fake_label, + label_weights=None, + target_shape=fake_pred.shape, + ignore_index=255) + else: + # should mask out the ignored elements + valid_mask = ((fake_label >= 0) & (fake_label != 255)).float() + weight = valid_mask + torch_loss = torch.nn.functional.binary_cross_entropy_with_logits( + fake_pred, + fake_label.float(), + reduction='none', + weight=fake_weight) + if avg_non_ignore: + avg_factor = valid_mask.sum().item() + torch_loss = (torch_loss * weight).sum() / avg_factor + else: + torch_loss = (torch_loss * weight).mean() + else: + if avg_non_ignore: + torch_loss = torch.nn.functional.cross_entropy( + fake_pred, fake_label, reduction='mean', ignore_index=255) + else: + torch_loss = torch.nn.functional.cross_entropy( + fake_pred, fake_label, reduction='sum', + ignore_index=255) / fake_label.numel() + assert torch.allclose(loss, torch_loss) # test loss with class weights from file + fake_pred = torch.Tensor([[100, -100]]) + fake_label = torch.Tensor([1]).long() import os import tempfile @@ -63,27 +123,103 @@ def test_ce_loss(): loss_cls = build_loss(loss_cls_cfg) assert torch.allclose(loss_cls(fake_pred, fake_label), torch.tensor(200.)) - loss_cls_cfg = dict( - type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0) - loss_cls = build_loss(loss_cls_cfg) - assert torch.allclose(loss_cls(fake_pred, fake_label), torch.tensor(100.)) - - fake_pred = torch.full(size=(2, 21, 8, 8), fill_value=0.5) - fake_label = torch.ones(2, 8, 8).long() - assert torch.allclose( - loss_cls(fake_pred, fake_label), torch.tensor(0.9503), atol=1e-4) - fake_label[:, 0, 0] = 255 + # test `avg_non_ignore` without ignore index would not affect ce/bce loss + # when reduction='sum'/'none'/'mean' + loss_cls_cfg1 = dict( + type='CrossEntropyLoss', + use_sigmoid=use_sigmoid, + reduction=reduction, + loss_weight=1.0, + avg_non_ignore=True) + loss_cls1 = build_loss(loss_cls_cfg1) + loss_cls_cfg2 = dict( + type='CrossEntropyLoss', + use_sigmoid=use_sigmoid, + reduction=reduction, + loss_weight=1.0, + avg_non_ignore=False) + loss_cls2 = build_loss(loss_cls_cfg2) assert torch.allclose( - loss_cls(fake_pred, fake_label, ignore_index=255), - torch.tensor(0.9354), + loss_cls1(fake_pred, fake_label, ignore_index=255) / fake_pred.numel(), + loss_cls2(fake_pred, fake_label, ignore_index=255) / fake_pred.numel(), atol=1e-4) - # test cross entropy loss has name `loss_ce` + # test ce/bce loss with ignore index and class weight + # in 5-way classification + if use_sigmoid: + # test bce loss when pred.shape == or != label.shape + if bce_input_same_dim: + fake_pred = torch.randn(2, 10).float() + fake_label = torch.rand(2, 10).float() + class_weight = torch.rand(2, 10) + else: + fake_pred = torch.full(size=(2, 21, 8, 8), fill_value=0.5) + fake_label = torch.ones(2, 8, 8).long() + class_weight = torch.randn(2, 21, 8, 8) + fake_label, weight, valid_mask = _expand_onehot_labels( + labels=fake_label, + label_weights=None, + target_shape=fake_pred.shape, + ignore_index=-100) + torch_loss = torch.nn.functional.binary_cross_entropy_with_logits( + fake_pred, + fake_label.float(), + reduction='mean', + pos_weight=class_weight) + else: + fake_pred = torch.randn(2, 5, 10).float() # 5-way classification + fake_label = torch.randint(0, 5, (2, 10)).long() + class_weight = torch.rand(5) + class_weight /= class_weight.sum() + torch_loss = torch.nn.functional.cross_entropy( + fake_pred, fake_label, reduction='sum', + weight=class_weight) / fake_label.numel() loss_cls_cfg = dict( type='CrossEntropyLoss', - use_sigmoid=False, + use_sigmoid=use_sigmoid, + reduction='mean', + class_weight=class_weight, loss_weight=1.0, - loss_name='loss_ce') + avg_non_ignore=avg_non_ignore) loss_cls = build_loss(loss_cls_cfg) + + # test cross entropy loss has name `loss_ce` assert loss_cls.loss_name == 'loss_ce' - # TODO test use_mask + # test avg_non_ignore is in extra_repr + assert loss_cls.extra_repr() == f'avg_non_ignore={avg_non_ignore}' + + loss = loss_cls(fake_pred, fake_label) + assert torch.allclose(loss, torch_loss) + + fake_label[0, [1, 2, 5, 7]] = 10 # set ignore_index + fake_label[1, [0, 5, 8, 9]] = 10 + loss = loss_cls(fake_pred, fake_label, ignore_index=10) + if use_sigmoid: + if avg_non_ignore: + torch_loss = torch.nn.functional.binary_cross_entropy_with_logits( + fake_pred[fake_label != 10], + fake_label[fake_label != 10].float(), + pos_weight=class_weight[fake_label != 10], + reduction='mean') + else: + torch_loss = torch.nn.functional.binary_cross_entropy_with_logits( + fake_pred[fake_label != 10], + fake_label[fake_label != 10].float(), + pos_weight=class_weight[fake_label != 10], + reduction='sum') / fake_label.numel() + else: + if avg_non_ignore: + torch_loss = torch.nn.functional.cross_entropy( + fake_pred, + fake_label, + ignore_index=10, + reduction='sum', + weight=class_weight) / fake_label[fake_label != 10].numel() + else: + torch_loss = torch.nn.functional.cross_entropy( + fake_pred, + fake_label, + ignore_index=10, + reduction='sum', + weight=class_weight) / fake_label.numel() + assert torch.allclose(loss, torch_loss)