-
-
Notifications
You must be signed in to change notification settings - Fork 504
/
losses.py
345 lines (272 loc) · 12.9 KB
/
losses.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
import torch
import torch.nn.functional as F
from torch import nn as nn
from torch.nn import MSELoss, SmoothL1Loss, L1Loss
def compute_per_channel_dice(input, target, epsilon=1e-6, weight=None):
"""
Computes DiceCoefficient as defined in https://arxiv.org/abs/1606.04797 given a multi channel input and target.
Assumes the input is a normalized probability, e.g. a result of Sigmoid or Softmax function.
Args:
input (torch.Tensor): NxCxSpatial input tensor
target (torch.Tensor): NxCxSpatial target tensor
epsilon (float): prevents division by zero
weight (torch.Tensor): Cx1 tensor of weight per channel/class
"""
# input and target shapes must match
assert input.size() == target.size(), "'input' and 'target' must have the same shape"
input = flatten(input)
target = flatten(target)
target = target.float()
# compute per channel Dice Coefficient
intersect = (input * target).sum(-1)
if weight is not None:
intersect = weight * intersect
# here we can use standard dice (input + target).sum(-1) or extension (see V-Net) (input^2 + target^2).sum(-1)
denominator = (input * input).sum(-1) + (target * target).sum(-1)
return 2 * (intersect / denominator.clamp(min=epsilon))
class _MaskingLossWrapper(nn.Module):
"""
Loss wrapper which prevents the gradient of the loss to be computed where target is equal to `ignore_index`.
"""
def __init__(self, loss, ignore_index):
super(_MaskingLossWrapper, self).__init__()
assert ignore_index is not None, 'ignore_index cannot be None'
self.loss = loss
self.ignore_index = ignore_index
def forward(self, input, target):
mask = target.clone().ne_(self.ignore_index)
mask.requires_grad = False
# mask out input/target so that the gradient is zero where on the mask
input = input * mask
target = target * mask
# forward masked input and target to the loss
return self.loss(input, target)
class SkipLastTargetChannelWrapper(nn.Module):
"""
Loss wrapper which removes additional target channel
"""
def __init__(self, loss, squeeze_channel=False):
super(SkipLastTargetChannelWrapper, self).__init__()
self.loss = loss
self.squeeze_channel = squeeze_channel
def forward(self, input, target, weight=None):
assert target.size(1) > 1, 'Target tensor has a singleton channel dimension, cannot remove channel'
# skips last target channel if needed
target = target[:, :-1, ...]
if self.squeeze_channel:
# squeeze channel dimension
target = torch.squeeze(target, dim=1)
if weight is not None:
return self.loss(input, target, weight)
return self.loss(input, target)
class _AbstractDiceLoss(nn.Module):
"""
Base class for different implementations of Dice loss.
"""
def __init__(self, weight=None, normalization='sigmoid'):
super(_AbstractDiceLoss, self).__init__()
self.register_buffer('weight', weight)
# The output from the network during training is assumed to be un-normalized probabilities and we would
# like to normalize the logits. Since Dice (or soft Dice in this case) is usually used for binary data,
# normalizing the channels with Sigmoid is the default choice even for multi-class segmentation problems.
# However if one would like to apply Softmax in order to get the proper probability distribution from the
# output, just specify `normalization=Softmax`
assert normalization in ['sigmoid', 'softmax', 'none']
if normalization == 'sigmoid':
self.normalization = nn.Sigmoid()
elif normalization == 'softmax':
self.normalization = nn.Softmax(dim=1)
else:
self.normalization = lambda x: x
def dice(self, input, target, weight):
# actual Dice score computation; to be implemented by the subclass
raise NotImplementedError
def forward(self, input, target):
# get probabilities from logits
input = self.normalization(input)
# compute per channel Dice coefficient
per_channel_dice = self.dice(input, target, weight=self.weight)
# average Dice score across all channels/classes
return 1. - torch.mean(per_channel_dice)
class DiceLoss(_AbstractDiceLoss):
"""Computes Dice Loss according to https://arxiv.org/abs/1606.04797.
For multi-class segmentation `weight` parameter can be used to assign different weights per class.
The input to the loss function is assumed to be a logit and will be normalized by the Sigmoid function.
"""
def __init__(self, weight=None, normalization='sigmoid'):
super().__init__(weight, normalization)
def dice(self, input, target, weight):
return compute_per_channel_dice(input, target, weight=self.weight)
class GeneralizedDiceLoss(_AbstractDiceLoss):
"""Computes Generalized Dice Loss (GDL) as described in https://arxiv.org/pdf/1707.03237.pdf.
"""
def __init__(self, normalization='sigmoid', epsilon=1e-6):
super().__init__(weight=None, normalization=normalization)
self.epsilon = epsilon
def dice(self, input, target, weight):
assert input.size() == target.size(), "'input' and 'target' must have the same shape"
input = flatten(input)
target = flatten(target)
target = target.float()
if input.size(0) == 1:
# for GDL to make sense we need at least 2 channels (see https://arxiv.org/pdf/1707.03237.pdf)
# put foreground and background voxels in separate channels
input = torch.cat((input, 1 - input), dim=0)
target = torch.cat((target, 1 - target), dim=0)
# GDL weighting: the contribution of each label is corrected by the inverse of its volume
w_l = target.sum(-1)
w_l = 1 / (w_l * w_l).clamp(min=self.epsilon)
w_l.requires_grad = False
intersect = (input * target).sum(-1)
intersect = intersect * w_l
denominator = (input + target).sum(-1)
denominator = (denominator * w_l).clamp(min=self.epsilon)
return 2 * (intersect.sum() / denominator.sum())
class BCEDiceLoss(nn.Module):
"""Linear combination of BCE and Dice losses"""
def __init__(self, alpha, beta):
super(BCEDiceLoss, self).__init__()
self.alpha = alpha
self.bce = nn.BCEWithLogitsLoss()
self.beta = beta
self.dice = DiceLoss()
def forward(self, input, target):
return self.alpha * self.bce(input, target) + self.beta * self.dice(input, target)
class WeightedCrossEntropyLoss(nn.Module):
"""WeightedCrossEntropyLoss (WCE) as described in https://arxiv.org/pdf/1707.03237.pdf
"""
def __init__(self, ignore_index=-1):
super(WeightedCrossEntropyLoss, self).__init__()
self.ignore_index = ignore_index
def forward(self, input, target):
weight = self._class_weights(input)
return F.cross_entropy(input, target, weight=weight, ignore_index=self.ignore_index)
@staticmethod
def _class_weights(input):
# normalize the input first
input = F.softmax(input, dim=1)
flattened = flatten(input)
nominator = (1. - flattened).sum(-1)
denominator = flattened.sum(-1)
class_weights = nominator / denominator
return class_weights.detach()
class PixelWiseCrossEntropyLoss(nn.Module):
def __init__(self, ignore_index=None):
super(PixelWiseCrossEntropyLoss, self).__init__()
self.ignore_index = ignore_index
self.log_softmax = nn.LogSoftmax(dim=1)
def forward(self, input, target, weights):
assert target.size() == weights.size()
# normalize the input
log_probabilities = self.log_softmax(input)
# standard CrossEntropyLoss requires the target to be (NxDxHxW), so we need to expand it to (NxCxDxHxW)
if self.ignore_index is not None:
mask = target == self.ignore_index
target[mask] = 0
else:
mask = torch.zeros_like(target)
# add channel dimension and invert the mask
mask = 1 - mask.unsqueeze(1)
# convert target to one-hot encoding
target = F.one_hot(target.long())
if target.ndim == 5:
# permute target to (NxCxDxHxW)
target = target.permute(0, 4, 1, 2, 3).contiguous()
else:
target = target.permute(0, 3, 1, 2).contiguous()
# apply the mask on the target
target = target * mask
# add channel dimension to the weights
weights = weights.unsqueeze(1)
# compute the losses
result = -weights * target * log_probabilities
return result.mean()
class WeightedSmoothL1Loss(nn.SmoothL1Loss):
def __init__(self, threshold, initial_weight, apply_below_threshold=True):
super().__init__(reduction="none")
self.threshold = threshold
self.apply_below_threshold = apply_below_threshold
self.weight = initial_weight
def forward(self, input, target):
l1 = super().forward(input, target)
if self.apply_below_threshold:
mask = target < self.threshold
else:
mask = target >= self.threshold
l1[mask] = l1[mask] * self.weight
return l1.mean()
def flatten(tensor):
"""Flattens a given tensor such that the channel axis is first.
The shapes are transformed as follows:
(N, C, D, H, W) -> (C, N * D * H * W)
"""
# number of channels
C = tensor.size(1)
# new axis order
axis_order = (1, 0) + tuple(range(2, tensor.dim()))
# Transpose: (N, C, D, H, W) -> (C, N, D, H, W)
transposed = tensor.permute(axis_order)
# Flatten: (C, N, D, H, W) -> (C, N * D * H * W)
return transposed.contiguous().view(C, -1)
def get_loss_criterion(config):
"""
Returns the loss function based on provided configuration
:param config: (dict) a top level configuration object containing the 'loss' key
:return: an instance of the loss function
"""
assert 'loss' in config, 'Could not find loss function configuration'
loss_config = config['loss']
name = loss_config.pop('name')
ignore_index = loss_config.pop('ignore_index', None)
skip_last_target = loss_config.pop('skip_last_target', False)
weight = loss_config.pop('weight', None)
if weight is not None:
weight = torch.tensor(weight)
pos_weight = loss_config.pop('pos_weight', None)
if pos_weight is not None:
pos_weight = torch.tensor(pos_weight)
loss = _create_loss(name, loss_config, weight, ignore_index, pos_weight)
if not (ignore_index is None or name in ['CrossEntropyLoss', 'WeightedCrossEntropyLoss']):
# use MaskingLossWrapper only for non-cross-entropy losses, since CE losses allow specifying 'ignore_index' directly
loss = _MaskingLossWrapper(loss, ignore_index)
if skip_last_target:
loss = SkipLastTargetChannelWrapper(loss, loss_config.get('squeeze_channel', False))
if torch.cuda.is_available():
loss = loss.cuda()
return loss
#######################################################################################################################
def _create_loss(name, loss_config, weight, ignore_index, pos_weight):
if name == 'BCEWithLogitsLoss':
return nn.BCEWithLogitsLoss(pos_weight=pos_weight)
elif name == 'BCEDiceLoss':
alpha = loss_config.get('alpha', 1.)
beta = loss_config.get('beta', 1.)
return BCEDiceLoss(alpha, beta)
elif name == 'CrossEntropyLoss':
if ignore_index is None:
ignore_index = -100 # use the default 'ignore_index' as defined in the CrossEntropyLoss
return nn.CrossEntropyLoss(weight=weight, ignore_index=ignore_index)
elif name == 'WeightedCrossEntropyLoss':
if ignore_index is None:
ignore_index = -100 # use the default 'ignore_index' as defined in the CrossEntropyLoss
return WeightedCrossEntropyLoss(ignore_index=ignore_index)
elif name == 'PixelWiseCrossEntropyLoss':
return PixelWiseCrossEntropyLoss(ignore_index=ignore_index)
elif name == 'GeneralizedDiceLoss':
normalization = loss_config.get('normalization', 'sigmoid')
return GeneralizedDiceLoss(normalization=normalization)
elif name == 'DiceLoss':
normalization = loss_config.get('normalization', 'sigmoid')
return DiceLoss(weight=weight, normalization=normalization)
elif name == 'MSELoss':
return MSELoss()
elif name == 'SmoothL1Loss':
return SmoothL1Loss()
elif name == 'L1Loss':
return L1Loss()
elif name == 'WeightedSmoothL1Loss':
return WeightedSmoothL1Loss(threshold=loss_config['threshold'],
initial_weight=loss_config['initial_weight'],
apply_below_threshold=loss_config.get('apply_below_threshold', True))
else:
raise RuntimeError(f"Unsupported loss function: '{name}'")