Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] Add avg_non_ignore in cross entropy loss #1409

Merged
merged 21 commits into from
Mar 30, 2022
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions docs/en/tutorials/training_tricks.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,3 +68,21 @@ model = dict(
In this way, `loss_weight` and `loss_name` will be weight and name in training log of corresponding loss, respectively.

Note: If you want this loss item to be included into the backward graph, `loss_` must be the prefix of the name.

## Ignore specified label index in loss calculation

For loss calculation, we support ignore index of certain label by `avg_non_ignore` and `ignore_index`. Here is an example config of training `unet` on `DRIVE` dataset: in loss calculation it would ignore label 0 which is background and loss average is only calculated on non-ignore labels:

```python
_base_ = './fcn_unet_s5-d16_64x64_40k_drive.py'
model = dict(
decode_head=dict(
ignore_index=0,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0, avg_non_ignore=True),
auxiliary_head=dict(
ignore_index=0,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0, avg_non_ignore=True)),
))
```
MengzhangLI marked this conversation as resolved.
Show resolved Hide resolved
23 changes: 23 additions & 0 deletions docs/zh_cn/tutorials/training_tricks.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,3 +68,26 @@ model = dict(
通过这种方式,确定训练过程中损失函数的权重 `loss_weight` 和在训练日志里的名字 `loss_name`。

注意: `loss_name` 的名字必须带有 `loss_` 前缀,这样它才能被包括在反传的图里。

## 在损失函数中忽略特定的 label 类别

对于训练时损失函数的计算,我们目前支持使用 `avg_non_ignore` 和 `ignore_index` 来忽略 label 特定的类别。 以 `unet` 使用 `DRIVE` 数据集训练为例,
在计算损失函数时,忽略 label 为0的背景,并且仅在不被忽略的像素上计算均值。配置文件写为:

```python
_base_ = './fcn_unet_s5-d16_64x64_40k_drive.py'
model = dict(
decode_head=dict(
ignore_index=0,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0, avg_non_ignore=True),
auxiliary_head=dict(
ignore_index=0,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0, avg_non_ignore=True)),
))
```

通过这种方式,确定训练过程中损失函数的权重 `loss_weight` 和在训练日志里的名字 `loss_name`。

注意: `loss_name` 的名字必须带有 `loss_` 前缀,这样它才能被包括在反传的图里。
92 changes: 79 additions & 13 deletions mmseg/models/losses/cross_entropy_loss.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
import warnings

import torch
import torch.nn as nn
import torch.nn.functional as F
Expand All @@ -13,8 +15,30 @@ def cross_entropy(pred,
class_weight=None,
reduction='mean',
avg_factor=None,
ignore_index=-100):
"""The wrapper function for :func:`F.cross_entropy`"""
ignore_index=-100,
avg_non_ignore=False):
"""cross_entropy. The wrapper function for :func:`F.cross_entropy`

Args:
pred (torch.Tensor): The prediction with shape (N, 1).
label (torch.Tensor): The learning label of the prediction.
weight (torch.Tensor, optional): Sample-wise loss weight.
Default: None.
class_weight (list[float], optional): The weight for each class.
Default: None.
reduction (str, optional): The method used to reduce the loss.
Options are "none", "mean" and "sum". Default: 'mean'.
MengzhangLI marked this conversation as resolved.
Show resolved Hide resolved
avg_factor (int, optional): Average factor that is used to average
the loss. Default: None.
ignore_index (int): Specifies a target value that is ignored and
does not contribute to the input gradients. When
``avg_non_ignore `` is ``True``, and the ``reduction`` is
``''mean''``, the loss is averaged over non-ignored targets.
Defaults: -100.
avg_non_ignore (bool): The flag decides to whether the loss is
only averaged over non-ignored targets. Default: False.
MengzhangLI marked this conversation as resolved.
Show resolved Hide resolved
"""
MengzhangLI marked this conversation as resolved.
Show resolved Hide resolved

# class_weight is a manual rescaling weight given to each class.
# If given, has to be a Tensor of size C element-wise losses
loss = F.cross_entropy(
Expand All @@ -25,6 +49,11 @@ def cross_entropy(pred,
ignore_index=ignore_index)

# apply weights and do the reduction
# average loss over non-ignored elements
# pytorch's official cross_entropy average loss over non-ignored elements
# refer to https://github.com/pytorch/pytorch/blob/56b43f4fec1f76953f15a627694d4bba34588969/torch/nn/functional.py#L2660 # noqa
if (avg_factor is None) and avg_non_ignore and reduction == 'mean':
avg_factor = label.numel() - (label == ignore_index).sum().item()
if weight is not None:
weight = weight.float()
loss = weight_reduce_loss(
Expand All @@ -46,13 +75,14 @@ def _expand_onehot_labels(labels, label_weights, target_shape, ignore_index):
bin_labels[inds[0], labels[valid_mask]] = 1

valid_mask = valid_mask.unsqueeze(1).expand(target_shape).float()

if label_weights is None:
bin_label_weights = valid_mask
else:
bin_label_weights = label_weights.unsqueeze(1).expand(target_shape)
bin_label_weights *= valid_mask

return bin_labels, bin_label_weights
return bin_labels, bin_label_weights, valid_mask


def binary_cross_entropy(pred,
Expand All @@ -61,19 +91,25 @@ def binary_cross_entropy(pred,
reduction='mean',
avg_factor=None,
class_weight=None,
ignore_index=255):
ignore_index=-100,
avg_non_ignore=False,
MengzhangLI marked this conversation as resolved.
Show resolved Hide resolved
**kwargs):
"""Calculate the binary CrossEntropy loss.

Args:
pred (torch.Tensor): The prediction with shape (N, 1).
label (torch.Tensor): The learning label of the prediction.
Note: In bce loss, label < 0 is invalid.
weight (torch.Tensor, optional): Sample-wise loss weight.
reduction (str, optional): The method used to reduce the loss.
Options are "none", "mean" and "sum".
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
class_weight (list[float], optional): The weight for each class.
ignore_index (int | None): The label index to be ignored. Default: 255
ignore_index (int | None): The label index to be ignored.
Default: -100.
MengzhangLI marked this conversation as resolved.
Show resolved Hide resolved
avg_non_ignore (bool): The flag decides to whether the loss is
MengzhangLI marked this conversation as resolved.
Show resolved Hide resolved
only averaged over non-ignored targets. Default: False.

Returns:
torch.Tensor: The calculated loss
Expand All @@ -83,12 +119,21 @@ def binary_cross_entropy(pred,
pred.dim() == 4 and label.dim() == 3), \
'Only pred shape [N, C], label shape [N] or pred shape [N, C, ' \
'H, W], label shape [N, H, W] are supported'
label, weight = _expand_onehot_labels(label, weight, pred.shape,
ignore_index)
# `weight` returned from `_expand_onehot_labels`
# has been treated for valid (non-ignore) pixels
label, weight, valid_mask = _expand_onehot_labels(
label, weight, pred.shape, ignore_index)
else:
MengzhangLI marked this conversation as resolved.
Show resolved Hide resolved
# should mask out the ignored elements
valid_mask = ((label >= 0) & (label != ignore_index)).float()
if weight is not None:
weight *= valid_mask
else:
weight = valid_mask
# average loss over non-ignored elements
MengzhangLI marked this conversation as resolved.
Show resolved Hide resolved
if reduction == 'mean' and avg_factor is None and avg_non_ignore:
avg_factor = valid_mask.sum().item()

# weighted element-wise losses
if weight is not None:
weight = weight.float()
loss = F.binary_cross_entropy_with_logits(
pred, label.float(), pos_weight=class_weight, reduction='none')
# do the reduction for the weighted loss
Expand All @@ -104,7 +149,8 @@ def mask_cross_entropy(pred,
reduction='mean',
avg_factor=None,
class_weight=None,
ignore_index=None):
ignore_index=-100,
MengzhangLI marked this conversation as resolved.
Show resolved Hide resolved
**kwargs):
"""Calculate the CrossEntropy loss for masks.

Args:
Expand All @@ -121,7 +167,7 @@ def mask_cross_entropy(pred,
the loss. Defaults to None.
class_weight (list[float], optional): The weight for each class.
ignore_index (None): Placeholder, to be consistent with other loss.
Default: None.
Default: -100.

Returns:
torch.Tensor: The calculated loss
Expand Down Expand Up @@ -153,6 +199,8 @@ class CrossEntropyLoss(nn.Module):
loss_name (str, optional): Name of the loss item. If you want this loss
item to be included into the backward graph, `loss_` must be the
prefix of the name. Defaults to 'loss_ce'.
avg_non_ignore (bool): The flag decides to whether the loss is
only averaged over non-ignored targets. Default: False.
MengzhangLI marked this conversation as resolved.
Show resolved Hide resolved
"""

def __init__(self,
Expand All @@ -161,14 +209,22 @@ def __init__(self,
reduction='mean',
class_weight=None,
loss_weight=1.0,
loss_name='loss_ce'):
loss_name='loss_ce',
avg_non_ignore=False):
super(CrossEntropyLoss, self).__init__()
assert (use_sigmoid is False) or (use_mask is False)
self.use_sigmoid = use_sigmoid
self.use_mask = use_mask
self.reduction = reduction
self.loss_weight = loss_weight
self.class_weight = get_class_weight(class_weight)
self.avg_non_ignore = avg_non_ignore
if not self.avg_non_ignore and self.reduction == 'mean':
warnings.warn(
'Default ``avg_non_ignore`` is False, if you would like to '
'ignore the certain label and average loss over non-ignore '
'labels, which is the same with PyTorch official '
'cross_entropy, set ``avg_non_ignore=True``.')

if self.use_sigmoid:
self.cls_criterion = binary_cross_entropy
Expand All @@ -178,12 +234,18 @@ def __init__(self,
self.cls_criterion = cross_entropy
self._loss_name = loss_name

def extra_repr(self):
"""Extra repr."""
s = f'avg_non_ignore={self.avg_non_ignore}'
return s

def forward(self,
cls_score,
label,
weight=None,
avg_factor=None,
reduction_override=None,
ignore_index=-100,
**kwargs):
"""Forward function."""
assert reduction_override in (None, 'none', 'mean', 'sum')
Expand All @@ -193,13 +255,16 @@ def forward(self,
class_weight = cls_score.new_tensor(self.class_weight)
else:
class_weight = None
# Note: In cls_criterion, label < 0 is invalid.
MengzhangLI marked this conversation as resolved.
Show resolved Hide resolved
loss_cls = self.loss_weight * self.cls_criterion(
cls_score,
label,
weight,
class_weight=class_weight,
reduction=reduction,
avg_factor=avg_factor,
avg_non_ignore=self.avg_non_ignore,
ignore_index=ignore_index,
**kwargs)
MengzhangLI marked this conversation as resolved.
Show resolved Hide resolved
return loss_cls

Expand All @@ -212,6 +277,7 @@ def loss_name(self):
by simple sum operation. In addition, if you want this loss item to be
included into the backward graph, `loss_` must be the prefix of the
name.

Returns:
str: The name of this loss item.
"""
Expand Down
6 changes: 5 additions & 1 deletion mmseg/models/losses/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import mmcv
import numpy as np
import torch
import torch.nn.functional as F


Expand Down Expand Up @@ -69,7 +70,10 @@ def weight_reduce_loss(loss, weight=None, reduction='mean', avg_factor=None):
else:
# if reduction is mean, then average the loss by avg_factor
if reduction == 'mean':
loss = loss.sum() / avg_factor
# Avoid causing ZeroDivisionError when avg_factor is 0.0,
# i.e., all labels of an image belong to ignore index.
eps = torch.finfo(torch.float32).eps
loss = loss.sum() / (avg_factor + eps)
# if reduction is 'none', then do nothing, otherwise raise an error
elif reduction != 'none':
raise ValueError('avg_factor can not be used with reduction="sum"')
Expand Down
Loading