-
Notifications
You must be signed in to change notification settings - Fork 1.7k
/
lovasz_loss.py
230 lines (195 loc) · 7.87 KB
/
lovasz_loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Lovasz-Softmax and Jaccard hinge loss in PaddlePaddle"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import paddle
from paddle import nn
import paddle.nn.functional as F
from paddleseg.cvlibs import manager
@manager.LOSSES.add_component
class LovaszSoftmaxLoss(nn.Layer):
"""
Multi-class Lovasz-Softmax loss.
Args:
ignore_index (int64): Specifies a target value that is ignored and does not contribute to the input gradient. Default ``255``.
classes (str|list): 'all' for all, 'present' for classes present in labels, or a list of classes to average.
"""
def __init__(self, ignore_index=255, classes='present'):
super(LovaszSoftmaxLoss, self).__init__()
self.ignore_index = ignore_index
self.classes = classes
def forward(self, logits, labels):
r"""
Forward computation.
Args:
logits (Tensor): Shape is [N, C, H, W], logits at each prediction (between -\infty and +\infty).
labels (Tensor): Shape is [N, 1, H, W] or [N, H, W], ground truth labels (between 0 and C - 1).
"""
probas = F.softmax(logits, axis=1)
vprobas, vlabels = flatten_probas(probas, labels, self.ignore_index)
loss = lovasz_softmax_flat(vprobas, vlabels, classes=self.classes)
return loss
@manager.LOSSES.add_component
class LovaszHingeLoss(nn.Layer):
"""
Binary Lovasz hinge loss.
Args:
ignore_index (int64): Specifies a target value that is ignored and does not contribute to the input gradient. Default ``255``.
"""
def __init__(self, ignore_index=255):
super(LovaszHingeLoss, self).__init__()
self.ignore_index = ignore_index
def forward(self, logits, labels):
r"""
Forward computation.
Args:
logits (Tensor): Shape is [N, 1, H, W] or [N, 2, H, W], logits at each pixel (between -\infty and +\infty).
labels (Tensor): Shape is [N, 1, H, W] or [N, H, W], binary ground truth masks (0 or 1).
"""
if logits.shape[1] == 2:
logits = binary_channel_to_unary(logits)
loss = lovasz_hinge_flat(
*flatten_binary_scores(logits, labels, self.ignore_index))
return loss
def lovasz_grad(gt_sorted):
"""
Computes gradient of the Lovasz extension w.r.t sorted errors.
See Alg. 1 in paper.
"""
gts = paddle.sum(gt_sorted)
p = len(gt_sorted)
intersection = gts - paddle.cumsum(gt_sorted, axis=0)
union = gts + paddle.cumsum(1 - gt_sorted, axis=0)
jaccard = 1.0 - intersection.cast('float32') / union.cast('float32')
if p > 1: # cover 1-pixel case
jaccard[1:p] = jaccard[1:p] - jaccard[0:-1]
return jaccard
def binary_channel_to_unary(logits, eps=1e-9):
"""
Converts binary channel logits to unary channel logits for lovasz hinge loss.
"""
probas = F.softmax(logits, axis=1)
probas = probas[:, 1, :, :]
logits = paddle.log(probas + eps / (1 - probas + eps))
logits = logits.unsqueeze(1)
return logits
def lovasz_hinge_flat(logits, labels):
r"""
Binary Lovasz hinge loss.
Args:
logits (Tensor): Shape is [P], logits at each prediction (between -\infty and +\infty).
labels (Tensor): Shape is [P], binary ground truth labels (0 or 1).
"""
if len(labels) == 0:
# only void pixels, the gradients should be 0
return logits.sum() * 0.
signs = 2. * labels - 1.
signs.stop_gradient = True
errors = 1. - logits * signs
if hasattr(paddle, "_legacy_C_ops"):
errors_sorted, perm = paddle._legacy_C_ops.argsort(errors, 'axis', 0,
'descending', True)
else:
errors_sorted, perm = paddle._C_ops.argsort(errors, 'axis', 0,
'descending', True)
errors_sorted.stop_gradient = False
gt_sorted = paddle.gather(labels, perm)
grad = lovasz_grad(gt_sorted)
grad.stop_gradient = True
loss = paddle.sum(F.relu(errors_sorted) * grad)
return loss
def flatten_binary_scores(scores, labels, ignore=None):
"""
Flattens predictions in the batch (binary case).
Remove labels according to 'ignore'.
"""
scores = paddle.reshape(scores, [-1])
labels = paddle.reshape(labels, [-1])
labels.stop_gradient = True
if ignore is None:
return scores, labels
valid = labels != ignore
valid_mask = paddle.reshape(valid, (-1, 1))
indexs = paddle.nonzero(valid_mask)
indexs.stop_gradient = True
vscores = paddle.gather(scores, indexs[:, 0])
vlabels = paddle.gather(labels, indexs[:, 0])
return vscores, vlabels
def lovasz_softmax_flat(probas, labels, classes='present'):
"""
Multi-class Lovasz-Softmax loss.
Args:
probas (Tensor): Shape is [P, C], class probabilities at each prediction (between 0 and 1).
labels (Tensor): Shape is [P], ground truth labels (between 0 and C - 1).
classes (str|list): 'all' for all, 'present' for classes present in labels, or a list of classes to average.
"""
if probas.numel() == 0:
# only void pixels, the gradients should be 0
return probas * 0.
C = probas.shape[1]
losses = []
classes_to_sum = list(range(C)) if classes in ['all', 'present'
] else classes
for c in classes_to_sum:
fg = paddle.cast(labels == c, probas.dtype) # foreground for class c
if classes == 'present' and fg.sum() == 0:
continue
fg.stop_gradient = True
if C == 1:
if len(classes_to_sum) > 1:
raise ValueError('Sigmoid output possible only with 1 class')
class_pred = probas[:, 0]
else:
class_pred = probas[:, c]
errors = paddle.abs(fg - class_pred)
if hasattr(paddle, "_legacy_C_ops"):
errors_sorted, perm = paddle._legacy_C_ops.argsort(
errors, 'axis', 0, 'descending', True)
else:
errors_sorted, perm = paddle._C_ops.argsort(errors, 'axis', 0,
'descending', True)
errors_sorted.stop_gradient = False
fg_sorted = paddle.gather(fg, perm)
fg_sorted.stop_gradient = True
grad = lovasz_grad(fg_sorted)
grad.stop_gradient = True
loss = paddle.sum(errors_sorted * grad)
losses.append(loss)
if len(classes_to_sum) == 1:
return losses[0]
losses_tensor = paddle.stack(losses)
mean_loss = paddle.mean(losses_tensor)
return mean_loss
def flatten_probas(probas, labels, ignore=None):
"""
Flattens predictions in the batch.
"""
if len(probas.shape) == 3:
probas = paddle.unsqueeze(probas, axis=1)
C = probas.shape[1]
probas = paddle.transpose(probas, [0, 2, 3, 1])
probas = paddle.reshape(probas, [-1, C])
labels = paddle.reshape(labels, [-1])
if ignore is None:
return probas, labels
valid = labels != ignore
valid_mask = paddle.reshape(valid, [-1, 1])
indexs = paddle.nonzero(valid_mask)
indexs.stop_gradient = True
vprobas = paddle.gather(probas, indexs[:, 0])
vlabels = paddle.gather(labels, indexs[:, 0])
return vprobas, vlabels